flax_extra.model.perceiver package¶
The Perceiver model.
- class flax_extra.model.perceiver.CrossAttentionBlock(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, use_q_residual: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleA block of a cross-attention module and following feed-forward module.
\[\begin{split}\begin{aligned} & \textrm{SelfAttentionBlock}( \\ & \quad x_{q} \in \sR^{\nBatchSize \times \nSeqLen_{x_{q}} \times d_{x_{q}}} \\ & \quad x_{kv} \in \sR^{\nBatchSize \times \nSeqLen_{x_{kv}} \times d_{x_{kv}}} \\ & \quad \_ \\ & \quad \theta \gets 2 \times LayerNorm() \\ & \quad \theta \gets CrossAttentionBlock() \\ & \quad \theta \gets LayerNorm() \\ & \quad \theta \gets FeedForward() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{x_{q}} \times d_{x_{kv}}} \end{aligned}\end{split}\]- Parameters
inputs_q – inputs for the query.
inputs_kv – inputs for key and value.
mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.
- Returns
latent features.
- attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))¶
a type of the cross-attention.
- deterministic: bool = True¶
whether to perform deterministically or not.
- dropout_rate: float = 0.0¶
probababilistic rate for dropout.
- feed_forward¶
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- precision: Optional[Any] = None¶
numerical precision of the computation. See
jax.lax.Precisionfor details.
- scope = None¶
- use_q_residual: bool = True¶
whether to include a residual to the query. Consider omitting the residual if the semantics of query and output are different, e.g. if queries are positions and outputs are pixels.
- class flax_extra.model.perceiver.Decoder(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, use_q_residual: int = False, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleA single cross-attention block projects latent features to outputs of desired dimension.
\[\begin{split}\begin{aligned} & \textrm{Decoder}( \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad o \in \sR^{\nBatchSize \times \nSeqLen_{o} \times d_{o}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{z} \times \nSeqLen_{o}} \\ & \quad \_ \\ & \quad \theta \gets CrossAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{o} \times d_{o}} \end{aligned}\end{split}\]- Parameters
latents – latent features.
targets – inputs for decoder’s query of desired dimension.
target_mask – a padding mask indicating at which positions values of the targets are valid.
- Returns
outputs with the same dimension as inputs for decoder’s query.
- attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))¶
a type of the cross-attention.
- deterministic: bool = True¶
whether to perform deterministically or not.
- dropout_rate: float = 0.0¶
probababilistic rate for dropout.
- feed_forward¶
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- precision: Optional[Any] = None¶
numerical precision of the computation. See
jax.lax.Precisionfor details.
- scope = None¶
- use_q_residual: int = False¶
whether to include a residual to the query. Consider omitting the residual if the semantics of decoder’s query and outputs are different (e.g. if queries are positions and outputs are pixels).
- class flax_extra.model.perceiver.Encoder(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, use_q_residual: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleA single cross-attention block projects a high-dimensional input vectors to a fixed-dimensional vectors of latent features.
\[\begin{split}\begin{aligned} & \textrm{Encoder}( \\ & \quad x \in \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad mask \in \sR^{\nBatchSize \times \nSeqLen_{z} \times \nSeqLen_{x}} \\ & \quad \_ \\ & \quad \theta \gets CrossAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \end{aligned}\end{split}\]- Parameters
inputs – a high-dimensional input array.
latents – inputs for encoder’s query of desired dimension.
input_mask – a padding mask indicating at which positions values of the inputs are valid.
- Returns
latent features with the same dimension as inputs for encoder’s query.
- attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))¶
a type of the cross-attention.
- deterministic: bool = True¶
whether to perform deterministically or not.
- dropout_rate: float = 0.0¶
probababilistic rate for dropout.
- feed_forward¶
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- precision: Optional[Any] = None¶
numerical precision of the computation. See
jax.lax.Precisionfor details.
- scope = None¶
- use_q_residual: bool = True¶
whether to include a residual to the query. Consider omitting the residual if the semantics of encoder’s query and latent features are different (e.g. if queries are positions and latents are pixels).
- class flax_extra.model.perceiver.Processor(n_shards: int = 8, n_blocks: int = 6, attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._attention.SelfAttention'>, feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleSelf-attends latent features.
Self-attention blocks grouped in shards. Weights of the blocks are shared across shards.
\[\begin{split}\begin{aligned} & \textrm{Processor}( \\ & \quad z \in \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \\ & \quad \_ \\ & \quad \theta \gets n_{blocks} \times SelfAttentionBlock() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{z} \times d_{z}} \end{aligned}\end{split}\]- Parameters
latents – an array of latent features.
- Returns
an array of latent features.
- attention¶
- deterministic: bool = True¶
whether to perform deterministically or not.
- dropout_rate: float = 0.0¶
probababilistic rate for dropout.
- feed_forward¶
- n_blocks: int = 6¶
a number of self-attention blocks building up a single shard.
- n_shards: int = 8¶
a number of shards.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- precision: Optional[Any] = None¶
numerical precision of the computation. See
jax.lax.Precisionfor details.
- scope = None¶
- class flax_extra.model.perceiver.SelfAttentionBlock(attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._attention.SelfAttention'>, feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, dropout_rate: float = 0.0, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleA block of a self-attention module and following feed-forward module.
\[\begin{split}\begin{aligned} & \textrm{SelfAttentionBlock}( \\ & \quad x \in \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \\ & \quad \_ \\ & \quad \theta \gets LayerNorm() \\ & \quad \theta \gets SelfAttentionBlock() \\ & \quad \theta \gets LayerNorm() \\ & \quad \theta \gets FeedForward() \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times \nSeqLen_{x} \times d_{x}} \end{aligned}\end{split}\]- Parameters
latents – latent features.
mask – a mask tensor with boolean values indicating whether a particular query attends to a particular key.
- Returns
latent features.
- attention¶
- deterministic: bool = True¶
whether to perform deterministically or not.
- dropout_rate: float = 0.0¶
probababilistic rate for dropout.
- feed_forward¶
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- precision: Optional[Any] = None¶
numerical precision of the computation. See
jax.lax.Precisionfor details.
- scope = None¶