flax_extra.model package¶
Flax models.
- class flax_extra.model.PerceiverIO(input_encoding: Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding], encoder_query_encoding: Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding], decoder_query_encoding: Union[Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding], Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding]], output_decoding: Callable[[...], flax_extra.layer._multimodal_decoding.MultimodalDecoding], n_processor_shards: int = 8, n_processor_blocks: int = 6, processor_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._attention.SelfAttention'>, processor_feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, encoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), encoder_feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, use_encoder_q_residual: bool = True, decoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0)), decoder_feed_forward: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = <class 'flax_extra.layer._feedforward.FeedForward'>, use_decoder_q_residual: bool = False, deterministic: bool = True, precision: Optional[Any] = None, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleMaps arbitrary input tensors to arbitrary output tensors in a domain agnostic process.
The bulk of the computation happens in a latent space whose size is typically smaller than the inputs and outputs, which makes the process computationally tractable even for very large inputs and outputs.
\[\begin{split}\begin{aligned} & \textrm{PerceiverIO}( \\ & \quad x \in n_{mod_{x}} \times \sR^{m \times T_{x} \times d_{x}} \\ & \quad mask_{x} \in \sR^{\nBatchSize \times \nSeqLen_{x}} \\ & \quad y \in n_{mod_{y}} \times \sR^{m \times T_{y} \times d_{y}} \\ & \quad mask_{y} \in \sR^{\nBatchSize \times \nSeqLen_{y}} \\ & \quad t_{y} \in n_{mod_{y}} \times \sN^{\nSeqLen^{\prime}_{y}} \\ & \quad \_ \\ & \quad \theta \gets MultimodalEncoding() \\ & \quad \theta \gets MultimodalPositionalEncoding() \\ & \quad \theta \gets Encoder() \\ & \quad \theta \gets Processor() \\ & \quad \theta \gets MultimodalEncoding() \\ & \quad \theta \gets MultimodalPositionalEncoding() \mid MultimodalEncoding() \\ & \quad \theta \gets Decoder() \\ & \quad \theta \gets MultimodalDecoding() \\ & ) \\ & \rightarrow n_{mod_{y}} \times \sR^{\nBatchSize \times \dots d^{\prime}_{y}} \end{aligned}\end{split}\]- Parameters
inputs – a single or multimodal input sequence(s).
input_mask – a padding mask indicating at which positions values of the inputs are valid.
targets – a single or multimodal target sequence(s). If provided, decoder_query_encoding must be of
flax_extra.layer.MultimodalEncodingtype. Otherwiseflax_extra.layer.MultimodalPositionalEncodingtype must be used.target_mask – a padding mask indicating at which positions values of the targets are valid.
output_positions – a subset of positions (i.e. time steps) within each modality encoding will be calculated.
- Returns
a single or multimodal output sequence(s).
- decoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))¶
a type of the cross-attention for decoder. See
flax_extra.layer.KVQAttention.
- decoder_feed_forward¶
- decoder_query_encoding: Union[Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding], Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding]]¶
an encoding for decoder’s query. Use
flax_extra.layer.io.query_encoding()to define the module type.
- deterministic: bool = True¶
whether to perform deterministically or not.
- encoder_attention: Callable[[...], Callable[[jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] = functools.partial(<class 'flax_extra.layer._attention.Attention'>, d_qk=functools.partial(<function _like_shape>, index=1), d_output=functools.partial(<function _like_shape>, index=0))¶
a type of the cross-attention for encoder. See
flax_extra.layer.KVQAttention.
- encoder_feed_forward¶
- encoder_query_encoding: Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding]¶
an encoding for encoder’s query. Use
flax_extra.layer.io.query_encoding()to define the module type.
- input_encoding: Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding]¶
preprocessing and encoding of a single or multimodal input. Use
flax_extra.layer.io.input_encoding()to define the module type.
- n_processor_blocks: int = 6¶
a number of self-attention blocks building up a single shard. See
flax_extra.model.perceiver.Processor.
- n_processor_shards: int = 8¶
a number of shards. See
flax_extra.model.perceiver.Processor.
- name: str = None¶
- output_decoding: Callable[[...], flax_extra.layer._multimodal_decoding.MultimodalDecoding]¶
output decoding and postrocessing. Use
flax_extra.layer.io.output_decoding()to define the module type.
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- precision: Optional[Any] = None¶
numerical precision of the computation. See
jax.lax.Precisionfor details.
- processor_attention¶
- processor_feed_forward¶
- scope = None¶
- use_decoder_q_residual: bool = False¶
whether to include a residual to the query. Consider omitting the residual if the semantics of decoder’s query and outputs are different (e.g. if queries are positions and outputs are pixels).
- use_encoder_q_residual: bool = True¶
whether to include a residual to the query. Consider omitting the residual if the semantics of encoder’s query and latent features are different (e.g. if queries are positions and latents are pixels).
- class flax_extra.model.RNNLM(vocab_size: int, d_model: int = 512, n_layers: int = 2, dropout_rate: float = 0.1, rnn_type: Callable[[...], flax.linen.module.Module] = <class 'flax_extra.layer._lstm.LSTM'>, deterministic: bool = True, parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>, name: str = None)[source]¶
Bases:
flax.linen.module.ModuleRNN language model.
- d_model: int = 512¶
depth of the model.
- deterministic: bool = True¶
whether to perform deterministically or not.
- dropout_rate: float = 0.1¶
probababilistic rate for dropout.
- n_layers: int = 2¶
a number of RNN layers.
- name: str = None¶
- parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>¶
- rnn_type¶
alias of
flax_extra.layer._lstm.LSTM
- scope = None¶
- vocab_size: int¶
size of the vocabulary.