Spaces:
Runtime error
Runtime error
| import os | |
| import pytorch_lightning as pl | |
| import torch | |
| from typing import Any | |
| from modules.params.diffusion.inference_params import InferenceParams | |
| from modules.loader.module_loader import GenericModuleLoader | |
| from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams | |
| class AbstractTrainer(pl.LightningModule): | |
| def __init__(self, | |
| inference_params: Any, | |
| diff_trainer_params: DiffusionTrainerParams, | |
| module_loader: GenericModuleLoader, | |
| ): | |
| super().__init__() | |
| self.inference_params = inference_params | |
| self.diff_trainer_params = diff_trainer_params | |
| self.module_loader = module_loader | |
| self.on_start_once_called = False | |
| self._setup_methods = [] | |
| module_loader( | |
| trainer=self, | |
| diff_trainer_params=diff_trainer_params) | |
| # ------ IMPLEMENTATION HOOKS ------- | |
| def post_init(self, batch): | |
| ''' | |
| Is called after LightningDataModule and LightningModule is created, but before any training/validation/prediction. | |
| First possible access to the 'trainer' object (e.g. to get 'device'). | |
| ''' | |
| def generate_output(self, batch, batch_idx, inference_params: InferenceParams): | |
| ''' | |
| Is called during validation to generate for each batch an output. | |
| Return the meta information about produced result (where result were stored). | |
| This is used for the metric evaluation. | |
| ''' | |
| # ------- HELPER FUNCTIONS ------- | |
| def _reset_random_generator(self): | |
| ''' | |
| Reset the random generator to the same seed across all workers. The generator is used only for inference. | |
| ''' | |
| if not hasattr(self, "random_generator"): | |
| self.random_generator = torch.Generator(device=self.device) | |
| # set seed according to 'seed_everything' in config | |
| seed = int(os.environ.get("PL_GLOBAL_SEED", 42)) | |
| else: | |
| seed = self.random_generator.initial_seed() | |
| self.random_generator.manual_seed(seed) | |
| # ----- PREDICT HOOKS ------ | |
| def on_predict_start(self): | |
| self.on_start() | |
| def predict_step(self, batch, batch_idx): | |
| self.on_inference_step(batch=batch, batch_idx=batch_idx) | |
| def on_predict_epoch_start(self): | |
| self.on_inference_epoch_start() | |
| # ----- CUSTOM HOOKS ----- | |
| # Global Hooks (Called by Training, Validation and Prediction) | |
| # abstract method | |
| def _on_start_once(self): | |
| ''' | |
| Will be called only once by on_start. Thus, it will be called by the first call of train,validation or prediction. | |
| ''' | |
| if self.on_start_once_called: | |
| return | |
| else: | |
| self.on_start_once_called = True | |
| self.post_init() | |
| def on_start(self): | |
| ''' | |
| Called at the beginning of training, validation and prediction. | |
| ''' | |
| self._on_start_once() | |
| # Inference Hooks (Called by Validation and Prediction) | |
| # ----- Inference Hooks (called by 'validation' and 'predict') ------ | |
| def on_inference_epoch_start(self): | |
| # reset seed at every inference | |
| self._reset_random_generator() | |
| def on_inference_step(self, batch, batch_idx): | |
| if self.inference_params.reset_seed_per_generation: | |
| self._reset_random_generator() | |
| self.generate_output( | |
| batch=batch, inference_params=self.inference_params, batch_idx=batch_idx) | |