Source code for flax_extra.layer._encoding

r"""Encoding."""

from typing import cast, Callable, List, Optional
from jax import numpy as jnp
from flax import linen as nn

from flax_extra import combinator as cb

Array = jnp.ndarray
Positions = List[int]

PreprocessingFn = Callable[[Array], Array]
PreprocessingCt = Callable[..., PreprocessingFn]
PositionalEncodingFn = Callable[[int, Optional[Positions]], Array]
PositionalEncodingCt = Callable[..., PositionalEncodingFn]
BinaryOperator = Callable[[Array, Array], Array]


[docs]class Encoding(nn.Module): r"""Encodes a vector at each position (i.e. time step) of the sequence using an optional preprocessing step and a positional encoding. Preprocessed inputs get reshaped to :math:`\sR^{\nBatchSize \times T \times d}`. Then, optionally, positional encodings get concatenated or added to the sequence. .. math:: \begin{aligned} & \textrm{Encoding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad \_ \\ & \quad \theta \gets Preprocessing() \\ & \quad \theta \gets PositionalEncoding() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d^{\prime}} \end{aligned} \begin{aligned} & \textrm{Encoding}( \\ & \quad x \in \sR^{\nBatchSize \times T \times d} \\ & \quad t \in \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets Preprocessing() \\ & \quad \theta \gets PositionalEncoding() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \end{aligned} Args: inputs: an input sequence. output_positions: a subset of positions (i.e. time steps) within the sequence encoding will be calculated. Returns: encoded vectors. """ preprocessing: PreprocessingCt = cast(PreprocessingCt, lambda: cb.identity(n_in=1)) r"""a constructor for a module or function that accepts inputs as its single argument and returns the preprocessed inputs as a single output.""" positional_encoding: Optional[PositionalEncodingCt] = None r"""a type of the positional encoding (e.g. :class:`flax_extra.layer.TrainablePositionalEncoding` or :class:`flax_extra.layer.FourierPositionEncoding`).""" aggregation: BinaryOperator = cast(BinaryOperator, cb.concatenate()) r"""a binary operation (e.g. :class:`flax_extra.combinator.concatenate` or :class:`flax_extra.combinator.add`).""" @nn.compact def __call__( # type: ignore[override] # pylint: disable=arguments-differ self, inputs: Array, output_positions: Optional[Positions], ) -> Array: batch_size = inputs.shape[0] # Preprocess. preprocessed_inputs = self.preprocessing()(inputs) # Reshape to a single feature dimension. d_preprocessed_inputs = preprocessed_inputs.shape[-1] preprocessed_inputs_1d = jnp.reshape( preprocessed_inputs, (batch_size, -1, d_preprocessed_inputs), ) if self.positional_encoding is None: return preprocessed_inputs_1d # type: ignore ## Enhance with positional information. return self.aggregation( # type: ignore preprocessed_inputs_1d, self.positional_encoding()( # pylint: disable=not-callable batch_size, output_positions, ), )
EncodingCt = Callable[..., Encoding]