Source code for flax_extra.model.perceiver._self_attention_block
r"""Self attention block of the Perceiver model."""
from typing import Any, Optional
from functools import partial
import jax.numpy as jnp
from flax import linen as nn
from flax_extra import combinator as cb
from flax_extra.layer._feedforward import FeedForward, FeedForwardCt
from flax_extra.layer._attention import SelfAttention, SelfAttentionCt
Array = jnp.ndarray
Precision = Any
[docs]class SelfAttentionBlock(nn.Module):
r"""A block of a self-attention module and following
feed-forward module.
.. math::
\begin{aligned}
& \textrm{SelfAttentionBlock}( \\
& \quad x \in \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \\
& \quad \_ \\
& \quad \theta \gets LayerNorm() \\
& \quad \theta \gets SelfAttentionBlock() \\
& \quad \theta \gets LayerNorm() \\
& \quad \theta \gets FeedForward() \\
& ) \\
& \rightarrow \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}}
\end{aligned}
Args:
latents: latent features.
mask: a mask tensor with boolean values indicating whether
a particular query attends to a particular key.
Returns:
latent features.
"""
attention: SelfAttentionCt = SelfAttention
r"""a type of the self-attention."""
feed_forward: FeedForwardCt = FeedForward
r"""a type of the feed-forward."""
dropout_rate: float = 0.0
r"""probababilistic rate for dropout."""
deterministic: bool = True
r"""whether to perform deterministically or not."""
precision: Optional[Precision] = None
r"""numerical precision of the computation.
See :attr:`jax.lax.Precision` for details."""
@nn.compact
def __call__( # type: ignore[override] # pylint: disable=arguments-differ
self,
latents: Array,
mask: Optional[Array],
) -> Array:
block = cb.serial(
cb.residual(
nn.LayerNorm(epsilon=1e-5),
partial(
self.attention,
deterministic=self.deterministic,
precision=self.precision,
)(),
nn.Dropout(
rate=self.dropout_rate,
deterministic=self.deterministic,
),
),
cb.residual(
nn.LayerNorm(epsilon=1e-5),
self.feed_forward(),
nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic),
),
)
return block(latents, mask) # type: ignore