Source code for flax_extra.checkpoint._checkpoint
r"""A representation of the training state."""
from typing import Any
from dataclasses import dataclass
from flax.core.frozen_dict import FrozenDict
import optax
FrozenVars = FrozenDict[Any, Any]
# pylint: disable=too-many-instance-attributes
[docs]@dataclass
class Checkpoint:
r"""A representation of the training state at particular recurrent step."""
model_params: FrozenVars
r"""parameters of the model at the checkpoint."""
model_state: FrozenVars
r"""a state of the model at the checkpoint."""
optimizer_state: optax.OptState
r"""a state of the optimizer at the checkpoint."""
grads: FrozenVars
r"""gradients at the checkpoint."""
loss: float
r"""a loss at the checkpoint."""
n_completed_steps: int
r"""a number of completed steps between previous checkpoint and this one."""
elapsed_time: float
r"""an elapsed time between previous checkpoint and this one."""
step: int
r"""a step number this checkpoint was occured."""