Source code for flax_extra.checkpoint._best_checkpoint_file_reader
r"""Reading the best metric checkpoints from the local file system."""
from dataclasses import dataclass
from flax_extra.checkpoint._checkpoint_file import CheckpointFile
from flax_extra.checkpoint._checkpoint_file_reader import CheckpointFileReader
from flax_extra.checkpoint._best_checkpoint_file_writer import (
lowest_checkpoint_prefix,
highest_checkpoint_prefix,
)
[docs]@dataclass
class LowestCheckpointFileReader(CheckpointFileReader):
r"""A reader of a lowest metric checkpoint file."""
metric: str
r"""a metric label."""
group: str = "eval"
r"""a group label."""
def __call__(self, initializer: CheckpointFile) -> CheckpointFile:
r"""Reads the latest checkpoint related to the lowest metric value
from the file system.
Args:
initializer: an initial checkpoint at step 0.
It is required to restore type information.
Returns:
either the loaded checkpoint file or initial checkpoint.
Raises:
TypeError: if the initializer is not of the :class:`Checkpoint` type.
"""
prefix = lowest_checkpoint_prefix(
metric_label=self.metric,
group_label=self.group,
)
return super().read(initializer, prefix=prefix) # type: ignore
[docs]@dataclass
class HighestCheckpointFileReader(CheckpointFileReader):
r"""A reader of a highest metric checkpoint file."""
metric: str
r"""a metric label."""
group: str = "eval"
r"""a group label."""
def __call__(self, initializer: CheckpointFile) -> CheckpointFile:
r"""Reads the latest checkpoint related to the highest metric value
from the file system.
Args:
initializer: an initial checkpoint at step 0.
It is required to restore type information.
Returns:
either the loaded checkpoint file or initial checkpoint.
Raises:
TypeError: if the initializer is not of the :class:`Checkpoint` type.
"""
prefix = highest_checkpoint_prefix(
metric_label=self.metric,
group_label=self.group,
)
return super().read(initializer, prefix=prefix) # type: ignore