Source code for flax_extra.layer._embedding_decoding
r"""Embedding decoding."""
from jax import numpy as jnp
from flax import linen as nn
Array = jnp.ndarray
[docs]class EmbedDecoding(nn.Module):
r"""Projects encoded vectors at each position (i.e. time step)
of the sequence to its original representation using projection
tensor of the embeding encoder.
.. math::
\begin{aligned}
& \textrm{Decoding}( \\
& \quad y \in \sR^{\nBatchSize \times \nSeqLen \times d} \\
& \quad \_ \\
& \quad b_{x} \in \sR^{d_{x}} \\
& ) \\
& \rightarrow \sR^{\nBatchSize \times \nSeqLen \times d_{x}}
\end{aligned}
Args:
outputs: model's outputs.
Returns:
decoded outputs.
"""
embedding: Array
r"""embedding projection tensor of the embedding encoder
(e.g. :class:`flax.linen.Embed`)."""
@nn.compact
def __call__(self, embeddings: Array) -> Array: # type: ignore[override] # pylint: disable=arguments-differ
vocab_size, d_model = self.embedding.shape
batch_size, seq_len, _ = embeddings.shape
output = jnp.matmul(
jnp.reshape(embeddings, (-1, d_model)),
jnp.transpose(self.embedding),
)
bias = self.param(
"bias",
nn.zeros,
(vocab_size,),
jnp.float32,
)
output = output + bias
return jnp.reshape(output, (batch_size, seq_len, vocab_size)) # type: ignore