-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy patheval.py
88 lines (62 loc) · 2.23 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
import argparse
import os
import subprocess
import tensorrt as trt
import torch
import torchvision
import torchvision.transforms as transforms
parser = argparse.ArgumentParser()
parser.add_argument('engine', type=str, default=None, help='Path to the optimized TensorRT engine')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--dataset_path', type=str, default='data/cifar10')
args = parser.parse_args()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_dataset = torchvision.datasets.CIFAR10(
root=args.dataset_path,
train=False,
download=True,
transform=transform
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False
)
logger = trt.Logger()
runtime = trt.Runtime(logger)
with open(args.engine, 'rb') as f:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
input_binding_idx = engine.get_binding_index('input')
output_binding_idx = engine.get_binding_index('output')
input_shape = (args.batch_size, 3, 32, 32)
output_shape = (args.batch_size, 10)
context.set_binding_shape(
input_binding_idx,
input_shape
)
input_buffer = torch.zeros(input_shape, dtype=torch.float32, device=torch.device('cuda'))
output_buffer = torch.zeros(output_shape, dtype=torch.float32, device=torch.device('cuda'))
bindings = [None, None]
bindings[input_binding_idx] = input_buffer.data_ptr()
bindings[output_binding_idx] = output_buffer.data_ptr()
test_accuracy = 0
# run through test dataset
for image, label in iter(test_loader):
actual_batch_size = int(image.shape[0])
input_buffer[0:actual_batch_size].copy_(image)
context.execute_async_v2(
bindings,
torch.cuda.current_stream().cuda_stream
)
torch.cuda.current_stream().synchronize()
output = output_buffer[0:actual_batch_size]
label = label.cuda()
test_accuracy += int(torch.sum(output.argmax(dim=-1) == label))
test_accuracy /= len(test_dataset)
print(f'TEST ACCURACY: {test_accuracy}')