flax_extra.model.perceiver package

The Perceiver model.

class flax_extra.model.perceiver.CrossAttentionBlock(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, use_q_residual: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

A block of a cross-attention module and following feed-forward module.

\[\begin{split}\begin{aligned} & \textrm{SelfAttentionBlock}( \\ & \quad x_{q} \in \sR^{\nBatchSize \times \nSeqLen_{x_{q}} \times d_{x_{q}}} \\ & \quad x_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{x_{kv}} \times d_{x_{kv}}} \\ & \quad \_ \\ & \quad \theta \gets 2 \times LayerNorm() \\ & \quad \theta \gets CrossAttentionBlock() \\ & \quad \theta \gets LayerNorm() \\ & \quad \theta \gets FeedForward() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{x_{q}} \times d_{x_{kv}}} \end{aligned}\end{split}\]
Parameters
  • inputs_q – inputs for the query.

  • inputs_kv – inputs for key and value.

  • mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.

Returns

latent features.

attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))

a type of the cross-attention.

deterministic: bool = True

whether to perform deterministically or not.

dropout_rate: float = 0.0

probababilistic rate for dropout.

feed_forward

alias of flax_extra.layer._feedforward.FeedForward

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Optional[Any] = None

numerical precision of the computation. See jax.lax.Precision for details.

scope = None
use_q_residual: bool = True

whether to include a residual to the query. Consider omitting the residual if the semantics of query and output are different, e.g. if queries are positions and outputs are pixels.

class flax_extra.model.perceiver.Decoder(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, use_q_residual: int = False, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

A single cross-attention block projects latent features to outputs of desired dimension.

\[\begin{split}\begin{aligned} & \textrm{Decoder}( \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad o \in \sR^{\nBatchSize \times \nSeqLen_{o} \times d_{o}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{z} \times \nSeqLen_{o}} \\ & \quad \_ \\ & \quad \theta \gets CrossAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{o} \times d_{o}} \end{aligned}\end{split}\]
Parameters
  • latents – latent features.

  • targets – inputs for decoder’s query of desired dimension.

  • target_mask – a padding mask indicating at which positions values of the targets are valid.

Returns

outputs with the same dimension as inputs for decoder’s query.

attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))

a type of the cross-attention.

deterministic: bool = True

whether to perform deterministically or not.

dropout_rate: float = 0.0

probababilistic rate for dropout.

feed_forward

alias of flax_extra.layer._feedforward.FeedForward

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Optional[Any] = None

numerical precision of the computation. See jax.lax.Precision for details.

scope = None
use_q_residual: int = False

whether to include a residual to the query. Consider omitting the residual if the semantics of decoder’s query and outputs are different (e.g. if queries are positions and outputs are pixels).

class flax_extra.model.perceiver.Encoder(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, use_q_residual: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

A single cross-attention block projects a high-dimensional input vectors to a fixed-dimensional vectors of latent features.

\[\begin{split}\begin{aligned} & \textrm{Encoder}( \\ & \quad x \in \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{z} \times \nSeqLen_{x}} \\ & \quad \_ \\ & \quad \theta \gets CrossAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \end{aligned}\end{split}\]
Parameters
  • inputs – a high-dimensional input array.

  • latents – inputs for encoder’s query of desired dimension.

  • input_mask – a padding mask indicating at which positions values of the inputs are valid.

Returns

latent features with the same dimension as inputs for encoder’s query.

attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))

a type of the cross-attention.

deterministic: bool = True

whether to perform deterministically or not.

dropout_rate: float = 0.0

probababilistic rate for dropout.

feed_forward

alias of flax_extra.layer._feedforward.FeedForward

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Optional[Any] = None

numerical precision of the computation. See jax.lax.Precision for details.

scope = None
use_q_residual: bool = True

whether to include a residual to the query. Consider omitting the residual if the semantics of encoder’s query and latent features are different (e.g. if queries are positions and latents are pixels).

class flax_extra.model.perceiver.Processor(n_shards: int = 8, n_blocks: int = 6, attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._attention.SelfAttention'>, feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Self-attends latent features.

Self-attention blocks grouped in shards. Weights of the blocks are shared across shards.

\[\begin{split}\begin{aligned} & \textrm{Processor}( \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad \_ \\ & \quad \theta \gets n_{blocks} \times SelfAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \end{aligned}\end{split}\]
Parameters

latents – an array of latent features.

Returns

an array of latent features.

attention

alias of flax_extra.layer._attention.SelfAttention

deterministic: bool = True

whether to perform deterministically or not.

dropout_rate: float = 0.0

probababilistic rate for dropout.

feed_forward

alias of flax_extra.layer._feedforward.FeedForward

n_blocks: int = 6

a number of self-attention blocks building up a single shard.

n_shards: int = 8

a number of shards.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Optional[Any] = None

numerical precision of the computation. See jax.lax.Precision for details.

scope = None
class flax_extra.model.perceiver.SelfAttentionBlock(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._attention.SelfAttention'>, feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

A block of a self-attention module and following feed-forward module.

\[\begin{split}\begin{aligned} & \textrm{SelfAttentionBlock}( \\ & \quad x \in \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \\ & \quad \_ \\ & \quad \theta \gets LayerNorm() \\ & \quad \theta \gets SelfAttentionBlock() \\ & \quad \theta \gets LayerNorm() \\ & \quad \theta \gets FeedForward() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \end{aligned}\end{split}\]
Parameters
  • latents – latent features.

  • mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.

Returns

latent features.

attention

alias of flax_extra.layer._attention.SelfAttention

deterministic: bool = True

whether to perform deterministically or not.

dropout_rate: float = 0.0

probababilistic rate for dropout.

feed_forward

alias of flax_extra.layer._feedforward.FeedForward

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Optional[Any] = None

numerical precision of the computation. See jax.lax.Precision for details.

scope = None