Spaces:
Runtime error
Runtime error
Upload 9 files
Browse files- src/dataset.py +50 -0
- src/get_loss.py +79 -0
- src/losses.py +498 -0
- src/models/dino.py +37 -0
- src/models/segmentation_head.py +40 -0
- src/models/unet.py +171 -0
- src/models/vit.py +36 -0
- src/train.py +264 -0
- src/utils.py +58 -0
src/dataset.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Callable
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import datasets
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SegmentationDataset(Dataset):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
dataset: datasets.Dataset,
|
| 12 |
+
train: bool = True,
|
| 13 |
+
transform: Callable = None,
|
| 14 |
+
target_transform: Callable = None,
|
| 15 |
+
test_size: float = 0.25,
|
| 16 |
+
) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.dataset = dataset
|
| 19 |
+
self.train = train
|
| 20 |
+
self.transform = transform
|
| 21 |
+
self.target_transform = target_transform
|
| 22 |
+
self.test_size = test_size
|
| 23 |
+
|
| 24 |
+
total_size = len(dataset)
|
| 25 |
+
indices = list(range(total_size))
|
| 26 |
+
split = int(self.test_size * total_size)
|
| 27 |
+
|
| 28 |
+
if train:
|
| 29 |
+
self.indices = indices[split:]
|
| 30 |
+
else:
|
| 31 |
+
self.indices = indices[:split]
|
| 32 |
+
|
| 33 |
+
def __len__(self) -> int:
|
| 34 |
+
return len(self.indices)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 37 |
+
item = self.dataset[self.indices[idx]]
|
| 38 |
+
image = item["image"]
|
| 39 |
+
mask = item["mask"]
|
| 40 |
+
if self.transform:
|
| 41 |
+
image = self.transform(image)
|
| 42 |
+
if self.target_transform:
|
| 43 |
+
mask = self.target_transform(mask)
|
| 44 |
+
return image, mask
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 48 |
+
images = torch.stack([item[0] for item in items])
|
| 49 |
+
masks = torch.stack([item[1] for item in items])
|
| 50 |
+
return images, masks
|
src/get_loss.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Callable
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from losses import SoftDiceLoss, SSLoss, IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, ExpLog_loss, FocalLoss, LovaszSoftmax, TopKLoss, WeightedCrossEntropyLoss, SoftDiceLoss_v2, IoULoss_v2, TverskyLoss_v2, FocalTversky_loss_v2, AsymLoss_v2, SSLoss_v2
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_loss(loss_type: str) -> Callable | None:
|
| 9 |
+
if loss_type == "cross_entropy":
|
| 10 |
+
return nn.CrossEntropyLoss()
|
| 11 |
+
elif loss_type == "SoftDiceLoss":
|
| 12 |
+
return SoftDiceLoss()
|
| 13 |
+
elif loss_type == "SSLoss":
|
| 14 |
+
return SSLoss()
|
| 15 |
+
elif loss_type == "IoULoss":
|
| 16 |
+
return IoULoss()
|
| 17 |
+
elif loss_type == "TverskyLoss":
|
| 18 |
+
return TverskyLoss()
|
| 19 |
+
elif loss_type == "FocalTversky_loss":
|
| 20 |
+
tversky_kwargs = {
|
| 21 |
+
"apply_nonlin": None,
|
| 22 |
+
"batch_dice": False,
|
| 23 |
+
"do_bg": True,
|
| 24 |
+
"smooth": 1.0,
|
| 25 |
+
"square": False
|
| 26 |
+
}
|
| 27 |
+
return FocalTversky_loss(tversky_kwargs=tversky_kwargs)
|
| 28 |
+
elif loss_type == "AsymLoss":
|
| 29 |
+
return AsymLoss()
|
| 30 |
+
elif loss_type == "ExpLog_loss":
|
| 31 |
+
soft_dice_kwargs = {
|
| 32 |
+
"smooth": 1.0
|
| 33 |
+
}
|
| 34 |
+
wce_kwargs = {
|
| 35 |
+
"weight": None
|
| 36 |
+
}
|
| 37 |
+
return ExpLog_loss(soft_dice_kwargs=soft_dice_kwargs, wce_kwargs=wce_kwargs)
|
| 38 |
+
elif loss_type == "FocalLoss":
|
| 39 |
+
return FocalLoss()
|
| 40 |
+
elif loss_type == "LovaszSoftmax":
|
| 41 |
+
return LovaszSoftmax()
|
| 42 |
+
elif loss_type == "TopKLoss":
|
| 43 |
+
return TopKLoss()
|
| 44 |
+
elif loss_type == "WeightedCrossEntropyLoss":
|
| 45 |
+
return WeightedCrossEntropyLoss()
|
| 46 |
+
elif loss_type == "SoftDiceLoss_v2":
|
| 47 |
+
return SoftDiceLoss_v2()
|
| 48 |
+
elif loss_type == "IoULoss_v2":
|
| 49 |
+
return IoULoss_v2()
|
| 50 |
+
elif loss_type == "TverskyLoss_v2":
|
| 51 |
+
return TverskyLoss_v2()
|
| 52 |
+
elif loss_type == "FocalTversky_loss_v2":
|
| 53 |
+
return FocalTversky_loss_v2()
|
| 54 |
+
elif loss_type == "AsymLoss_v2":
|
| 55 |
+
return AsymLoss_v2()
|
| 56 |
+
elif loss_type == "SSLoss_v2":
|
| 57 |
+
return SSLoss_v2()
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unsupported loss type: {loss_type}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_composite_criterion(losses_config: Dict[str, float]) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 63 |
+
losses = []
|
| 64 |
+
weights = []
|
| 65 |
+
|
| 66 |
+
for loss_name, weight in losses_config.items():
|
| 67 |
+
if weight != 0.0:
|
| 68 |
+
loss_fn = get_loss(loss_name)
|
| 69 |
+
if loss_fn is not None:
|
| 70 |
+
losses.append(loss_fn)
|
| 71 |
+
weights.append(weight)
|
| 72 |
+
|
| 73 |
+
def composite_loss(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
total_loss = 0.0
|
| 75 |
+
for loss_fn, weight in zip(losses, weights):
|
| 76 |
+
total_loss += weight * loss_fn(output, target)
|
| 77 |
+
return total_loss
|
| 78 |
+
|
| 79 |
+
return composite_loss
|
src/losses.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, List, Tuple, Dict
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def sum_tensor(inp: torch.Tensor, axes: int | List[int], keepdim: bool = False) -> torch.Tensor:
|
| 9 |
+
axes = np.unique(axes).astype(int)
|
| 10 |
+
if keepdim:
|
| 11 |
+
for ax in axes:
|
| 12 |
+
inp = inp.sum(int(ax), keepdim=True)
|
| 13 |
+
else:
|
| 14 |
+
for ax in sorted(axes, reverse=True):
|
| 15 |
+
inp = inp.sum(int(ax))
|
| 16 |
+
return inp
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_tp_fp_fn(net_output: torch.Tensor, gt: torch.Tensor, axes: int | Tuple[int, ...] | None = None, mask: torch.Tensor | None = None, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 20 |
+
if axes is None:
|
| 21 |
+
axes = tuple(range(2, len(net_output.size())))
|
| 22 |
+
shp_x = net_output.shape
|
| 23 |
+
shp_y = gt.shape
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
if len(shp_x) != len(shp_y):
|
| 26 |
+
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
|
| 27 |
+
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
|
| 28 |
+
y_onehot = gt
|
| 29 |
+
else:
|
| 30 |
+
gt = gt.long()
|
| 31 |
+
y_onehot = torch.zeros(shp_x)
|
| 32 |
+
if net_output.device.type == "cuda":
|
| 33 |
+
y_onehot = y_onehot.cuda(net_output.device.index)
|
| 34 |
+
y_onehot.scatter_(1, gt, 1)
|
| 35 |
+
tp = net_output * y_onehot
|
| 36 |
+
fp = net_output * (1 - y_onehot)
|
| 37 |
+
fn = (1 - net_output) * y_onehot
|
| 38 |
+
if mask is not None:
|
| 39 |
+
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
|
| 40 |
+
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
|
| 41 |
+
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
|
| 42 |
+
if square:
|
| 43 |
+
tp = tp ** 2
|
| 44 |
+
fp = fp ** 2
|
| 45 |
+
fn = fn ** 2
|
| 46 |
+
tp = sum_tensor(tp, axes, keepdim=False)
|
| 47 |
+
fp = sum_tensor(fp, axes, keepdim=False)
|
| 48 |
+
fn = sum_tensor(fn, axes, keepdim=False)
|
| 49 |
+
return tp, fp, fn
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def softmax_helper(x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
rpt = [1 for _ in range(len(x.size()))]
|
| 54 |
+
rpt[1] = x.size(1)
|
| 55 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
| 56 |
+
e_x = torch.exp(x - x_max)
|
| 57 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
| 58 |
+
|
| 59 |
+
def flatten(tensor: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
C = tensor.size(1)
|
| 61 |
+
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
|
| 62 |
+
transposed = tensor.permute(axis_order).contiguous()
|
| 63 |
+
return transposed.view(C, -1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class SoftDiceLoss(nn.Module):
|
| 67 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1.0, square: bool = True) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.square = square
|
| 70 |
+
self.do_bg = do_bg
|
| 71 |
+
self.batch_dice = batch_dice
|
| 72 |
+
self.apply_nonlin = apply_nonlin
|
| 73 |
+
self.smooth = smooth
|
| 74 |
+
|
| 75 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 76 |
+
shp_x = x.shape
|
| 77 |
+
if self.batch_dice:
|
| 78 |
+
axes = [0] + list(range(2, len(shp_x)))
|
| 79 |
+
else:
|
| 80 |
+
axes = list(range(2, len(shp_x)))
|
| 81 |
+
if self.apply_nonlin is not None:
|
| 82 |
+
x = self.apply_nonlin(x)
|
| 83 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
| 84 |
+
dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
|
| 85 |
+
if not self.do_bg:
|
| 86 |
+
if self.batch_dice:
|
| 87 |
+
dc = dc[1:]
|
| 88 |
+
else:
|
| 89 |
+
dc = dc[:, 1:]
|
| 90 |
+
dc = dc.mean()
|
| 91 |
+
return -dc
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class SoftDiceLoss_v2(nn.Module):
|
| 95 |
+
def __init__(self, smooth: float = 1.0) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.smooth = smooth
|
| 98 |
+
|
| 99 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
probs = F.softmax(logits, dim=1)
|
| 101 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
| 102 |
+
intersection = torch.sum(probs * targets, dim=(0, 2, 3))
|
| 103 |
+
union = torch.sum(probs + targets, dim=(0, 2, 3))
|
| 104 |
+
dl = 1 - (2.0 * intersection + self.smooth) / (union + self.smooth)
|
| 105 |
+
dice_loss = torch.mean(dl)
|
| 106 |
+
return dice_loss
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SSLoss(nn.Module):
|
| 110 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.square = square
|
| 113 |
+
self.do_bg = do_bg
|
| 114 |
+
self.batch_dice = batch_dice
|
| 115 |
+
self.apply_nonlin = apply_nonlin
|
| 116 |
+
self.smooth = smooth
|
| 117 |
+
self.r = 0.1
|
| 118 |
+
|
| 119 |
+
def forward(self, net_output: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
shp_x = net_output.shape
|
| 121 |
+
shp_y = gt.shape
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
if len(shp_x) != len(shp_y):
|
| 124 |
+
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
|
| 125 |
+
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
|
| 126 |
+
y_onehot = gt
|
| 127 |
+
else:
|
| 128 |
+
gt = gt.long()
|
| 129 |
+
y_onehot = torch.zeros(shp_x)
|
| 130 |
+
if net_output.device.type == "cuda":
|
| 131 |
+
y_onehot = y_onehot.cuda(net_output.device.index)
|
| 132 |
+
y_onehot.scatter_(1, gt, 1)
|
| 133 |
+
if self.batch_dice:
|
| 134 |
+
axes = [0] + list(range(2, len(shp_x)))
|
| 135 |
+
else:
|
| 136 |
+
axes = list(range(2, len(shp_x)))
|
| 137 |
+
if self.apply_nonlin is not None:
|
| 138 |
+
net_output = self.apply_nonlin(net_output)
|
| 139 |
+
bg_onehot = 1 - y_onehot
|
| 140 |
+
squared_error = (y_onehot - net_output)**2
|
| 141 |
+
specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)
|
| 142 |
+
sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)
|
| 143 |
+
ss = self.r * specificity_part + (1-self.r) * sensitivity_part
|
| 144 |
+
if not self.do_bg:
|
| 145 |
+
if self.batch_dice:
|
| 146 |
+
ss = ss[1:]
|
| 147 |
+
else:
|
| 148 |
+
ss = ss[:, 1:]
|
| 149 |
+
ss = ss.mean()
|
| 150 |
+
return ss
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class SSLoss_v2(nn.Module):
|
| 154 |
+
def __init__(self, alpha: float = 0.5) -> None:
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.alpha = alpha
|
| 157 |
+
|
| 158 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
probs = F.softmax(logits, dim=1)
|
| 160 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
| 161 |
+
intersection = torch.sum(probs * targets, dim=(0, 2, 3))
|
| 162 |
+
cardinality = torch.sum(probs + targets, dim=(0, 2, 3))
|
| 163 |
+
dice_loss = 1 - (2.0 * intersection + 1e-6) / (cardinality + 1e-6)
|
| 164 |
+
ce_loss = F.cross_entropy(probs, targets, reduction='mean')
|
| 165 |
+
loss = 0.5 * dice_loss.mean() + (1 - 0.5) * ce_loss
|
| 166 |
+
return loss
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class IoULoss(nn.Module):
|
| 170 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.square = square
|
| 173 |
+
self.do_bg = do_bg
|
| 174 |
+
self.batch_dice = batch_dice
|
| 175 |
+
self.apply_nonlin = apply_nonlin
|
| 176 |
+
self.smooth = smooth
|
| 177 |
+
|
| 178 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 179 |
+
shp_x = x.shape
|
| 180 |
+
if self.batch_dice:
|
| 181 |
+
axes = [0] + list(range(2, len(shp_x)))
|
| 182 |
+
else:
|
| 183 |
+
axes = list(range(2, len(shp_x)))
|
| 184 |
+
if self.apply_nonlin is not None:
|
| 185 |
+
x = self.apply_nonlin(x)
|
| 186 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
| 187 |
+
iou = (tp + self.smooth) / (tp + fp + fn + self.smooth)
|
| 188 |
+
if not self.do_bg:
|
| 189 |
+
if self.batch_dice:
|
| 190 |
+
iou = iou[1:]
|
| 191 |
+
else:
|
| 192 |
+
iou = iou[:, 1:]
|
| 193 |
+
iou = iou.mean()
|
| 194 |
+
return -iou
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class IoULoss_v2(nn.Module):
|
| 198 |
+
def __init__(self, smooth: float = 1.0) -> None:
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.smooth = smooth
|
| 201 |
+
|
| 202 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 203 |
+
probs = F.softmax(logits, dim=1)
|
| 204 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
| 205 |
+
intersection = torch.sum(probs * targets, dim=(0, 2, 3))
|
| 206 |
+
union = torch.sum(probs + targets, dim=(0, 2, 3)) - intersection
|
| 207 |
+
iou = 1 - (intersection + self.smooth) / (union + self.smooth)
|
| 208 |
+
iou_loss = torch.mean(iou)
|
| 209 |
+
return iou_loss
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class TverskyLoss(nn.Module):
|
| 213 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.square = square
|
| 216 |
+
self.do_bg = do_bg
|
| 217 |
+
self.batch_dice = batch_dice
|
| 218 |
+
self.apply_nonlin = apply_nonlin
|
| 219 |
+
self.smooth = smooth
|
| 220 |
+
self.alpha = 0.3
|
| 221 |
+
self.beta = 0.7
|
| 222 |
+
|
| 223 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 224 |
+
shp_x = x.shape
|
| 225 |
+
if self.batch_dice:
|
| 226 |
+
axes = [0] + list(range(2, len(shp_x)))
|
| 227 |
+
else:
|
| 228 |
+
axes = list(range(2, len(shp_x)))
|
| 229 |
+
if self.apply_nonlin is not None:
|
| 230 |
+
x = self.apply_nonlin(x)
|
| 231 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
| 232 |
+
tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)
|
| 233 |
+
if not self.do_bg:
|
| 234 |
+
if self.batch_dice:
|
| 235 |
+
tversky = tversky[1:]
|
| 236 |
+
else:
|
| 237 |
+
tversky = tversky[:, 1:]
|
| 238 |
+
tversky = tversky.mean()
|
| 239 |
+
return -tversky
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class TverskyLoss_v2(nn.Module):
|
| 243 |
+
def __init__(self, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1.0) -> None:
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.alpha = alpha
|
| 246 |
+
self.beta = beta
|
| 247 |
+
self.smooth = smooth
|
| 248 |
+
|
| 249 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 250 |
+
probs = F.softmax(logits, dim=1)
|
| 251 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
| 252 |
+
tp = torch.sum(probs * targets, dim=(0, 2, 3))
|
| 253 |
+
fp = torch.sum((1 - targets) * probs, dim=(0, 2, 3))
|
| 254 |
+
fn = torch.sum(targets * (1 - probs), dim=(0, 2, 3))
|
| 255 |
+
tversky = 1 - (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)
|
| 256 |
+
tversky_loss = torch.mean(tversky)
|
| 257 |
+
return tversky_loss
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class FocalTversky_loss(nn.Module):
|
| 261 |
+
def __init__(self, tversky_kwargs: Dict, gamma: float = 0.75) -> None:
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.gamma = gamma
|
| 264 |
+
self.tversky = TverskyLoss(**tversky_kwargs)
|
| 265 |
+
|
| 266 |
+
def forward(self, net_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 267 |
+
tversky_loss = 1 + self.tversky(net_output, target)
|
| 268 |
+
focal_tversky = torch.pow(tversky_loss, self.gamma)
|
| 269 |
+
return focal_tversky
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class FocalTversky_loss_v2(nn.Module):
|
| 273 |
+
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.5, smooth: float = 1.0) -> None:
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.alpha = alpha
|
| 276 |
+
self.beta = beta
|
| 277 |
+
self.gamma = gamma
|
| 278 |
+
self.smooth = smooth
|
| 279 |
+
|
| 280 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 281 |
+
probs = F.softmax(logits, dim=1)
|
| 282 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
| 283 |
+
tp = torch.sum(probs * targets, dim=(0, 2, 3))
|
| 284 |
+
fp = torch.sum((1 - targets) * probs, dim=(0, 2, 3))
|
| 285 |
+
fn = torch.sum(targets * (1 - probs), dim=(0, 2, 3))
|
| 286 |
+
focal_tversky = (1 - (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)) ** self.gamma
|
| 287 |
+
focal_tversky_loss = torch.mean(focal_tversky)
|
| 288 |
+
return focal_tversky_loss
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class AsymLoss(nn.Module):
|
| 292 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.square = square
|
| 295 |
+
self.do_bg = do_bg
|
| 296 |
+
self.batch_dice = batch_dice
|
| 297 |
+
self.apply_nonlin = apply_nonlin
|
| 298 |
+
self.smooth = smooth
|
| 299 |
+
self.beta = 1.5
|
| 300 |
+
|
| 301 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 302 |
+
shp_x = x.shape
|
| 303 |
+
if self.batch_dice:
|
| 304 |
+
axes = [0] + list(range(2, len(shp_x)))
|
| 305 |
+
else:
|
| 306 |
+
axes = list(range(2, len(shp_x)))
|
| 307 |
+
if self.apply_nonlin is not None:
|
| 308 |
+
x = self.apply_nonlin(x)
|
| 309 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
| 310 |
+
weight = (self.beta**2)/(1+self.beta**2)
|
| 311 |
+
asym = (tp + self.smooth) / (tp + weight*fn + (1-weight)*fp + self.smooth)
|
| 312 |
+
if not self.do_bg:
|
| 313 |
+
if self.batch_dice:
|
| 314 |
+
asym = asym[1:]
|
| 315 |
+
else:
|
| 316 |
+
asym = asym[:, 1:]
|
| 317 |
+
asym = asym.mean()
|
| 318 |
+
return -asym
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class AsymLoss_v2(nn.Module):
|
| 322 |
+
def __init__(self, alpha: float = 0.5, gamma: float = 2.0, smooth: float = 1e-5) -> None:
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.alpha = alpha
|
| 325 |
+
self.gamma = gamma
|
| 326 |
+
self.smooth = smooth
|
| 327 |
+
|
| 328 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 329 |
+
probs = F.softmax(logits, dim=1)
|
| 330 |
+
targets_one_hot = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
| 331 |
+
pos_loss = -self.alpha * (1 - probs) ** self.gamma * targets_one_hot * torch.log(probs + self.smooth)
|
| 332 |
+
neg_loss = -(1 - self.alpha) * probs ** self.gamma * (1 - targets_one_hot) * torch.log(1 - probs + self.smooth)
|
| 333 |
+
loss = pos_loss + neg_loss
|
| 334 |
+
return loss.mean()
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class ExpLog_loss(nn.Module):
|
| 338 |
+
def __init__(self, soft_dice_kwargs: Dict, wce_kwargs: Dict, gamma: float = 0.3) -> None:
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.wce = WeightedCrossEntropyLoss(**wce_kwargs)
|
| 341 |
+
self.dc = SoftDiceLoss_v2(**soft_dice_kwargs)
|
| 342 |
+
self.gamma = gamma
|
| 343 |
+
|
| 344 |
+
def forward(self, net_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 345 |
+
dc_loss = -self.dc(net_output, target)
|
| 346 |
+
wce_loss = self.wce(net_output, target)
|
| 347 |
+
explog_loss = 0.8*torch.pow(-torch.log(torch.clamp(dc_loss, 1e-6)), self.gamma) + 0.2*wce_loss
|
| 348 |
+
return explog_loss
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class FocalLoss(nn.Module):
|
| 352 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, alpha: float | List[float] | np.ndarray | None = None, gamma: int = 2, balance_index: int = 0, smooth: float = 1e-4, size_average: bool = True) -> None:
|
| 353 |
+
super().__init__()
|
| 354 |
+
self.apply_nonlin = apply_nonlin
|
| 355 |
+
self.alpha = alpha
|
| 356 |
+
self.gamma = gamma
|
| 357 |
+
self.balance_index = balance_index
|
| 358 |
+
self.smooth = smooth
|
| 359 |
+
self.size_average = size_average
|
| 360 |
+
if self.smooth is not None:
|
| 361 |
+
if self.smooth < 0 or self.smooth > 1.0:
|
| 362 |
+
raise ValueError("smooth value should be in [0,1]")
|
| 363 |
+
|
| 364 |
+
def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 365 |
+
if self.apply_nonlin is not None:
|
| 366 |
+
logit = self.apply_nonlin(logit)
|
| 367 |
+
num_class = logit.shape[1]
|
| 368 |
+
if logit.dim() > 2:
|
| 369 |
+
logit = logit.view(logit.size(0), logit.size(1), -1)
|
| 370 |
+
logit = logit.permute(0, 2, 1).contiguous()
|
| 371 |
+
logit = logit.view(-1, logit.size(-1))
|
| 372 |
+
target = torch.squeeze(target, 1)
|
| 373 |
+
target = target.view(-1, 1)
|
| 374 |
+
alpha = self.alpha
|
| 375 |
+
if alpha is None:
|
| 376 |
+
alpha = torch.ones(num_class, 1)
|
| 377 |
+
elif isinstance(alpha, (list, np.ndarray)):
|
| 378 |
+
assert len(alpha) == num_class
|
| 379 |
+
alpha = torch.FloatTensor(alpha).view(num_class, 1)
|
| 380 |
+
alpha = alpha / alpha.sum()
|
| 381 |
+
elif isinstance(alpha, float):
|
| 382 |
+
alpha = torch.ones(num_class, 1)
|
| 383 |
+
alpha = alpha * (1 - self.alpha)
|
| 384 |
+
alpha[self.balance_index] = self.alpha
|
| 385 |
+
else:
|
| 386 |
+
raise TypeError("Not support alpha type")
|
| 387 |
+
if alpha.device != logit.device:
|
| 388 |
+
alpha = alpha.to(logit.device)
|
| 389 |
+
idx = target.cpu().long()
|
| 390 |
+
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
|
| 391 |
+
one_hot_key = one_hot_key.scatter_(1, idx, 1)
|
| 392 |
+
if one_hot_key.device != logit.device:
|
| 393 |
+
one_hot_key = one_hot_key.to(logit.device)
|
| 394 |
+
if self.smooth:
|
| 395 |
+
one_hot_key = torch.clamp(
|
| 396 |
+
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
|
| 397 |
+
pt = (one_hot_key * logit).sum(1) + self.smooth
|
| 398 |
+
logpt = pt.log()
|
| 399 |
+
gamma = self.gamma
|
| 400 |
+
alpha = alpha[idx]
|
| 401 |
+
alpha = torch.squeeze(alpha)
|
| 402 |
+
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
|
| 403 |
+
if self.size_average:
|
| 404 |
+
loss = loss.mean()
|
| 405 |
+
else:
|
| 406 |
+
loss = loss.sum()
|
| 407 |
+
return loss
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
|
| 411 |
+
p = len(gt_sorted)
|
| 412 |
+
gts = gt_sorted.sum()
|
| 413 |
+
intersection = gts - gt_sorted.float().cumsum(0)
|
| 414 |
+
union = gts + (1 - gt_sorted).float().cumsum(0)
|
| 415 |
+
jaccard = 1. - intersection / union
|
| 416 |
+
if p > 1:
|
| 417 |
+
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
| 418 |
+
return jaccard
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class LovaszSoftmax(nn.Module):
|
| 422 |
+
def __init__(self, reduction: str = "mean") -> None:
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.reduction = reduction
|
| 425 |
+
|
| 426 |
+
def prob_flatten(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 427 |
+
assert input.dim() in [4, 5]
|
| 428 |
+
num_class = input.size(1)
|
| 429 |
+
if input.dim() == 4:
|
| 430 |
+
input = input.permute(0, 2, 3, 1).contiguous()
|
| 431 |
+
input_flatten = input.view(-1, num_class)
|
| 432 |
+
elif input.dim() == 5:
|
| 433 |
+
input = input.permute(0, 2, 3, 4, 1).contiguous()
|
| 434 |
+
input_flatten = input.view(-1, num_class)
|
| 435 |
+
target_flatten = target.view(-1)
|
| 436 |
+
return input_flatten, target_flatten
|
| 437 |
+
|
| 438 |
+
def lovasz_softmax_flat(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 439 |
+
num_classes = inputs.size(1)
|
| 440 |
+
losses = []
|
| 441 |
+
for c in range(num_classes):
|
| 442 |
+
target_c = (targets == c).float()
|
| 443 |
+
if num_classes == 1:
|
| 444 |
+
input_c = inputs[:, 0]
|
| 445 |
+
else:
|
| 446 |
+
input_c = inputs[:, c]
|
| 447 |
+
loss_c = (torch.autograd.Variable(target_c) - input_c).abs()
|
| 448 |
+
loss_c_sorted, loss_index = torch.sort(loss_c, 0, descending=True)
|
| 449 |
+
target_c_sorted = target_c[loss_index]
|
| 450 |
+
losses.append(torch.dot(loss_c_sorted, torch.autograd.Variable(lovasz_grad(target_c_sorted))))
|
| 451 |
+
losses = torch.stack(losses)
|
| 452 |
+
if self.reduction == "none":
|
| 453 |
+
loss = losses
|
| 454 |
+
elif self.reduction == "sum":
|
| 455 |
+
loss = losses.sum()
|
| 456 |
+
else:
|
| 457 |
+
loss = losses.mean()
|
| 458 |
+
return loss
|
| 459 |
+
|
| 460 |
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 461 |
+
inputs, targets = self.prob_flatten(inputs, targets)
|
| 462 |
+
losses = self.lovasz_softmax_flat(inputs, targets)
|
| 463 |
+
return losses
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class TopKLoss(nn.Module):
|
| 467 |
+
def __init__(self, weight: torch.Tensor | None = None, ignore_index: int = -100, k: int = 10) -> None:
|
| 468 |
+
super().__init__()
|
| 469 |
+
self.k = k
|
| 470 |
+
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction="none")
|
| 471 |
+
|
| 472 |
+
def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 473 |
+
pixel_losses = self.cross_entropy(inp, target)
|
| 474 |
+
pixel_losses = pixel_losses.view(-1)
|
| 475 |
+
num_voxels = pixel_losses.numel()
|
| 476 |
+
res, _ = torch.topk(pixel_losses, int(num_voxels * self.k / 100), sorted=False)
|
| 477 |
+
return res.mean()
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
|
| 481 |
+
def __init__(self, weight: torch.Tensor | None = None) -> None:
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.weight = weight
|
| 484 |
+
|
| 485 |
+
def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 486 |
+
target = target.long()
|
| 487 |
+
num_classes = inp.size()[1]
|
| 488 |
+
i0 = 1
|
| 489 |
+
i1 = 2
|
| 490 |
+
while i1 < len(inp.shape):
|
| 491 |
+
inp = inp.transpose(i0, i1)
|
| 492 |
+
i0 += 1
|
| 493 |
+
i1 += 1
|
| 494 |
+
inp = inp.contiguous()
|
| 495 |
+
inp = inp.view(-1, num_classes)
|
| 496 |
+
target = target.view(-1,)
|
| 497 |
+
wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
|
| 498 |
+
return wce_loss(inp, target)
|
src/models/dino.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Dinov2Backbone
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from src.models.segmentation_head import SegmentationHead
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DINOSegmentationModel(nn.Module):
|
| 10 |
+
def __init__(self, image_size: int = 224, num_classes: int = 18) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.mean = [0.485, 0.456, 0.406]
|
| 13 |
+
self.std = [0.229, 0.224, 0.225]
|
| 14 |
+
self.image_size = image_size
|
| 15 |
+
model_name = "facebook/dinov2-small"
|
| 16 |
+
self.backbone = Dinov2Backbone.from_pretrained(model_name)
|
| 17 |
+
for param in self.backbone.parameters():
|
| 18 |
+
param.requires_grad = False
|
| 19 |
+
self.segmentation_head = SegmentationHead(in_channels=384, num_classes=num_classes)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
batch_size, channels, height, width = x.size()
|
| 23 |
+
assert height == width == self.image_size, "The image must match the size required by the DINO model"
|
| 24 |
+
features = self.backbone(pixel_values=x).feature_maps[0]
|
| 25 |
+
masks = self.segmentation_head(features)
|
| 26 |
+
return masks
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main() -> None:
|
| 30 |
+
# model = DINOSegmentationModel()
|
| 31 |
+
model = SegmentationHead(384, 18)
|
| 32 |
+
num_params = sum([p.numel() for p in model.parameters()])
|
| 33 |
+
print(f"params: {num_params/1e6:.2f} M")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
src/models/segmentation_head.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SegmentationHead(nn.Module):
|
| 6 |
+
def __init__(self, in_channels: int, num_classes: int):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.head = nn.Sequential(
|
| 9 |
+
nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
|
| 10 |
+
nn.BatchNorm2d(256),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
| 13 |
+
nn.BatchNorm2d(256),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Upsample(size=(64, 64), mode="bilinear"),
|
| 16 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 17 |
+
nn.BatchNorm2d(128),
|
| 18 |
+
nn.ReLU(),
|
| 19 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 20 |
+
nn.BatchNorm2d(128),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Upsample(size=(128, 128), mode="bilinear"),
|
| 23 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 24 |
+
nn.BatchNorm2d(64),
|
| 25 |
+
nn.ReLU(),
|
| 26 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
| 27 |
+
nn.BatchNorm2d(64),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Upsample(size=(224, 224), mode="bilinear"),
|
| 30 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 31 |
+
nn.BatchNorm2d(32),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
| 34 |
+
nn.BatchNorm2d(32),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
nn.Conv2d(32, num_classes, kernel_size=3, padding=1),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
return self.head(x)
|
src/models/unet.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UNet(nn.Module):
|
| 6 |
+
def __init__(self) -> None:
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.mean = [0.485, 0.456, 0.406]
|
| 9 |
+
self.std = [0.229, 0.224, 0.225]
|
| 10 |
+
# Downsampler
|
| 11 |
+
self.enc_conv0 = nn.Sequential(
|
| 12 |
+
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| 13 |
+
nn.LeakyReLU(inplace=True),
|
| 14 |
+
nn.BatchNorm2d(64),
|
| 15 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| 16 |
+
nn.LeakyReLU(inplace=True),
|
| 17 |
+
nn.BatchNorm2d(64),
|
| 18 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| 19 |
+
nn.LeakyReLU(inplace=True),
|
| 20 |
+
nn.BatchNorm2d(64)
|
| 21 |
+
)
|
| 22 |
+
self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 23 |
+
self.enc_conv1 = nn.Sequential(
|
| 24 |
+
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 25 |
+
nn.LeakyReLU(inplace=True),
|
| 26 |
+
nn.BatchNorm2d(128),
|
| 27 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 28 |
+
nn.LeakyReLU(inplace=True),
|
| 29 |
+
nn.BatchNorm2d(128),
|
| 30 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 31 |
+
nn.LeakyReLU(inplace=True),
|
| 32 |
+
nn.BatchNorm2d(128)
|
| 33 |
+
)
|
| 34 |
+
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 35 |
+
self.enc_conv2 = nn.Sequential(
|
| 36 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| 37 |
+
nn.LeakyReLU(inplace=True),
|
| 38 |
+
nn.BatchNorm2d(256),
|
| 39 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| 40 |
+
nn.LeakyReLU(inplace=True),
|
| 41 |
+
nn.BatchNorm2d(256),
|
| 42 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| 43 |
+
nn.LeakyReLU(inplace=True),
|
| 44 |
+
nn.BatchNorm2d(256)
|
| 45 |
+
)
|
| 46 |
+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 47 |
+
self.enc_conv3 = nn.Sequential(
|
| 48 |
+
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
| 49 |
+
nn.LeakyReLU(inplace=True),
|
| 50 |
+
nn.BatchNorm2d(512),
|
| 51 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
| 52 |
+
nn.LeakyReLU(inplace=True),
|
| 53 |
+
nn.BatchNorm2d(512),
|
| 54 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
| 55 |
+
nn.LeakyReLU(inplace=True),
|
| 56 |
+
nn.BatchNorm2d(512)
|
| 57 |
+
)
|
| 58 |
+
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 59 |
+
|
| 60 |
+
# bottleneck
|
| 61 |
+
self.bottleneck_conv = nn.Sequential(
|
| 62 |
+
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
| 63 |
+
nn.LeakyReLU(inplace=True),
|
| 64 |
+
nn.BatchNorm2d(1024),
|
| 65 |
+
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
| 66 |
+
nn.LeakyReLU(inplace=True),
|
| 67 |
+
nn.BatchNorm2d(1024),
|
| 68 |
+
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
| 69 |
+
nn.LeakyReLU(inplace=True),
|
| 70 |
+
nn.BatchNorm2d(1024)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Upsampler
|
| 74 |
+
|
| 75 |
+
self.upsample0 = nn.Sequential(
|
| 76 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
| 77 |
+
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.dec_conv0 = nn.Sequential(
|
| 81 |
+
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
|
| 82 |
+
nn.LeakyReLU(inplace=True),
|
| 83 |
+
nn.BatchNorm2d(512),
|
| 84 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
| 85 |
+
nn.LeakyReLU(inplace=True),
|
| 86 |
+
nn.BatchNorm2d(512),
|
| 87 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
| 88 |
+
nn.LeakyReLU(inplace=True),
|
| 89 |
+
nn.BatchNorm2d(512)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.upsample1 = nn.Sequential(
|
| 93 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
| 94 |
+
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.dec_conv1 = nn.Sequential(
|
| 98 |
+
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| 99 |
+
nn.LeakyReLU(inplace=True),
|
| 100 |
+
nn.BatchNorm2d(256),
|
| 101 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| 102 |
+
nn.LeakyReLU(inplace=True),
|
| 103 |
+
nn.BatchNorm2d(256),
|
| 104 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
| 105 |
+
nn.LeakyReLU(inplace=True),
|
| 106 |
+
nn.BatchNorm2d(256)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.upsample2 = nn.Sequential(
|
| 110 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
| 111 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.dec_conv2 = nn.Sequential(
|
| 115 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 116 |
+
nn.LeakyReLU(inplace=True),
|
| 117 |
+
nn.BatchNorm2d(128),
|
| 118 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 119 |
+
nn.LeakyReLU(inplace=True),
|
| 120 |
+
nn.BatchNorm2d(128),
|
| 121 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 122 |
+
nn.LeakyReLU(inplace=True),
|
| 123 |
+
nn.BatchNorm2d(128)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.upsample3 = nn.Sequential(
|
| 127 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
| 128 |
+
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.dec_conv3 = nn.Sequential(
|
| 132 |
+
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| 133 |
+
nn.LeakyReLU(inplace=True),
|
| 134 |
+
nn.BatchNorm2d(64),
|
| 135 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| 136 |
+
nn.LeakyReLU(inplace=True),
|
| 137 |
+
nn.BatchNorm2d(64),
|
| 138 |
+
nn.Conv2d(in_channels=64, out_channels=18, kernel_size=1, stride=1, padding=0)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 142 |
+
# encoder
|
| 143 |
+
e0 = self.enc_conv0(x)
|
| 144 |
+
e1 = self.pool0(e0)
|
| 145 |
+
e1 = self.enc_conv1(e1)
|
| 146 |
+
e2 = self.pool1(e1)
|
| 147 |
+
e2 = self.enc_conv2(e2)
|
| 148 |
+
e3 = self.pool2(e2)
|
| 149 |
+
e3 = self.enc_conv3(e3)
|
| 150 |
+
|
| 151 |
+
# bottleneck
|
| 152 |
+
b = self.pool3(e3)
|
| 153 |
+
b = self.bottleneck_conv(b)
|
| 154 |
+
|
| 155 |
+
# decoder
|
| 156 |
+
d0 = self.upsample0(b)
|
| 157 |
+
d0 = torch.cat([d0, e3], dim=1)
|
| 158 |
+
d0 = self.dec_conv0(d0)
|
| 159 |
+
|
| 160 |
+
d1 = self.upsample1(d0)
|
| 161 |
+
d1 = torch.cat([d1, e2], dim=1)
|
| 162 |
+
d1 = self.dec_conv1(d1)
|
| 163 |
+
|
| 164 |
+
d2 = self.upsample2(d1)
|
| 165 |
+
d2 = torch.cat([d2, e1], dim=1)
|
| 166 |
+
d2 = self.dec_conv2(d2)
|
| 167 |
+
|
| 168 |
+
d3 = self.upsample3(d2)
|
| 169 |
+
d3 = torch.cat([d3, e0], dim=1)
|
| 170 |
+
d3 = self.dec_conv3(d3)
|
| 171 |
+
return d3
|
src/models/vit.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import ViTModel
|
| 4 |
+
|
| 5 |
+
from src.models.segmentation_head import SegmentationHead
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ViTSegmentation(nn.Module):
|
| 9 |
+
def __init__(self, image_size: int = 224, num_classes: int = 18) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.mean = [0.5, 0.5, 0.5]
|
| 12 |
+
self.std = [0.5, 0.5, 0.5]
|
| 13 |
+
self.backbone = ViTModel.from_pretrained("google/vit-base-patch16-224")
|
| 14 |
+
self.segmentation_head = SegmentationHead(in_channels=768, num_classes=num_classes)
|
| 15 |
+
for param in self.backbone.parameters():
|
| 16 |
+
param.requires_grad = False
|
| 17 |
+
|
| 18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
batch_size, channels, height, width = x.size()
|
| 20 |
+
assert height == width == self.backbone.config.image_size, "The image must match the size required by the ViT model"
|
| 21 |
+
outputs = self.backbone(pixel_values=x).last_hidden_state
|
| 22 |
+
patch_dim = int(height / self.backbone.config.patch_size)
|
| 23 |
+
outputs = outputs[:, 1:, :]
|
| 24 |
+
outputs = outputs.permute(0, 2, 1).view(batch_size, -1, patch_dim, patch_dim)
|
| 25 |
+
masks = self.segmentation_head(outputs)
|
| 26 |
+
return masks
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main() -> None:
|
| 30 |
+
model = ViTSegmentation(image_size=224, num_classes=18)
|
| 31 |
+
num_params = sum([p.numel() for p in model.parameters()])
|
| 32 |
+
print(f"params: {num_params/1e6:.2f} M")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
main()
|
src/train.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
from accelerate.utils import set_seed
|
| 6 |
+
from matplotlib import cm
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
import wandb
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
|
| 16 |
+
from models.unet import UNet
|
| 17 |
+
from dataset import SegmentationDataset, collate_fn
|
| 18 |
+
from utils import get_transform, mask_transform, EMA
|
| 19 |
+
from get_loss import get_composite_criterion
|
| 20 |
+
from models.vit import ViTSegmentation
|
| 21 |
+
from models.dino import DINOSegmentationModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
color_map = cm.get_cmap('tab20', 18)
|
| 25 |
+
fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def mask_to_color(mask: np.ndarray) -> np.ndarray:
|
| 29 |
+
h, w = mask.shape
|
| 30 |
+
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
|
| 31 |
+
for class_idx in range(18):
|
| 32 |
+
color_mask[mask == class_idx] = fixed_colors[class_idx]
|
| 33 |
+
return color_mask
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_combined_image(
|
| 37 |
+
x: torch.Tensor,
|
| 38 |
+
y: torch.Tensor,
|
| 39 |
+
y_pred: torch.Tensor,
|
| 40 |
+
mean: list[float] = [0.485, 0.456, 0.406],
|
| 41 |
+
std: list[float] = [0.229, 0.224, 0.225]
|
| 42 |
+
) -> np.ndarray:
|
| 43 |
+
batch_size, _, height, width = x.shape
|
| 44 |
+
combined_height = height * 3
|
| 45 |
+
combined_width = width * batch_size
|
| 46 |
+
combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
|
| 47 |
+
|
| 48 |
+
for i in range(batch_size):
|
| 49 |
+
image = x[i].cpu().permute(1, 2, 0).numpy()
|
| 50 |
+
image = (image * std + mean).clip(0, 1)
|
| 51 |
+
image = (image * 255).astype(np.uint8)
|
| 52 |
+
true_mask = y[i].cpu().numpy()
|
| 53 |
+
true_mask_color = mask_to_color(true_mask)
|
| 54 |
+
pred_mask = y_pred[i].cpu().numpy()
|
| 55 |
+
pred_mask_color = mask_to_color(pred_mask)
|
| 56 |
+
combined_image[:height, i * width:(i + 1) * width, :] = image
|
| 57 |
+
combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color
|
| 58 |
+
combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color
|
| 59 |
+
return combined_image
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 18) -> Tuple[float, float, float, float, float, float]:
|
| 63 |
+
pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1)
|
| 64 |
+
target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1)
|
| 65 |
+
class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float()
|
| 66 |
+
tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
|
| 67 |
+
fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
|
| 68 |
+
fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
|
| 69 |
+
tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
|
| 70 |
+
overall_tp = tp.sum()
|
| 71 |
+
overall_fp = fp.sum()
|
| 72 |
+
overall_fn = fn.sum()
|
| 73 |
+
overall_tn = tn.sum()
|
| 74 |
+
precision = tp / (tp + fp).clamp(min=1e-8)
|
| 75 |
+
recall = tp / (tp + fn).clamp(min=1e-8)
|
| 76 |
+
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
| 77 |
+
macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
|
| 78 |
+
macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
|
| 79 |
+
macro_accuracy = accuracy.mean().item()
|
| 80 |
+
micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item()
|
| 81 |
+
micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item()
|
| 82 |
+
global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item()
|
| 83 |
+
return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def parse_args():
|
| 87 |
+
parser = argparse.ArgumentParser(description="Train a model on human parsing dataset")
|
| 88 |
+
parser.add_argument("--data-path", type=str, default="mattmdjaga/human_parsing_dataset", help="Path to the data")
|
| 89 |
+
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
|
| 90 |
+
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
|
| 91 |
+
parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
|
| 92 |
+
parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs")
|
| 93 |
+
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
|
| 94 |
+
parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
|
| 95 |
+
parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
|
| 96 |
+
parser.add_argument("--logs-dir", type=str, default="dino-logs", help="Directory for saving logs")
|
| 97 |
+
parser.add_argument("--model", type=str, default="dino", choices=["unet", "vit", "dino"], help="Model class name")
|
| 98 |
+
parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses")
|
| 99 |
+
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
|
| 100 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
|
| 101 |
+
parser.add_argument("--project-name", type=str, default="human_parsing_segmentation_ttk", help="WandB project name")
|
| 102 |
+
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
|
| 103 |
+
parser.add_argument("--log-steps", type=int, default=400, help="Number of steps between logging training images and metrics")
|
| 104 |
+
parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
|
| 105 |
+
return parser.parse_args()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def main() -> None:
|
| 109 |
+
args = parse_args()
|
| 110 |
+
|
| 111 |
+
set_seed(args.seed)
|
| 112 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
|
| 113 |
+
|
| 114 |
+
with open(args.losses_path, "r") as fp:
|
| 115 |
+
losses_config = json.load(fp)
|
| 116 |
+
|
| 117 |
+
with accelerator.main_process_first():
|
| 118 |
+
logs_dir = Path(args.logs_dir)
|
| 119 |
+
logs_dir.mkdir(exist_ok=True)
|
| 120 |
+
wandb.init(project=args.project_name, dir=logs_dir)
|
| 121 |
+
wandb.save(args.losses_path)
|
| 122 |
+
|
| 123 |
+
optimizer_class = getattr(torch.optim, args.optimizer)
|
| 124 |
+
|
| 125 |
+
if args.model == "unet":
|
| 126 |
+
model = UNet().to(accelerator.device)
|
| 127 |
+
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
|
| 128 |
+
elif args.model == "vit":
|
| 129 |
+
model = ViTSegmentation().to(accelerator.device)
|
| 130 |
+
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
|
| 131 |
+
elif args.model == "dino":
|
| 132 |
+
model = DINOSegmentationModel().to(accelerator.device)
|
| 133 |
+
optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate)
|
| 134 |
+
else:
|
| 135 |
+
raise NotImplementedError("Incorrect model name")
|
| 136 |
+
|
| 137 |
+
transform = get_transform(model.mean, model.std)
|
| 138 |
+
|
| 139 |
+
dataset = load_dataset(args.data_path, split="train")
|
| 140 |
+
train_dataset = SegmentationDataset(dataset, train=True, transform=transform, target_transform=mask_transform)
|
| 141 |
+
valid_dataset = SegmentationDataset(dataset, train=False, transform=transform, target_transform=mask_transform)
|
| 142 |
+
|
| 143 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
|
| 144 |
+
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
|
| 145 |
+
|
| 146 |
+
criterion = get_composite_criterion(losses_config)
|
| 147 |
+
|
| 148 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader))
|
| 149 |
+
|
| 150 |
+
model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler)
|
| 151 |
+
|
| 152 |
+
best_accuracy = 0
|
| 153 |
+
|
| 154 |
+
print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M")
|
| 155 |
+
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
|
| 156 |
+
|
| 157 |
+
train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA()
|
| 158 |
+
for epoch in range(1, args.num_epochs + 1):
|
| 159 |
+
model.train()
|
| 160 |
+
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
|
| 161 |
+
exit()
|
| 162 |
+
pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
|
| 163 |
+
for index, (x, y) in enumerate(pbar):
|
| 164 |
+
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
|
| 165 |
+
with accelerator.accumulate(model):
|
| 166 |
+
with accelerator.autocast():
|
| 167 |
+
output = model(x)
|
| 168 |
+
loss = criterion(output, y)
|
| 169 |
+
accelerator.backward(loss)
|
| 170 |
+
train_loss = loss.item()
|
| 171 |
+
grad_norm = None
|
| 172 |
+
_, y_pred = output.max(dim=1)
|
| 173 |
+
train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y)
|
| 174 |
+
if accelerator.sync_gradients:
|
| 175 |
+
grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
|
| 176 |
+
optimizer.step()
|
| 177 |
+
lr_scheduler.step()
|
| 178 |
+
optimizer.zero_grad()
|
| 179 |
+
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
|
| 180 |
+
images_to_log = []
|
| 181 |
+
combined_image = create_combined_image(x, y, y_pred)
|
| 182 |
+
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})"))
|
| 183 |
+
wandb.log({"train_samples": images_to_log})
|
| 184 |
+
pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)})
|
| 185 |
+
log_data = {
|
| 186 |
+
"train/epoch": epoch,
|
| 187 |
+
"train/loss": train_loss,
|
| 188 |
+
"train/macro_accuracy": train_macro_accuracy,
|
| 189 |
+
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
| 190 |
+
"train/macro_precision": train_macro_precision,
|
| 191 |
+
"train/macro_recall": train_macro_recall,
|
| 192 |
+
"train/micro_precision": train_micro_precision,
|
| 193 |
+
"train/micro_recall": train_micro_recall,
|
| 194 |
+
"train/global_accuracy": train_global_accuracy,
|
| 195 |
+
}
|
| 196 |
+
if grad_norm is not None:
|
| 197 |
+
log_data["train/grad_norm"] = grad_norm
|
| 198 |
+
if accelerator.is_main_process:
|
| 199 |
+
wandb.log(log_data)
|
| 200 |
+
accelerator.wait_for_everyone()
|
| 201 |
+
|
| 202 |
+
model.eval()
|
| 203 |
+
valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
| 204 |
+
with torch.inference_mode():
|
| 205 |
+
pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
|
| 206 |
+
for index, (x, y) in enumerate(valid_loader):
|
| 207 |
+
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
|
| 208 |
+
output = model(x)
|
| 209 |
+
_, y_pred = output.max(dim=1)
|
| 210 |
+
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
|
| 211 |
+
images_to_log = []
|
| 212 |
+
combined_image = create_combined_image(x, y, y_pred)
|
| 213 |
+
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})"))
|
| 214 |
+
wandb.log({"valid_samples": images_to_log})
|
| 215 |
+
valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y)
|
| 216 |
+
valid_macro_precisions += valid_macro_precision
|
| 217 |
+
valid_macro_recalls += valid_macro_recall
|
| 218 |
+
valid_macro_accuracies += valid_macro_accuracy
|
| 219 |
+
valid_micro_precisions += valid_micro_precision
|
| 220 |
+
valid_micro_recalls += valid_micro_recall
|
| 221 |
+
valid_global_accuracies += valid_global_accuracy
|
| 222 |
+
loss = criterion(output, y)
|
| 223 |
+
valid_loss += loss.item()
|
| 224 |
+
valid_loss = valid_loss / len(valid_loader)
|
| 225 |
+
valid_macro_accuracies = valid_macro_accuracies / len(valid_loader)
|
| 226 |
+
valid_macro_precisions = valid_macro_precisions / len(valid_loader)
|
| 227 |
+
valid_macro_recalls = valid_macro_recalls / len(valid_loader)
|
| 228 |
+
valid_global_accuracies = valid_global_accuracies / len(valid_loader)
|
| 229 |
+
valid_micro_precisions = valid_micro_precisions / len(valid_loader)
|
| 230 |
+
valid_micro_recalls = valid_micro_recalls / len(valid_loader)
|
| 231 |
+
accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}")
|
| 232 |
+
if accelerator.is_main_process:
|
| 233 |
+
wandb.log(
|
| 234 |
+
{
|
| 235 |
+
"val/epoch": epoch,
|
| 236 |
+
"val/loss": valid_loss,
|
| 237 |
+
"val/macro_accuracy": valid_macro_accuracies,
|
| 238 |
+
"val/macro_precision": valid_macro_precisions,
|
| 239 |
+
"val/macro_recall": valid_macro_recalls,
|
| 240 |
+
"val/global_accuracy": valid_global_accuracies,
|
| 241 |
+
"val/micro_precision": valid_micro_precisions,
|
| 242 |
+
"val/micro_recall": valid_micro_recalls,
|
| 243 |
+
}
|
| 244 |
+
)
|
| 245 |
+
if valid_global_accuracies > best_accuracy:
|
| 246 |
+
best_accuracy = valid_global_accuracies
|
| 247 |
+
if args.model in ["dino", "vit"]:
|
| 248 |
+
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth")
|
| 249 |
+
else:
|
| 250 |
+
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth")
|
| 251 |
+
accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}")
|
| 252 |
+
if epoch % args.save_frequency == 0:
|
| 253 |
+
if args.model in ["dino", "vit"]:
|
| 254 |
+
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
|
| 255 |
+
else:
|
| 256 |
+
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
|
| 257 |
+
accelerator.wait_for_everyone()
|
| 258 |
+
|
| 259 |
+
accelerator.wait_for_everyone()
|
| 260 |
+
wandb.finish()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|
src/utils.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision.transforms as T
|
| 3 |
+
import PIL.Image
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
size = (224, 224)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ResizeWithPadding:
|
| 11 |
+
def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None:
|
| 12 |
+
self.target_size = target_size
|
| 13 |
+
self.fill = fill
|
| 14 |
+
self.mode = mode
|
| 15 |
+
|
| 16 |
+
def __call__(self, image: PIL.Image) -> PIL.Image:
|
| 17 |
+
original_width, original_height = image.size
|
| 18 |
+
aspect_ratio = original_width / original_height
|
| 19 |
+
if aspect_ratio > 1:
|
| 20 |
+
new_width = self.target_size
|
| 21 |
+
new_height = int(self.target_size / aspect_ratio)
|
| 22 |
+
else:
|
| 23 |
+
new_height = self.target_size
|
| 24 |
+
new_width = int(self.target_size * aspect_ratio)
|
| 25 |
+
resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST)
|
| 26 |
+
delta_w = self.target_size - new_width
|
| 27 |
+
delta_h = self.target_size - new_height
|
| 28 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
|
| 29 |
+
padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill)
|
| 30 |
+
padded_image.paste(resized_image, (padding[0], padding[1]))
|
| 31 |
+
return padded_image
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_transform(mean: List[float], std: List[float]) -> T.Compose:
|
| 35 |
+
return T.Compose([
|
| 36 |
+
ResizeWithPadding(),
|
| 37 |
+
T.ToTensor(),
|
| 38 |
+
T.Normalize(mean=mean, std=std),
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
mask_transform = T.Compose([
|
| 42 |
+
ResizeWithPadding(mode="L"),
|
| 43 |
+
T.ToTensor(),
|
| 44 |
+
T.Lambda(lambda x: (x * 255).long()),
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class EMA:
|
| 49 |
+
def __init__(self, alpha: float = 0.9) -> None:
|
| 50 |
+
self.value = None
|
| 51 |
+
self.alpha = alpha
|
| 52 |
+
|
| 53 |
+
def __call__(self, value: float) -> float:
|
| 54 |
+
if self.value is None:
|
| 55 |
+
self.value = value
|
| 56 |
+
else:
|
| 57 |
+
self.value = self.alpha * self.value + (1 - self.alpha) * value
|
| 58 |
+
return self.value
|