flax_extra.layer package

Flax linen layers.

class flax_extra.layer.Attention(n_heads: int = 8, d_qk: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_shape>, index=0), d_v: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=2), d_output: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=3), kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function variance_scaling.<locals>.init>, bias_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function zeros>, use_bias: bool = True, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Learns an attention vector for each position (i.e. time step) of the query sequence to all unmasked positions of the key sequence.

\[\begin{split}\begin{aligned} & \textrm{Attention}( \\ & \quad x_{q} \in \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{x_{q}}} \\ & \quad x_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times d_{x_{kv}}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \\ & \quad \_ \\ & \quad w_{q} \in \sR^{d_{x_{q}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{k} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{v} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad w_{o} \in \sR^{d_{v} \times d_{o} \gets d_{v}} \\ & \quad b_{q} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{k} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{v} \in \sR^{n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad b_{o} \in \sR^{d_{o} \gets d_{v}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{o} \gets d_{v}} \end{aligned}\end{split}\]
Parameters
  • inputs_q – inputs for the query.

  • inputs_kv – inputs for key and value.

  • mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.

Returns

attention vectors.

bias_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)

an initializer for bias.

d_output(*, index: int = 3) Any

an output dimension. Defaults to the same dimension as the values’s dimesnion.

d_qk(*, index: int = 0) Any

a dimension of the query array (the key has the same dimension). Defaults to the same dimension as the dimension of query’s inputs.

d_v(*, index: int = 2) Any

a dimension of the value array. Defaults to the same dimension as the key’s dimesnion.

deterministic: bool = True

whether to perform deterministically or not.

dropout_rate: float = 0.0

probababilistic rate for attention dropout.

kernel_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)

an initializer for weights.

n_heads: int = 8

a number of attention heads.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Optional[Any] = None

numerical precision of the computation. See jax.lax.Precision for details.

scope = None
use_bias: bool = True

wether to use bias.

class flax_extra.layer.Decoding(embedding_decoding: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function Decoding.<lambda>>, postprocessing: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function Decoding.<lambda>>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Decodes a vector at each position (i.e. time step) of the sequence using an optional embedding decoding and postprocessing step.

The postprocessing is intended to change shape of the sequence to desired.

\[\begin{split}\begin{aligned} & \textrm{Decoding}( \\ & \quad y \in \sR^{\nBatchSize \times \nSeqLen \times d} \\ & \quad \_ \\ & \quad \theta \gets EmbeddingDecoding() \\ & \quad \theta \gets Postprocessing() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \dots d^{\prime}} \end{aligned}\end{split}\]
Parameters

outputs – model’s outputs.

Returns

decoded outputs.

embedding_decoding()

a type of the embedding decoder (e.g. flax_extra.layer.EmbedDecoding or flax.linen.Dense).

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
postprocessing()

a constructor for a module or function that accepts decoded outputs of as its single argument and returns the preprocessed outputs as a single output.

scope = None
class flax_extra.layer.EmbedDecoding(embedding: jax._src.numpy.lax_numpy.ndarray, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Projects encoded vectors at each position (i.e. time step) of the sequence to its original representation using projection tensor of the embeding encoder.

\[\begin{split}\begin{aligned} & \textrm{Decoding}( \\ & \quad y \in \sR^{\nBatchSize \times \nSeqLen \times d} \\ & \quad \_ \\ & \quad b_{x} \in \sR^{d_{x}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen \times d_{x}} \end{aligned}\end{split}\]
Parameters

outputs – model’s outputs.

Returns

decoded outputs.

embedding: jax._src.numpy.lax_numpy.ndarray

embedding projection tensor of the embedding encoder (e.g. flax.linen.Embed).

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
class flax_extra.layer.Encoding(preprocessing: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function Encoding.<lambda>>, positional_encoding: Optional[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]] = None, aggregation: Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray] = Concatenate(signature=Signature(n_in=2, n_out=1, start_index=0, in_shape=((), ())), axis=-1), parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Encodes a vector at each position (i.e. time step) of the sequence using an optional preprocessing step and a positional encoding.

Preprocessed inputs get reshaped to \(\sR^{\nBatchSize \times T \times d}\). Then, optionally, positional encodings get concatenated or added to the sequence.

\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{Encoding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad \_ \\ & \quad \theta \gets Preprocessing() \\ & \quad \theta \gets PositionalEncoding() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d^{\prime}} \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{Encoding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad t \in \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets Preprocessing() \\ & \quad \theta \gets PositionalEncoding() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \end{aligned}\end{split}\end{aligned}\end{align} \]
Parameters
  • inputs – an input sequence.

  • output_positions – a subset of positions (i.e. time steps) within the sequence encoding will be calculated.

Returns

encoded vectors.

aggregation: Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray] = Concatenate(signature=Signature(n_in=2, n_out=1, start_index=0, in_shape=((), ())), axis=-1)

a binary operation (e.g. flax_extra.combinator.concatenate or flax_extra.combinator.add).

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
positional_encoding: Optional[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]] = None

a type of the positional encoding (e.g. flax_extra.layer.TrainablePositionalEncoding or flax_extra.layer.FourierPositionEncoding).

preprocessing()

a constructor for a module or function that accepts inputs as its single argument and returns the preprocessed inputs as a single output.

scope = None
class flax_extra.layer.FeedForward(widening_factor: int = 1, activation: Callable[[...], jax._src.numpy.lax_numpy.ndarray] = <function gelu>, kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function variance_scaling.<locals>.init>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Learns a dense vector representation.

\[\begin{split}\begin{aligned} & \textrm{FeedForward}( \\ & \quad x \in \sR^{\nBatchSize \times \dots d} \\ & \quad \_ \\ & \quad w_{h} \in \sR^{d_{x} \times d \cdot \text{factor}} \\ & \quad w_{o} \in \sR^{d \cdot \text{factor} \times d} \\ & \quad b_{h} \in \sR^{d \cdot \text{factor}} \\ & \quad b_{o} \in \sR^{d} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \dots d} \end{aligned}\end{split}\]
Parameters

inputs – a tensor.

Returns

a tensor of the same shape as the input tensor.

activation(approximate: bool = True) Any

a nonlinear function.

kernel_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)

an initializer for weights.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
widening_factor: int = 1

determines a hidden dimension of the layer as a product of the inputs dimension by the factor value.

class flax_extra.layer.FourierPositionEncoding(seqshape: tuple[int, ...], n_bands: int, spatial_coordinate_bounds: tuple[float, float] = (-1.0, 1.0), use_spatial_coordinates: bool = True, use_sine_only: bool = False, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Computes a vector of Fourier (or sinusoidal) features that may be seen as a multi-dimensional representation of the position across all its spatial dimensions (or similarly, as a single encoding of different time axes).

\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{FourierPositionEncoding}( \\ & \quad m \in \sN \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d} \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{FourierPositionEncoding}( \\ & \quad m \in \sN \\ & \quad t \in \sN^{\nSeqLen^{\prime}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times T^{\prime} \times d} \end{aligned}\end{split}\end{aligned}\end{align} \]

Output dimension may vary depending on parameters of the module:

  • For a single phase (i.e. sine only), use_sine_only=True.
    • With spatial coordinates get concatenated:

      \(d = n_{bands} \cdot len(seqshape) + len(seqshape)\)

    • Otherwise:

      \(d = n_{bands} \cdot len(seqshape)\)

  • For two phase (i.e. sine and cosine) for each frequency band.
    • With spatial coordinates get concatenated:

      \(d = n_{bands} \cdot len(seqshape) \cdot 2 + len(seqshape)\)

    • Otherwise:

      \(d = n_{bands} \cdot len(seqshape) \cdot 2\)

Parameters
  • batch_size – a batch size of the sequence.

  • output_positions – a subset of positions (i.e. time steps) within the sequence positional encoding will be calculated.

Returns

positional encoding vectors.

n_bands: int

a number of frequency bands to use.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
seqshape: tuple[int, ...]

a shape of the sequence.

spatial_coordinate_bounds: tuple[float, float] = (-1.0, 1.0)

lower and upper bound for spatial coordinates.

use_sine_only: bool = False

whether to use a single phase (i.e. sine only) or two phase (i.e. sine and cosine) for each frequency band.

use_spatial_coordinates: bool = True

whether to concatenate the spatial coordinates to the Fourier features.

class flax_extra.layer.LSTM(d_hidden: int, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

LSTM running on time axis.

The layer scans over each time step of the input and returns its hidden output state for the last time step (hidden cell state is dropped).

d_hidden: int

depth of a hidden state. LSTM has (output state and cell state).

initial_state(inputs: Any) tuple[typing.Any, typing.Any][source]

Creates an LSTM state.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
class flax_extra.layer.LSTMCell(gate_fn: Callable[[...], Any] = <function sigmoid>, activation_fn: Callable[[...], Any] = <CompiledFunction object>, kernel_init: Callable[[Any, Tuple[int], Any], Any] = <function variance_scaling.<locals>.init>, recurrent_kernel_init: Callable[[Any, Tuple[int], Any], Any] = <function orthogonal.<locals>.init>, bias_init: Callable[[Any, Tuple[int], Any], Any] = <function zeros>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.recurrent.LSTMCell

LSTM cell.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
class flax_extra.layer.MultimodalDecoding(modalities: List[Callable[[...], flax_extra.layer._decoding.Decoding]], multimodal_embedding_decoding: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <function MultimodalDecoding.<lambda>>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Given original sequence lengths, partitions a multimodal tensor into tensors for each modality and decode each of them.

\[\begin{split}\begin{aligned} & \textrm{MultimodalDecoding}( \\ & \quad o \in \sR^{\nBatchSize \times \nSeqLen \times d} \\ & \quad l_{x} \in n_{mod} \times \sN \\ & \quad \_ \\ & \quad \theta \gets MultimodalEmbeddingDecoding() \\ & \quad \theta \gets n_{mod} \times Decoding() \\ & ) \\ & \rightarrow n_{mod} \times \sR^{\nBatchSize \times \dots d^{\prime}} \end{aligned}\end{split}\]
Parameters
  • outputs – multimodal model’s outputs.

  • seqlen_outputs – sequence lengths of original modalities.

Returns

partitioned modalities.

modalities: List[Callable[[...], flax_extra.layer._decoding.Decoding]]

a list of constructors for a decoding module (e.g. flax_extra.layer.Decoding) for each modality.

multimodal_embedding_decoding()

a type of the embedding decoder (e.g. flax_extra.layer.EmbedDecoding or flax.linen.Dense).

property n_modalities: int

“a number of modalities.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
class flax_extra.layer.MultimodalEncoding(modalities: List[Callable[[...], flax_extra.layer._encoding.Encoding]], mask_rates: Optional[List[int]] = None, d_reserved: int = 1, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Encodes multiple modalities.

Each modality is padded to the same feature dimension and gets concatenatenated to form a single sequence.

\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{MultimodalEncoding}( \\ & \quad x \in n_{mod} \times \sR^{m \times T \times d} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & \quad \theta \gets TrainablePositionalMasking() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{mod} \times \sN \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{MultimodalEncoding}( \\ & \quad x \in n_{mod} \times \sR^{m \times T \times d} \\ & \quad t \in n_{mod} \times \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & \quad \theta \gets TrainablePositionalMasking() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{mod} \times \sN \end{aligned}\end{split}\end{aligned}\end{align} \]
Parameters
  • multimodal_inputs – a batch size of the sequence.

  • multimodal_output_positions – a subset of positions (i.e. time steps) within each modality encoding will be calculated.

Returns

encoded vectors combined in a single sequence

and sequence lengths of original modalities.

d_reserved: int = 1

a number of reserved feature dimensions.

mask_rates: Optional[List[int]] = None

a probability for each modality being masked.

modalities: List[Callable[[...], flax_extra.layer._encoding.Encoding]]

a list of constructors for an encoding module (e.g. flax_extra.layer.Encoding) for each modality.

property n_modalities: int

a number of modalities.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
class flax_extra.layer.MultimodalPositionalEncoding(modalities: List[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]], d_reserved: int = 1, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Encodes positional encoding for each modality.

Each modality is padded to the same feature dimension and gets concatenatenated to form a single sequence.

\[ \begin{align}\begin{aligned}\begin{split}\begin{aligned} & \textrm{MultimodalPositionalEncoding}( \\ & \quad m \in \sN \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{modailies} \times \sN \end{aligned}\end{split}\\\begin{split}\begin{aligned} & \textrm{MultimodalPositionalEncoding}( \\ & \quad m \in \sN \\ & \quad t \in n_{mod} \times \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{modailies} \times \sN \end{aligned}\end{split}\end{aligned}\end{align} \]
Parameters
  • batch_size – a batch size of the sequence.

  • multimodal_output_positions – a subset of positions (i.e. time steps) within each modality positional encoding will be calculated.

Returns

positional encoding vectors combined in a single sequence

and sequence lengths of original modalities.

d_reserved: int = 1

a number of reserved feature dimensions.

modalities: List[Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]]]

a list of constructors for a positional encoding module (e.g. flax_extra.layer.TrainablePositionalEncoding or flax_extra.layer.FourierPositionEncoding) for each modality.

property n_modalities: int

a number of modalities.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
flax_extra.layer.QKVAttention

A variation of multi-head cross-attention where key’s dimension defaults to the dimension of query’s inputs and output dimension defaults to the value’s dimension.

\[\begin{split}\begin{aligned} & \textrm{Attention}( \\ & \quad x_{q} \in \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{x_{q}}} \\ & \quad x_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times d_{x_{kv}}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \\ & \quad \_ \\ & \quad w_{q} \in \sR^{d_{x_{q}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{k} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad w_{v} \in \sR^{d_{x_{kv}} \times n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad w_{o} \in \sR^{d_{v} \times d_{o} \gets d_{v}} \\ & \quad b_{q} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{k} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{q}}} \\ & \quad b_{v} \in \sR^{n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad b_{o} \in \sR^{d_{o} \gets d_{v}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{o} \gets d_{v}} \end{aligned}\end{split}\]
class flax_extra.layer.SelfAttention(n_heads: int = 8, d_qk: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_shape>, index=0), d_v: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=2), d_output: Union[int, Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], Any]] = functools.partial(<function _like_value>, index=3), kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function variance_scaling.<locals>.init>, bias_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function zeros>, use_bias: bool = True, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax_extra.layer._attention.Attention

Learns an attention of each position (i.e. time step) of the sequence to all unmasked positions of the same sequence.

\[\begin{split}\begin{aligned} & \textrm{SelfAttention}( \\ & \quad x_{qkv} \in \sR^{\nBatchSize \times \nSeqLen_{qkv} \times d_{x_{qkv}}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{qkv} \times \nSeqLen_{qkv}} \\ & \quad \_ \\ & \quad w_{q} \in \sR^{d_{x_{qkv}} \times n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad w_{k} \in \sR^{d_{x_{qkv}} \times n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad w_{v} \in \sR^{d_{x_{qkv}} \times n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad w_{o} \in \sR^{d_{v} \times d_{o} \gets d_{v}} \\ & \quad b_{q} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad b_{k} \in \sR^{n_{heads} \times d_{qk} \gets d_{x_{qkv}}} \\ & \quad b_{v} \in \sR^{n_{heads} \times d_{v} \gets d_{qk}} \\ & \quad b_{o} \in \sR^{d_{o} \gets d_{v}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{o} \gets d_{v}} \end{aligned}\end{split}\]
Parameters
  • inputs_qkv – inputs for the query, key, and value.

  • mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.

Returns

attention vectors.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
class flax_extra.layer.TrainablePositionalEncoding(seqlen: int, dimension: int, kernel_init: Callable[[jax._src.prng.PRNGKeyArray, Optional[Union[Sequence[int], jax.core.NamedShape]], Any], jax._src.numpy.lax_numpy.ndarray] = <function _truncated_normal.<locals>.init>, dtype: Any = <class 'jax._src.numpy.lax_numpy.float32'>, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Learns a positional encoding vector for each position (i.e. time step) of the sequence.

\[\begin{split}\begin{aligned} & \textrm{TrainablePositionalEncoding}( \\ & \quad m \in \sN \\ & \quad \_ \\ & \quad w \in \sR^{\nSeqLen \times d} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d} \end{aligned}\end{split}\]
Parameters
  • batch_size – a batch size of the sequence.

  • output_positions – cannot be used with this kind of positional encoding and must be omitted.

Returns

positional encoding vectors.

dimension: int

a desared dimension of the positional encoding.

dtype

alias of jax._src.numpy.lax_numpy.float32

kernel_init(shape: Optional[Union[Sequence[int], jax.core.NamedShape]], dtype: Any = <class 'jax._src.numpy.lax_numpy.float32'>) jax._src.numpy.lax_numpy.ndarray

an initializer for weights.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
seqlen: int

a length of the sequence.

class flax_extra.layer.TrainablePositionalMasking(rate: float, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Learns a mask vector and applies it to a vector at each position (i.e. time step) of the sequence according specified probabilistic rate.

\[\begin{split}\begin{aligned} & \textrm{TrainablePositionalMasking}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad \_ \\ & \quad w \in \sR^{1 \times d} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d} \end{aligned}\end{split}\]
Parameters

inputs – a sequence.

Returns

a masked sequence.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
rate: float

a probability for a vector at each position being masked.

scope = None
class flax_extra.layer.TrainablePositionalPadding(d_max: int, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Trainable positional padding.

Learns a padding vector and applies it to a vector at each position (i.e. time step) of the sequence to pad up vector’s dimension to desired value.

\[\begin{split}\begin{aligned} & \textrm{TrainablePositionalPadding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad \_ \\ & \quad w \in \sR^{1 \times d^{\prime}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d + d^{\prime}} \end{aligned}\end{split}\]
Parameters

inputs – a sequence.

Returns

a sequence with padding along feature dimension.

d_max: int

a desired dimension for vectors in the sequence.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
scope = None
flax_extra.layer.attend(query: jax._src.numpy.lax_numpy.ndarray, key: jax._src.numpy.lax_numpy.ndarray, value: jax._src.numpy.lax_numpy.ndarray, mask: Optional[jax._src.numpy.lax_numpy.ndarray] = None, dropout_rngkey: Optional[jax._src.prng.PRNGKeyArray] = None, dropout_rate: float = 0.0, deterministic: bool = False, precision: Optional[Any] = None) jax._src.numpy.lax_numpy.ndarray[source]

Computes an attention vector for each position (i.e. time step) of the query sequence to positions of the key sequence.

\[\begin{split}\begin{aligned} & \textrm{attend}( \\ & \quad q \in \sR^{\nBatchSize \times \nSeqLen_{q} \times n_{heads} \times d_{qk}} \\ & \quad k \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times n_{heads} \times d_{qk}} \\ & \quad v \in \sR^{\nBatchSize \times \nSeqLen_{kv} \times n_{heads} \times d_{v}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times d_{v}} \end{aligned}\end{split}\]
Parameters
  • query – a query vector.

  • key – a key vector.

  • value – a value vector.

  • mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.

  • dropout_rngkey – random number generator key for dropout.

  • dropout_rate – probababilistic rate for attention dropout.

  • deterministic – whether to perform deterministically or not.

  • precision – numerical precision of the computation. See jax.lax.Precision for details.

Returns

attention vectors.

flax_extra.layer.cross_attention_mask(mask_q: jax._src.numpy.lax_numpy.ndarray, mask_kv: jax._src.numpy.lax_numpy.ndarray) jax._src.numpy.lax_numpy.ndarray[source]

Creates a cross-attention mask tensor with boolean values indicating whether a particular query attends to a particular key.

\[\begin{split}\begin{aligned} & \textrm{cross_attention_mask}( \\ & \quad mask_{q} \in \sR^{\nBatchSize \times \nSeqLen_{q}} \\ & \quad mask_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{kv}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{q} \times \nSeqLen_{kv}} \end{aligned}\end{split}\]
Parameters
  • mask_q – a padding mask indicating at which positions values of the query are valid.

  • mask_kv – a padding mask indicating at which positions values of the key are valid.

Returns

a boolean matrix indicating which attention values are valid.

Submodules