Spaces:
Configuration error
Configuration error
| import time | |
| import datetime | |
| import torch | |
| import tqdm | |
| from torch.nn import DataParallel | |
| from lib.config import cfg | |
| class Trainer(object): | |
| def __init__(self, network): | |
| device = torch.device('cuda:{}'.format(cfg.local_rank)) | |
| network = network.to(device) | |
| if cfg.distributed: | |
| network = torch.nn.parallel.DistributedDataParallel( | |
| network, | |
| device_ids=[cfg.local_rank], | |
| output_device=cfg.local_rank | |
| ) | |
| self.network = network | |
| self.local_rank = cfg.local_rank | |
| self.device = device | |
| def reduce_loss_stats(self, loss_stats): | |
| reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()} | |
| return reduced_losses | |
| def to_cuda(self, batch): | |
| for k in batch: | |
| if k == 'meta': | |
| continue | |
| if isinstance(batch[k], tuple) or isinstance(batch[k], list): | |
| batch[k] = [b.to(self.device) for b in batch[k]] | |
| else: | |
| batch[k] = batch[k].to(self.device) | |
| return batch | |
| def train(self, epoch, data_loader, optimizer, recorder): | |
| max_iter = len(data_loader) | |
| self.network.train() | |
| end = time.time() | |
| for iteration, batch in enumerate(data_loader): | |
| data_time = time.time() - end | |
| iteration = iteration + 1 | |
| batch = self.to_cuda(batch) | |
| output, loss, loss_stats, image_stats = self.network(batch) | |
| # training stage: loss; optimizer; scheduler | |
| optimizer.zero_grad() | |
| loss = loss.mean() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_value_(self.network.parameters(), 40) | |
| optimizer.step() | |
| if cfg.local_rank > 0: | |
| continue | |
| # data recording stage: loss_stats, time, image_stats | |
| recorder.step += 1 | |
| loss_stats = self.reduce_loss_stats(loss_stats) | |
| recorder.update_loss_stats(loss_stats) | |
| batch_time = time.time() - end | |
| end = time.time() | |
| recorder.batch_time.update(batch_time) | |
| recorder.data_time.update(data_time) | |
| if iteration % cfg.log_interval == 0 or iteration == (max_iter - 1): | |
| # print training state | |
| eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration) | |
| eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
| lr = optimizer.param_groups[0]['lr'] | |
| memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 | |
| training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}']) | |
| training_state = training_state.format(eta_string, str(recorder), lr, memory) | |
| print(training_state) | |
| if iteration % cfg.record_interval == 0 or iteration == (max_iter - 1): | |
| # record loss_stats and image_dict | |
| recorder.update_image_stats(image_stats) | |
| recorder.record('train') | |
| def val(self, epoch, data_loader, evaluator=None, recorder=None): | |
| self.network.eval() | |
| torch.cuda.empty_cache() | |
| val_loss_stats = {} | |
| data_size = len(data_loader) | |
| for batch in tqdm.tqdm(data_loader): | |
| batch = self.to_cuda(batch) | |
| with torch.no_grad(): | |
| output, loss, loss_stats, image_stats = self.network(batch) | |
| if evaluator is not None: | |
| evaluator.evaluate(output, batch) | |
| loss_stats = self.reduce_loss_stats(loss_stats) | |
| for k, v in loss_stats.items(): | |
| val_loss_stats.setdefault(k, 0) | |
| val_loss_stats[k] += v | |
| loss_state = [] | |
| for k in val_loss_stats.keys(): | |
| val_loss_stats[k] /= data_size | |
| loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k])) | |
| print(loss_state) | |
| if evaluator is not None: | |
| result = evaluator.summarize() | |
| val_loss_stats.update(result) | |
| if recorder: | |
| recorder.record('val', epoch, val_loss_stats, image_stats) | |