Spaces:
Running
Running
| import os | |
| import argparse | |
| import albumentations | |
| from albumentations import HorizontalFlip, Resize, RandomResizedCrop | |
| import torch.backends.cudnn as cudnn | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader | |
| from torch.optim import lr_scheduler | |
| import processing | |
| from utils import build_loss, misc | |
| from model.build_model import build_model | |
| from datasets.build_dataset import dataset_generator | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--workers', type=int, default=8, | |
| metavar='N', help='Dataloader threads.') | |
| parser.add_argument('--batch_size', type=int, default=16, | |
| help='You can override model batch size by specify positive number.') | |
| parser.add_argument('--device', type=str, default='cuda', | |
| help="Whether use cuda, 'cuda' or 'cpu'.") | |
| parser.add_argument('--epochs', type=int, default=60, | |
| help='Epochs number.') | |
| parser.add_argument('--lr', type=int, default=1e-4, | |
| help='Learning rate.') | |
| parser.add_argument('--save_path', type=str, default="./logs", | |
| help='Where to save logs and checkpoints.') | |
| parser.add_argument('--dataset_path', type=str, default=r".\iHarmony4", | |
| help='Dataset path.') | |
| parser.add_argument('--print_freq', type=int, default=100, | |
| help='Number of iterations then print.') | |
| parser.add_argument('--base_size', type=int, default=256, | |
| help='Base size. Resolution of the image input into the Encoder') | |
| parser.add_argument('--input_size', type=int, default=256, | |
| help='Input size. Resolution of the image that want to be generated by the Decoder') | |
| parser.add_argument('--INR_input_size', type=int, default=256, | |
| help='INR input size. Resolution of the image that want to be generated by the Decoder. ' | |
| 'Should be the same as `input_size`') | |
| parser.add_argument('--INR_MLP_dim', type=int, default=32, | |
| help='Number of channels for INR linear layer.') | |
| parser.add_argument('--LUT_dim', type=int, default=7, | |
| help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076') | |
| parser.add_argument('--activation', type=str, default='leakyrelu_pe', | |
| help='INR activation layer type: leakyrelu_pe, sine') | |
| parser.add_argument('--pretrained', type=str, | |
| default=None, | |
| help='Pretrained weight path') | |
| parser.add_argument('--param_factorize_dim', type=int, | |
| default=10, | |
| help='The intermediate dimensions of the factorization of the predicted MLP parameters. ' | |
| 'Refer to https://arxiv.org/abs/2011.12026') | |
| parser.add_argument('--embedding_type', type=str, | |
| default="CIPS_embed", | |
| help='Which embedding_type to use.') | |
| parser.add_argument('--optim', type=str, | |
| default='adamw', | |
| help='Which optimizer to use.') | |
| parser.add_argument('--INRDecode', action="store_false", | |
| help='Whether INR decoder. Set it to False if you want to test the baseline ' | |
| '(https://github.com/SamsungLabs/image_harmonization)') | |
| parser.add_argument('--isMoreINRInput', action="store_false", | |
| help='Whether to cat RGB and mask. See Section 3.4 in the paper.') | |
| parser.add_argument('--hr_train', action="store_true", | |
| help='Whether use hr_train. See section 3.4 in the paper.') | |
| parser.add_argument('--isFullRes', action="store_true", | |
| help='Whether for original resolution. See section 3.4 in the paper.') | |
| opt = parser.parse_args() | |
| opt.save_path = misc.increment_path(os.path.join(opt.save_path, "exp1")) | |
| try: | |
| import wandb | |
| opt.wandb = True | |
| wandb.init(config=opt, project="INR_Harmonization", name=os.path.basename(opt.save_path)) | |
| except: | |
| opt.wandb = False | |
| return opt | |
| def main_process(opt): | |
| logger = misc.create_logger(os.path.join(opt.save_path, "log.txt")) | |
| cudnn.benchmark = True | |
| trainset_path = os.path.join(opt.dataset_path, "IHD_train.txt") | |
| valset_path = os.path.join(opt.dataset_path, "IHD_test.txt") | |
| opt.transform_mean = [.5, .5, .5] | |
| opt.transform_var = [.5, .5, .5] | |
| torch_transform = transforms.Compose([transforms.ToTensor(), | |
| transforms.Normalize(opt.transform_mean, opt.transform_var)]) | |
| trainset_alb_transform = albumentations.Compose( | |
| [ | |
| RandomResizedCrop(opt.input_size, opt.input_size, scale=(0.5, 1.0)), | |
| HorizontalFlip()], | |
| additional_targets={'real_image': 'image', 'object_mask': 'image'} | |
| ) | |
| valset_alb_transform = albumentations.Compose([Resize(opt.input_size, opt.input_size)], | |
| additional_targets={'real_image': 'image', 'object_mask': 'image'}) | |
| trainset = dataset_generator(trainset_path, trainset_alb_transform, torch_transform, opt, mode='Train') | |
| valset = dataset_generator(valset_path, valset_alb_transform, torch_transform, opt, mode='Val') | |
| train_loader = DataLoader(trainset, opt.batch_size, shuffle=True, drop_last=True, | |
| pin_memory=True, | |
| num_workers=opt.workers, persistent_workers=True) | |
| val_loader = DataLoader(valset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True, | |
| num_workers=opt.workers, persistent_workers=True) | |
| model = build_model(opt).to(opt.device) | |
| loss_fn = build_loss.loss_generator() | |
| optimizer_params = { | |
| 'lr': opt.lr, | |
| 'weight_decay': 1e-2 | |
| } | |
| optimizer = misc.get_optimizer(model, opt.optim, optimizer_params) | |
| scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.epochs * len(train_loader), | |
| pct_start=0.0) | |
| processing.train(train_loader, val_loader, model, optimizer, scheduler, loss_fn, logger, opt) | |
| if __name__ == '__main__': | |
| opt = parse_args() | |
| os.makedirs(opt.save_path, exist_ok=True) | |
| main_process(opt) | |