-
Notifications
You must be signed in to change notification settings - Fork 121
/
Copy pathmetric3d_onnx_export.py
119 lines (98 loc) · 3.63 KB
/
metric3d_onnx_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Export the torch hub model to ONNX format. Normalization is done in the model.
"""
import torch
class Metric3DExportModel(torch.nn.Module):
"""
The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
"""
def __init__(self, meta_arch):
super().__init__()
self.meta_arch = meta_arch
self.register_buffer(
"rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1).cuda()
)
self.register_buffer(
"rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1).cuda()
)
self.input_size = (616, 1064)
def normalize_image(self, image):
image = image - self.rgb_mean
image = image / self.rgb_std
return image
def forward(self, image):
image = self.normalize_image(image)
with torch.no_grad():
pred_depth, confidence, output_dict = self.meta_arch.inference(
{"input": image}
)
return pred_depth
def update_vit_sampling(model):
"""
For ViT models running on some TensorRT version, we need to change the interpolation method from bicubic to bilinear.
"""
import torch.nn as nn
import math
def interpolate_pos_encoding_bilinear(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(
0, 3, 1, 2
),
scale_factor=(sx, sy),
mode="bilinear", # Change from bicubic to bilinear
antialias=self.interpolate_antialias,
)
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
previous_dtype
)
model.depth_model.encoder.interpolate_pos_encoding = (
interpolate_pos_encoding_bilinear.__get__(
model.depth_model.encoder, model.depth_model.encoder.__class__
)
)
return model
def main(model_name="metric3d_vit_small", modify_upsample=False):
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
model.cuda().eval()
if modify_upsample:
model = update_vit_sampling(model)
B = 1
if "vit" in model_name:
dummy_image = torch.randn([B, 3, 616, 1064]).cuda()
else:
dummy_image = torch.randn([B, 3, 544, 1216]).cuda()
export_model = Metric3DExportModel(model)
export_model.eval()
export_model.cuda()
onnx_output = f"{model_name}.onnx"
dummy_input = (dummy_image,)
torch.onnx.export(
export_model,
dummy_input,
onnx_output,
input_names=["image"],
output_names=["pred_depth"],
opset_version=11,
)
if __name__ == "__main__":
from fire import Fire
Fire(main)