flax_extra.evaluation package¶
Evaluation API.
- class flax_extra.evaluation.EvalLoop(task: flax_extra.evaluation._loop.EvalTask, rnkey: jax._src.prng.PRNGKeyArray, n_steps: int = 1, collections: Optional[Mapping[str, List[str]]] = None, n_devices: Optional[int] = None)[source]¶
Bases:
objectThe evaluation loop performs a few steps evaluating a model then returns averaged metrics.
- property n_steps: int¶
a number of steps in the loop.
- class flax_extra.evaluation.EvalTask(apply: Callable[[...], Any], metrics: Mapping[str, Callable[[...], float]], data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None])[source]¶
Bases:
objectThe evaluation task describes how to evaluate a model.
- apply: Callable[[...], Any]¶
an apply function of the model (a linen module).
- data: Generator[Union[tuple[tuple[jax._src.numpy.lax_numpy.ndarray, ...], tuple[jax._src.numpy.lax_numpy.ndarray, ...]], tuple[typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray], typing.Union[tuple[jax._src.numpy.lax_numpy.ndarray, ...], jax._src.numpy.lax_numpy.ndarray]]], None, None]¶
a data stream of evaluation examples.
An item, yielded by the data stream, must consist of inputs and targets. Each of them may be represented as a single array or multiple arrays. Targets may be represented as an empty tuple.
Following forms are acceptable:
(x, y)
((x,…), (y,…))
((x,…), ())
etc.
Inputs get passed to model’s apply(x,…) function as arguments.
Targets along with model outputs, (o,…), get passed to a loss(o,…,y,…) function as arguments.
- metrics: Mapping[str, Callable[[...], float]]¶
evaluation metrics with corresponding labels.