flax_extra.checkpoint package

The checkpoint representation and file operations.

class flax_extra.checkpoint.Checkpoint(model_params: flax.core.frozen_dict.FrozenDict[Any, Any], model_state: flax.core.frozen_dict.FrozenDict[Any, Any], optimizer_state: NamedTuple, grads: flax.core.frozen_dict.FrozenDict[Any, Any], loss: float, n_completed_steps: int, elapsed_time: float, step: int)[source]

Bases: object

A representation of the training state at particular recurrent step.

elapsed_time: float

an elapsed time between previous checkpoint and this one.

grads: flax.core.frozen_dict.FrozenDict[Any, Any]

gradients at the checkpoint.

loss: float

a loss at the checkpoint.

model_params: flax.core.frozen_dict.FrozenDict[Any, Any]

parameters of the model at the checkpoint.

model_state: flax.core.frozen_dict.FrozenDict[Any, Any]

a state of the model at the checkpoint.

n_completed_steps: int

a number of completed steps between previous checkpoint and this one.

optimizer_state: NamedTuple

a state of the optimizer at the checkpoint.

step: int

a step number this checkpoint was occured.

class flax_extra.checkpoint.CheckpointFile(model_params: flax.core.frozen_dict.FrozenDict[Any, Any], model_state: flax.core.frozen_dict.FrozenDict[Any, Any], optimizer_state: NamedTuple, step: int, metrics: Mapping[str, float] = <factory>, version: int = 0)[source]

Bases: flax.struct.PyTreeNode

The file format to store checkpoint on the local file system.

metrics: Mapping[str, float]

metrics at the current checkpoint.

model_params: flax.core.frozen_dict.FrozenDict[Any, Any]

parameters of the model at the checkpoint.

model_state: flax.core.frozen_dict.FrozenDict[Any, Any]

a state of the model at the checkpoint.

optimizer_state: NamedTuple

a state of the optimizer at the checkpoint.

replace(**updates)

“Returns a new object replacing the specified fields with new values.

step: int

a step number this checkpoint was occured.

version: int = 0

the file format version.

class flax_extra.checkpoint.CheckpointFileReader(dir: str, target: Callable[[...], Any])[source]

Bases: object

A reader for the checkpoint file format.

dir: str

a directory path for checkpoint files.

read(initializer: Optional[flax_extra.checkpoint._checkpoint_file.CheckpointFile], prefix: Optional[str] = None) Union[flax_extra.checkpoint._checkpoint_file.CheckpointFile, Mapping[str, Any]][source]

Reads the latest checkpoint from the file system.

If the checkpoint file doesn’t exist, given initial checkpoint will be returned.

If an initial checkpoint isn’t given, the loaded checkpoint will of the Mapping[str, Any] type.

Parameters

initializer – an initial checkpoint at step 0. It is required to restore type information.

Returns

either the loaded checkpoint file or initial checkpoint.

Raises

TypeError – if the initializer is not of the Checkpoint type.

target: Callable[[...], Any]

init function of the model related to the checkpoint. It is used to restore type information.

class flax_extra.checkpoint.CheckpointFileWriter(output_dir: str, keep: int = 1, overwrite: bool = True)[source]

Bases: object

A writer for the checkpoint file format.

keep: int = 1

a number of checkpoints to keep.

output_dir: str

a directory path for checkpoint files.

overwrite: bool = True

whether to override previous checkpoints.

write(checkpoint: Union[flax_extra.checkpoint._checkpoint.Checkpoint, flax_extra.checkpoint._summary.Summary], prefix: Optional[str] = None) Union[flax_extra.checkpoint._checkpoint.Checkpoint, flax_extra.checkpoint._summary.Summary][source]

Writes a checkpoint to the file system.

Parameters
  • checkpoint – a checkpoint to write.

  • prefix – a string that will be added to the output file name.

Returns

an original checkpoint.

Raises

TypeError – if the checkpoint is neither of Checkpoint nor Summary type.

class flax_extra.checkpoint.HighestCheckpointFileReader(dir: str, target: Callable[[...], Any], metric: str, group: str = 'eval')[source]

Bases: flax_extra.checkpoint._checkpoint_file_reader.CheckpointFileReader

A reader of a highest metric checkpoint file.

group: str = 'eval'

a group label.

metric: str

a metric label.

class flax_extra.checkpoint.HighestCheckpointFileWriter(metric: str, group: str = 'eval', stdout: bool = True, **kwds: Any)[source]

Bases: flax_extra.checkpoint._best_checkpoint_file_writer.BestCheckpointFileWriter

A checkpoint writer that writes a checkpoint each time the highest metric value is observed.

default_metric_value() float[source]

an initial value for the best metric.

is_best(metric: float) bool[source]

Determines whether the metric value is the best.

Parameters

metric – a new value to compare the current best value with.

Returns

True if a new value is the best value, False otherwise.

output_dir: str

a directory path for checkpoint files.

class flax_extra.checkpoint.LowestCheckpointFileReader(dir: str, target: Callable[[...], Any], metric: str, group: str = 'eval')[source]

Bases: flax_extra.checkpoint._checkpoint_file_reader.CheckpointFileReader

A reader of a lowest metric checkpoint file.

group: str = 'eval'

a group label.

metric: str

a metric label.

class flax_extra.checkpoint.LowestCheckpointFileWriter(metric: str, group: str = 'eval', stdout: bool = True, **kwds: Any)[source]

Bases: flax_extra.checkpoint._best_checkpoint_file_writer.BestCheckpointFileWriter

A checkpoint writer that writes a checkpoint each time the lowest metric value is observed.

default_metric_value() float[source]

an initial value for the best metric.

is_best(metric: float) bool[source]

Determines whether the metric value is the best.

Parameters

metric – a new value to compare the current best value with.

Returns

True if a new value is the best value, False otherwise.

output_dir: str

a directory path for checkpoint files.

class flax_extra.checkpoint.Summary(model_params: flax.core.frozen_dict.FrozenDict[Any, Any], model_state: flax.core.frozen_dict.FrozenDict[Any, Any], optimizer_state: NamedTuple, grads: flax.core.frozen_dict.FrozenDict[Any, Any], loss: float, n_completed_steps: int, elapsed_time: float, step: int, metrics: Mapping[str, Mapping[str, float]])[source]

Bases: flax_extra.checkpoint._checkpoint.Checkpoint

A checkpoint with related metrics.

metrics: Mapping[str, Mapping[str, float]]
class flax_extra.checkpoint.SummaryLogger[source]

Bases: object

The writer prints summaries to stdout.

class flax_extra.checkpoint.SummaryWriter(output_dir: str)[source]

Bases: object

The writer persists summaries in Tensorboard file format to the specified directory on the local file system.

output_dir: str