{ "cells": [ { "cell_type": "markdown", "source": [ "# Model training and evaluation" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "A complete example of model training and evaluation." ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "# Uncomment to emulate an arbitrary number of devices running on CPU.\n", "\n", "# import os\n", "# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "import jax\n", "from jax import numpy as jnp\n", "import optax\n", "from flax_extra import combinator as cb\n", "from flax_extra import random\n", "from flax_extra.training import TrainLoop, TrainTask\n", "from flax_extra.evaluation import EvalLoop, EvalTask\n", "from flax_extra.checkpoint import (\n", " SummaryLogger,\n", " SummaryWriter,\n", " CheckpointFileReader,\n", " CheckpointFileWriter,\n", " LowestCheckpointFileReader,\n", " LowestCheckpointFileWriter,\n", ")\n", "from flax_extra.model import RNNLM" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "MAX_LENGTH = 256\n", "BATCH_SIZE = 32\n", "VOCAB_SIZE = 2 ** 8\n", "D_MODEL = 128\n", "\n", "model = RNNLM\n", "\n", "collections = dict(\n", " init=[\"params\",\"carry\",\"dropout\"],\n", " apply=[\"params\",\"carry\",\"dropout\"],\n", ")\n", "config = dict(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=D_MODEL,\n", " n_layers=2,\n", ")\n", "\n", "def presudo_data_steam(shape, bounds, rnkey):\n", " minval, maxval = bounds\n", " while True:\n", " x = jax.random.uniform(\n", " key=rnkey,\n", " shape=shape,\n", " minval=minval,\n", " maxval=maxval,\n", " ).astype(jnp.int32)\n", " yield x, x\n", "\n", "def categorical_cross_entropy(outputs, targets):\n", " n_categories = outputs.shape[-1]\n", " loss = optax.softmax_cross_entropy(\n", " outputs,\n", " jax.nn.one_hot(targets, n_categories),\n", " )\n", " return jnp.mean(loss)\n", "\n", "rnkeyg = random.sequence(seed=0)\n", "train_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))\n", "eval_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "train_loop = TrainLoop(\n", " # init=model(**config).init,\n", " init=CheckpointFileReader(dir=\"/tmp/checkpoints\", target=model(**config).init),\n", " # init=LowestCheckpointFileReader(dir=\"/tmp/checkpoints\", target=model(**config).init, metric=\"lnpp\"),\n", " task=TrainTask(\n", " apply=model(**config | dict(deterministic=False)).apply,\n", " optimizer=optax.sgd(learning_rate=0.1, momentum=0.9),\n", " loss=categorical_cross_entropy,\n", " data=train_datag,\n", " ),\n", " collections=collections,\n", " mutable_collections=True,\n", " n_steps_per_checkpoint=5,\n", " rnkey=next(rnkeyg),\n", " n_steps=10,\n", ")\n", "\n", "process_checkpoint = cb.serial(\n", " EvalLoop(\n", " task=EvalTask(\n", " apply=model(**config).apply,\n", " metrics=dict(lnpp=categorical_cross_entropy),\n", " data=eval_datag,\n", " ),\n", " collections=collections,\n", " rnkey=next(rnkeyg),\n", " n_steps=1,\n", " ),\n", " SummaryLogger(),\n", " SummaryWriter(output_dir=\"/tmp/tensorboard\"),\n", " CheckpointFileWriter(output_dir=\"/tmp/checkpoints\"),\n", " LowestCheckpointFileWriter(output_dir=\"/tmp/checkpoints\", metric=\"lnpp\"),\n", ")\n", "\n", "for checkpoint in train_loop:\n", " _ = process_checkpoint(checkpoint)" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Total model initialization time is 9.26 seconds.\n", "The lowest value for the metric eval/lnpp is set to inf.\n", "Total number of trainable weights: 328960 ~ 1.2 MB.\n", "\n", "Step 1: Ran 1 train steps in 6.62 seconds\n", "Step 1: train seconds_per_step | 6.61999798\n", "Step 1: train gradients_l2norm | 0.00298012\n", "Step 1: train weights_l2norm | 14.01909637\n", "Step 1: train loss | 5.54786873\n", "Step 1: eval lnpp | 5.54727793\n", "\n", "Step 6: Ran 5 train steps in 7.25 seconds\n", "Step 6: train seconds_per_step | 1.45005803\n", "Step 6: train gradients_l2norm | 0.00297097\n", "Step 6: train weights_l2norm | 14.02048969\n", "Step 6: train loss | 5.54690266\n", "Step 6: eval lnpp | 5.54723597\n", "\n", "Step 10: Ran 4 train steps in 1.33 seconds\n", "Step 10: train seconds_per_step | 0.33285725\n", "Step 10: train gradients_l2norm | 0.00292413\n", "Step 10: train weights_l2norm | 14.02241802\n", "Step 10: train loss | 5.54660368\n", "Step 10: eval lnpp | 5.54718399\n", "\n" ] } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.7", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.7 64-bit ('env': venv)" }, "interpreter": { "hash": "8b02d22ec999df61aa0506c585da9f18dfcc9eaff18f3e42f4c27a57de3e3420" } }, "nbformat": 4, "nbformat_minor": 2 }