Spaces:
Configuration error
Configuration error
| import os.path | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import os | |
| import os.path | |
| import torch | |
| from util.aug_utils import RandomScale, RandomSizeCrop, DivisibleCrop | |
| class SingleImageDataset(Dataset): | |
| def __init__(self, cfg): | |
| self.cfg = cfg | |
| self.base_transforms = transforms.Compose( | |
| [ | |
| transforms.Lambda(lambda x: transforms.ToTensor()(x).unsqueeze(0)), | |
| DivisibleCrop(cfg["d_divisible_crops"]), | |
| ] | |
| ) | |
| # used to create the internal dataset | |
| self.input_transforms = transforms.Compose( | |
| [ | |
| transforms.RandomApply( | |
| [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)], | |
| p=cfg["jitter_p"], | |
| ), | |
| transforms.RandomHorizontalFlip(p=cfg["flip_p"]), | |
| RandomScale((cfg["scale_min"], cfg["scale_max"])), | |
| RandomSizeCrop(cfg["crops_min_cover"]), | |
| self.base_transforms, | |
| ] | |
| ) | |
| # open source image | |
| self.src_img = Image.open(cfg["image_path"]).convert("RGB") | |
| if cfg["resize_input"] > 0: | |
| self.src_img = transforms.Resize(cfg["resize_input"])(self.src_img) | |
| self.step = -1 | |
| def get_img(self): | |
| return self.base_transforms(self.src_img) | |
| def __getitem__(self, index): | |
| self.step += 1 | |
| sample = {"step": self.step} | |
| if self.step % self.cfg["source_image_every"] == 0: | |
| sample["input_image"] = self.get_img() | |
| sample["input_crop"] = self.input_transforms(self.src_img) | |
| return sample | |
| def __len__(self): | |
| return 1 | |