Source code for flax_extra.model.perceiver._encoder

r"""The Perceiver encoder."""

from typing import Any, Optional
import jax.numpy as jnp
from flax import linen as nn
from flax_extra.layer._feedforward import FeedForward, FeedForwardCt
from flax_extra.layer._attention import (
    KVQAttention,
    KVQAttentionCt,
    cross_attention_mask,
)
from flax_extra.model.perceiver._cross_attention_block import CrossAttentionBlock

Array = jnp.ndarray
Precision = Any


[docs]class Encoder(nn.Module): r"""A single cross-attention block projects a high-dimensional input vectors to a fixed-dimensional vectors of latent features. .. math:: \begin{aligned} & \textrm{Encoder}( \\ & \quad x \in \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{z} \times \nSeqLen_{x}} \\ & \quad \_ \\ & \quad \theta \gets CrossAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \end{aligned} Args: inputs: a high-dimensional input array. latents: inputs for encoder's query of desired dimension. input_mask: a padding mask indicating at which positions values of the inputs are valid. Returns: latent features with the same dimension as inputs for encoder's query. """ attention: KVQAttentionCt = KVQAttention r"""a type of the cross-attention.""" feed_forward: FeedForwardCt = FeedForward r"""a type of the feed-forward.""" dropout_rate: float = 0.0 r"""probababilistic rate for dropout.""" deterministic: bool = True r"""whether to perform deterministically or not.""" use_q_residual: bool = True r"""whether to include a residual to the query. Consider omitting the residual if the semantics of encoder's query and latent features are different (e.g. if queries are positions and latents are pixels).""" precision: Optional[Precision] = None r"""numerical precision of the computation. See :attr:`jax.lax.Precision` for details.""" @nn.compact def __call__( # type: ignore[override] # pylint: disable=arguments-differ self, inputs: Array, latents: Array, input_mask: Optional[Array], ) -> Array: attention_mask = None if input_mask is not None: attention_mask = cross_attention_mask( mask_q=jnp.ones(latents.shape[:2], dtype=jnp.int32), mask_kv=input_mask, ) return CrossAttentionBlock( attention=self.attention, feed_forward=self.feed_forward, dropout_rate=self.dropout_rate, deterministic=self.deterministic, use_q_residual=self.use_q_residual, precision=self.precision, )(inputs_q=latents, inputs_kv=inputs, mask=attention_mask)