Source code for flax_extra.model.perceiver._processor

r"""The Perceiver processor."""

from typing import Any, Optional
from functools import partial
import jax.numpy as jnp
from flax import linen as nn
from flax_extra import combinator as cb
from flax_extra.layer._feedforward import FeedForward, FeedForwardCt
from flax_extra.layer._attention import SelfAttention, SelfAttentionCt
from flax_extra.model.perceiver._self_attention_block import SelfAttentionBlock

Array = jnp.ndarray
Precision = Any


[docs]class Processor(nn.Module): r"""Self-attends latent features. Self-attention blocks grouped in shards. Weights of the blocks are shared across shards. .. math:: \begin{aligned} & \textrm{Processor}( \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad \_ \\ & \quad \theta \gets n_{blocks} \times SelfAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \end{aligned} Args: latents: an array of latent features. Returns: an array of latent features. """ n_shards: int = 8 r"""a number of shards.""" n_blocks: int = 6 r"""a number of self-attention blocks building up a single shard.""" attention: SelfAttentionCt = SelfAttention r"""a type of the self-attention.""" feed_forward: FeedForwardCt = FeedForward r"""a type of the feed-forward.""" dropout_rate: float = 0.0 r"""probababilistic rate for dropout.""" deterministic: bool = True r"""whether to perform deterministically or not.""" precision: Optional[Precision] = None r"""numerical precision of the computation. See :attr:`jax.lax.Precision` for details.""" @nn.compact def __call__(self, latents: Array) -> Array: # type: ignore[override] # pylint: disable=arguments-differ self_attention_blocks = [ partial( SelfAttentionBlock( attention=self.attention, feed_forward=self.feed_forward, dropout_rate=self.dropout_rate, deterministic=self.deterministic, precision=self.precision, ), mask=None, ) for _ in range(self.n_blocks) ] # Processor block (latent-space Transformer: a number of self-attentions). blocks = cb.serial( [block for _ in range(self.n_shards) for block in self_attention_blocks] ) return blocks(latents) # type: ignore