Source code for flax_extra.layer._multimodal_decoding

r"""Multimodal decoding."""

from typing import cast, Callable, List, Union
from functools import reduce
import jax.numpy as jnp
from flax import linen as nn
from flax_extra.layer._decoding import EmbeddingDecodingCt, DecodingCt
from flax_extra.layer._multimodal_positional_encoding import _verify_modalities
from flax_extra import combinator as cb

Array = jnp.ndarray


def _split_modalities(inputs: Array, seqlens: List[int]) -> List[Array]:
    r"""Partitions a multimodal tensor into tensors for each modality."""

    def split(acc: tuple[int, List[Array]], seqlen: int) -> tuple[int, List[Array]]:
        index, modalities = acc
        modality = inputs[:, index : index + seqlen]
        return (
            index + seqlen,
            modalities + [modality],
        )

    initializer: tuple[int, List[Array]] = (0, [])
    _index, modalities = reduce(split, seqlens, initializer)
    return modalities


_DecodeInitializer = List[Array]
_DecodeData = tuple[Array, DecodingCt]


[docs]class MultimodalDecoding(nn.Module): r"""Given original sequence lengths, partitions a multimodal tensor into tensors for each modality and decode each of them. .. math:: \begin{aligned} & \textrm{MultimodalDecoding}( \\ & \quad o \in \sR^{\nBatchSize \times \nSeqLen \times d} \\ & \quad l_{x} \in n_{mod} \times \sN \\ & \quad \_ \\ & \quad \theta \gets MultimodalEmbeddingDecoding() \\ & \quad \theta \gets n_{mod} \times Decoding() \\ & ) \\ & \rightarrow n_{mod} \times \sR^{\nBatchSize \times \dots d^{\prime}} \end{aligned} Args: outputs: multimodal model's outputs. seqlen_outputs: sequence lengths of original modalities. Returns: partitioned modalities. """ modalities: List[DecodingCt] r"""a list of constructors for a decoding module (e.g. :class:`flax_extra.layer.Decoding`) for each modality.""" multimodal_embedding_decoding: EmbeddingDecodingCt = cast( EmbeddingDecodingCt, lambda: cb.identity(n_in=1), ) r"""a type of the embedding decoder (e.g. :class:`flax_extra.layer.EmbedDecoding` or :class:`flax.linen.Dense`).""" @property def n_modalities(self) -> int: r""" "a number of modalities.""" return len(self.modalities) @nn.compact def __call__( # type: ignore[override] # pylint: disable=arguments-differ self, outputs: Array, seqlen_outputs: List[int], ) -> Union[Array, List[Array]]: if _verify_modalities(seqlen_outputs, self.n_modalities) is None: raise ValueError( f"The number of {self.__name__} modalities {self.n_modalities} " f"doesn't match the number of output sequences." ) # Decode multimodal embedding. outputs = self.multimodal_embedding_decoding()(outputs) # Single modality. if self.n_modalities == 1: decoding = self.modalities[0] return decoding()(outputs) # Multimodal. multimodal_outputs = _split_modalities(outputs, seqlen_outputs) def decode(acc: _DecodeInitializer, data: _DecodeData) -> _DecodeInitializer: outputs, decoding = data decoded_outputs = decoding()(outputs) return acc + [decoded_outputs] decode_initializer: _DecodeInitializer = [] return reduce( decode, zip(multimodal_outputs, self.modalities), decode_initializer, )
MultimodalDecodingCt = Callable[..., MultimodalDecoding]