flax_extra.evaluation package

Evaluation API.

class flax_extra.evaluation.EvalLoop(task: flax_extra.evaluation._loop.EvalTask, rnkey: jax._src.prng.PRNGKeyArray, n_steps: int = 1, collections: Optional[Mapping[str, List[str]]] = None, n_devices: Optional[int] = None)[source]

Bases: object

The evaluation loop performs a few steps evaluating a model then returns averaged metrics.

property n_steps: int

a number of steps in the loop.

class flax_extra.evaluation.EvalTask(apply: Callable[[...], Any], metrics: Mapping[str, Callable[[...], float]], data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None])[source]

Bases: object

The evaluation task describes how to evaluate a model.

apply: Callable[[...], Any]

an apply function of the model (a linen module).

data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None]

a data stream of evaluation examples.

An item, yielded by the data stream, must consist of inputs and targets. Each of them may be represented as a single array or multiple arrays. Targets may be represented as an empty tuple.

Following forms are acceptable:

  • (x, y)

  • ((x,…), (y,…))

  • ((x,…), ())

  • etc.

Inputs get passed to model’s apply(x,…) function as arguments.

Targets along with model outputs, (o,…), get passed to a loss(o,…,y,…) function as arguments.

metrics: Mapping[str, Callable[[...], float]]

evaluation metrics with corresponding labels.