Source code for flax_extra.layer._trainable_positional_masking
r"""Trainable positional maskin."""
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax_extra.layer._trainable_positional_encoding import TrainablePositionalEncoding
Array = jnp.ndarray
[docs]class TrainablePositionalMasking(nn.Module):
r"""Learns a mask vector and applies it to a vector at each position
(i.e. time step) of the sequence according specified probabilistic rate.
.. math::
\begin{aligned}
& \textrm{TrainablePositionalMasking}( \\
& \quad x \in \sR^{\nBatchSize \times T \times d} \\
& \quad \_ \\
& \quad w \in \sR^{1 \times d} \\
& ) \\
& \rightarrow \sR^{\nBatchSize \times T \times d}
\end{aligned}
Args:
inputs: a sequence.
Returns:
a masked sequence.
"""
rate: float
r"""a probability for a vector at each position being masked."""
@nn.compact
def __call__(self, inputs: Array) -> Array: # type: ignore[override] # pylint: disable=arguments-differ
batch_size, seqlen_input, d_input = inputs.shape
masking = TrainablePositionalEncoding(
seqlen=1,
dimension=d_input,
)
mask_weight = jax.random.bernoulli(
self.make_rng("params"),
self.rate,
shape=(batch_size, seqlen_input),
)[:, :, jnp.newaxis]
mask_token = masking(batch_size=batch_size)
masked_inputs = (1 - mask_weight) * inputs + mask_weight * mask_token
return masked_inputs # type: ignore