forked from shakes76/PatternAnalysis-2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
53 lines (40 loc) · 1.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from dataset import get_loaders
from module import create_model
"""
Loads the best model and evaluates it on the test set.
"""
def test(model, data_loader, device):
model.eval()
correct_predictions = 0
total_samples = len(data_loader.dataset)
with torch.no_grad():
for batch, labels in data_loader:
batch, labels = batch.to(device), labels.to(device)
prediction = model(batch)
correct_predictions += torch.sum(torch.argmax(prediction, dim=1) == labels)
accuracy = correct_predictions / total_samples
return accuracy
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_, _, test_data = get_loaders()
model, optimizer, criterion, scheduler = create_model(
input_shape=(256, 256),
latent_dim=32, # Increase latent space dimension for more representational capacity
embed_dim=32,
attention_mlp_dim=32,
transformer_mlp_dim=32,
transformer_heads=4, # Use more attention heads for enhanced feature capturing
dropout=0.1,
transformer_layers=4,
n_blocks=4, # more perceiver blocks for improved Representation Learning
n_classes=2,
lr=0.003,
)
model = model.to(device)
best_model_state = torch.load('saved/best_model.pth')
model.load_state_dict(best_model_state['model_state_dict'])
test_acc = test(model, test_data, device)
print(f"Test accuracy: {test_acc}")
if __name__ == "__main__":
main()