flax_extra.training package¶
Training API.
- class flax_extra.training.TrainLoop(init: Union[Callable[[...], Any], flax_extra.checkpoint._checkpoint_file_reader.CheckpointFileReader], task: flax_extra.training._loop.TrainTask, rnkey: jax._src.prng.PRNGKeyArray, input_sample: Optional[Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]] = None, n_steps_per_checkpoint: int = 1, n_steps: int = 0, collections: Optional[Mapping[str, List[str]]] = None, mutable_collections: Union[bool, str, Container[str], flax.core.scope.DenyList] = DenyList(deny='intermediates'), devices: Optional[List[jaxlib.xla_extension.Device]] = None, stdout: bool = True)[source]¶
Bases:
objectThe training loop updates model parameters and yields checkpoints describing training state at specified steps.
- property n_steps: int¶
the total number of steps in the loop.
- property n_steps_per_checkpoint: int¶
a number of steps between checkpoints.
- next_checkpoint() flax_extra.checkpoint._checkpoint.Checkpoint[source]¶
Runs a number of steps remaining for the next checkpoint.
- Returns
a checkpoint.
- next_step() flax_extra.checkpoint._checkpoint.Checkpoint[source]¶
Runs a single step.
- Returns
a checkpoint.
- run(n_steps: int) Generator[flax_extra.checkpoint._checkpoint.Checkpoint, None, None][source]¶
Runs an arbitrary number of steps yelding checkpoints.
- Parameters
n_steps – a number of steps to ran.
- Yields
a checkpoint.
- property step: int¶
the current step.
- class flax_extra.training.TrainTask(apply: Callable[[...], Any], optimizer: optax._src.base.GradientTransformation, loss: Callable[[...], float], data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None])[source]¶
Bases:
objectThe training task describes how to train a model.
- apply: Callable[[...], Any]¶
an apply function of the model (a linen module).
- data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None]¶
a data stream of training examples.
An item, yielded by the data stream, must consist of inputs and targets. Each of them may be represented as a single array or multiple arrays. Targets may be represented as an empty tuple.
Following forms are acceptable:
(x, y)
((x,…), (y,…))
((x,…), ())
etc.
Inputs get passed to model’s apply(x,…) function as arguments.
Targets along with model outputs, (o,…), get passed to a loss(o,…,y,…) function as arguments.
- loss: Callable[[...], float]¶
a loss function.
The function must accept as input arguments all model outputs and targets.
- optimizer: optax._src.base.GradientTransformation¶
an optimizer.