Spaces:
Running
Running
| import torch | |
| import cv2 | |
| import numpy as np | |
| import torchvision | |
| import os | |
| import random | |
| from utils.misc import prepare_cooridinate_input, customRandomCrop | |
| from datasets.build_INR_dataset import Implicit2DGenerator | |
| import albumentations | |
| from albumentations import Resize, RandomResizedCrop, HorizontalFlip | |
| from torch.utils.data import DataLoader | |
| class dataset_generator(torch.utils.data.Dataset): | |
| def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'): | |
| super().__init__() | |
| self.opt = opt | |
| self.root_path = opt.dataset_path | |
| self.mode = mode | |
| self.alb_transforms = alb_transforms | |
| self.torch_transforms = torch_transforms | |
| self.kp_t = area_keep_thresh | |
| with open(dataset_txt, 'r') as f: | |
| self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()] | |
| self.INR_dataset = Implicit2DGenerator(opt, self.mode) | |
| def __len__(self): | |
| return len(self.dataset_samples) | |
| def __getitem__(self, idx): | |
| composite_image = self.dataset_samples[idx] | |
| if self.opt.hr_train: | |
| if self.opt.isFullRes: | |
| "Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \ | |
| "quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \ | |
| "if `opt.isFullRes` is set to True." | |
| composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori") | |
| real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg' | |
| mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png' | |
| composite_image = cv2.imread(composite_image) | |
| composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) | |
| real_image = cv2.imread(real_image) | |
| real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) | |
| mask = cv2.imread(mask) | |
| mask = mask[:, :, 0].astype(np.float32) / 255. | |
| """ | |
| If set `opt.hr_train` to True: | |
| Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres), | |
| the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size` | |
| patch to feed in multiINR process. For inference, just resize it. | |
| While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size. | |
| BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5` | |
| """ | |
| if self.opt.hr_train: | |
| if self.mode == 'Train' and self.opt.isFullRes: | |
| if random.random() < 0.5: # LR mix training | |
| mixTransform = albumentations.Compose( | |
| [ | |
| RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), | |
| HorizontalFlip()], | |
| additional_targets={'real_image': 'image', 'object_mask': 'image'} | |
| ) | |
| origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
| origin_bg_ratio = 1 - origin_fg_ratio | |
| "Ensure fg and bg not disappear after transformation" | |
| valid_augmentation = False | |
| transform_out = None | |
| time = 0 | |
| while not valid_augmentation: | |
| time += 1 | |
| # There are some extreme ratio pics, this code is to avoid being hindered by them. | |
| if time == 20: | |
| tmp_transform = albumentations.Compose( | |
| [Resize(self.opt.base_size, self.opt.base_size)], | |
| additional_targets={'real_image': 'image', | |
| 'object_mask': 'image'}) | |
| transform_out = tmp_transform(image=composite_image, real_image=real_image, | |
| object_mask=mask) | |
| valid_augmentation = True | |
| else: | |
| transform_out = mixTransform(image=composite_image, real_image=real_image, | |
| object_mask=mask) | |
| valid_augmentation = check_augmented_sample(transform_out['object_mask'], | |
| origin_fg_ratio, | |
| origin_bg_ratio, | |
| self.kp_t) | |
| composite_image = transform_out['image'] | |
| real_image = transform_out['real_image'] | |
| mask = transform_out['object_mask'] | |
| else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop. | |
| if real_image.shape[0] < 256: | |
| bottom_pad = 256 - real_image.shape[0] | |
| else: | |
| bottom_pad = (4 - real_image.shape[0] % 4) % 4 | |
| if real_image.shape[1] < 256: | |
| right_pad = 256 - real_image.shape[1] | |
| else: | |
| right_pad = (4 - real_image.shape[1] % 4) % 4 | |
| composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad, | |
| cv2.BORDER_REPLICATE) | |
| real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) | |
| mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) | |
| origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
| origin_bg_ratio = 1 - origin_fg_ratio | |
| "Ensure fg and bg not disappear after transformation" | |
| valid_augmentation = False | |
| transform_out = None | |
| time = 0 | |
| if self.opt.hr_train: | |
| if self.mode == 'Train': | |
| if not self.opt.isFullRes: | |
| if random.random() < 0.5: # LR mix training | |
| mixTransform = albumentations.Compose( | |
| [ | |
| RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), | |
| HorizontalFlip()], | |
| additional_targets={'real_image': 'image', 'object_mask': 'image'} | |
| ) | |
| while not valid_augmentation: | |
| time += 1 | |
| # There are some extreme ratio pics, this code is to avoid being hindered by them. | |
| if time == 20: | |
| tmp_transform = albumentations.Compose( | |
| [Resize(self.opt.base_size, self.opt.base_size)], | |
| additional_targets={'real_image': 'image', | |
| 'object_mask': 'image'}) | |
| transform_out = tmp_transform(image=composite_image, real_image=real_image, | |
| object_mask=mask) | |
| valid_augmentation = True | |
| else: | |
| transform_out = mixTransform(image=composite_image, real_image=real_image, | |
| object_mask=mask) | |
| valid_augmentation = check_augmented_sample(transform_out['object_mask'], | |
| origin_fg_ratio, | |
| origin_bg_ratio, | |
| self.kp_t) | |
| else: | |
| while not valid_augmentation: | |
| time += 1 | |
| # There are some extreme ratio pics, this code is to avoid being hindered by them. | |
| if time == 20: | |
| tmp_transform = albumentations.Compose( | |
| [Resize(self.opt.input_size, self.opt.input_size)], | |
| additional_targets={'real_image': 'image', | |
| 'object_mask': 'image'}) | |
| transform_out = tmp_transform(image=composite_image, real_image=real_image, | |
| object_mask=mask) | |
| valid_augmentation = True | |
| else: | |
| transform_out = self.alb_transforms(image=composite_image, real_image=real_image, | |
| object_mask=mask) | |
| valid_augmentation = check_augmented_sample(transform_out['object_mask'], | |
| origin_fg_ratio, | |
| origin_bg_ratio, | |
| self.kp_t) | |
| composite_image = transform_out['image'] | |
| real_image = transform_out['real_image'] | |
| mask = transform_out['object_mask'] | |
| origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
| full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0) | |
| tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)], | |
| additional_targets={'real_image': 'image', | |
| 'object_mask': 'image'}) | |
| transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) | |
| compos_list = [self.torch_transforms(transform_out['image'])] | |
| real_list = [self.torch_transforms(transform_out['real_image'])] | |
| mask_list = [ | |
| torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))] | |
| coord_map_list = [] | |
| valid_augmentation = False | |
| while not valid_augmentation: | |
| # RSC strategy. To crop different resolutions. | |
| transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord], | |
| self.opt.base_size, self.opt.base_size) | |
| valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio) | |
| compos_list.append(self.torch_transforms(transform_out[0])) | |
| real_list.append(self.torch_transforms(transform_out[1])) | |
| mask_list.append( | |
| torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) | |
| coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) | |
| coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) | |
| for n in range(2): | |
| tmp_comp = cv2.resize(composite_image, ( | |
| composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1))) | |
| tmp_real = cv2.resize(real_image, | |
| (real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1))) | |
| tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1))) | |
| tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0) | |
| transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord], | |
| self.opt.base_size // 2 ** (n + 1), | |
| self.opt.base_size // 2 ** (n + 1), c_h, c_w) | |
| compos_list.append(self.torch_transforms(transform_out[0])) | |
| real_list.append(self.torch_transforms(transform_out[1])) | |
| mask_list.append( | |
| torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) | |
| coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) | |
| out_comp = compos_list | |
| out_real = real_list | |
| out_mask = mask_list | |
| out_coord = coord_map_list | |
| fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
| self.torch_transforms, transform_out[0], transform_out[1], mask) | |
| return { | |
| 'file_path': self.dataset_samples[idx], | |
| 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
| 'composite_image': out_comp, | |
| 'real_image': out_real, | |
| 'mask': out_mask, | |
| 'coordinate_map': out_coord, | |
| 'composite_image0': out_comp[0], | |
| 'real_image0': out_real[0], | |
| 'mask0': out_mask[0], | |
| 'coordinate_map0': out_coord[0], | |
| 'composite_image1': out_comp[1], | |
| 'real_image1': out_real[1], | |
| 'mask1': out_mask[1], | |
| 'coordinate_map1': out_coord[1], | |
| 'composite_image2': out_comp[2], | |
| 'real_image2': out_real[2], | |
| 'mask2': out_mask[2], | |
| 'coordinate_map2': out_coord[2], | |
| 'composite_image3': out_comp[3], | |
| 'real_image3': out_real[3], | |
| 'mask3': out_mask[3], | |
| 'coordinate_map3': out_coord[3], | |
| 'fg_INR_coordinates': fg_INR_coordinates, | |
| 'bg_INR_coordinates': bg_INR_coordinates, | |
| 'fg_INR_RGB': fg_INR_RGB, | |
| 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
| 'bg_INR_RGB': bg_INR_RGB | |
| } | |
| else: | |
| if not self.opt.isFullRes: | |
| tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], | |
| additional_targets={'real_image': 'image', | |
| 'object_mask': 'image'}) | |
| transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) | |
| coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) | |
| "Generate INR dataset." | |
| mask = (torchvision.transforms.ToTensor()( | |
| transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) | |
| mask = np.bool_(mask.numpy()) | |
| fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
| self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) | |
| return { | |
| 'file_path': self.dataset_samples[idx], | |
| 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
| 'composite_image': self.torch_transforms(transform_out['image']), | |
| 'real_image': self.torch_transforms(transform_out['real_image']), | |
| 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), | |
| # Can automatically transfer to Tensor. | |
| 'coordinate_map': coordinate_map, | |
| 'fg_INR_coordinates': fg_INR_coordinates, | |
| 'bg_INR_coordinates': bg_INR_coordinates, | |
| 'fg_INR_RGB': fg_INR_RGB, | |
| 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
| 'bg_INR_RGB': bg_INR_RGB | |
| } | |
| else: | |
| coordinate_map = prepare_cooridinate_input(mask) | |
| "Generate INR dataset." | |
| mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1) | |
| mask_tmp = np.bool_(mask_tmp.numpy()) | |
| fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
| self.torch_transforms, composite_image, real_image, mask_tmp) | |
| return { | |
| 'file_path': self.dataset_samples[idx], | |
| 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
| 'composite_image': self.torch_transforms(composite_image), | |
| 'real_image': self.torch_transforms(real_image), | |
| 'mask': mask[np.newaxis, ...].astype(np.float32), | |
| # Can automatically transfer to Tensor. | |
| 'coordinate_map': coordinate_map, | |
| 'fg_INR_coordinates': fg_INR_coordinates, | |
| 'bg_INR_coordinates': bg_INR_coordinates, | |
| 'fg_INR_RGB': fg_INR_RGB, | |
| 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
| 'bg_INR_RGB': bg_INR_RGB | |
| } | |
| while not valid_augmentation: | |
| time += 1 | |
| # There are some extreme ratio pics, this code is to avoid being hindered by them. | |
| if time == 20: | |
| tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], | |
| additional_targets={'real_image': 'image', | |
| 'object_mask': 'image'}) | |
| transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) | |
| valid_augmentation = True | |
| else: | |
| transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask) | |
| valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, | |
| origin_bg_ratio, | |
| self.kp_t) | |
| coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) | |
| "Generate INR dataset." | |
| mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) | |
| mask = np.bool_(mask.numpy()) | |
| fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( | |
| self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) | |
| return { | |
| 'file_path': self.dataset_samples[idx], | |
| 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], | |
| 'composite_image': self.torch_transforms(transform_out['image']), | |
| 'real_image': self.torch_transforms(transform_out['real_image']), | |
| 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), | |
| # Can automatically transfer to Tensor. | |
| 'coordinate_map': coordinate_map, | |
| 'fg_INR_coordinates': fg_INR_coordinates, | |
| 'bg_INR_coordinates': bg_INR_coordinates, | |
| 'fg_INR_RGB': fg_INR_RGB, | |
| 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, | |
| 'bg_INR_RGB': bg_INR_RGB | |
| } | |
| def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh): | |
| current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
| current_bg_ratio = 1 - current_fg_ratio | |
| if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh: | |
| return False | |
| return True | |
| def check_hr_crop_sample(mask, origin_fg_ratio): | |
| current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) | |
| if current_fg_ratio < 0.8 * origin_fg_ratio: | |
| return False | |
| return True | |