flax_extra.checkpoint package¶
The checkpoint representation and file operations.
- class flax_extra.checkpoint.Checkpoint(model_params: flax.core.frozen_dict.FrozenDict[Any, Any], model_state: flax.core.frozen_dict.FrozenDict[Any, Any], optimizer_state: NamedTuple, grads: flax.core.frozen_dict.FrozenDict[Any, Any], loss: float, n_completed_steps: int, elapsed_time: float, step: int)[source]¶
Bases:
objectA representation of the training state at particular recurrent step.
- elapsed_time: float¶
an elapsed time between previous checkpoint and this one.
- grads: flax.core.frozen_dict.FrozenDict[Any, Any]¶
gradients at the checkpoint.
- loss: float¶
a loss at the checkpoint.
- model_params: flax.core.frozen_dict.FrozenDict[Any, Any]¶
parameters of the model at the checkpoint.
- model_state: flax.core.frozen_dict.FrozenDict[Any, Any]¶
a state of the model at the checkpoint.
- n_completed_steps: int¶
a number of completed steps between previous checkpoint and this one.
- optimizer_state: NamedTuple¶
a state of the optimizer at the checkpoint.
- step: int¶
a step number this checkpoint was occured.
- class flax_extra.checkpoint.CheckpointFile(model_params: flax.core.frozen_dict.FrozenDict[Any, Any], model_state: flax.core.frozen_dict.FrozenDict[Any, Any], optimizer_state: NamedTuple, step: int, metrics: Mapping[str, float] = <factory>, version: int = 0)[source]¶
Bases:
flax.struct.PyTreeNodeThe file format to store checkpoint on the local file system.
- metrics: Mapping[str, float]¶
metrics at the current checkpoint.
- model_params: flax.core.frozen_dict.FrozenDict[Any, Any]¶
parameters of the model at the checkpoint.
- model_state: flax.core.frozen_dict.FrozenDict[Any, Any]¶
a state of the model at the checkpoint.
- optimizer_state: NamedTuple¶
a state of the optimizer at the checkpoint.
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- step: int¶
a step number this checkpoint was occured.
- version: int = 0¶
the file format version.
- class flax_extra.checkpoint.CheckpointFileReader(dir: str, target: Callable[[...], Any])[source]¶
Bases:
objectA reader for the checkpoint file format.
- dir: str¶
a directory path for checkpoint files.
- read(initializer: Optional[flax_extra.checkpoint._checkpoint_file.CheckpointFile], prefix: Optional[str] = None) Union[flax_extra.checkpoint._checkpoint_file.CheckpointFile, Mapping[str, Any]][source]¶
Reads the latest checkpoint from the file system.
If the checkpoint file doesn’t exist, given initial checkpoint will be returned.
If an initial checkpoint isn’t given, the loaded checkpoint will of the Mapping[str, Any] type.
- Parameters
initializer – an initial checkpoint at step 0. It is required to restore type information.
- Returns
either the loaded checkpoint file or initial checkpoint.
- Raises
TypeError – if the initializer is not of the
Checkpointtype.
- target: Callable[[...], Any]¶
init function of the model related to the checkpoint. It is used to restore type information.
- class flax_extra.checkpoint.CheckpointFileWriter(output_dir: str, keep: int = 1, overwrite: bool = True)[source]¶
Bases:
objectA writer for the checkpoint file format.
- keep: int = 1¶
a number of checkpoints to keep.
- output_dir: str¶
a directory path for checkpoint files.
- overwrite: bool = True¶
whether to override previous checkpoints.
- write(checkpoint: Union[flax_extra.checkpoint._checkpoint.Checkpoint, flax_extra.checkpoint._summary.Summary], prefix: Optional[str] = None) Union[flax_extra.checkpoint._checkpoint.Checkpoint, flax_extra.checkpoint._summary.Summary][source]¶
Writes a checkpoint to the file system.
- Parameters
checkpoint – a checkpoint to write.
prefix – a string that will be added to the output file name.
- Returns
an original checkpoint.
- Raises
TypeError – if the checkpoint is neither of
CheckpointnorSummarytype.
- class flax_extra.checkpoint.HighestCheckpointFileReader(dir: str, target: Callable[[...], Any], metric: str, group: str = 'eval')[source]¶
Bases:
flax_extra.checkpoint._checkpoint_file_reader.CheckpointFileReaderA reader of a highest metric checkpoint file.
- group: str = 'eval'¶
a group label.
- metric: str¶
a metric label.
- class flax_extra.checkpoint.HighestCheckpointFileWriter(metric: str, group: str = 'eval', stdout: bool = True, **kwds: Any)[source]¶
Bases:
flax_extra.checkpoint._best_checkpoint_file_writer.BestCheckpointFileWriterA checkpoint writer that writes a checkpoint each time the highest metric value is observed.
- is_best(metric: float) bool[source]¶
Determines whether the metric value is the best.
- Parameters
metric – a new value to compare the current best value with.
- Returns
True if a new value is the best value, False otherwise.
- output_dir: str¶
a directory path for checkpoint files.
- class flax_extra.checkpoint.LowestCheckpointFileReader(dir: str, target: Callable[[...], Any], metric: str, group: str = 'eval')[source]¶
Bases:
flax_extra.checkpoint._checkpoint_file_reader.CheckpointFileReaderA reader of a lowest metric checkpoint file.
- group: str = 'eval'¶
a group label.
- metric: str¶
a metric label.
- class flax_extra.checkpoint.LowestCheckpointFileWriter(metric: str, group: str = 'eval', stdout: bool = True, **kwds: Any)[source]¶
Bases:
flax_extra.checkpoint._best_checkpoint_file_writer.BestCheckpointFileWriterA checkpoint writer that writes a checkpoint each time the lowest metric value is observed.
- is_best(metric: float) bool[source]¶
Determines whether the metric value is the best.
- Parameters
metric – a new value to compare the current best value with.
- Returns
True if a new value is the best value, False otherwise.
- output_dir: str¶
a directory path for checkpoint files.
- class flax_extra.checkpoint.Summary(model_params: flax.core.frozen_dict.FrozenDict[Any, Any], model_state: flax.core.frozen_dict.FrozenDict[Any, Any], optimizer_state: NamedTuple, grads: flax.core.frozen_dict.FrozenDict[Any, Any], loss: float, n_completed_steps: int, elapsed_time: float, step: int, metrics: Mapping[str, Mapping[str, float]])[source]¶
Bases:
flax_extra.checkpoint._checkpoint.CheckpointA checkpoint with related metrics.
- metrics: Mapping[str, Mapping[str, float]]¶