Source code for flax_extra.layer._trainable_positional_padding
r"""Trainable positional padding."""
from typing import List
from jax import numpy as jnp
from flax import linen as nn
from flax_extra.layer._trainable_positional_encoding import TrainablePositionalEncoding
Array = jnp.ndarray
Positions = List[int]
[docs]class TrainablePositionalPadding(nn.Module):
r"""Trainable positional padding.
Learns a padding vector and applies it to a vector at each position
(i.e. time step) of the sequence to pad up vector's dimension
to desired value.
.. math::
\begin{aligned}
& \textrm{TrainablePositionalPadding}( \\
& \quad x \in \sR^{\nBatchSize \times T \times d} \\
& \quad \_ \\
& \quad w \in \sR^{1 \times d^{\prime}} \\
& ) \\
& \rightarrow \sR^{\nBatchSize \times T \times d + d^{\prime}}
\end{aligned}
Args:
inputs: a sequence.
Returns:
a sequence with padding along feature dimension.
"""
d_max: int
r"""a desired dimension for vectors in the sequence."""
@nn.compact
def __call__(self, inputs: Array) -> Array: # type: ignore[override] # pylint: disable=arguments-differ
batch_size = inputs.shape[0]
d_input = inputs.shape[-1]
d_padding = self.d_max - d_input
padding = TrainablePositionalEncoding(
seqlen=1,
dimension=d_padding,
)
padded_input_shape = list(inputs.shape)
padded_input_shape[-1] = d_padding
pad_tokens = jnp.broadcast_to(
padding(batch_size=batch_size),
shape=padded_input_shape,
)
padded_inputs = jnp.concatenate(
[inputs, pad_tokens],
axis=-1,
)
return padded_inputs # type: ignore