flax_extra.batch module

Batch types and data processing functions.

The batch, inputs, and targets in Training API usually represented using unnormalized form for user convenience.

During construction of TrainLoop data get normalized.

In that form, a batch is a tuple, ((x,…), (y,…)), that consist of:

  • inputs, (x,…), a tuple of a single array or multiple arrays.

    Inputs get passed to model’s apply(x,…) function as arguments.

  • targets, (y,…), a tuple of arbitrary size or an empty tuple.

    Targets along with model outputs, (o,…), get passed to a loss(o,…,y,…) function as arguments.

flax_extra.batch.normalize_batch(batch: Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]]) tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]][source]

Converts a batch to its normalized form.

Parameters

batch – a batch to normalize.

Returns

a normalized batch.

flax_extra.batch.normalize_batch_per_device(batch: Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], n_devices: int) tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]][source]

Converts a batch to the normalized form splitting head axis of inputs and targets evenly across the number of devices.

Parameters

batch – a batch to normalize.

Returns

a normalized batch of shared arrays.