import torch import torch.nn as nn from typing import Dict from utils.distributed import DistributedMetric from tqdm import tqdm from torchpack import distributed as dist from utils import accuracy import copy import torch.nn.functional as F import numpy as np def eval_local_lip(model, x, xp, top_norm=1, btm_norm=float('inf'), reduction='mean'): model.eval() down = torch.flatten(x - xp, start_dim=1) with torch.no_grad(): if top_norm == "kl": criterion_kl = nn.KLDivLoss(reduction='none') top = criterion_kl(F.log_softmax(model(xp), dim=1), F.softmax(model(x), dim=1)) ret = torch.sum(top, dim=1) / torch.norm(down + 1e-6, dim=1, p=btm_norm) else: top = torch.flatten(model(x), start_dim=1) - torch.flatten(model(xp), start_dim=1) ret = torch.norm(top, dim=1, p=top_norm) / torch.norm(down + 1e-6, dim=1, p=btm_norm) if reduction == 'mean': return torch.mean(ret) elif reduction == 'sum': return torch.sum(ret) else: raise ValueError("Not supported reduction")