flax_extra.training package

Training API.

class flax_extra.training.TrainLoop(init: Union[Callable[[...], Any], flax_extra.checkpoint._checkpoint_file_reader.CheckpointFileReader], task: flax_extra.training._loop.TrainTask, rnkey: jax._src.prng.PRNGKeyArray, input_sample: Optional[Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]] = None, n_steps_per_checkpoint: int = 1, n_steps: int = 0, collections: Optional[Mapping[str, List[str]]] = None, mutable_collections: Union[bool, str, Container[str], flax.core.scope.DenyList] = DenyList(deny='intermediates'), devices: Optional[List[jaxlib.xla_extension.Device]] = None, stdout: bool = True)[source]

Bases: object

The training loop updates model parameters and yields checkpoints describing training state at specified steps.

property n_steps: int

the total number of steps in the loop.

property n_steps_per_checkpoint: int

a number of steps between checkpoints.

next_checkpoint() flax_extra.checkpoint._checkpoint.Checkpoint[source]

Runs a number of steps remaining for the next checkpoint.

Returns

a checkpoint.

next_step() flax_extra.checkpoint._checkpoint.Checkpoint[source]

Runs a single step.

Returns

a checkpoint.

run(n_steps: int) Generator[flax_extra.checkpoint._checkpoint.Checkpoint, None, None][source]

Runs an arbitrary number of steps yelding checkpoints.

Parameters

n_steps – a number of steps to ran.

Yields

a checkpoint.

property step: int

the current step.

class flax_extra.training.TrainTask(apply: Callable[[...], Any], optimizer: optax._src.base.GradientTransformation, loss: Callable[[...], float], data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None])[source]

Bases: object

The training task describes how to train a model.

apply: Callable[[...], Any]

an apply function of the model (a linen module).

data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None]

a data stream of training examples.

An item, yielded by the data stream, must consist of inputs and targets. Each of them may be represented as a single array or multiple arrays. Targets may be represented as an empty tuple.

Following forms are acceptable:

  • (x, y)

  • ((x,…), (y,…))

  • ((x,…), ())

  • etc.

Inputs get passed to model’s apply(x,…) function as arguments.

Targets along with model outputs, (o,…), get passed to a loss(o,…,y,…) function as arguments.

loss: Callable[[...], float]

a loss function.

The function must accept as input arguments all model outputs and targets.

optimizer: optax._src.base.GradientTransformation

an optimizer.