RNN LM with Flax and Redex combinatorsΒΆ

Redex library makes possible designing Flax layers and models using combinators.

[2]:
import logging
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.recurrent import Array
from flax_extra import combinator as cb
from flax_extra import operator as op
from flax_extra import random

Combinators provide a simple and concise way of composing functional code. They may compose, be used by, be mixed with Flax linen modules, other combinators, or standard python functions.

[3]:
LSTMState = tuple[Array, Array]

class LSTMCell(nn.LSTMCell):
    def __call__(self, carry: LSTMState, inputs: Array) -> tuple[LSTMState, Array]:
        return super().__call__(carry, inputs)

class LSTM(nn.Module):
    d_hidden: int

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        return cb.serial(
            cb.branch(self.initial_state, cb.identity(n_in=1)),
            nn.scan(
                LSTMCell,
                variable_broadcast="params",
                split_rngs={"params": False},
                in_axes=1,
                out_axes=1,
            )(),
            cb.drop(n_in=2),
        )(inputs)

    def initial_state(self, inputs: Array) -> LSTMState:
        batch_size = inputs.shape[0]
        return nn.LSTMCell.initialize_carry(
            self.make_rng("carry"),
            (batch_size,),
            self.d_hidden,
        )

class RNNLM(nn.Module):
    vocab_size: int
    d_model: int = 512
    n_layers: int = 2
    dropout_rate: float = 0.1
    deterministic: bool = True

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        return cb.serial(
            op.ShiftRight(axis=1),
            nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.d_model,
                embedding_init=nn.initializers.normal(stddev=1.0),
            ),
            [LSTM(d_hidden=self.d_model) for _ in range(self.n_layers)],
            nn.Dropout(
                rate=self.dropout_rate,
                deterministic=self.deterministic,
            ),
            nn.Dense(features=self.vocab_size),
        )(inputs)
[4]:
MAX_LENGTH = 4
VOCAB_SIZE = 3
BATCH_SIZE = 1
D_MODEL = 2

model = RNNLM(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
)

model_init = jax.jit(model.init)
model_apply = jax.jit(model.apply)

rnkeyg = random.sequence(seed=1356)
sample = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=int)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

With debug logging enabled, we can inspect how data flow within combinators.

[5]:
logging.getLogger().setLevel("DEBUG")
initial_variables = model_init(
    rngs=random.into_collection(
        key=next(rnkeyg),
        labels=["params", "carry", "dropout"],
    ),
    inputs=sample,
)
logging.getLogger().setLevel("INFO")
DEBUG:absl:Compiling _threefry_split (7090507712) for args (ShapedArray(uint32[2]),).
DEBUG:absl:Compiling _threefry_split (7090529152) for args (ShapedArray(uint32[2]),).
DEBUG:root:constrained_call :: ShiftRight           stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Embed                stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: LSTM                 stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Serial               stack_size=1  signature=Signature(n_in=1, n_out=3, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Select               stack_size=1  signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Parallel             stack_size=2  signature=Signature(n_in=2, n_out=3, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: initial_state        stack_size=1  signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Identity             stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=1, in_shape=((),))
DEBUG:root:constrained_call :: ScanLSTMCell         stack_size=3  signature=Signature(n_in=3, n_out=3, start_index=0, in_shape=(((), ()), ()))
DEBUG:root:constrained_call :: Drop                 stack_size=3  signature=Signature(n_in=2, n_out=0, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: LSTM                 stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Serial               stack_size=1  signature=Signature(n_in=1, n_out=3, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Select               stack_size=1  signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Parallel             stack_size=2  signature=Signature(n_in=2, n_out=3, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: initial_state        stack_size=1  signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Identity             stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=1, in_shape=((),))
DEBUG:root:constrained_call :: ScanLSTMCell         stack_size=3  signature=Signature(n_in=3, n_out=3, start_index=0, in_shape=(((), ()), ()))
DEBUG:root:constrained_call :: Drop                 stack_size=3  signature=Signature(n_in=2, n_out=0, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: Dropout              stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Dense                stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:absl:Compiling init (7090507328) for args (ShapedArray(int32[1,4]), ShapedArray(uint32[2]), ShapedArray(uint32[2]), ShapedArray(uint32[2])).
[6]:
model_apply(
    initial_variables,
    inputs=sample,
    rngs=random.into_collection(
        key=next(rnkeyg),
        labels=["params", "carry", "dropout"],
    ),
)
[6]:
DeviceArray([[[-0.00180463, -0.00309571, -0.00328913],
              [-0.00401801, -0.00492572, -0.0054962 ],
              [-0.00640121, -0.00592831, -0.0069736 ],
              [-0.00915737, -0.00677909, -0.00839544]]], dtype=float32)