Source code for flax_extra.layer._decoding
r"""Trainable decoding."""
from typing import cast, Callable
from jax import numpy as jnp
from flax import linen as nn
from flax_extra import combinator as cb
Array = jnp.ndarray
EmbeddingDecodingFn = Callable[[Array], Array]
EmbeddingDecodingCt = Callable[..., EmbeddingDecodingFn]
PostprocessingFn = Callable[[Array], Array]
PostprocessingCt = Callable[..., PostprocessingFn]
[docs]class Decoding(nn.Module):
r"""Decodes a vector at each position (i.e. time step)
of the sequence using an optional embedding decoding
and postprocessing step.
The postprocessing is intended to change shape
of the sequence to desired.
.. math::
\begin{aligned}
& \textrm{Decoding}( \\
& \quad y \in \sR^{\nBatchSize \times \nSeqLen \times d} \\
& \quad \_ \\
& \quad \theta \gets EmbeddingDecoding() \\
& \quad \theta \gets Postprocessing() \\
& ) \\
& \rightarrow \sR^{\nBatchSize \times \dots d^{\prime}}
\end{aligned}
Args:
outputs: model's outputs.
Returns:
decoded outputs.
"""
embedding_decoding: EmbeddingDecodingCt = cast(
EmbeddingDecodingCt,
lambda: cb.identity(n_in=1),
)
r"""a type of the embedding decoder
(e.g. :class:`flax_extra.layer.EmbedDecoding` or
:class:`flax.linen.Dense`)."""
postprocessing: PostprocessingCt = cast(
PostprocessingCt,
lambda: cb.identity(n_in=1),
)
r"""a constructor for a module or function that accepts decoded outputs
of as its single argument and returns the preprocessed outputs as a
single output."""
@nn.compact
def __call__(self, outputs: Array) -> Array: # type: ignore[override] # pylint: disable=arguments-differ
# Decode embeddings.
decoded_outputs = self.embedding_decoding()(outputs)
# Postprocess.
postprocessed_outputs = self.postprocessing()(decoded_outputs)
return postprocessed_outputs
DecodingCt = Callable[..., Decoding]