Model training and evaluationΒΆ

A complete example of model training and evaluation.

[1]:
# Uncomment to emulate an arbitrary number of devices running on CPU.

# import os
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
[2]:
import jax
from jax import numpy as jnp
import optax
from flax_extra import combinator as cb
from flax_extra import random
from flax_extra.training import TrainLoop, TrainTask
from flax_extra.evaluation import EvalLoop, EvalTask
from flax_extra.checkpoint import (
    SummaryLogger,
    SummaryWriter,
    CheckpointFileReader,
    CheckpointFileWriter,
    LowestCheckpointFileReader,
    LowestCheckpointFileWriter,
)
from flax_extra.model import RNNLM
[3]:
MAX_LENGTH = 256
BATCH_SIZE = 32
VOCAB_SIZE = 2 ** 8
D_MODEL = 128

model = RNNLM

collections = dict(
    init=["params","carry","dropout"],
    apply=["params","carry","dropout"],
)
config = dict(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_layers=2,
)

def presudo_data_steam(shape, bounds, rnkey):
    minval, maxval = bounds
    while True:
        x = jax.random.uniform(
            key=rnkey,
            shape=shape,
            minval=minval,
            maxval=maxval,
        ).astype(jnp.int32)
        yield x, x

def categorical_cross_entropy(outputs, targets):
    n_categories = outputs.shape[-1]
    loss = optax.softmax_cross_entropy(
        outputs,
        jax.nn.one_hot(targets, n_categories),
    )
    return jnp.mean(loss)

rnkeyg = random.sequence(seed=0)
train_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))
eval_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
train_loop = TrainLoop(
    # init=model(**config).init,
    init=CheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init),
    # init=LowestCheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init, metric="lnpp"),
    task=TrainTask(
        apply=model(**config | dict(deterministic=False)).apply,
        optimizer=optax.sgd(learning_rate=0.1, momentum=0.9),
        loss=categorical_cross_entropy,
        data=train_datag,
    ),
    collections=collections,
    mutable_collections=True,
    n_steps_per_checkpoint=5,
    rnkey=next(rnkeyg),
    n_steps=10,
)

process_checkpoint = cb.serial(
    EvalLoop(
        task=EvalTask(
            apply=model(**config).apply,
            metrics=dict(lnpp=categorical_cross_entropy),
            data=eval_datag,
        ),
        collections=collections,
        rnkey=next(rnkeyg),
        n_steps=1,
    ),
    SummaryLogger(),
    SummaryWriter(output_dir="/tmp/tensorboard"),
    CheckpointFileWriter(output_dir="/tmp/checkpoints"),
    LowestCheckpointFileWriter(output_dir="/tmp/checkpoints", metric="lnpp"),
)

for checkpoint in train_loop:
    _ = process_checkpoint(checkpoint)
Total model initialization time is 9.26 seconds.
The lowest value for the metric eval/lnpp is set to inf.
Total number of trainable weights: 328960 ~ 1.2 MB.

Step      1: Ran 1 train steps in 6.62 seconds
Step      1: train seconds_per_step | 6.61999798
Step      1: train gradients_l2norm | 0.00298012
Step      1: train   weights_l2norm | 14.01909637
Step      1: train             loss | 5.54786873
Step      1: eval              lnpp | 5.54727793

Step      6: Ran 5 train steps in 7.25 seconds
Step      6: train seconds_per_step | 1.45005803
Step      6: train gradients_l2norm | 0.00297097
Step      6: train   weights_l2norm | 14.02048969
Step      6: train             loss | 5.54690266
Step      6: eval              lnpp | 5.54723597

Step     10: Ran 4 train steps in 1.33 seconds
Step     10: train seconds_per_step | 0.33285725
Step     10: train gradients_l2norm | 0.00292413
Step     10: train   weights_l2norm | 14.02241802
Step     10: train             loss | 5.54660368
Step     10: eval              lnpp | 5.54718399