Jacob Logas
commited on
Fix device usage?
Browse files- app.py +2 -1
- util/attack_utils.py +1 -0
- util/feature_extraction_utils.py +1 -2
- util/prepare_utils.py +4 -0
app.py
CHANGED
|
@@ -10,7 +10,6 @@ from align.detector import detect_faces
|
|
| 10 |
from align.align_trans import get_reference_facial_points, warp_and_crop_face
|
| 11 |
import torchvision.transforms as transforms
|
| 12 |
|
| 13 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
to_tensor = transforms.ToTensor()
|
| 15 |
|
| 16 |
eps = 0.05
|
|
@@ -45,6 +44,8 @@ for root in model_roots:
|
|
| 45 |
|
| 46 |
@spaces.GPU(duration=120)
|
| 47 |
def protect(img, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
| 48 |
models_attack, V_reduction, dim = prepare_models(
|
| 49 |
model_backbones,
|
| 50 |
input_size,
|
|
|
|
| 10 |
from align.align_trans import get_reference_facial_points, warp_and_crop_face
|
| 11 |
import torchvision.transforms as transforms
|
| 12 |
|
|
|
|
| 13 |
to_tensor = transforms.ToTensor()
|
| 14 |
|
| 15 |
eps = 0.05
|
|
|
|
| 44 |
|
| 45 |
@spaces.GPU(duration=120)
|
| 46 |
def protect(img, progress=gr.Progress(track_tqdm=True)):
|
| 47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 48 |
+
|
| 49 |
models_attack, V_reduction, dim = prepare_models(
|
| 50 |
model_backbones,
|
| 51 |
input_size,
|
util/attack_utils.py
CHANGED
|
@@ -55,6 +55,7 @@ class Attack(nn.Module):
|
|
| 55 |
|
| 56 |
def execute(self, images, dir_vec, direction):
|
| 57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 58 |
self.lpips_loss.to(device)
|
| 59 |
images = Variable(images).to(device)
|
| 60 |
dir_vec = dir_vec.to(device)
|
|
|
|
| 55 |
|
| 56 |
def execute(self, images, dir_vec, direction):
|
| 57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
print("Device in Excute:", device)
|
| 59 |
self.lpips_loss.to(device)
|
| 60 |
images = Variable(images).to(device)
|
| 61 |
dir_vec = dir_vec.to(device)
|
util/feature_extraction_utils.py
CHANGED
|
@@ -5,11 +5,10 @@ import torchvision.transforms as transforms
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
-
|
| 10 |
|
| 11 |
def warp_image(tensor_img, theta_warp, crop_size=112):
|
| 12 |
# applies affine transform theta to image and crops it
|
|
|
|
| 13 |
|
| 14 |
theta_warp = torch.Tensor(theta_warp).unsqueeze(0).to(device)
|
| 15 |
grid = F.affine_grid(theta_warp, tensor_img.size())
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import numpy as np
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def warp_image(tensor_img, theta_warp, crop_size=112):
|
| 10 |
# applies affine transform theta to image and crops it
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
|
| 13 |
theta_warp = torch.Tensor(theta_warp).unsqueeze(0).to(device)
|
| 14 |
grid = F.affine_grid(theta_warp, tensor_img.size())
|
util/prepare_utils.py
CHANGED
|
@@ -123,6 +123,8 @@ def get_ensemble(
|
|
| 123 |
warp=False,
|
| 124 |
theta_warp=None,
|
| 125 |
):
|
|
|
|
|
|
|
| 126 |
# function prepares ensemble of feature extractors
|
| 127 |
# outputs list of pytorch nn models
|
| 128 |
feature_extractor_ensemble = []
|
|
@@ -213,6 +215,8 @@ def prepare_models(
|
|
| 213 |
using_subspace,
|
| 214 |
V_reduction_root,
|
| 215 |
):
|
|
|
|
|
|
|
| 216 |
backbone_dict = {
|
| 217 |
"IR_50": IR_50(input_size),
|
| 218 |
"IR_152": IR_152(input_size),
|
|
|
|
| 123 |
warp=False,
|
| 124 |
theta_warp=None,
|
| 125 |
):
|
| 126 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 127 |
+
|
| 128 |
# function prepares ensemble of feature extractors
|
| 129 |
# outputs list of pytorch nn models
|
| 130 |
feature_extractor_ensemble = []
|
|
|
|
| 215 |
using_subspace,
|
| 216 |
V_reduction_root,
|
| 217 |
):
|
| 218 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 219 |
+
|
| 220 |
backbone_dict = {
|
| 221 |
"IR_50": IR_50(input_size),
|
| 222 |
"IR_152": IR_152(input_size),
|