Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from typing import NoReturn | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| from pytorch_lightning.utilities import rank_zero_only | |
| class SaveCheckpointsCallback(pl.Callback): | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| checkpoints_dir: str, | |
| save_step_frequency: int, | |
| ): | |
| r"""Callback to save checkpoints every #save_step_frequency steps. | |
| Args: | |
| model: nn.Module | |
| checkpoints_dir: str, directory to save checkpoints | |
| save_step_frequency: int | |
| """ | |
| self.model = model | |
| self.checkpoints_dir = checkpoints_dir | |
| self.save_step_frequency = save_step_frequency | |
| os.makedirs(self.checkpoints_dir, exist_ok=True) | |
| def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: | |
| r"""Save checkpoint.""" | |
| global_step = trainer.global_step | |
| if global_step % self.save_step_frequency == 0: | |
| checkpoint_path = os.path.join( | |
| self.checkpoints_dir, "step={}.pth".format(global_step) | |
| ) | |
| checkpoint = {'step': global_step, 'model': self.model.state_dict()} | |
| torch.save(checkpoint, checkpoint_path) | |
| logging.info("Save checkpoint to {}".format(checkpoint_path)) | |