-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
107 lines (83 loc) · 3.65 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from torch.utils.data import DataLoader
from utils.constants import *
from typing import List, Tuple
from models.enums.Genre import Genre
import numpy as np
from torch.utils.data import DataLoader
from utils.constants import *
class Tester:
# input: both network models
# return average loss, acc; etc.
def __init__(self,
model,
data_loader_test: DataLoader,
data_loader_sentence,
model_state_path='',
device='cpu'):
# the saved network as an object
self.model = model
self.model_state_path = model_state_path
self.data_loader_test = data_loader_test
self.data_loader_sentence = data_loader_sentence
self.device = device
self.model.eval()
def test(self):
"""
main testing function
"""
try:
# loading saved trained weights
if self.model_state_path: # because if we are testing a CombinedClassifier the states are already loaded
self.model.load_state_dict(torch.load(self.model_state_path))
log = {'final_scores': [], 'combination': {
'classifier_scores': [], 'vaes_scores': []},
'accuracies_per_batch': [],
'true_targets': [],
'length_lstm': [],
'length_vae': []
}
for i, items in enumerate(zip(self.data_loader_test, self.data_loader_sentence)):
(batch, targets, lengths), (batch2, targets2, lengths2) = items
accuracy_batch = self._batch_iteration(batch, targets, lengths, log, (batch2, targets2, lengths2), i)
log['accuracies_per_batch'].append(accuracy_batch)
log['true_targets'].append(targets)
# if i % 100 == 0:
# print('combined accuracy so far', np.mean(log['accuracies_per_batch']), i)
return log
except KeyboardInterrupt as e:
print(f"Killed by user: {e}")
return False
except Exception as e:
print(e)
raise e
def _batch_iteration(self,
batch: torch.Tensor,
targets: torch.Tensor,
lengths: torch.Tensor,
log,
sentencebatch, step):
"""
runs forward pass on batch and backward pass if in train_mode
"""
if (step % 100) == 0:
print(step)
batch2, targets2, lengths2 = sentencebatch
batch = batch.to(self.device).detach()
targets = targets.to(self.device).detach()
lengths = lengths.to(self.device).detach()
if batch2 is not None:
batch2 = batch2.to(self.device).detach()
targets2 = targets2.to(self.device).detach()
lengths2 = lengths2.to(self.device).detach()
output = self.model.forward(batch, targets, lengths, (batch2, targets2, lengths2), step)
final_scores_per_class = output
if 'Combined' in type(self.model).__name__:
final_scores_per_class, (score_classifier, score_elbo) = output
log['combination']['classifier_scores'].append(score_classifier.detach())
log['combination']['vaes_scores'].append(score_elbo.detach())
_, classifications = score_classifier.detach().max(dim=-1)
accuracy = (targets.eq(classifications)).float().mean().item()
log['final_scores'].append(final_scores_per_class.detach())
log['length_lstm'].append(lengths.detach())
log['length_vae'].append(lengths2.detach())
return accuracy