-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_function.py
More file actions
96 lines (75 loc) · 3.31 KB
/
test_function.py
File metadata and controls
96 lines (75 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from tqdm import tqdm
import torch
import numpy as np
from scipy.special import softmax
import torch.nn.functional as F
from helpers import adjust_keep_rate
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_reprs(model, data_loader, use_cuda,dataset):
model.eval()
reprs = []
for _, (data, _) in enumerate(data_loader):
if use_cuda:
data = data.cuda()
if dataset == 'SUN':
with torch.no_grad():
# only take the global feature
feat,_,_,_= model.vit(data,keep_rate=(1,1,1,0.9,1,1,0.9,1,1,1,1,1))
else:
with torch.no_grad():
# only take the global feature
feat,_,_,_= model.vit(data,keep_rate=(1,1,1,0.9,1,1,0.9,1,1,0.9,1,1))
# global feature
feat = model.mlp_g(feat)
reprs.append(feat.cpu().data.numpy())
reprs = np.concatenate(reprs, 0)
return reprs
def compute_accuracy(pred_labels, true_labels, labels):
acc_per_class = np.zeros(labels.shape[0])
for i in range(labels.shape[0]):
idx = (true_labels == labels[i])
acc_per_class[i] = np.sum(pred_labels[idx] == true_labels[idx]) / np.sum(idx)
return np.mean(acc_per_class)
def test(args,model, test_seen_loader, test_seen_labels, test_unseen_loader, test_unseen_labels, attrs_mat, use_cuda, gamma):
# Representation
with torch.no_grad():
seen_reprs = get_reprs(model, test_seen_loader, use_cuda,dataset=args.dataset)
unseen_reprs = get_reprs(model, test_unseen_loader, use_cuda,dataset=args.dataset)
# Labels
uniq_test_seen_labels = np.unique(test_seen_labels)
uniq_test_unseen_labels = np.unique(test_unseen_labels)
# GZSL
if args.gzsl:
# Calibrated stacking
Cs_mat = np.zeros(attrs_mat.shape[0])
Cs_mat[uniq_test_seen_labels] = gamma
# seen classes
gzsl_seen_sim = softmax(seen_reprs @ attrs_mat.T, axis=1) - Cs_mat
gzsl_seen_predict_labels = np.argmax(gzsl_seen_sim, axis=1)
gzsl_seen_acc = compute_accuracy(gzsl_seen_predict_labels, test_seen_labels, uniq_test_seen_labels)
# unseen classes
gzsl_unseen_sim = softmax(unseen_reprs @ attrs_mat.T, axis=1) - Cs_mat
gzsl_unseen_predict_labels = np.argmax(gzsl_unseen_sim, axis=1)
gzsl_unseen_acc = compute_accuracy(gzsl_unseen_predict_labels, test_unseen_labels, uniq_test_unseen_labels)
H = 2 * gzsl_unseen_acc * gzsl_seen_acc / (gzsl_unseen_acc + gzsl_seen_acc)
print('GZSL Seen=%.2f Unseen=%.2f H=%.2f' %(gzsl_seen_acc * 100,gzsl_unseen_acc * 100,H * 100))
# ZSL
else:
zsl_unseen_sim = unseen_reprs @ attrs_mat[uniq_test_unseen_labels].T
predict_labels = np.argmax(zsl_unseen_sim, axis=1)
zsl_unseen_predict_labels = uniq_test_unseen_labels[predict_labels]
zsl_unseen_acc = compute_accuracy(zsl_unseen_predict_labels, test_unseen_labels, uniq_test_unseen_labels)
print('CZSL Acc=%.2f' %(zsl_unseen_acc * 100))