Source code for flax_extra.model.rnn._lm
r"""RNN language model."""
from typing import Callable
from flax import linen as nn
from flax.linen.recurrent import Array
from flax_extra import combinator as cb
from flax_extra import layer as xn
from flax_extra import operator as xp
RNNT = Callable[..., nn.Module]
[docs]class RNNLM(nn.Module):
r"""RNN language model."""
vocab_size: int
r"""size of the vocabulary."""
d_model: int = 512
r"""depth of the model."""
n_layers: int = 2
r"""a number of RNN layers."""
dropout_rate: float = 0.1
r"""probababilistic rate for dropout."""
rnn_type: RNNT = xn.LSTM
r"""a type of the RNN."""
deterministic: bool = True
r"""whether to perform deterministically or not."""
@nn.compact
def __call__(self, inputs: Array) -> Array: # type: ignore[override] # pylint: disable=arguments-differ
return cb.serial(
xp.ShiftRight(axis=1),
nn.Embed(
num_embeddings=self.vocab_size,
features=self.d_model,
embedding_init=nn.initializers.normal(stddev=1.0),
),
[self.rnn_type(d_hidden=self.d_model) for _ in range(self.n_layers)],
nn.Dropout(
rate=self.dropout_rate,
deterministic=self.deterministic,
),
nn.Dense(features=self.vocab_size),
)(inputs)