flax_extra.layer.io module

Input and output encoding utility functions.

flax_extra.layer.io.input_encoding(*modalities: Callable[[...], flax_extra.layer._encoding.Encoding], **kvargs: Any) Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding][source]

Defines an encoding for single or multimodal input.

flax_extra.layer.io.output_decoding(*modalities: Callable[[...], flax_extra.layer._decoding.Decoding], **kvargs: Any) Callable[[...], flax_extra.layer._multimodal_decoding.MultimodalDecoding][source]

Defines a decoding for single or multimodal output.

flax_extra.layer.io.query_encoding(*modalities: Callable[[...], Callable[[int, Optional[List[int]]], jax._src.numpy.lax_numpy.ndarray]], **kvargs: Any) Callable[[...], flax_extra.layer._multimodal_positional_encoding.MultimodalPositionalEncoding][source]

Defines a decoder or encoder query.

flax_extra.layer.io.target_encoding(*modalities: Callable[[...], flax_extra.layer._encoding.Encoding], **kvargs: Any) Callable[[...], flax_extra.layer._multimodal_encoding.MultimodalEncoding]

Defines an encoding for single or multimodal target.