forked from sarathknv/adversarial-examples-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_function.py
30 lines (22 loc) · 914 Bytes
/
test_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch.autograd import Variable
def test(G, f, target, is_targeted, thres, test_loader, epoch, epochs, device, verbose=True):
n = 0
acc = 0
G.eval()
for i, (img, label) in enumerate(test_loader):
img_real = Variable(img.to(device))
pert = torch.clamp(G(img_real), -thres, thres)
img_fake = pert + img_real
img_fake = img_fake.clamp(min=0, max=1)
y_pred = f(img_fake)
if is_targeted:
y_target = Variable(torch.ones_like(label).fill_(target).to(device))
acc += torch.sum(torch.max(y_pred, 1)[1] == y_target).item()
else:
y_true = Variable(label.to(device))
acc += torch.sum(torch.max(y_pred, 1)[1] != y_true).item()
n += img.size(0)
if verbose:
print('Test [%d/%d]: [%d/%d]' %(epoch+1, epochs, i, len(test_loader)), end="\r")
return acc/n