Source code for flax_extra.model.perceiver._io

r"""The Perceiver IO model."""

from typing import cast, Any, List, Optional, Union
from jax import numpy as jnp
from flax import linen as nn
from flax_extra import combinator as cb
from flax_extra.layer._attention import (
    SelfAttention,
    SelfAttentionCt,
    KVQAttention,
    KVQAttentionCt,
)
from flax_extra.layer._feedforward import (
    FeedForward,
    FeedForwardCt,
)
from flax_extra.layer._multimodal_encoding import (
    MultimodalEncodingCt,
)
from flax_extra.layer._multimodal_positional_encoding import (
    MultimodalPositionalEncodingCt,
)
from flax_extra.layer._multimodal_decoding import (
    MultimodalDecodingCt,
)
from flax_extra.model.perceiver._encoder import Encoder
from flax_extra.model.perceiver._processor import Processor
from flax_extra.model.perceiver._decoder import Decoder

Array = jnp.ndarray
Precision = Any
Positions = List[int]


[docs]class PerceiverIO(nn.Module): r"""Maps arbitrary input tensors to arbitrary output tensors in a domain agnostic process. The bulk of the computation happens in a latent space whose size is typically smaller than the inputs and outputs, which makes the process computationally tractable even for very large inputs and outputs. .. math:: \begin{aligned} & \textrm{PerceiverIO}( \\ & \quad x \in n_{mod_{x}} \times \sR^{m \times T_{x} \times d_{x}} \\ & \quad mask_{x} \in \sR^{\nBatchSize \times \nSeqLen_{x}} \\ & \quad y \in n_{mod_{y}} \times \sR^{m \times T_{y} \times d_{y}} \\ & \quad mask_{y} \in \sR^{\nBatchSize \times \nSeqLen_{y}} \\ & \quad t_{y} \in n_{mod_{y}} \times \sN^{\nSeqLen^{\prime}_{y}} \\ & \quad \_ \\ & \quad \theta \gets MultimodalEncoding() \\ & \quad \theta \gets MultimodalPositionalEncoding() \\ & \quad \theta \gets Encoder() \\ & \quad \theta \gets Processor() \\ & \quad \theta \gets MultimodalEncoding() \\ & \quad \theta \gets MultimodalPositionalEncoding() \mid MultimodalEncoding() \\ & \quad \theta \gets Decoder() \\ & \quad \theta \gets MultimodalDecoding() \\ & ) \\ & \rightarrow n_{mod_{y}} \times \sR^{\nBatchSize \times \dots d^{\prime}_{y}} \end{aligned} Args: inputs: a single or multimodal input sequence(s). input_mask: a padding mask indicating at which positions values of the inputs are valid. targets: a single or multimodal target sequence(s). If provided, `decoder_query_encoding` must be of :class:`flax_extra.layer.MultimodalEncoding` type. Otherwise :class:`flax_extra.layer.MultimodalPositionalEncoding` type must be used. target_mask: a padding mask indicating at which positions values of the targets are valid. output_positions: a subset of positions (i.e. time steps) within each modality encoding will be calculated. Returns: a single or multimodal output sequence(s). """ input_encoding: MultimodalEncodingCt r"""preprocessing and encoding of a single or multimodal input. Use :meth:`flax_extra.layer.io.input_encoding` to define the module type.""" encoder_query_encoding: MultimodalPositionalEncodingCt r"""an encoding for encoder's query. Use :meth:`flax_extra.layer.io.query_encoding` to define the module type.""" decoder_query_encoding: Union[MultimodalPositionalEncodingCt, MultimodalEncodingCt] r"""an encoding for decoder's query. Use :meth:`flax_extra.layer.io.query_encoding` to define the module type.""" output_decoding: MultimodalDecodingCt r"""output decoding and postrocessing. Use :meth:`flax_extra.layer.io.output_decoding` to define the module type.""" n_processor_shards: int = 8 r"""a number of shards. See :class:`flax_extra.model.perceiver.Processor`.""" n_processor_blocks: int = 6 r"""a number of self-attention blocks building up a single shard. See :class:`flax_extra.model.perceiver.Processor`.""" processor_attention: SelfAttentionCt = SelfAttention r"""a type of the self-attention for processor. See :class:`flax_extra.layer.SelfAttention`.""" processor_feed_forward: FeedForwardCt = FeedForward r"""a type of the feed-forward for processor. See :class:`flax_extra.layer.FeedForward`.""" encoder_attention: KVQAttentionCt = KVQAttention r"""a type of the cross-attention for encoder. See :class:`flax_extra.layer.KVQAttention`.""" encoder_feed_forward: FeedForwardCt = FeedForward r"""a type of the feed-forward for encoder. See :class:`flax_extra.layer.FeedForward`.""" use_encoder_q_residual: bool = True r"""whether to include a residual to the query. Consider omitting the residual if the semantics of encoder's query and latent features are different (e.g. if queries are positions and latents are pixels).""" decoder_attention: KVQAttentionCt = KVQAttention r"""a type of the cross-attention for decoder. See :class:`flax_extra.layer.KVQAttention`.""" decoder_feed_forward: FeedForwardCt = FeedForward r"""a type of the feed-forward for decoder. See :class:`flax_extra.layer.FeedForward`.""" use_decoder_q_residual: bool = False r"""whether to include a residual to the query. Consider omitting the residual if the semantics of decoder's query and outputs are different (e.g. if queries are positions and outputs are pixels).""" deterministic: bool = True r"""whether to perform deterministically or not.""" precision: Optional[Precision] = None r"""numerical precision of the computation. See :attr:`jax.lax.Precision` for details.""" @nn.compact def __call__( # type: ignore[override] # pylint: disable=arguments-differ, disable=too-many-arguments), disable=too-many-locals self, inputs: Union[Array, List[Array]], input_mask: Optional[Array] = None, targets: Union[None, Array, List[Array]] = None, target_mask: Optional[Array] = None, output_positions: Union[None, Positions, List[Optional[Positions]]] = None, ) -> Union[Array, List[Array]]: encoder_query_encoding = self.encoder_query_encoding( name="EncodedQueryEncoding" ) input_encoding = self.input_encoding( name="InputEncoding", ) decoder_query_encoding = self.decoder_query_encoding( name="DecoderQueryEncoding" ) io_block = cb.serial( Encoder( attention=self.encoder_attention, feed_forward=self.encoder_feed_forward, use_q_residual=self.use_encoder_q_residual, deterministic=self.deterministic, precision=self.precision, name="Encoder", ), Processor( n_shards=self.n_processor_shards, n_blocks=self.n_processor_blocks, attention=self.processor_attention, feed_forward=self.processor_feed_forward, deterministic=self.deterministic, precision=self.precision, name="Processor", ), Decoder( attention=self.decoder_attention, feed_forward=self.decoder_feed_forward, use_q_residual=self.use_decoder_q_residual, deterministic=self.deterministic, precision=self.precision, name="Decoder", ), ) output_decoding = self.output_decoding( name="OutputDecoding", ) # Cast inputs to a list of a single or # multiple sequences per each input modality. if input_encoding.n_modalities == 1: inputs = cast(List[Array], [inputs]) inputs = cast(List[Array], inputs) # Cast targets and output positions to a list of a single or # multiple sequences per each output modality. if output_decoding.n_modalities == 1: if targets is not None: targets = cast(List[Array], [targets]) if output_positions is not None: output_positions = cast(List[Optional[Positions]], [output_positions]) targets = cast(Optional[List[Array]], targets) output_positions = cast(Optional[List[Optional[Positions]]], output_positions) batch_size = inputs[0].shape[0] encoded_inputs, seqlen_inputs = input_encoding( multimodal_inputs=inputs, multimodal_output_positions=None, ) del seqlen_inputs encoder_query, seqlen_latents = encoder_query_encoding(batch_size, None) del seqlen_latents decoder_query, seqlen_outputs = decoder_query_encoding( batch_size if targets is None else targets, # type: ignore output_positions, ) outputs = io_block( # type: ignore encoded_inputs, # type: ignore encoder_query, input_mask, decoder_query, target_mask, ) return output_decoding(outputs, seqlen_outputs) # type: ignore