Model training and evaluationΒΆ
A complete example of model training and evaluation.
[1]:
# Uncomment to emulate an arbitrary number of devices running on CPU.
# import os
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
[2]:
import jax
from jax import numpy as jnp
import optax
from flax_extra import combinator as cb
from flax_extra import random
from flax_extra.training import TrainLoop, TrainTask
from flax_extra.evaluation import EvalLoop, EvalTask
from flax_extra.checkpoint import (
SummaryLogger,
SummaryWriter,
CheckpointFileReader,
CheckpointFileWriter,
LowestCheckpointFileReader,
LowestCheckpointFileWriter,
)
from flax_extra.model import RNNLM
[3]:
MAX_LENGTH = 256
BATCH_SIZE = 32
VOCAB_SIZE = 2 ** 8
D_MODEL = 128
model = RNNLM
collections = dict(
init=["params","carry","dropout"],
apply=["params","carry","dropout"],
)
config = dict(
vocab_size=VOCAB_SIZE,
d_model=D_MODEL,
n_layers=2,
)
def presudo_data_steam(shape, bounds, rnkey):
minval, maxval = bounds
while True:
x = jax.random.uniform(
key=rnkey,
shape=shape,
minval=minval,
maxval=maxval,
).astype(jnp.int32)
yield x, x
def categorical_cross_entropy(outputs, targets):
n_categories = outputs.shape[-1]
loss = optax.softmax_cross_entropy(
outputs,
jax.nn.one_hot(targets, n_categories),
)
return jnp.mean(loss)
rnkeyg = random.sequence(seed=0)
train_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))
eval_datag = presudo_data_steam(shape=(BATCH_SIZE,MAX_LENGTH), bounds=(1,VOCAB_SIZE), rnkey=next(rnkeyg))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
train_loop = TrainLoop(
# init=model(**config).init,
init=CheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init),
# init=LowestCheckpointFileReader(dir="/tmp/checkpoints", target=model(**config).init, metric="lnpp"),
task=TrainTask(
apply=model(**config | dict(deterministic=False)).apply,
optimizer=optax.sgd(learning_rate=0.1, momentum=0.9),
loss=categorical_cross_entropy,
data=train_datag,
),
collections=collections,
mutable_collections=True,
n_steps_per_checkpoint=5,
rnkey=next(rnkeyg),
n_steps=10,
)
process_checkpoint = cb.serial(
EvalLoop(
task=EvalTask(
apply=model(**config).apply,
metrics=dict(lnpp=categorical_cross_entropy),
data=eval_datag,
),
collections=collections,
rnkey=next(rnkeyg),
n_steps=1,
),
SummaryLogger(),
SummaryWriter(output_dir="/tmp/tensorboard"),
CheckpointFileWriter(output_dir="/tmp/checkpoints"),
LowestCheckpointFileWriter(output_dir="/tmp/checkpoints", metric="lnpp"),
)
for checkpoint in train_loop:
_ = process_checkpoint(checkpoint)
Total model initialization time is 9.26 seconds.
The lowest value for the metric eval/lnpp is set to inf.
Total number of trainable weights: 328960 ~ 1.2 MB.
Step 1: Ran 1 train steps in 6.62 seconds
Step 1: train seconds_per_step | 6.61999798
Step 1: train gradients_l2norm | 0.00298012
Step 1: train weights_l2norm | 14.01909637
Step 1: train loss | 5.54786873
Step 1: eval lnpp | 5.54727793
Step 6: Ran 5 train steps in 7.25 seconds
Step 6: train seconds_per_step | 1.45005803
Step 6: train gradients_l2norm | 0.00297097
Step 6: train weights_l2norm | 14.02048969
Step 6: train loss | 5.54690266
Step 6: eval lnpp | 5.54723597
Step 10: Ran 4 train steps in 1.33 seconds
Step 10: train seconds_per_step | 0.33285725
Step 10: train gradients_l2norm | 0.00292413
Step 10: train weights_l2norm | 14.02241802
Step 10: train loss | 5.54660368
Step 10: eval lnpp | 5.54718399