Source code for flax_extra.layer._multimodal_encoding

r"""Multimodal encoding."""

from typing import Callable, List, Optional
from functools import reduce
import jax.numpy as jnp
from flax import linen as nn
from flax_extra.layer._trainable_positional_padding import TrainablePositionalPadding
from flax_extra.layer._trainable_positional_masking import TrainablePositionalMasking
from flax_extra.layer._multimodal_positional_encoding import (
    _normalize_modalities,
    _verify_modalities,
)
from flax_extra.layer._encoding import EncodingCt

Array = jnp.ndarray
Positions = List[int]

_EncodeInitializer = tuple[int, List[int], List[Array]]
_EncodeData = tuple[Array, Optional[Positions], EncodingCt]


[docs]class MultimodalEncoding(nn.Module): r"""Encodes multiple modalities. Each modality is padded to the same feature dimension and gets concatenatenated to form a single sequence. .. math:: \begin{aligned} & \textrm{MultimodalEncoding}( \\ & \quad x \in n_{mod} \times \sR^{m \times T \times d} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & \quad \theta \gets TrainablePositionalMasking() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{mod} \times \sN \end{aligned} \begin{aligned} & \textrm{MultimodalEncoding}( \\ & \quad x \in n_{mod} \times \sR^{m \times T \times d} \\ & \quad t \in n_{mod} \times \sN^{\nSeqLen^{\prime}} \\ & \quad \_ \\ & \quad \theta \gets PositionalEncoding() \\ & \quad \theta \gets TrainablePositionalPadding() \\ & \quad \theta \gets TrainablePositionalMasking() \\ & ) \\ & \rightarrow \\ & \quad h \in \sR^{\nBatchSize \times T^{\prime} \times d^{\prime}} \\ & \quad l_{x} \in n_{mod} \times \sN \end{aligned} Args: multimodal_inputs: a batch size of the sequence. multimodal_output_positions: a subset of positions (i.e. time steps) within each modality encoding will be calculated. Returns: encoded vectors combined in a single sequence and sequence lengths of original modalities. """ modalities: List[EncodingCt] r"""a list of constructors for an encoding module (e.g. :class:`flax_extra.layer.Encoding`) for each modality.""" mask_rates: Optional[List[int]] = None r"""a probability for each modality being masked.""" d_reserved: int = 1 r"""a number of reserved feature dimensions.""" @property def n_modalities(self) -> int: r"""a number of modalities.""" return len(self.modalities) @nn.compact def __call__( # type: ignore[override] # pylint: disable=arguments-differ, disable=too-many-locals self, multimodal_inputs: List[Array], multimodal_output_positions: Optional[List[Optional[Positions]]], ) -> tuple[Array, List[int]]: if _verify_modalities(multimodal_inputs, self.n_modalities) is None: raise ValueError( f"The number of {self.__name__} modalities {self.n_modalities} " f"doesn't match the number of its arguments." ) multimodal_output_positions = _normalize_modalities( multimodal_output_positions, self.n_modalities, ) if multimodal_output_positions is None: raise ValueError( f"The number of {self.__name__} modalities {self.n_modalities} " f"doesn't match the number of provided output positions." ) mask_rates = _normalize_modalities(self.mask_rates, self.n_modalities) if mask_rates is None: raise ValueError( f"The number of {self.__name__} modalities {self.n_modalities} " f"doesn't match the number of provided mask rates." ) # Single modality. if len(self.modalities) == 1: encoding = self.modalities[0] encoded_inputs = encoding()( multimodal_inputs[0], multimodal_output_positions[0], ) seqlen_encoded_input = encoded_inputs.shape[1] return encoded_inputs, [seqlen_encoded_input] # Multimodal. # Encode modalities, count input lengths and max input dimension. def encode(acc: _EncodeInitializer, data: _EncodeData) -> _EncodeInitializer: inputs, output_positions, encoding = data d_input_max, seqlen_inputs, multimodal_encoded_inputs = acc encoded_inputs = encoding()(inputs, output_positions) d_input = encoded_inputs.shape[-1] seqlen_input = encoded_inputs.shape[1] return ( max(d_input_max, d_input), seqlen_inputs + [seqlen_input], multimodal_encoded_inputs + [encoded_inputs], ) encode_initializer: _EncodeInitializer = (0, [], []) d_input_max, seqlen_inputs, multimodal_encoded_inputs = reduce( encode, zip(multimodal_inputs, multimodal_output_positions, self.modalities), encode_initializer, ) d_input_max += self.d_reserved # Pad inputs along their dimension axis, then apply masking. def padmask(acc: List[Array], data: tuple[float, Array]) -> List[Array]: mask_rate, inputs = data padded_inputs = TrainablePositionalPadding(d_max=d_input_max)(inputs) if mask_rate is not None: padded_inputs = TrainablePositionalMasking(rate=mask_rate)( padded_inputs ) return acc + [padded_inputs] padmask_initializer: List[Array] = [] multimodal_padded_inputs = reduce( padmask, zip(mask_rates, multimodal_encoded_inputs), padmask_initializer, ) return jnp.concatenate(multimodal_padded_inputs, axis=1), seqlen_inputs
MultimodalEncodingCt = Callable[..., MultimodalEncoding]