Skip to content

Commit 8438390

Browse files
committed
Add proper autopad with conv layer tests.
1 parent 50dd338 commit 8438390

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

tests/onnx2pytorch/operations/test_autopad.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import numpy as np
2+
import onnxruntime as ort
23
import torch
34
import pytest
5+
from onnx import helper, TensorProto
46
from onnx2pytorch.operations.autopad import AutoPad
7+
from onnx2pytorch.convert import ConvertModel
58

69

710
def test_autopad_same_upper_2d():
@@ -104,3 +107,98 @@ def test_autopad_invalid_mode():
104107
"""Test that invalid mode raises error."""
105108
with pytest.raises(ValueError, match="Unsupported auto_pad mode"):
106109
AutoPad(kernel_size=3, stride=1, dilation=1, mode="INVALID")
110+
111+
112+
@pytest.mark.parametrize(
113+
"auto_pad,kernel_shape,strides,dilations,input_shape",
114+
[
115+
# SAME_UPPER with stride=1
116+
("SAME_UPPER", [3, 3], [1, 1], [1, 1], [1, 1, 5, 5]),
117+
# SAME_LOWER with stride=1
118+
("SAME_LOWER", [3, 3], [1, 1], [1, 1], [1, 1, 5, 5]),
119+
# VALID (no padding) - supports dilation
120+
("VALID", [3, 3], [1, 1], [1, 1], [1, 1, 5, 5]),
121+
("VALID", [3, 3], [1, 1], [2, 2], [1, 1, 7, 7]),
122+
# SAME_UPPER with stride=2
123+
("SAME_UPPER", [3, 3], [2, 2], [1, 1], [1, 1, 6, 6]),
124+
# SAME_LOWER with stride=2
125+
("SAME_LOWER", [3, 3], [2, 2], [1, 1], [1, 1, 6, 6]),
126+
# SAME_UPPER with asymmetric kernel
127+
("SAME_UPPER", [3, 5], [1, 2], [1, 1], [1, 1, 10, 10]),
128+
# SAME_LOWER with asymmetric kernel
129+
("SAME_LOWER", [3, 5], [1, 2], [1, 1], [1, 1, 10, 10]),
130+
# SAME_UPPER with larger kernel
131+
("SAME_UPPER", [5, 5], [1, 1], [1, 1], [1, 1, 8, 8]),
132+
# Note: onnxruntime does not support dilation with SAME_UPPER/SAME_LOWER
133+
],
134+
)
135+
def test_autopad_with_conv_onnxruntime(
136+
auto_pad, kernel_shape, strides, dilations, input_shape
137+
):
138+
"""Test AutoPad implementation against onnxruntime using Conv operator."""
139+
np.random.seed(42)
140+
torch.manual_seed(42)
141+
142+
# Create input
143+
X = np.random.randn(*input_shape).astype(np.float32)
144+
in_channels = input_shape[1]
145+
out_channels = 2
146+
147+
# Create random weights for Conv
148+
# W shape: [out_channels, in_channels, kernel_h, kernel_w]
149+
W = np.random.randn(out_channels, in_channels, *kernel_shape).astype(np.float32)
150+
151+
# Create ONNX graph with Conv node using auto_pad
152+
input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)
153+
output_tensor = helper.make_tensor_value_info(
154+
"Y", TensorProto.FLOAT, None
155+
) # Output shape is dynamic
156+
157+
W_initializer = helper.make_tensor(
158+
"W", TensorProto.FLOAT, W.shape, W.flatten().tolist()
159+
)
160+
161+
conv_node = helper.make_node(
162+
"Conv",
163+
inputs=["X", "W"],
164+
outputs=["Y"],
165+
kernel_shape=kernel_shape,
166+
strides=strides,
167+
dilations=dilations,
168+
auto_pad=auto_pad,
169+
)
170+
171+
graph = helper.make_graph(
172+
[conv_node],
173+
"conv_autopad_test",
174+
[input_tensor],
175+
[output_tensor],
176+
[W_initializer],
177+
)
178+
179+
model = helper.make_model(
180+
graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8
181+
)
182+
183+
# Run with onnxruntime to get expected output (onnxruntime will validate the model)
184+
ort_session = ort.InferenceSession(model.SerializeToString())
185+
ort_inputs = {"X": X}
186+
ort_outputs = ort_session.run(None, ort_inputs)
187+
expected_Y = ort_outputs[0]
188+
189+
# Convert to PyTorch and run
190+
o2p_model = ConvertModel(model, experimental=True)
191+
X_torch = torch.from_numpy(X)
192+
193+
with torch.no_grad():
194+
o2p_output = o2p_model(X_torch)
195+
196+
# Compare outputs
197+
torch.testing.assert_close(
198+
o2p_output,
199+
torch.from_numpy(expected_Y),
200+
rtol=1e-5,
201+
atol=1e-5,
202+
msg=f"AutoPad mismatch for {auto_pad} with kernel={kernel_shape}, "
203+
f"stride={strides}, dilation={dilations}",
204+
)

0 commit comments

Comments
 (0)