{ "cells": [ { "cell_type": "markdown", "source": [ "# RNN LM with Flax and Redex combinators" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "[Redex](https://github.com/manifest/redex) library makes possible designing [Flax](https://github.com/google/flax) layers and models using combinators." ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "import logging\n", "import jax\n", "from jax import numpy as jnp\n", "from flax import linen as nn\n", "from flax.linen.recurrent import Array\n", "from flax_extra import combinator as cb\n", "from flax_extra import operator as op\n", "from flax_extra import random" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "Combinators provide a simple and concise way of composing functional code. They may compose, be used by, be mixed with Flax linen modules, other combinators, or standard python functions." ], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "LSTMState = tuple[Array, Array]\n", "\n", "class LSTMCell(nn.LSTMCell):\n", " def __call__(self, carry: LSTMState, inputs: Array) -> tuple[LSTMState, Array]:\n", " return super().__call__(carry, inputs)\n", "\n", "class LSTM(nn.Module):\n", " d_hidden: int\n", "\n", " @nn.compact\n", " def __call__(self, inputs: Array) -> Array:\n", " return cb.serial(\n", " cb.branch(self.initial_state, cb.identity(n_in=1)),\n", " nn.scan(\n", " LSTMCell,\n", " variable_broadcast=\"params\",\n", " split_rngs={\"params\": False},\n", " in_axes=1,\n", " out_axes=1,\n", " )(),\n", " cb.drop(n_in=2),\n", " )(inputs)\n", "\n", " def initial_state(self, inputs: Array) -> LSTMState:\n", " batch_size = inputs.shape[0]\n", " return nn.LSTMCell.initialize_carry(\n", " self.make_rng(\"carry\"),\n", " (batch_size,),\n", " self.d_hidden,\n", " )\n", "\n", "class RNNLM(nn.Module):\n", " vocab_size: int\n", " d_model: int = 512\n", " n_layers: int = 2\n", " dropout_rate: float = 0.1\n", " deterministic: bool = True\n", "\n", " @nn.compact\n", " def __call__(self, inputs: Array) -> Array:\n", " return cb.serial(\n", " op.ShiftRight(axis=1),\n", " nn.Embed(\n", " num_embeddings=self.vocab_size,\n", " features=self.d_model,\n", " embedding_init=nn.initializers.normal(stddev=1.0),\n", " ),\n", " [LSTM(d_hidden=self.d_model) for _ in range(self.n_layers)],\n", " nn.Dropout(\n", " rate=self.dropout_rate,\n", " deterministic=self.deterministic,\n", " ),\n", " nn.Dense(features=self.vocab_size),\n", " )(inputs)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "MAX_LENGTH = 4\n", "VOCAB_SIZE = 3\n", "BATCH_SIZE = 1\n", "D_MODEL = 2\n", "\n", "model = RNNLM(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=D_MODEL,\n", ")\n", "\n", "model_init = jax.jit(model.init)\n", "model_apply = jax.jit(model.apply)\n", "\n", "rnkeyg = random.sequence(seed=1356)\n", "sample = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=int)" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "With debug logging enabled, we can inspect how data flow within combinators." ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "logging.getLogger().setLevel(\"DEBUG\")\n", "initial_variables = model_init(\n", " rngs=random.into_collection(\n", " key=next(rnkeyg),\n", " labels=[\"params\", \"carry\", \"dropout\"],\n", " ),\n", " inputs=sample,\n", ")\n", "logging.getLogger().setLevel(\"INFO\")" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "DEBUG:absl:Compiling _threefry_split (7090507712) for args (ShapedArray(uint32[2]),).\n", "DEBUG:absl:Compiling _threefry_split (7090529152) for args (ShapedArray(uint32[2]),).\n", "DEBUG:root:constrained_call :: ShiftRight stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Embed stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: LSTM stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Serial stack_size=1 signature=Signature(n_in=1, n_out=3, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Select stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Parallel stack_size=2 signature=Signature(n_in=2, n_out=3, start_index=0, in_shape=((), ()))\n", "DEBUG:root:constrained_call :: initial_state stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Identity stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=1, in_shape=((),))\n", "DEBUG:root:constrained_call :: ScanLSTMCell stack_size=3 signature=Signature(n_in=3, n_out=3, start_index=0, in_shape=(((), ()), ()))\n", "DEBUG:root:constrained_call :: Drop stack_size=3 signature=Signature(n_in=2, n_out=0, start_index=0, in_shape=((), ()))\n", "DEBUG:root:constrained_call :: LSTM stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Serial stack_size=1 signature=Signature(n_in=1, n_out=3, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Select stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Parallel stack_size=2 signature=Signature(n_in=2, n_out=3, start_index=0, in_shape=((), ()))\n", "DEBUG:root:constrained_call :: initial_state stack_size=1 signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Identity stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=1, in_shape=((),))\n", "DEBUG:root:constrained_call :: ScanLSTMCell stack_size=3 signature=Signature(n_in=3, n_out=3, start_index=0, in_shape=(((), ()), ()))\n", "DEBUG:root:constrained_call :: Drop stack_size=3 signature=Signature(n_in=2, n_out=0, start_index=0, in_shape=((), ()))\n", "DEBUG:root:constrained_call :: Dropout stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))\n", "DEBUG:root:constrained_call :: Dense stack_size=1 signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))\n", "DEBUG:absl:Compiling init (7090507328) for args (ShapedArray(int32[1,4]), ShapedArray(uint32[2]), ShapedArray(uint32[2]), ShapedArray(uint32[2])).\n" ] } ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "model_apply(\n", " initial_variables,\n", " inputs=sample,\n", " rngs=random.into_collection(\n", " key=next(rnkeyg),\n", " labels=[\"params\", \"carry\", \"dropout\"],\n", " ),\n", ")" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DeviceArray([[[-0.00180463, -0.00309571, -0.00328913],\n", " [-0.00401801, -0.00492572, -0.0054962 ],\n", " [-0.00640121, -0.00592831, -0.0069736 ],\n", " [-0.00915737, -0.00677909, -0.00839544]]], dtype=float32)" ] }, "metadata": {}, "execution_count": 6 } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.7", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.7 64-bit ('env': venv)" }, "interpreter": { "hash": "8b02d22ec999df61aa0506c585da9f18dfcc9eaff18f3e42f4c27a57de3e3420" } }, "nbformat": 4, "nbformat_minor": 2 }