Source code for flax_extra.checkpoint._best_checkpoint_file_writer

r"""Writing the best metric checkpoints to the local file system."""

import abc
from typing import Any, Optional
from flax.training import checkpoints
from flax_extra import console
from flax_extra.checkpoint._summary import Summary
from flax_extra.checkpoint._checkpoint_file_writer import CheckpointFileWriter


class BestCheckpointFileWriter(CheckpointFileWriter, metaclass=abc.ABCMeta):
    r"""A base class for the best metric checkpoint writers."""

    @property
    def best_metric(self) -> float:
        r"""the current value for the best metric."""
        return self._best_metric

    @abc.abstractmethod
    def default_metric_value(self) -> float:
        r"""an initial value for the best metric."""

    @abc.abstractmethod
    def is_best(self, metric: float) -> bool:
        r"""Determines whether the metric value is the best.

        Args:
            metric: a new value to compare the current best value with.

        Returns:
            `True` if a new value is the best value, `False` otherwise.
        """

    def __init__(
        self,
        metric: str,
        group: str,
        prefix: str,
        **kwds: Any,
    ) -> None:
        r"""Initializes the writer.

        Args:
            metric: a metric label.
            group: a group label (such as "train" or "eval").
            prefix: a string that will be added to the output file name.
        """
        super().__init__(**kwds)

        self._group_label = group
        self._metric_label = metric
        self._prefix = prefix
        self._best_metric = (
            _load_best_metric(
                self._group_label,
                self._metric_label,
                self.output_dir,
                self._prefix,
            )
            or self.default_metric_value()
        )

    def __call__(self, summary: Summary) -> Summary:  # type: ignore
        r"""Keeps track of the best metric value and writes a checkpoint
        to the local file system when the value get updated.

        Args:
            summary: a checkpoint that also incudes metrics.

        Returns:
            an original summary.

        Raises:
            TypeError: if the checkpoint is not of the :class:`Summary` type.
        """
        if not isinstance(summary, Summary):
            raise TypeError(
                f"Cannot write a checkpoint to `{self.output_dir}`. "
                "Expecting checkpoint of type `Summary`, "
                f"but have got `{type(summary)}`."
            )

        metric = _read_best_metric(summary, self._group_label, self._metric_label)
        if self.is_best(metric):
            self._best_metric = metric
            _ = super().write(summary, self._prefix)
            return summary

        return summary


[docs]class LowestCheckpointFileWriter(BestCheckpointFileWriter): r"""A checkpoint writer that writes a checkpoint each time the lowest metric value is observed.""" def __init__( self, metric: str, group: str = "eval", stdout: bool = True, **kwds: Any, ) -> None: r"""Initializes the lowest metric checkpoint writer. Args: metric: a metric label. group: a group label. stdout: whether to write to stdout. """ prefix = lowest_checkpoint_prefix(metric_label=metric, group_label=group) super().__init__(metric, group, prefix=prefix, **kwds) console.log( f"The lowest value for the metric {self._group_label}/{self._metric_label} " f"is set to {self._best_metric:.8f}.", stdout=stdout, )
[docs] def default_metric_value(self) -> float: return float("+inf")
[docs] def is_best(self, metric: float) -> bool: return min(self._best_metric, metric) == metric
[docs]class HighestCheckpointFileWriter(BestCheckpointFileWriter): r"""A checkpoint writer that writes a checkpoint each time the highest metric value is observed.""" def __init__( self, metric: str, group: str = "eval", stdout: bool = True, **kwds: Any, ): r"""Initializes the highest metric checkpoint writer. Args: metric: a metric label. group: a group label. stdout: whether to write to stdout. """ prefix = lowest_checkpoint_prefix(metric_label=metric, group_label=group) super().__init__(metric, group, prefix=prefix, **kwds) console.log( f"The highest value for the metric {self._group_label}/{self._metric_label} " f"is set to {self._best_metric:.8f}.", stdout=stdout, )
[docs] def default_metric_value(self) -> float: return float("-inf")
[docs] def is_best(self, metric: float) -> bool: return max(self._best_metric, metric) == metric
def lowest_checkpoint_prefix(group_label: str, metric_label: str) -> str: r"""Returns a prefix for a lowest metric checkpoint.""" return f"lowest_{group_label}_{metric_label}_" def highest_checkpoint_prefix(group_label: str, metric_label: str) -> str: r"""Returns a prefix for a highest metric checkpoint.""" return f"highest_{group_label}_{metric_label}_" def _load_best_metric( group_label: str, metric_label: str, checkpoint_dir: str, prefix: str, ) -> Optional[float]: metric = None checkpoint_dict = checkpoints.restore_checkpoint( checkpoint_dir, target=None, prefix=prefix, ) if checkpoint_dict: group = checkpoint_dict.get("metrics").get(group_label) if group is None: raise ValueError( f"Cannot load the metric {group_label}/{metric_label} " f"from `{checkpoint_dir}`, because restored checkpoint " "does not contain the metric's group." ) metric = group.get(metric_label) if metric is None: raise ValueError( f"Cannot load the metric {group_label}/{metric_label} " f"from `{checkpoint_dir}`, because restored checkpoint " "does not contain the metric." ) return metric def _read_best_metric( summary: Summary, group_label: str, metric_label: str, ) -> float: group = summary.metrics.get(group_label) if group is None: raise ValueError( f"Cannot read the metric {group_label}/{metric_label} " "from a summary, because it does not contain metric's group." ) metric = group.get(metric_label) if metric is None: raise ValueError( f"Cannot read the metric {group_label}/{metric_label} " "from a summary, because it does not contain the metric." ) return metric