Source code for flax_extra.batch

r"""Batch types and data processing functions.

The batch, inputs, and targets in **Training API** usually
represented using unnormalized form for user convenience.

During construction of :class:`TrainLoop` data get normalized.

In that form, a batch is a tuple, `((x,...), (y,...))`, that consist of:

- inputs, `(x,...)`, a tuple of a single array or multiple arrays.
    Inputs get passed to model's `apply(x,...)` function as arguments.
- targets, `(y,...)`, a tuple of arbitrary size or an empty tuple.
    Targets along with model outputs, `(o,...)`, get passed to a
    `loss(o,...,y,...)` function as arguments.
"""

from typing import Union, Generator
from functools import partial
from jax import numpy as jnp
import redex
from flax_extra import util

Array = jnp.ndarray

Inputs = tuple[Array, ...]  # type: ignore
Targets = Inputs
Outputs = Inputs
Batch = tuple[Inputs, Targets]
UnnormalizedInputs = Union[Inputs, Array]
UnnormalizedBatch = Union[Batch, tuple[UnnormalizedInputs, UnnormalizedInputs]]
DataStream = Generator[UnnormalizedBatch, None, None]


[docs]def normalize_batch(batch: UnnormalizedBatch) -> Batch: r"""Converts a batch to its normalized form. Args: batch: a batch to normalize. Returns: a normalized batch. """ return tuple(map(redex.util.expand_to_tuple, batch)) # type: ignore
[docs]def normalize_batch_per_device(batch: UnnormalizedBatch, n_devices: int) -> Batch: r"""Converts a batch to the normalized form splitting head axis of inputs and targets evenly across the number of devices. Args: batch: a batch to normalize. Returns: a normalized batch of shared arrays. """ batch_per_device = partial(util.batch_per_device, n_devices=n_devices) def normalize_items(group: UnnormalizedInputs) -> Inputs: return tuple(map(batch_per_device, redex.util.expand_to_tuple(group))) return tuple(map(normalize_items, batch)) # type: ignore