-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathattacks.py
78 lines (57 loc) · 2.18 KB
/
attacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Adversarial attack class
"""
import os
import torch
class Attack(object):
"""Base class for attacks
Arguments:
object {[type]} -- [description]
"""
def __init__(self, attack_type, target_cls, img_type='float'):
self.attack_name = attack_type
self.target_cls = target_cls
self.training = target_cls.training
self.device = next(target_cls.parameters()).device
self.mode = img_type
def forward(self, *args):
"""Call adversarial examples
Should be overridden by all attakc classes
"""
raise NotImplementedError
def inference(self, save_path, file_name, data_loader):
"""[summary]
Arguments:
save_path {[type]} -- [description]
data_loader {[type]} -- [description]
"""
self.target_cls.eval()
adv_list = []
label_list = []
correct = 0
accumulated_num = 0.
total_num = len(data_loader)
for step, (imgs, labels) in enumerate(data_loader):
adv_imgs, labels = self.__call__(imgs, labels)
adv_list.append(adv_imgs.cpu())
label_list.append(labels.cpu())
accumulated_num += labels.size(0)
if self.mode.lower() == 'int':
adv_imgs = adv_imgs.float()/255.
outputs = self.target_cls(adv_imgs)
_, predicted = torch.max(outputs, 1)
correct += predicted.eq(labels).sum().item()
acc = 100 * correct / accumulated_num
print('Progress : {:.2f}% / Accuracy : {:.2f}%'.format(
(step+1)/total_num*100, acc), end='\r')
adversarials = torch.cat(adv_list, 0)
y = torch.cat(label_list, 0)
os.makedirs(save_path, exist_ok=True)
save_path = os.path.join(save_path, file_name)
torch.save((adversarials, y), save_path)
print("\n Save Images & Labels")
def __call__(self, *args, **kwargs):
self.target_cls.eval()
adv_examples, labels = self.forward(*args, **kwargs)
if self.mode.lower() == 'int':
adv_examples, labels = (adv_examples*255).type(torch.uint8)
return adv_examples, labels