Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageOps | |
| from torchvision.transforms import ToPILImage, ToTensor | |
| totensor = ToTensor() | |
| topil = ToPILImage() | |
| def resize_and_crop(img, size, crop_type="center"): | |
| '''Resize and crop the image to the given size.''' | |
| if crop_type == "top": | |
| center = (0, 0) | |
| elif crop_type == "center": | |
| center = (0.5, 0.5) | |
| else: | |
| raise ValueError | |
| resize = list(size) | |
| if size[0] is None: | |
| resize[0] = img.size[0] | |
| if size[1] is None: | |
| resize[1] = img.size[1] | |
| return ImageOps.fit(img, resize, centering=center) | |
| def recover_image(image, init_image, mask, background=False): | |
| image = totensor(image) | |
| mask = totensor(mask)[0] | |
| init_image = totensor(init_image) | |
| if background: | |
| result = mask * init_image + (1 - mask) * image | |
| else: | |
| result = mask * image + (1 - mask) * init_image | |
| return topil(result) | |
| def preprocess(image): | |
| w, h = image.size | |
| w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
| image = image.resize((w, h), resample=Image.LANCZOS) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = image[None].transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image) | |
| return 2.0 * image - 1.0 | |
| def prepare_mask_and_masked_image(image, mask): | |
| image = np.array(image.convert("RGB")) | |
| image = image[None].transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 | |
| mask = np.array(mask.convert("L")) | |
| mask = mask.astype(np.float32) / 255.0 | |
| mask = mask[None, None] | |
| mask[mask < 0.5] = 0 | |
| mask[mask >= 0.5] = 1 | |
| mask = torch.from_numpy(mask) | |
| masked_image = image * (mask < 0.5) | |
| return mask, masked_image | |