Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision.transforms as T | |
| import PIL.Image | |
| from typing import List | |
| size = (224, 224) | |
| class ResizeWithPadding: | |
| def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None: | |
| self.target_size = target_size | |
| self.fill = fill | |
| self.mode = mode | |
| def __call__(self, image: PIL.Image) -> PIL.Image: | |
| original_width, original_height = image.size | |
| aspect_ratio = original_width / original_height | |
| if aspect_ratio > 1: | |
| new_width = self.target_size | |
| new_height = int(self.target_size / aspect_ratio) | |
| else: | |
| new_height = self.target_size | |
| new_width = int(self.target_size * aspect_ratio) | |
| resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST) | |
| delta_w = self.target_size - new_width | |
| delta_h = self.target_size - new_height | |
| padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2) | |
| padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill) | |
| padded_image.paste(resized_image, (padding[0], padding[1])) | |
| return padded_image | |
| def get_transform(mean: List[float], std: List[float]) -> T.Compose: | |
| return T.Compose([ | |
| ResizeWithPadding(), | |
| T.ToTensor(), | |
| T.Normalize(mean=mean, std=std), | |
| ]) | |
| mask_transform = T.Compose([ | |
| ResizeWithPadding(mode="L"), | |
| T.ToTensor(), | |
| T.Lambda(lambda x: (x * 255).long()), | |
| ]) | |
| class EMA: | |
| def __init__(self, alpha: float = 0.9) -> None: | |
| self.value = None | |
| self.alpha = alpha | |
| def __call__(self, value: float) -> float: | |
| if self.value is None: | |
| self.value = value | |
| else: | |
| self.value = self.alpha * self.value + (1 - self.alpha) * value | |
| return self.value |