Spaces:
Runtime error
Runtime error
| import argparse | |
| import datetime | |
| import json | |
| import random | |
| import time | |
| from pathlib import Path | |
| import os, sys | |
| from util.get_param_dicts import get_param_dict | |
| from util.logger import setup_logger | |
| import numpy as np | |
| import torch | |
| import util.misc as utils | |
| from detrsmpl.data.datasets import build_dataloader | |
| from mmcv.parallel import MMDistributedDataParallel | |
| from engine import evaluate, train_one_epoch, inference | |
| from util.config import DictAction | |
| from util.utils import ModelEma | |
| import shutil | |
| import torchvision.transforms as transforms | |
| from torch.utils.tensorboard import SummaryWriter | |
| import config.config as cfg | |
| from datasets.dataset import MultipleDatasets | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser('Set transformer detector', | |
| add_help=False) | |
| parser.add_argument('--config_file', '-c', type=str, required=True) | |
| parser.add_argument( | |
| '--options', | |
| nargs='+', | |
| action=DictAction, | |
| help='override some settings in the used config, the key-value pair ' | |
| 'in xxx=yyy format will be merged into config file.') | |
| # parser.add_argument('--exp_name', default='data/log/smplx_test', type=str) | |
| # dataset parameters | |
| # training parameters | |
| parser.add_argument('--output_dir', | |
| default='', | |
| help='path where to save, empty for no saving') | |
| parser.add_argument('--device', | |
| default='cuda', | |
| help='device to use for training / testing') | |
| parser.add_argument('--seed', default=42, type=int) | |
| parser.add_argument('--resume', default='', help='resume from checkpoint') | |
| parser.add_argument('--pretrain_model_path', | |
| help='load from other checkpoint') | |
| parser.add_argument('--finetune_ignore', type=str, nargs='+') | |
| parser.add_argument('--start_epoch', | |
| default=0, | |
| type=int, | |
| metavar='N', | |
| help='start epoch') | |
| parser.add_argument('--eval', action='store_true') | |
| parser.add_argument('--num_workers', default=0, type=int) | |
| parser.add_argument('--test', action='store_true') | |
| parser.add_argument('--debug', action='store_true') | |
| parser.add_argument('--find_unused_params', action='store_true') | |
| parser.add_argument('--save_log', action='store_true') | |
| parser.add_argument('--to_vid', action='store_true') | |
| parser.add_argument('--inference', action='store_true') | |
| # distributed training parameters | |
| parser.add_argument('--world_size', default=1, type=int, | |
| help='number of distributed processes') | |
| parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') | |
| parser.add_argument('--rank', default=0, type=int, | |
| help='number of distributed processes') | |
| parser.add_argument("--local_rank", default=0, type=int, help='local rank for DistributedDataParallel') | |
| parser.add_argument('--amp', action='store_true', | |
| help="Train with mixed precision") | |
| parser.add_argument('--inference_input', default=None, type=str) | |
| return parser | |
| def build_model_main(args, cfg): | |
| print(args.modelname) | |
| from models.registry import MODULE_BUILD_FUNCS | |
| assert args.modelname in MODULE_BUILD_FUNCS._module_dict | |
| build_func = MODULE_BUILD_FUNCS.get(args.modelname) | |
| model, criterion, postprocessors, postprocessors_aios = build_func( | |
| args, cfg) | |
| return model, criterion, postprocessors, postprocessors_aios | |
| def main(args): | |
| utils.init_distributed_mode(args) | |
| print('Loading config file from {}'.format(args.config_file)) | |
| shutil.copy2(args.config_file,'config/aios_smplx.py') | |
| from config.config import cfg | |
| if args.options is not None: | |
| cfg.merge_from_dict(args.options) | |
| if args.rank == 0: | |
| save_cfg_path = os.path.join(args.output_dir, 'config_cfg.py') | |
| cfg.dump(save_cfg_path) | |
| save_json_path = os.path.join(args.output_dir, 'config_args_raw.json') | |
| with open(save_json_path, 'w') as f: | |
| json.dump(vars(args), f, indent=2) | |
| cfg_dict = cfg._cfg_dict.to_dict() | |
| args_vars = vars(args) | |
| for k, v in cfg_dict.items(): | |
| if k not in args_vars: | |
| setattr(args, k, v) | |
| else: | |
| continue | |
| raise ValueError('Key {} can used by args only'.format(k)) | |
| # update some new args temporally | |
| if not getattr(args, 'use_ema', None): | |
| args.use_ema = False | |
| if not getattr(args, 'debug', None): | |
| args.debug = False | |
| # setup logger | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| logger = setup_logger(output=os.path.join(args.output_dir, 'info.txt'), | |
| distributed_rank=args.rank, | |
| color=False, | |
| name='detr') | |
| logger.info('git:\n {}\n'.format(utils.get_sha())) | |
| logger.info('Command: ' + ' '.join(sys.argv)) | |
| writer = None | |
| if args.rank == 0: | |
| writer = SummaryWriter(args.output_dir) | |
| save_json_path = os.path.join(args.output_dir, 'config_args_all.json') | |
| # print("args:", vars(args)) | |
| with open(save_json_path, 'w') as f: | |
| json.dump(vars(args), f, indent=2) | |
| logger.info('Full config saved to {}'.format(save_json_path)) | |
| logger.info('world size: {}'.format(args.world_size)) | |
| logger.info('rank: {}'.format(args.rank)) | |
| logger.info('local_rank: {}'.format(args.local_rank)) | |
| logger.info('args: ' + str(args) + '\n') | |
| if args.frozen_weights is not None: | |
| assert args.masks, 'Frozen training is meant for segmentation only' | |
| device = torch.device(args.device) | |
| # fix the seed for reproducibility | |
| seed = args.seed + utils.get_rank() | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| # build model | |
| model, criterion, postprocessors, _ = build_model_main( | |
| args, cfg) | |
| wo_class_error = False | |
| model.to(device) | |
| # ema | |
| if args.use_ema: | |
| ema_m = ModelEma(model, args.ema_decay) | |
| else: | |
| ema_m = None | |
| model_without_ddp = model | |
| if args.distributed: | |
| model = MMDistributedDataParallel( | |
| model, | |
| device_ids=[args.gpu], | |
| find_unused_parameters=args.find_unused_params) | |
| model_without_ddp = model.module | |
| n_parameters = sum(p.numel() for p in model.parameters() | |
| if p.requires_grad) | |
| logger.info('number of params:' + str(n_parameters)) | |
| logger.info('params:\n' + json.dumps( | |
| {n: p.numel() | |
| for n, p in model.named_parameters() if p.requires_grad}, | |
| indent=2)) | |
| param_dicts = get_param_dict(args, model_without_ddp) | |
| optimizer = torch.optim.AdamW(param_dicts, | |
| lr=args.lr, | |
| weight_decay=args.weight_decay) | |
| logger.info('Creating dataset...') | |
| if not args.eval: | |
| trainset= [] | |
| for trainset_i,v in cfg.trainset_partition.items(): | |
| exec('from datasets.' + trainset_i + | |
| ' import ' + trainset_i) | |
| trainset.append( | |
| eval(trainset_i)(transforms.ToTensor(), 'train')) | |
| trainset_loader = MultipleDatasets(trainset, make_same_len=False,partition=cfg.trainset_partition) | |
| data_loader_train = build_dataloader( | |
| trainset_loader, | |
| args.batch_size, | |
| 0 if 'workers_per_gpu' in args else 1, | |
| dist=args.distributed) | |
| exec('from datasets.' + cfg.testset + | |
| ' import ' + cfg.testset) | |
| if not args.inference: | |
| dataset_val = eval(cfg.testset)(transforms.ToTensor(), "test") | |
| else: | |
| dataset_val = eval(cfg.testset)(args.inference_input, args.output_dir) | |
| data_loader_val = build_dataloader( | |
| dataset_val, | |
| args.batch_size, | |
| 0 if 'workers_per_gpu' in args else 2, | |
| dist=args.distributed, | |
| shuffle=False) | |
| if args.onecyclelr: | |
| lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| max_lr=args.lr, | |
| steps_per_epoch=len(data_loader_train), | |
| epochs=args.epochs, | |
| pct_start=0.2) | |
| elif args.multi_step_lr: | |
| lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, milestones=args.lr_drop_list) | |
| else: | |
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) | |
| if args.frozen_weights is not None: | |
| checkpoint = torch.load(args.frozen_weights, map_location='cpu') | |
| model_without_ddp.detr.load_state_dict(checkpoint['model']) | |
| output_dir = Path(args.output_dir) | |
| if os.path.exists(os.path.join(args.output_dir, 'checkpoint.pth')): | |
| args.resume = os.path.join(args.output_dir, 'checkpoint.pth') | |
| if args.resume: | |
| if args.resume.startswith('https'): | |
| checkpoint = torch.hub.load_state_dict_from_url(args.resume, | |
| map_location='cpu', | |
| check_hash=True) | |
| else: | |
| checkpoint = torch.load(args.resume, map_location='cpu') | |
| model_without_ddp.load_state_dict(checkpoint['model']) | |
| if args.use_ema: | |
| if 'ema_model' in checkpoint: | |
| ema_m.module.load_state_dict( | |
| utils.clean_state_dict(checkpoint['ema_model'])) | |
| else: | |
| del ema_m | |
| ema_m = ModelEma(model, args.ema_decay) | |
| if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |
| args.start_epoch = checkpoint['epoch'] + 1 | |
| if (not args.resume) and args.pretrain_model_path: | |
| checkpoint = torch.load(args.pretrain_model_path, | |
| map_location='cpu')['model'] | |
| from collections import OrderedDict | |
| _ignorekeywordlist = args.finetune_ignore if args.finetune_ignore else [] | |
| ignorelist = [] | |
| def check_keep(keyname, ignorekeywordlist): | |
| for keyword in ignorekeywordlist: | |
| if keyword in keyname: | |
| ignorelist.append(keyname) | |
| return False | |
| return True | |
| _tmp_st = OrderedDict({ | |
| k: v | |
| for k, v in utils.clean_state_dict(checkpoint).items() | |
| if check_keep(k, _ignorekeywordlist) | |
| }) | |
| logger.info('Ignore keys: {}'.format(json.dumps(ignorelist, indent=2))) | |
| # Change This | |
| _load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False) | |
| print('loading') | |
| logger.info(str(_load_output)) | |
| if args.use_ema: | |
| if 'ema_model' in checkpoint: | |
| ema_m.module.load_state_dict(utils.clean_state_dict(checkpoint['ema_model'])) | |
| else: | |
| del ema_m | |
| ema_m = ModelEma(model, args.ema_decay) | |
| _load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False) | |
| logger.info(str(_load_output)) | |
| if args.eval: | |
| os.environ['EVAL_FLAG'] = 'TRUE' | |
| if args.inference_input is not None and args.inference: | |
| inference(model, | |
| criterion, | |
| postprocessors, | |
| data_loader_val, | |
| device, | |
| args.output_dir, | |
| wo_class_error=wo_class_error, | |
| args=args) | |
| else: | |
| from config.config import cfg | |
| cfg.result_dir=args.output_dir | |
| cfg.exp_name=args.pretrain_model_path | |
| evaluate(model, | |
| criterion, | |
| postprocessors, | |
| data_loader_val, | |
| device, | |
| args.output_dir, | |
| wo_class_error=wo_class_error, | |
| args=args) | |
| return | |
| print('Start training') | |
| start_time = time.time() | |
| for epoch in range(args.start_epoch, args.epochs): | |
| epoch_start_time = time.time() | |
| train_stats = train_one_epoch( | |
| model, | |
| criterion, | |
| data_loader_train, | |
| optimizer, | |
| device, | |
| epoch, | |
| args.clip_max_norm, | |
| wo_class_error=wo_class_error, | |
| lr_scheduler=lr_scheduler, | |
| args=args, | |
| logger=(logger if args.save_log else None), | |
| ema_m=ema_m, | |
| tf_writer=writer) | |
| if args.output_dir: | |
| checkpoint_paths = [output_dir / 'checkpoint.pth'] | |
| if not args.onecyclelr: | |
| lr_scheduler.step() | |
| if args.output_dir: | |
| checkpoint_paths = [output_dir / 'checkpoint.pth'] | |
| # extra checkpoint before LR drop and every 100 epochs | |
| if (epoch + 1) % args.lr_drop == 0 or ( | |
| epoch + 1) % args.save_checkpoint_interval == 0: | |
| checkpoint_paths.append(output_dir / | |
| f'checkpoint{epoch:04}.pth') | |
| for checkpoint_path in checkpoint_paths: | |
| weights = { | |
| 'model': model_without_ddp.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| 'lr_scheduler': lr_scheduler.state_dict(), | |
| 'epoch': epoch, | |
| 'args': args, | |
| } | |
| if args.use_ema: | |
| weights.update({ | |
| 'ema_model': ema_m.module.state_dict(), | |
| }) | |
| utils.save_on_master(weights, checkpoint_path) | |
| log_stats = { | |
| **{f'train_{k}': v | |
| for k, v in train_stats.items()}, | |
| } | |
| ep_paras = {'epoch': epoch, 'n_parameters': n_parameters} | |
| log_stats.update(ep_paras) | |
| try: | |
| log_stats.update({'now_time': str(datetime.datetime.now())}) | |
| except: | |
| pass | |
| epoch_time = time.time() - epoch_start_time | |
| epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time))) | |
| log_stats['epoch_time'] = epoch_time_str | |
| if args.output_dir and utils.is_main_process(): | |
| with (output_dir / 'log.txt').open('a') as f: | |
| f.write(json.dumps(log_stats) + '\n') | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print('Training time {}'.format(total_time_str)) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('DETR training and evaluation script', | |
| parents=[get_args_parser()]) | |
| __spec__ = "ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>)" | |
| args = parser.parse_args() | |
| if args.output_dir: | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| main(args) | |