flax_extra.model package

Flax models.

class flax_extra.model.PerceiverIO(input_encoding: Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding], encoder_query_encoding: Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding], decoder_query_encoding: Union[Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding], Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding]], output_decoding: Callable[[...], flax_extra.layer._multimodal_decoding.MultimodalDecoding], n_processor_shards: int = 8, n_processor_blocks: int = 6, processor_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._attention.SelfAttention'>, processor_feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, encoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), encoder_feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, use_encoder_q_residual: bool = True, decoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), decoder_feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, use_decoder_q_residual: bool = False, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

Maps arbitrary input tensors to arbitrary output tensors in a domain agnostic process.

The bulk of the computation happens in a latent space whose size is typically smaller than the inputs and outputs, which makes the process computationally tractable even for very large inputs and outputs.

\[\begin{split}\begin{aligned} & \textrm{PerceiverIO}( \\ & \quad x \in n_{mod_{x}} \times \sR^{m \times T_{x} \times d_{x}} \\ & \quad mask_{x} \in \sR^{\nBatchSize \times \nSeqLen_{x}} \\ & \quad y \in n_{mod_{y}} \times \sR^{m \times T_{y} \times d_{y}} \\ & \quad mask_{y} \in \sR^{\nBatchSize \times \nSeqLen_{y}} \\ & \quad t_{y} \in n_{mod_{y}} \times \sN^{\nSeqLen^{\prime}_{y}} \\ & \quad \_ \\ & \quad \theta \gets MultimodalEncoding() \\ & \quad \theta \gets MultimodalPositionalEncoding() \\ & \quad \theta \gets Encoder() \\ & \quad \theta \gets Processor() \\ & \quad \theta \gets MultimodalEncoding() \\ & \quad \theta \gets MultimodalPositionalEncoding() \mid MultimodalEncoding() \\ & \quad \theta \gets Decoder() \\ & \quad \theta \gets MultimodalDecoding() \\ & ) \\ & \rightarrow n_{mod_{y}} \times \sR^{\nBatchSize \times \dots d^{\prime}_{y}} \end{aligned}\end{split}\]
Parameters
  • inputs – a single or multimodal input sequence(s).

  • input_mask – a padding mask indicating at which positions values of the inputs are valid.

  • targets – a single or multimodal target sequence(s). If provided, decoder_query_encoding must be of flax_extra.layer.MultimodalEncoding type. Otherwise flax_extra.layer.MultimodalPositionalEncoding type must be used.

  • target_mask – a padding mask indicating at which positions values of the targets are valid.

  • output_positions – a subset of positions (i.e. time steps) within each modality encoding will be calculated.

Returns

a single or multimodal output sequence(s).

decoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))

a type of the cross-attention for decoder. See flax_extra.layer.KVQAttention.

decoder_feed_forward

alias of flax_extra.layer._feedforward.FeedForward

decoder_query_encoding: Union[Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding], Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding]]

an encoding for decoder’s query. Use flax_extra.layer.io.query_encoding() to define the module type.

deterministic: bool = True

whether to perform deterministically or not.

encoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))

a type of the cross-attention for encoder. See flax_extra.layer.KVQAttention.

encoder_feed_forward

alias of flax_extra.layer._feedforward.FeedForward

encoder_query_encoding: Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding]

an encoding for encoder’s query. Use flax_extra.layer.io.query_encoding() to define the module type.

input_encoding: Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding]

preprocessing and encoding of a single or multimodal input. Use flax_extra.layer.io.input_encoding() to define the module type.

n_processor_blocks: int = 6

a number of self-attention blocks building up a single shard. See flax_extra.model.perceiver.Processor.

n_processor_shards: int = 8

a number of shards. See flax_extra.model.perceiver.Processor.

name: str = None
output_decoding: Callable[[...], flax_extra.layer._multimodal_decoding.MultimodalDecoding]

output decoding and postrocessing. Use flax_extra.layer.io.output_decoding() to define the module type.

parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Optional[Any] = None

numerical precision of the computation. See jax.lax.Precision for details.

processor_attention

alias of flax_extra.layer._attention.SelfAttention

processor_feed_forward

alias of flax_extra.layer._feedforward.FeedForward

scope = None
use_decoder_q_residual: bool = False

whether to include a residual to the query. Consider omitting the residual if the semantics of decoder’s query and outputs are different (e.g. if queries are positions and outputs are pixels).

use_encoder_q_residual: bool = True

whether to include a residual to the query. Consider omitting the residual if the semantics of encoder’s query and latent features are different (e.g. if queries are positions and latents are pixels).

class flax_extra.model.RNNLM(vocab_size: int, d_model: int = 512, n_layers: int = 2, dropout_rate: float = 0.1, rnn_type: Callable[[...], flax.linen.module.Module] = <class 'flax_extra.layer._lstm.LSTM'>, deterministic: bool = True, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]

Bases: flax.linen.module.Module

RNN language model.

d_model: int = 512

depth of the model.

deterministic: bool = True

whether to perform deterministically or not.

dropout_rate: float = 0.1

probababilistic rate for dropout.

n_layers: int = 2

a number of RNN layers.

name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
rnn_type

alias of flax_extra.layer._lstm.LSTM

scope = None
vocab_size: int

size of the vocabulary.

Subpackages