Source code for flax_extra.layer._multimodal_positional_encoding

r"""Multimodal positional encoding."""

from typing import Any, Callable, List, Optional
from functools import reduce
import jax.numpy as jnp
from flax import linen as nn
from flax_extra.layer._trainable_positional_padding import TrainablePositionalPadding
from flax_extra.layer._encoding import PositionalEncodingCt

Array = jnp.ndarray
Positions = List[int]


def _normalize_modalities(
    values: Optional[List[Any]],
    n_modalities: int,
) -> Optional[List[Any]]:
    if values is None:
        values = [None] * n_modalities
    return _verify_modalities(values, n_modalities)


def _verify_modalities(
    values: List[Any],
    n_modalities: int,
) -> Optional[List[Any]]:
    n_values = len(values)
    if n_modalities != n_values:
        return None
    return values


_EncodeInitializer = tuple[int, List[int], List[Array]]
_EncodeData = tuple[PositionalEncodingCt, Optional[Positions]]


[docs]class MultimodalPositionalEncoding(nn.Module): r"""Encodes positional encoding for each modality. Each modality is padded to the same feature dimension and gets concatenatenated to form a single sequence. .. math:: \begin{aligned} & \textrm{MultimodalPositionalEncoding}( \\ & \quad m \in \sN \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{modailies} \times \sN \end{aligned} \begin{aligned} & \textrm{MultimodalPositionalEncoding}( \\ & \quad m \in \sN \\ & \quad t \in n_{mod} \times \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{modailies} \times \sN \end{aligned} Args: batch_size: a batch size of the sequence. multimodal_output_positions: a subset of positions (i.e. time steps) within each modality positional encoding will be calculated. Returns: positional encoding vectors combined in a single sequence and sequence lengths of original modalities. """ modalities: List[PositionalEncodingCt] r"""a list of constructors for a positional encoding module (e.g. :class:`flax_extra.layer.TrainablePositionalEncoding` or :class:`flax_extra.layer.FourierPositionEncoding`) for each modality.""" d_reserved: int = 1 r"""a number of reserved feature dimensions.""" @property def n_modalities(self) -> int: r"""a number of modalities.""" return len(self.modalities) @nn.compact def __call__( # type: ignore[override] # pylint: disable=arguments-differ self, batch_size: int, multimodal_output_positions: Optional[List[Optional[Positions]]], ) -> tuple[Array, List[int]]: multimodal_output_positions = _normalize_modalities( multimodal_output_positions, self.n_modalities, ) if multimodal_output_positions is None: raise ValueError( f"The number of {self.__name__} modalities {self.n_modalities} " f"doesn't match the number of provided subsamples." ) # Single modality. if self.n_modalities == 1: positional_encoding = self.modalities[0] encoded_inputs = positional_encoding()( batch_size, multimodal_output_positions[0], ) seqlen_encoded_input = encoded_inputs.shape[1] return encoded_inputs, [seqlen_encoded_input] # Multimodal. # Encode modalities, count input lengths and max input dimension. def encode(acc: _EncodeInitializer, data: _EncodeData) -> _EncodeInitializer: positional_encoding, output_positions = data d_input_max, seqlen_inputs, multimodal_encoded_inputs = acc encoded_inputs = positional_encoding()( batch_size, output_positions, ) d_input = encoded_inputs.shape[-1] seqlen_input = encoded_inputs.shape[1] return ( max(d_input_max, d_input), seqlen_inputs + [seqlen_input], multimodal_encoded_inputs + [encoded_inputs], ) encode_initializer: _EncodeInitializer = (0, [], []) d_input_max, seqlen_inputs, multimodal_encoded_inputs = reduce( encode, zip(self.modalities, multimodal_output_positions), encode_initializer, ) d_input_max += self.d_reserved # Pad inputs along their feature dimension axis. def pad(acc: List[Array], inputs: Array) -> List[Array]: padded_inputs = TrainablePositionalPadding(d_max=d_input_max)(inputs) return acc + [padded_inputs] pad_initializer: List[Array] = [] multimodal_padded_inputs = reduce( pad, multimodal_encoded_inputs, pad_initializer, ) return jnp.concatenate(multimodal_padded_inputs, axis=1), seqlen_inputs
MultimodalPositionalEncodingCt = Callable[..., MultimodalPositionalEncoding]