-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathExport.py
34 lines (26 loc) · 945 Bytes
/
Export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import resnet
import os
os.makedirs('../model', exist_ok=True)
featExtractor = resnet.resnetc18(pretrained=True)
device = torch.device('cpu')
featExtractor.to(device)
featExtractor.eval()
input = torch.zeros([4, 3, 224, 224], dtype=torch.float32)
output = featExtractor(input)
print(type(output))
for item in output:
print(type(item), item.dim(), item.size())
tracedModule = torch.jit.trace(featExtractor, input)
tracedModule.save('../model/resnetc18-features.pt')
print()
siameseNetwork = resnet.resnets18(pretrained=True)
device = torch.device('cpu')
siameseNetwork.to(device)
siameseNetwork.eval()
input = [torch.zeros([4, 3, 224, 224], dtype=torch.float32),
torch.zeros([4, 3, 224, 224], dtype=torch.float32)]
output = siameseNetwork(input[0], input[1])
print(type(output), output.dim(), output.size())
tracedModule = torch.jit.trace(siameseNetwork, input)
tracedModule.save('../model/resnets18-siamese.pt')