RNN LM with Flax and Redex combinatorsΒΆ
Redex library makes possible designing Flax layers and models using combinators.
[2]:
import logging
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.recurrent import Array
from flax_extra import combinator as cb
from flax_extra import operator as op
from flax_extra import random
Combinators provide a simple and concise way of composing functional code. They may compose, be used by, be mixed with Flax linen modules, other combinators, or standard python functions.
[3]:
LSTMState = tuple[Array, Array]
class LSTMCell(nn.LSTMCell):
def __call__(self, carry: LSTMState, inputs: Array) -> tuple[LSTMState, Array]:
return super().__call__(carry, inputs)
class LSTM(nn.Module):
d_hidden: int
@nn.compact
def __call__(self, inputs: Array) -> Array:
return cb.serial(
cb.branch(self.initial_state, cb.identity(n_in=1)),
nn.scan(
LSTMCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=1,
out_axes=1,
)(),
cb.drop(n_in=2),
)(inputs)
def initial_state(self, inputs: Array) -> LSTMState:
batch_size = inputs.shape[0]
return nn.LSTMCell.initialize_carry(
self.make_rng("carry"),
(batch_size,),
self.d_hidden,
)
class RNNLM(nn.Module):
vocab_size: int
d_model: int = 512
n_layers: int = 2
dropout_rate: float = 0.1
deterministic: bool = True
@nn.compact
def __call__(self, inputs: Array) -> Array:
return cb.serial(
op.ShiftRight(axis=1),
nn.Embed(
num_embeddings=self.vocab_size,
features=self.d_model,
embedding_init=nn.initializers.normal(stddev=1.0),
),
[LSTM(d_hidden=self.d_model) for _ in range(self.n_layers)],
nn.Dropout(
rate=self.dropout_rate,
deterministic=self.deterministic,
),
nn.Dense(features=self.vocab_size),
)(inputs)
[4]:
MAX_LENGTH = 4
VOCAB_SIZE = 3
BATCH_SIZE = 1
D_MODEL = 2
model = RNNLM(
vocab_size=VOCAB_SIZE,
d_model=D_MODEL,
)
model_init = jax.jit(model.init)
model_apply = jax.jit(model.apply)
rnkeyg = random.sequence(seed=1356)
sample = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=int)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
With debug logging enabled, we can inspect how data flow within combinators.
[5]:
logging.getLogger().setLevel("DEBUG")
initial_variables = model_init(
rngs=random.into_collection(
key=next(rnkeyg),
labels=["params", "carry", "dropout"],
),
inputs=sample,
)
logging.getLogger().setLevel("INFO")
DEBUG:absl:Compiling _threefry_split (7090507712) for args (ShapedArray(uint32[2]),).
DEBUG:absl:Compiling _threefry_split (7090529152) for args (ShapedArray(uint32[2]),).
DEBUG:root:constrained_call :: ShiftRight stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Embed stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: LSTM stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Serial stack_size=1 signature=Signature(n_in=1, n_out=3, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Select stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Parallel stack_size=2 signature=Signature(n_in=2, n_out=3, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: initial_state stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Identity stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=1, in_shape=((),))
DEBUG:root:constrained_call :: ScanLSTMCell stack_size=3 signature=Signature(n_in=3, n_out=3, start_index=0, in_shape=(((), ()), ()))
DEBUG:root:constrained_call :: Drop stack_size=3 signature=Signature(n_in=2, n_out=0, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: LSTM stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Serial stack_size=1 signature=Signature(n_in=1, n_out=3, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Select stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Parallel stack_size=2 signature=Signature(n_in=2, n_out=3, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: initial_state stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Identity stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=1, in_shape=((),))
DEBUG:root:constrained_call :: ScanLSTMCell stack_size=3 signature=Signature(n_in=3, n_out=3, start_index=0, in_shape=(((), ()), ()))
DEBUG:root:constrained_call :: Drop stack_size=3 signature=Signature(n_in=2, n_out=0, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: Dropout stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Dense stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:absl:Compiling init (7090507328) for args (ShapedArray(int32[1,4]), ShapedArray(uint32[2]), ShapedArray(uint32[2]), ShapedArray(uint32[2])).
[6]:
model_apply(
initial_variables,
inputs=sample,
rngs=random.into_collection(
key=next(rnkeyg),
labels=["params", "carry", "dropout"],
),
)
[6]:
DeviceArray([[[-0.00180463, -0.00309571, -0.00328913],
[-0.00401801, -0.00492572, -0.0054962 ],
[-0.00640121, -0.00592831, -0.0069736 ],
[-0.00915737, -0.00677909, -0.00839544]]], dtype=float32)