ProArd / attacks /deepfool.py
smi08's picture
Upload folder using huggingface_hub
7771996 verified
import copy
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from .utils import ctx_noparamgrad_and_eval
from .base import Attack, LabelMixin
from typing import Dict
from .utils import batch_multiply
from .utils import clamp
from .utils import is_float_or_torch_tensor
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def perturb_deepfool(xvar, yvar, predict, nb_iter=50, overshoot=0.02, ord=np.inf, clip_min=0.0, clip_max=1.0,
search_iter=0, device=None):
"""
Compute DeepFool perturbations (Moosavi-Dezfooli et al, 2016).
Arguments:
xvar (torch.Tensor): input images.
yvar (torch.Tensor): predictions.
predict (nn.Module): forward pass function.
nb_iter (int): number of iterations.
overshoot (float): how much to overshoot the boundary.
ord (int): (optional) the order of maximum distortion (inf or 2).
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
search_iter (int): no of search iterations.
device (torch.device): device to work on.
Returns:
torch.Tensor containing the perturbed input,
torch.Tensor containing the perturbation
"""
x_orig = xvar
x = torch.empty_like(xvar).copy_(xvar)
x.requires_grad_(True)
batch_i = torch.arange(x.shape[0])
r_tot = torch.zeros_like(x.data)
for i in range(nb_iter):
if x.grad is not None:
x.grad.zero_()
logits = predict(x)
df_inds = np.argsort(logits.detach().cpu().numpy(), axis=-1)
df_inds_other, df_inds_orig = df_inds[:, :-1], df_inds[:, -1]
df_inds_orig = torch.from_numpy(df_inds_orig)
df_inds_orig = df_inds_orig.to(device)
not_done_inds = df_inds_orig == yvar
if not_done_inds.sum() == 0:
break
logits[batch_i, df_inds_orig].sum().backward(retain_graph=True)
grad_orig = x.grad.data.clone().detach()
pert = x.data.new_ones(x.shape[0]) * np.inf
w = torch.zeros_like(x.data)
for inds in df_inds_other.T:
x.grad.zero_()
logits[batch_i, inds].sum().backward(retain_graph=True)
grad_cur = x.grad.data.clone().detach()
with torch.no_grad():
w_k = grad_cur - grad_orig
f_k = logits[batch_i, inds] - logits[batch_i, df_inds_orig]
if ord == 2:
pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 2, -1)
elif ord == np.inf:
pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 1, -1)
else:
raise NotImplementedError("Only ord=inf and ord=2 have been implemented")
swi = pert_k < pert
if swi.sum() > 0:
pert[swi] = pert_k[swi]
w[swi] = w_k[swi]
if ord == 2:
r_i = (pert + 1e-6)[:, None, None, None] * w / torch.norm(w.flatten(1), 2, -1)[:, None, None, None]
elif ord == np.inf:
r_i = (pert + 1e-6)[:, None, None, None] * w.sign()
r_tot += r_i * not_done_inds[:, None, None, None].float()
x.data = x_orig + (1. + overshoot) * r_tot
x.data = torch.clamp(x.data, clip_min, clip_max)
x = x.detach()
if search_iter > 0:
dx = x - x_orig
dx_l_low, dx_l_high = torch.zeros_like(dx), torch.ones_like(dx)
for i in range(search_iter):
dx_l = (dx_l_low + dx_l_high) / 2.
dx_x = x_orig + dx_l * dx
dx_y = predict(dx_x).argmax(-1)
label_stay = dx_y == yvar
label_change = dx_y != yvar
dx_l_low[label_stay] = dx_l[label_stay]
dx_l_high[label_change] = dx_l[label_change]
x = dx_x
# x.data = torch.clamp(x.data, clip_min, clip_max)
r_tot = x.data - x_orig
return x, r_tot
class DeepFoolAttack(Attack, LabelMixin):
"""
DeepFool attack.
[Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Pascal Frossard,
"DeepFool: a simple and accurate method to fool deep neural networks"]
Arguments:
predict (nn.Module): forward pass function.
overshoot (float): how much to overshoot the boundary.
nb_iter (int): number of iterations.
search_iter (int): no of search iterations.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
ord (int): (optional) the order of maximum distortion (inf or 2).
"""
def __init__(
self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1., ord=np.inf):
super(DeepFoolAttack, self).__init__(predict, None, clip_min, clip_max)
self.overshoot = overshoot
self.nb_iter = nb_iter
self.search_iter = search_iter
self.targeted = False
self.ord = ord
assert is_float_or_torch_tensor(self.overshoot)
def perturb(self, x, y=None):
"""
Given examples x, returns their adversarial counterparts.
Arguments:
x (torch.Tensor): input tensor.
y (torch.Tensor): label tensor.
- if None and self.targeted=False, compute y as predicted labels.
Returns:
torch.Tensor containing perturbed inputs,
torch.Tensor containing the perturbation
"""
x, y = self._verify_and_process_inputs(x, None)
x_adv, r_adv = perturb_deepfool(x, y, self.predict, self.nb_iter, self.overshoot, ord=self.ord,
clip_min=self.clip_min, clip_max=self.clip_max, search_iter=self.search_iter,
device=device)
return x_adv, r_adv
def eval_deepfool(self,data_loader_dict: Dict)-> Dict:
test_criterion = nn.CrossEntropyLoss().cuda()
val_loss = DistributedMetric()
val_top1 = DistributedMetric()
val_top5 = DistributedMetric()
val_advloss = DistributedMetric()
val_advtop1 = DistributedMetric()
val_advtop5 = DistributedMetric()
self.predict.eval()
with tqdm(
total=len(data_loader_dict["val"]),
desc="Eval",
disable=not dist.is_master(),
) as t:
for images, labels in data_loader_dict["val"]:
images, labels = images.cuda(), labels.cuda()
# compute output
output = self.predict(images)
loss = test_criterion(output, labels)
val_loss.update(loss, images.shape[0])
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
val_top5.update(acc5[0], images.shape[0])
val_top1.update(acc1[0], images.shape[0])
with ctx_noparamgrad_and_eval(self.predict):
images_adv,_ = self.perturb(images, labels)
output_adv = self.predict(images_adv)
loss_adv = test_criterion(output_adv,labels)
val_advloss.update(loss_adv, images.shape[0])
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
val_advtop1.update(acc1_adv[0], images.shape[0])
val_advtop5.update(acc5_adv[0], images.shape[0])
t.set_postfix(
{
"loss": val_loss.avg.item(),
"top1": val_top1.avg.item(),
"top5": val_top5.avg.item(),
"adv_loss": val_advloss.avg.item(),
"adv_top1": val_advtop1.avg.item(),
"adv_top5": val_advtop5.avg.item(),
"#samples": val_top1.count.item(),
"batch_size": images.shape[0],
"img_size": images.shape[2],
}
)
t.update()
val_results = {
"val_top1": val_top1.avg.item(),
"val_top5": val_top5.avg.item(),
"val_loss": val_loss.avg.item(),
"val_advtop1": val_advtop1.avg.item(),
"val_advtop5": val_advtop5.avg.item(),
"val_advloss": val_advloss.avg.item(),
}
return val_results
class LinfDeepFoolAttack(DeepFoolAttack):
"""
DeepFool Attack with order=Linf.
Arguments:
Arguments:
predict (nn.Module): forward pass function.
overshoot (float): how much to overshoot the boundary.
nb_iter (int): number of iterations.
search_iter (int): no of search iterations.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
"""
def __init__(
self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.):
ord = np.inf
super(LinfDeepFoolAttack, self).__init__(
predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min,
clip_max=clip_max, ord=ord)
class L2DeepFoolAttack(DeepFoolAttack):
"""
DeepFool Attack with order=L2.
Arguments:
predict (nn.Module): forward pass function.
overshoot (float): how much to overshoot the boundary.
nb_iter (int): number of iterations.
search_iter (int): no of search iterations.
clip_min (float): mininum value per input dimension.
clip_max (float): maximum value per input dimension.
"""
def __init__(
self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.):
ord = 2
super(L2DeepFoolAttack, self).__init__(
predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min,
clip_max=clip_max, ord=ord)