flax_extra.layer package¶
Flax linen layers.
- class flax_extra.layer.Attention(n_heads: int = 8, d_qk: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_shape>, index=0), d_v: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=2), d_output: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=3), kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function variance_scaling.<locals>.init>, bias_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function zeros>, use_bias: bool = True, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleLearns an attention vector for each position (i.e. time step) of the query sequence to all unmasked positions of the key sequence.
\[\begin{split}\begin{aligned} & \textrm{Attention}( \\ & \quad x_{q} \in \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{x_{q}}} \\ & \quad x_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times d_{x_{kv}}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \\ & \quad \_ \\ & \quad w_{q} \in \sR^{d_{x_{q}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{k} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{v} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad w_{o} \in \sR^{d_{v} \times d_{o} \gets d_{v}} \\ & \quad b_{q} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{k} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{v} \in \sR^{n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad b_{o} \in \sR^{d_{o} \gets d_{v}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{o} \gets d_{v}} \end{aligned}\end{split}\]- Parameters
inputs_q – inputs for the query.
inputs_kv – inputs for key and value.
mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.
- Returns
attention vectors.
- bias_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)¶
an initializer for bias.
- d_output(*, index: int = 3) Any¶
an output dimension. Defaults to the same dimension as the values’s dimesnion.
- d_qk(*, index: int = 0) Any¶
a dimension of the query array (the key has the same dimension). Defaults to the same dimension as the dimension of query’s inputs.
- d_v(*, index: int = 2) Any¶
a dimension of the value array. Defaults to the same dimension as the key’s dimesnion.
- deterministic: bool = True¶
whether to perform deterministically or not.
- dropout_rate: float = 0.0¶
probababilistic rate for attention dropout.
- kernel_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)¶
an initializer for weights.
- n_heads: int = 8¶
a number of attention heads.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- precision: Optional[Any] = None¶
numerical precision of the computation. See
jax.lax.Precisionfor details.
- scope = None¶
- use_bias: bool = True¶
wether to use bias.
- class flax_extra.layer.Decoding(embedding_decoding: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function Decoding.<lambda>>, postprocessing: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function Decoding.<lambda>>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleDecodes a vector at each position (i.e. time step) of the sequence using an optional embedding decoding and postprocessing step.
The postprocessing is intended to change shape of the sequence to desired.
\[\begin{split}\begin{aligned} & \textrm{Decoding}( \\ & \quad y \in \sR^{\nBatchSize \times \nSeqLen \times d} \\ & \quad \_ \\ & \quad \theta \gets EmbeddingDecoding() \\ & \quad \theta \gets Postprocessing() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \dots d^{\prime}} \end{aligned}\end{split}\]- Parameters
outputs – model’s outputs.
- Returns
decoded outputs.
- embedding_decoding()¶
a type of the embedding decoder (e.g.
flax_extra.layer.EmbedDecodingorflax.linen.Dense).
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- postprocessing()¶
a constructor for a module or function that accepts decoded outputs of as its single argument and returns the preprocessed outputs as a single output.
- scope = None¶
- class flax_extra.layer.EmbedDecoding(embedding: jax._src.numpy.lax_numpy.ndarray, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleProjects encoded vectors at each position (i.e. time step) of the sequence to its original representation using projection tensor of the embeding encoder.
\[\begin{split}\begin{aligned} & \textrm{Decoding}( \\ & \quad y \in \sR^{\nBatchSize \times \nSeqLen \times d} \\ & \quad \_ \\ & \quad b_{x} \in \sR^{d_{x}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen \times d_{x}} \end{aligned}\end{split}\]- Parameters
outputs – model’s outputs.
- Returns
decoded outputs.
- embedding: jax._src.numpy.lax_numpy.ndarray¶
embedding projection tensor of the embedding encoder (e.g.
flax.linen.Embed).
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- class flax_extra.layer.Encoding(preprocessing: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function Encoding.<lambda>>, positional_encoding: Optional[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]] = None, aggregation: Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray] = Concatenate(signature=Signature(n_in=2, n_out=1, start_index=0, in_shape=((), ())), axis=-1), parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleEncodes a vector at each position (i.e. time step) of the sequence using an optional preprocessing step and a positional encoding.
Preprocessed inputs get reshaped to \(\sR^{\nBatchSize \times T \times d}\). Then, optionally, positional encodings get concatenated or added to the sequence.
\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{Encoding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad \_ \\ & \quad \theta \gets Preprocessing() \\ & \quad \theta \gets PositionalEncoding() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d^{\prime}} \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{Encoding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad t \in \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets Preprocessing() \\ & \quad \theta \gets PositionalEncoding() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \end{aligned}\end{split}\end{aligned}\end{align} \]- Parameters
inputs – an input sequence.
output_positions – a subset of positions (i.e. time steps) within the sequence encoding will be calculated.
- Returns
encoded vectors.
- aggregation: Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray] = Concatenate(signature=Signature(n_in=2, n_out=1, start_index=0, in_shape=((), ())), axis=-1)¶
a binary operation (e.g.
flax_extra.combinator.concatenateorflax_extra.combinator.add).
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- positional_encoding: Optional[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]] = None¶
a type of the positional encoding (e.g.
flax_extra.layer.TrainablePositionalEncodingorflax_extra.layer.FourierPositionEncoding).
- preprocessing()¶
a constructor for a module or function that accepts inputs as its single argument and returns the preprocessed inputs as a single output.
- scope = None¶
- class flax_extra.layer.FeedForward(widening_factor: int = 1, activation: Callable[[...], jax._src.numpy.lax_numpy.ndarray] = <function gelu>, kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function variance_scaling.<locals>.init>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleLearns a dense vector representation.
\[\begin{split}\begin{aligned} & \textrm{FeedForward}( \\ & \quad x \in \sR^{\nBatchSize \times \dots d} \\ & \quad \_ \\ & \quad w_{h} \in \sR^{d_{x} \times d \cdot \text{factor}} \\ & \quad w_{o} \in \sR^{d \cdot \text{factor} \times d} \\ & \quad b_{h} \in \sR^{d \cdot \text{factor}} \\ & \quad b_{o} \in \sR^{d} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \dots d} \end{aligned}\end{split}\]- Parameters
inputs – a tensor.
- Returns
a tensor of the same shape as the input tensor.
- activation(approximate: bool = True) Any¶
a nonlinear function.
- kernel_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)¶
an initializer for weights.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- widening_factor: int = 1¶
determines a hidden dimension of the layer as a product of the inputs dimension by the factor value.
- class flax_extra.layer.FourierPositionEncoding(seqshape: tuple[int, ...], n_bands: int, spatial_coordinate_bounds: tuple[float, float] = (-1.0, 1.0), use_spatial_coordinates: bool = True, use_sine_only: bool = False, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleComputes a vector of Fourier (or sinusoidal) features that may be seen as a multi-dimensional representation of the position across all its spatial dimensions (or similarly, as a single encoding of different time axes).
\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{FourierPositionEncoding}( \\ & \quad m \in \sN \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d} \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{FourierPositionEncoding}( \\ & \quad m \in \sN \\ & \quad t \in \sN^{\nSeqLen^{\prime}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times T^{\prime} \times d} \end{aligned}\end{split}\end{aligned}\end{align} \]Output dimension may vary depending on parameters of the module:
- For a single phase (i.e. sine only), use_sine_only=True.
- With spatial coordinates get concatenated:
\(d = n_{bands} \cdot len(seqshape) + len(seqshape)\)
- Otherwise:
\(d = n_{bands} \cdot len(seqshape)\)
- For two phase (i.e. sine and cosine) for each frequency band.
- With spatial coordinates get concatenated:
\(d = n_{bands} \cdot len(seqshape) \cdot 2 + len(seqshape)\)
- Otherwise:
\(d = n_{bands} \cdot len(seqshape) \cdot 2\)
- Parameters
batch_size – a batch size of the sequence.
output_positions – a subset of positions (i.e. time steps) within the sequence positional encoding will be calculated.
- Returns
positional encoding vectors.
- n_bands: int¶
a number of frequency bands to use.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- seqshape: tuple[int, ...]¶
a shape of the sequence.
- spatial_coordinate_bounds: tuple[float, float] = (-1.0, 1.0)¶
lower and upper bound for spatial coordinates.
- use_sine_only: bool = False¶
whether to use a single phase (i.e. sine only) or two phase (i.e. sine and cosine) for each frequency band.
- use_spatial_coordinates: bool = True¶
whether to concatenate the spatial coordinates to the Fourier features.
- class flax_extra.layer.LSTM(d_hidden: int, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleLSTM running on time axis.
The layer scans over each time step of the input and returns its hidden output state for the last time step (hidden cell state is dropped).
depth of a hidden state. LSTM has (output state and cell state).
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- class flax_extra.layer.LSTMCell(gate_fn: Callable[[...], Any] = <function sigmoid>, activation_fn: Callable[[...], Any] = <CompiledFunction object>, kernel_init: Callable[[Any, Tuple[int], Any], Any] = <function variance_scaling.<locals>.init>, recurrent_kernel_init: Callable[[Any, Tuple[int], Any], Any] = <function orthogonal.<locals>.init>, bias_init: Callable[[Any, Tuple[int], Any], Any] = <function zeros>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.recurrent.LSTMCellLSTM cell.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- class flax_extra.layer.MultimodalDecoding(modalities: List[Callable[[...], flax_extra.layer._decoding.Decoding]], multimodal_embedding_decoding: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function MultimodalDecoding.<lambda>>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleGiven original sequence lengths, partitions a multimodal tensor into tensors for each modality and decode each of them.
\[\begin{split}\begin{aligned} & \textrm{MultimodalDecoding}( \\ & \quad o \in \sR^{\nBatchSize \times \nSeqLen \times d} \\ & \quad l_{x} \in n_{mod} \times \sN \\ & \quad \_ \\ & \quad \theta \gets MultimodalEmbeddingDecoding() \\ & \quad \theta \gets n_{mod} \times Decoding() \\ & ) \\ & \rightarrow n_{mod} \times \sR^{\nBatchSize \times \dots d^{\prime}} \end{aligned}\end{split}\]- Parameters
outputs – multimodal model’s outputs.
seqlen_outputs – sequence lengths of original modalities.
- Returns
partitioned modalities.
- modalities: List[Callable[[...], flax_extra.layer._decoding.Decoding]]¶
a list of constructors for a decoding module (e.g.
flax_extra.layer.Decoding) for each modality.
- multimodal_embedding_decoding()¶
a type of the embedding decoder (e.g.
flax_extra.layer.EmbedDecodingorflax.linen.Dense).
- property n_modalities: int¶
“a number of modalities.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- class flax_extra.layer.MultimodalEncoding(modalities: List[Callable[[...], flax_extra.layer._encoding.Encoding]], mask_rates: Optional[List[int]] = None, d_reserved: int = 1, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleEncodes multiple modalities.
Each modality is padded to the same feature dimension and gets concatenatenated to form a single sequence.
\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{MultimodalEncoding}( \\ & \quad x \in n_{mod} \times \sR^{m \times T \times d} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & \quad \theta \gets TrainablePositionalMasking() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{mod} \times \sN \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{MultimodalEncoding}( \\ & \quad x \in n_{mod} \times \sR^{m \times T \times d} \\ & \quad t \in n_{mod} \times \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & \quad \theta \gets TrainablePositionalMasking() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{mod} \times \sN \end{aligned}\end{split}\end{aligned}\end{align} \]- Parameters
multimodal_inputs – a batch size of the sequence.
multimodal_output_positions – a subset of positions (i.e. time steps) within each modality encoding will be calculated.
- Returns
- encoded vectors combined in a single sequence
and sequence lengths of original modalities.
- d_reserved: int = 1¶
a number of reserved feature dimensions.
- mask_rates: Optional[List[int]] = None¶
a probability for each modality being masked.
- modalities: List[Callable[[...], flax_extra.layer._encoding.Encoding]]¶
a list of constructors for an encoding module (e.g.
flax_extra.layer.Encoding) for each modality.
- property n_modalities: int¶
a number of modalities.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- class flax_extra.layer.MultimodalPositionalEncoding(modalities: List[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]], d_reserved: int = 1, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleEncodes positional encoding for each modality.
Each modality is padded to the same feature dimension and gets concatenatenated to form a single sequence.
\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{MultimodalPositionalEncoding}( \\ & \quad m \in \sN \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{modailies} \times \sN \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{MultimodalPositionalEncoding}( \\ & \quad m \in \sN \\ & \quad t \in n_{mod} \times \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{modailies} \times \sN \end{aligned}\end{split}\end{aligned}\end{align} \]- Parameters
batch_size – a batch size of the sequence.
multimodal_output_positions – a subset of positions (i.e. time steps) within each modality positional encoding will be calculated.
- Returns
- positional encoding vectors combined in a single sequence
and sequence lengths of original modalities.
- d_reserved: int = 1¶
a number of reserved feature dimensions.
- modalities: List[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]]¶
a list of constructors for a positional encoding module (e.g.
flax_extra.layer.TrainablePositionalEncodingorflax_extra.layer.FourierPositionEncoding) for each modality.
- property n_modalities: int¶
a number of modalities.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- flax_extra.layer.QKVAttention¶
A variation of multi-head cross-attention where key’s dimension defaults to the dimension of query’s inputs and output dimension defaults to the value’s dimension.
\[\begin{split}\begin{aligned} & \textrm{Attention}( \\ & \quad x_{q} \in \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{x_{q}}} \\ & \quad x_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times d_{x_{kv}}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \\ & \quad \_ \\ & \quad w_{q} \in \sR^{d_{x_{q}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{k} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{v} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad w_{o} \in \sR^{d_{v} \times d_{o} \gets d_{v}} \\ & \quad b_{q} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{k} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{v} \in \sR^{n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad b_{o} \in \sR^{d_{o} \gets d_{v}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{o} \gets d_{v}} \end{aligned}\end{split}\]
- class flax_extra.layer.SelfAttention(n_heads: int = 8, d_qk: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_shape>, index=0), d_v: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=2), d_output: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=3), kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function variance_scaling.<locals>.init>, bias_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function zeros>, use_bias: bool = True, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax_extra.layer._attention.AttentionLearns an attention of each position (i.e. time step) of the sequence to all unmasked positions of the same sequence.
\[\begin{split}\begin{aligned} & \textrm{SelfAttention}( \\ & \quad x_{qkv} \in \sR^{\nBatchSize \times \nSeqLen_{qkv} \times d_{x_{qkv}}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{qkv} \times \nSeqLen_{qkv}} \\ & \quad \_ \\ & \quad w_{q} \in \sR^{d_{x_{qkv}} \times n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad w_{k} \in \sR^{d_{x_{qkv}} \times n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad w_{v} \in \sR^{d_{x_{qkv}} \times n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad w_{o} \in \sR^{d_{v} \times d_{o} \gets d_{v}} \\ & \quad b_{q} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad b_{k} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad b_{v} \in \sR^{n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad b_{o} \in \sR^{d_{o} \gets d_{v}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{o} \gets d_{v}} \end{aligned}\end{split}\]- Parameters
inputs_qkv – inputs for the query, key, and value.
mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.
- Returns
attention vectors.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- class flax_extra.layer.TrainablePositionalEncoding(seqlen: int, dimension: int, kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function _truncated_normal.<locals>.init>, dtype: Any = <class 'jax._src.numpy.lax_numpy.float32'>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleLearns a positional encoding vector for each position (i.e. time step) of the sequence.
\[\begin{split}\begin{aligned} & \textrm{TrainablePositionalEncoding}( \\ & \quad m \in \sN \\ & \quad \_ \\ & \quad w \in \sR^{\nSeqLen \times d} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d} \end{aligned}\end{split}\]- Parameters
batch_size – a batch size of the sequence.
output_positions – cannot be used with this kind of positional encoding and must be omitted.
- Returns
positional encoding vectors.
- dimension: int¶
a desared dimension of the positional encoding.
- dtype¶
alias of
jax._src.numpy.lax_numpy.float32
- kernel_init(shape: Optional[Union[Sequence[int], jax.core.NamedShape]], dtype: Any = <class 'jax._src.numpy.lax_numpy.float32'>) jax._src.numpy.lax_numpy.ndarray¶
an initializer for weights.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- seqlen: int¶
a length of the sequence.
- class flax_extra.layer.TrainablePositionalMasking(rate: float, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleLearns a mask vector and applies it to a vector at each position (i.e. time step) of the sequence according specified probabilistic rate.
\[\begin{split}\begin{aligned} & \textrm{TrainablePositionalMasking}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad \_ \\ & \quad w \in \sR^{1 \times d} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d} \end{aligned}\end{split}\]- Parameters
inputs – a sequence.
- Returns
a masked sequence.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- rate: float¶
a probability for a vector at each position being masked.
- scope = None¶
- class flax_extra.layer.TrainablePositionalPadding(d_max: int, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleTrainable positional padding.
Learns a padding vector and applies it to a vector at each position (i.e. time step) of the sequence to pad up vector’s dimension to desired value.
\[\begin{split}\begin{aligned} & \textrm{TrainablePositionalPadding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad \_ \\ & \quad w \in \sR^{1 \times d^{\prime}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d + d^{\prime}} \end{aligned}\end{split}\]- Parameters
inputs – a sequence.
- Returns
a sequence with padding along feature dimension.
- d_max: int¶
a desired dimension for vectors in the sequence.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- scope = None¶
- flax_extra.layer.attend(query: jax._src.numpy.lax_numpy.ndarray, key: jax._src.numpy.lax_numpy.ndarray, value: jax._src.numpy.lax_numpy.ndarray, mask: Optional[jax._src.numpy.lax_numpy.ndarray] = None, dropout_rngkey: Optional[jax._src.prng.PRNGKeyArray] = None, dropout_rate: float = 0.0, deterministic: bool = False, precision: Optional[Any] = None) jax._src.numpy.lax_numpy.ndarray[source]¶
Computes an attention vector for each position (i.e. time step) of the query sequence to positions of the key sequence.
\[\begin{split}\begin{aligned} & \textrm{attend}( \\ & \quad q \in \sR^{\nBatchSize \times \nSeqLen_{q} \times n_{heads} \times d_{qk}} \\ & \quad k \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times n_{heads} \times d_{qk}} \\ & \quad v \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times n_{heads} \times d_{v}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{v}} \end{aligned}\end{split}\]- Parameters
query – a query vector.
key – a key vector.
value – a value vector.
mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.
dropout_rngkey – random number generator key for dropout.
dropout_rate – probababilistic rate for attention dropout.
deterministic – whether to perform deterministically or not.
precision – numerical precision of the computation. See
jax.lax.Precisionfor details.
- Returns
attention vectors.
- flax_extra.layer.cross_attention_mask(mask_q: jax._src.numpy.lax_numpy.ndarray, mask_kv: jax._src.numpy.lax_numpy.ndarray) jax._src.numpy.lax_numpy.ndarray[source]¶
Creates a cross-attention mask tensor with boolean values indicating whether a particular query attends to a particular key.
\[\begin{split}\begin{aligned} & \textrm{cross_attention_mask}( \\ & \quad mask_{q} \in \sR^{\nBatchSize \times \nSeqLen_{q}} \\ & \quad mask_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{kv}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \end{aligned}\end{split}\]- Parameters
mask_q – a padding mask indicating at which positions values of the query are valid.
mask_kv – a padding mask indicating at which positions values of the key are valid.
- Returns
a boolean matrix indicating which attention values are valid.