Spaces:
Configuration error
Configuration error
| import argparse | |
| import logging | |
| import os | |
| import sys | |
| import json | |
| import random | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from utils.tools import ConfigWrapper | |
| from dataset.dataset import SVCDataset | |
| from modules.FastSVC import SVCNN | |
| from modules.discriminator import MelGANMultiScaleDiscriminator | |
| from optimizers.scheduler import StepLRScheduler | |
| from loss.adversarial_loss import GeneratorAdversarialLoss | |
| from loss.adversarial_loss import DiscriminatorAdversarialLoss | |
| from loss.stft_loss import MultiResolutionSTFTLoss | |
| from trainer import Trainer | |
| def main(): | |
| """Run training process.""" | |
| parser = argparse.ArgumentParser( | |
| description="Train the FastSVC model." | |
| ) | |
| parser.add_argument( | |
| "--data_root", | |
| type=str, | |
| required=True, | |
| help="dataset root path.", | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| required=True, | |
| help="configuration file path.", | |
| ) | |
| parser.add_argument( | |
| "--cp_path", | |
| required=True, | |
| type=str, | |
| nargs="?", | |
| help='checkpoint file path.', | |
| ) | |
| parser.add_argument( | |
| "--pretrain", | |
| default="", | |
| type=str, | |
| nargs="?", | |
| help='checkpoint file path to load pretrained params. (default="")', | |
| ) | |
| parser.add_argument( | |
| "--resume", | |
| default=False, | |
| type=bool, | |
| nargs="?", | |
| help='whether to resume training from a certain checkpoint.', | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| default=0, | |
| type=int, | |
| help="random seed.", | |
| ) | |
| parser.add_argument( | |
| "--verbose", | |
| type=int, | |
| default=1, | |
| help="logging level. higher is more logging. (default=1)", | |
| ) | |
| args = parser.parse_args() | |
| local_rank = 0 | |
| args.distributed = False | |
| if not torch.cuda.is_available(): | |
| device = torch.device("cpu") | |
| else: | |
| device = torch.device("cuda") | |
| # effective when using fixed size inputs | |
| # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 | |
| torch.backends.cudnn.benchmark = True | |
| # setup for distributed training | |
| # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed | |
| args.world_size = torch.cuda.device_count() | |
| args.distributed = args.world_size > 1 | |
| if args.distributed: | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(local_rank) | |
| print('Using multi-GPUs for training. n_GPU=%d.' %(args.world_size)) | |
| torch.distributed.init_process_group(backend="nccl") | |
| # random seed | |
| torch.manual_seed(args.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| random.seed(args.seed) | |
| # suppress logging for distributed training | |
| if local_rank != 0: | |
| sys.stdout = open(os.devnull, "w") | |
| # set logger | |
| if args.verbose > 1: | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| stream=sys.stdout, | |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
| ) | |
| elif args.verbose > 0: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| stream=sys.stdout, | |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
| ) | |
| else: | |
| logging.basicConfig( | |
| level=logging.WARN, | |
| stream=sys.stdout, | |
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
| ) | |
| logging.warning("Skip DEBUG/INFO messages") | |
| # load and save config | |
| with open(args.config) as f: | |
| config = ConfigWrapper(**json.load(f)) | |
| config.training_config.rank = local_rank | |
| config.training_config.distributed = args.distributed | |
| config.interval_config.out_dir = args.cp_path | |
| # get dataset | |
| train_set = SVCDataset(args.data_root, config.data_config.n_samples, config.data_config.sampling_rate, config.data_config.hop_size, 'train') | |
| valid_set = SVCDataset(args.data_root, config.data_config.n_samples, config.data_config.sampling_rate, config.data_config.hop_size, 'valid') | |
| dataset = { | |
| "train": train_set, | |
| "dev": valid_set, | |
| } | |
| # get data loader | |
| sampler = {"train": None, "dev": None} | |
| if args.distributed: | |
| # setup sampler for distributed training | |
| from torch.utils.data.distributed import DistributedSampler | |
| sampler["train"] = DistributedSampler( | |
| dataset=dataset["train"], | |
| num_replicas=args.world_size, | |
| rank=local_rank, | |
| shuffle=True, | |
| ) | |
| data_loader = { | |
| "train": DataLoader( | |
| dataset=dataset["train"], | |
| shuffle=False if args.distributed else True, | |
| batch_size=config.data_config.batch_size, | |
| num_workers=config.data_config.num_workers, | |
| sampler=sampler["train"], | |
| pin_memory=config.data_config.pin_memory, | |
| drop_last=True, | |
| ), | |
| "dev": DataLoader( | |
| dataset=dataset["dev"], | |
| shuffle=False, | |
| batch_size=config.data_config.batch_size, | |
| num_workers=config.data_config.num_workers, | |
| sampler=sampler["dev"], | |
| pin_memory=config.data_config.pin_memory, | |
| ), | |
| } | |
| # define models | |
| svc_mdl = SVCNN(config).to(device) | |
| discriminator = MelGANMultiScaleDiscriminator().to(device) | |
| model = { | |
| "generator": svc_mdl, | |
| "discriminator": discriminator, | |
| } | |
| # define criterions | |
| criterion = { | |
| "gen_adv": GeneratorAdversarialLoss( | |
| # keep compatibility | |
| **config.loss_config.generator_adv_loss_params | |
| ).to(device), | |
| "dis_adv": DiscriminatorAdversarialLoss( | |
| # keep compatibility | |
| **config.loss_config.discriminator_adv_loss_params | |
| ).to(device), | |
| } | |
| criterion["stft"] = MultiResolutionSTFTLoss( | |
| **config.loss_config.stft, | |
| ).to(device) | |
| # define optimizers and schedulers | |
| optimizer = { | |
| "generator": torch.optim.Adam(model["generator"].parameters(), lr=config.optimizer_config.lr), | |
| "discriminator": torch.optim.Adam(model["discriminator"].parameters(), lr=config.optimizer_config.lr), | |
| } | |
| scheduler = { | |
| "generator": StepLRScheduler(optimizer["generator"], step_size=config.optimizer_config.scheduler_step_size, gamma=config.optimizer_config.scheduler_gamma), | |
| "discriminator": StepLRScheduler(optimizer["discriminator"], step_size=config.optimizer_config.scheduler_step_size, gamma=config.optimizer_config.scheduler_gamma), | |
| } | |
| if args.distributed: | |
| from torch.nn.parallel import DistributedDataParallel | |
| model["generator"] = DistributedDataParallel(model["generator"]) | |
| model["discriminator"] = DistributedDataParallel(model["discriminator"]) | |
| # define trainer | |
| trainer = Trainer( | |
| steps=0, | |
| epochs=0, | |
| data_loader=data_loader, | |
| sampler=sampler, | |
| model=model, | |
| criterion=criterion, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| config=config, | |
| device=device, | |
| ) | |
| # load pretrained parameters from checkpoint | |
| if args.resume: | |
| if args.pretrain != "": | |
| trainer.load_checkpoint(args.pretrain, load_only_params=False, dst_train=args.distributed) | |
| logging.info(f"Successfully load parameters from {args.pretrain}.") | |
| else: | |
| if os.path.isdir(args.cp_path): | |
| dir_files = os.listdir(args.cp_path) | |
| cp_files = [fname for fname in dir_files if fname[:11] == 'checkpoint-'] | |
| if len(cp_files) == 0: | |
| logging.info(f'No pretrained checkpoints. Training from scratch...') | |
| else: | |
| cp_files.sort(key=lambda fname: os.path.getmtime(f'{args.cp_path}/{fname}')) | |
| latest_cp = f'{args.cp_path}/{cp_files[-1]}' | |
| trainer.load_checkpoint(latest_cp, load_only_params=False, dst_train=args.distributed) | |
| logging.info(f'No pretrain dir specified, use the latest one instead. Successfully load parameters from {latest_cp}.') | |
| else: | |
| logging.info(f'No pretrain dir specified. Training from scratch...') | |
| # run training loop | |
| try: | |
| trainer.run() | |
| finally: | |
| trainer.save_checkpoint( | |
| os.path.join(config.interval_config.out_dir, f"checkpoint-{trainer.steps}steps.pkl"), args.distributed | |
| ) | |
| logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") | |
| if __name__ == "__main__": | |
| main() | |