Source code for flax_extra.checkpoint._checkpoint_file

r"""The checkpoint file format."""

import typing
from typing import Any
from dataclasses import field
from flax.core.frozen_dict import FrozenDict
from flax.struct import PyTreeNode, TNode
import optax
from flax_extra.checkpoint._summary import Metrics


FrozenVars = FrozenDict[Any, Any]


# pylint: disable=too-few-public-methods
[docs]class CheckpointFile(PyTreeNode): r"""The file format to store checkpoint on the local file system.""" model_params: FrozenVars r"""parameters of the model at the checkpoint.""" model_state: FrozenVars r"""a state of the model at the checkpoint.""" optimizer_state: optax.OptState r"""a state of the optimizer at the checkpoint.""" step: int r"""a step number this checkpoint was occured.""" metrics: Metrics = field(default_factory=dict) r"""metrics at the current checkpoint.""" version: int = 0 r"""the file format version.""" if typing.TYPE_CHECKING: def __init__(self, *args: Any, **kwargs: Any) -> None: # This stub informs a type checker that this method is overridden. super().__init__() def replace(self: TNode, **overrides: Any) -> TNode: # This stub informs a type checker that this method is overridden. pass