|
1 | 1 | import numpy as np |
| 2 | +import onnxruntime as ort |
2 | 3 | import torch |
3 | 4 | import pytest |
| 5 | +from onnx import helper, TensorProto |
4 | 6 | from onnx2pytorch.operations.autopad import AutoPad |
| 7 | +from onnx2pytorch.convert import ConvertModel |
5 | 8 |
|
6 | 9 |
|
7 | 10 | def test_autopad_same_upper_2d(): |
@@ -104,3 +107,98 @@ def test_autopad_invalid_mode(): |
104 | 107 | """Test that invalid mode raises error.""" |
105 | 108 | with pytest.raises(ValueError, match="Unsupported auto_pad mode"): |
106 | 109 | AutoPad(kernel_size=3, stride=1, dilation=1, mode="INVALID") |
| 110 | + |
| 111 | + |
| 112 | +@pytest.mark.parametrize( |
| 113 | + "auto_pad,kernel_shape,strides,dilations,input_shape", |
| 114 | + [ |
| 115 | + # SAME_UPPER with stride=1 |
| 116 | + ("SAME_UPPER", [3, 3], [1, 1], [1, 1], [1, 1, 5, 5]), |
| 117 | + # SAME_LOWER with stride=1 |
| 118 | + ("SAME_LOWER", [3, 3], [1, 1], [1, 1], [1, 1, 5, 5]), |
| 119 | + # VALID (no padding) - supports dilation |
| 120 | + ("VALID", [3, 3], [1, 1], [1, 1], [1, 1, 5, 5]), |
| 121 | + ("VALID", [3, 3], [1, 1], [2, 2], [1, 1, 7, 7]), |
| 122 | + # SAME_UPPER with stride=2 |
| 123 | + ("SAME_UPPER", [3, 3], [2, 2], [1, 1], [1, 1, 6, 6]), |
| 124 | + # SAME_LOWER with stride=2 |
| 125 | + ("SAME_LOWER", [3, 3], [2, 2], [1, 1], [1, 1, 6, 6]), |
| 126 | + # SAME_UPPER with asymmetric kernel |
| 127 | + ("SAME_UPPER", [3, 5], [1, 2], [1, 1], [1, 1, 10, 10]), |
| 128 | + # SAME_LOWER with asymmetric kernel |
| 129 | + ("SAME_LOWER", [3, 5], [1, 2], [1, 1], [1, 1, 10, 10]), |
| 130 | + # SAME_UPPER with larger kernel |
| 131 | + ("SAME_UPPER", [5, 5], [1, 1], [1, 1], [1, 1, 8, 8]), |
| 132 | + # Note: onnxruntime does not support dilation with SAME_UPPER/SAME_LOWER |
| 133 | + ], |
| 134 | +) |
| 135 | +def test_autopad_with_conv_onnxruntime( |
| 136 | + auto_pad, kernel_shape, strides, dilations, input_shape |
| 137 | +): |
| 138 | + """Test AutoPad implementation against onnxruntime using Conv operator.""" |
| 139 | + np.random.seed(42) |
| 140 | + torch.manual_seed(42) |
| 141 | + |
| 142 | + # Create input |
| 143 | + X = np.random.randn(*input_shape).astype(np.float32) |
| 144 | + in_channels = input_shape[1] |
| 145 | + out_channels = 2 |
| 146 | + |
| 147 | + # Create random weights for Conv |
| 148 | + # W shape: [out_channels, in_channels, kernel_h, kernel_w] |
| 149 | + W = np.random.randn(out_channels, in_channels, *kernel_shape).astype(np.float32) |
| 150 | + |
| 151 | + # Create ONNX graph with Conv node using auto_pad |
| 152 | + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape) |
| 153 | + output_tensor = helper.make_tensor_value_info( |
| 154 | + "Y", TensorProto.FLOAT, None |
| 155 | + ) # Output shape is dynamic |
| 156 | + |
| 157 | + W_initializer = helper.make_tensor( |
| 158 | + "W", TensorProto.FLOAT, W.shape, W.flatten().tolist() |
| 159 | + ) |
| 160 | + |
| 161 | + conv_node = helper.make_node( |
| 162 | + "Conv", |
| 163 | + inputs=["X", "W"], |
| 164 | + outputs=["Y"], |
| 165 | + kernel_shape=kernel_shape, |
| 166 | + strides=strides, |
| 167 | + dilations=dilations, |
| 168 | + auto_pad=auto_pad, |
| 169 | + ) |
| 170 | + |
| 171 | + graph = helper.make_graph( |
| 172 | + [conv_node], |
| 173 | + "conv_autopad_test", |
| 174 | + [input_tensor], |
| 175 | + [output_tensor], |
| 176 | + [W_initializer], |
| 177 | + ) |
| 178 | + |
| 179 | + model = helper.make_model( |
| 180 | + graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8 |
| 181 | + ) |
| 182 | + |
| 183 | + # Run with onnxruntime to get expected output (onnxruntime will validate the model) |
| 184 | + ort_session = ort.InferenceSession(model.SerializeToString()) |
| 185 | + ort_inputs = {"X": X} |
| 186 | + ort_outputs = ort_session.run(None, ort_inputs) |
| 187 | + expected_Y = ort_outputs[0] |
| 188 | + |
| 189 | + # Convert to PyTorch and run |
| 190 | + o2p_model = ConvertModel(model, experimental=True) |
| 191 | + X_torch = torch.from_numpy(X) |
| 192 | + |
| 193 | + with torch.no_grad(): |
| 194 | + o2p_output = o2p_model(X_torch) |
| 195 | + |
| 196 | + # Compare outputs |
| 197 | + torch.testing.assert_close( |
| 198 | + o2p_output, |
| 199 | + torch.from_numpy(expected_Y), |
| 200 | + rtol=1e-5, |
| 201 | + atol=1e-5, |
| 202 | + msg=f"AutoPad mismatch for {auto_pad} with kernel={kernel_shape}, " |
| 203 | + f"stride={strides}, dilation={dilations}", |
| 204 | + ) |
0 commit comments