-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathpytorch_to_onnx.py
39 lines (28 loc) · 1010 Bytes
/
pytorch_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import torch
import torchvision.models as models
import onnxruntime as onnxrt
# Load pretrained weights of Resnet50
model = models.resnet50(pretrained=True)
model.eval()
# random input tensor
dummy_input = torch.randn(1, 3, 224, 224)
input_names = [ "actual_input" ]
output_names = [ "output" ]
# convert to onnx model
torch.onnx.export(model,
dummy_input,
"outputs/resnet50.onnx",
verbose=False,
input_names=input_names,
output_names=output_names,
export_params=True,
)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# test run onnx model
onnx_session= onnxrt.InferenceSession("outputs/resnet50.onnx")
onnx_inputs= {onnx_session.get_inputs()[0].name: to_numpy(dummy_input)}
onnx_output = onnx_session.run(None, onnx_inputs)
img_label_idx = np.argmax(onnx_output[0])
print(img_label_idx)