Spaces:
Running
Running
| import torch | |
| def loss_generator(ignore: list = None): | |
| loss_fn = {'mse': mse, | |
| 'lut_mse': lut_mse, | |
| 'masked_mse': masked_mse, | |
| 'sample_weighted_mse': sample_weighted_mse, | |
| 'regularize_LUT': regularize_LUT, | |
| 'MaskWeightedMSE': MaskWeightedMSE} | |
| if ignore: | |
| for fn in ignore: | |
| ignore.pop(fn) | |
| return loss_fn | |
| def mse(pred, gt): | |
| return torch.mean((pred - gt) ** 2) | |
| def masked_mse(pred, gt, mask): | |
| delimin = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() | |
| # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) | |
| out = torch.sum((mask > 100 / 255.) * (pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) | |
| out = out / delimin | |
| return torch.mean(out) | |
| def sample_weighted_mse(pred, gt, mask): | |
| multi_factor = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() | |
| multi_factor = multi_factor / (multi_factor.sum()) | |
| # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) | |
| out = torch.mean((pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) | |
| out = out * multi_factor | |
| return torch.sum(out) | |
| def regularize_LUT(lut): | |
| st = lut[lut < 0.] | |
| reg_st = (st ** 2).mean() if min(st.shape) != 0 else 0 | |
| lt = lut[lut > 1.] | |
| reg_lt = ((lt - 1.) ** 2).mean() if min(lt.shape) != 0 else 0 | |
| return reg_lt + reg_st | |
| def lut_mse(feat, lut_batch): | |
| loss = 0 | |
| for id in range(feat.shape[0] // lut_batch): | |
| for i in feat[id * lut_batch: id * lut_batch + lut_batch]: | |
| for j in feat[id * lut_batch: id * lut_batch + lut_batch]: | |
| loss += mse(i, j) | |
| return loss / lut_batch | |
| def MaskWeightedMSE(pred, label, mask): | |
| label = label.view(pred.size()) | |
| reduce_dims = get_dims_with_exclusion(label.dim(), 0) | |
| loss = (pred - label) ** 2 | |
| delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), 100) | |
| loss = torch.sum(loss, dim=reduce_dims) / delimeter | |
| return torch.mean(loss) | |
| def get_dims_with_exclusion(dim, exclude=None): | |
| dims = list(range(dim)) | |
| if exclude is not None: | |
| dims.remove(exclude) | |
| return dims |