Source code for flax_extra.checkpoint._summary_writer
r"""Writing summaries to the Tensorboard."""
from dataclasses import dataclass
import contextlib
from flax.metrics import tensorboard
from flax_extra.checkpoint._summary import Summary
[docs]@dataclass
class SummaryWriter:
r"""The writer persists summaries in Tensorboard file format
to the specified directory on the local file system."""
output_dir: str
def __call__(self, summary: Summary) -> Summary:
with contextlib.closing(tensorboard.SummaryWriter(self.output_dir)) as writer:
step = summary.step
for group_label, group in summary.metrics.items():
for metric_label, metric_value in group.items():
writer.scalar(
f"{group_label}/{metric_label}",
metric_value,
step,
)
return summary