Skip to content

Commit 04c6479

Browse files
authored
[ONNX] Add onnx parser for LpPool operator (#3449)
Similar to #3435 Solves nod-ai/SHARK-ModelDev#728
1 parent 6f94c7b commit 04c6479

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,122 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16871687
return success();
16881688
});
16891689

1690+
patterns.onOp(
1691+
"LpPool", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1692+
std::string autoPad;
1693+
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
1694+
return failure();
1695+
if (autoPad != "NOTSET") {
1696+
// TODO: Add support for `auto_pad` != "NOTSET"
1697+
return rewriter.notifyMatchFailure(
1698+
binder.op, "unsupported conversion: auto_pad != NOTSET");
1699+
}
1700+
1701+
Torch::ValueTensorType resultType;
1702+
Value operand;
1703+
int64_t ceilMode, p;
1704+
if (binder.tensorOperand(operand) ||
1705+
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
1706+
binder.s64IntegerAttr(p, "p", 2) ||
1707+
binder.tensorResultType(resultType))
1708+
return failure();
1709+
// Determine the rank of input tensor.
1710+
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
1711+
if (!maybeRank)
1712+
return rewriter.notifyMatchFailure(binder.op,
1713+
"Unimplemented: unranked tensor");
1714+
unsigned rank = *maybeRank;
1715+
// only 1D, 2D and 3D LpPool is supported.
1716+
if (rank > 5 or rank < 3) {
1717+
return failure();
1718+
}
1719+
1720+
SmallVector<int64_t> kernel, padding, strides, dilations;
1721+
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
1722+
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}) ||
1723+
binder.s64IntegerArrayAttr(padding, "pads", defaultPadding) ||
1724+
binder.s64IntegerArrayAttr(
1725+
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1)) ||
1726+
binder.s64IntegerArrayAttr(dilations, "dilations", {})) {
1727+
return failure();
1728+
}
1729+
if (kernel.size() != rank - 2) {
1730+
return rewriter.notifyMatchFailure(
1731+
binder.op, "kernel list size does not match the number of axes");
1732+
}
1733+
if (padding.size() != 2 * (rank - 2)) {
1734+
return rewriter.notifyMatchFailure(
1735+
binder.op,
1736+
"padding list size does not match twice the number of axes");
1737+
}
1738+
if (strides.size() != rank - 2) {
1739+
return rewriter.notifyMatchFailure(
1740+
binder.op, "strides list size does not match the number of axes");
1741+
}
1742+
if (dilations.size() > 0) {
1743+
return rewriter.notifyMatchFailure(
1744+
binder.op, "dilation is not supported by torch.aten.avgpool op "
1745+
"and therefore it is not supported for LpPool.");
1746+
}
1747+
1748+
SmallVector<Value> cstKernel, cstPadding, cstStrides;
1749+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
1750+
binder.getLoc(), rewriter.getI64IntegerAttr(1));
1751+
Value numElements = cstOne;
1752+
for (int64_t i : kernel) {
1753+
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
1754+
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
1755+
numElements = rewriter.create<Torch::AtenMulOp>(
1756+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1757+
cstKernel.back(), numElements);
1758+
}
1759+
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
1760+
binder.getLoc(),
1761+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
1762+
cstKernel);
1763+
Value paddingList = createConstantIntList(binder, rewriter, padding);
1764+
Value stridesList = createConstantIntList(binder, rewriter, strides);
1765+
Value cstCeilMode =
1766+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
1767+
// onnx lp pool doesn't have countIncludePad attribute but set it to
1768+
// true so that in 1D case numElements is correctly undoes divison. For
1769+
// 2D/3D case, division is avoided by divison_override.
1770+
Value cstCountIncludePad =
1771+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
1772+
Value pv = rewriter.create<Torch::ConstantIntOp>(
1773+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1774+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), p));
1775+
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
1776+
Value abs = rewriter.create<Torch::AtenAbsOp>(binder.getLoc(),
1777+
inputTensorType, operand);
1778+
Value pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
1779+
binder.getLoc(), inputTensorType, abs, pv);
1780+
Value avgPool;
1781+
if (rank == 3) {
1782+
avgPool = rewriter.create<Torch::AtenAvgPool1dOp>(
1783+
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
1784+
paddingList, cstCeilMode, cstCountIncludePad);
1785+
avgPool = rewriter.create<Torch::AtenMulScalarOp>(
1786+
binder.getLoc(), resultType, avgPool, numElements);
1787+
} else if (rank == 4) {
1788+
avgPool = rewriter.create<Torch::AtenAvgPool2dOp>(
1789+
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
1790+
paddingList, cstCeilMode, cstCountIncludePad,
1791+
/*divisor_override=*/cstOne);
1792+
} else { // rank == 5
1793+
avgPool = rewriter.create<Torch::AtenAvgPool3dOp>(
1794+
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
1795+
paddingList, cstCeilMode, cstCountIncludePad,
1796+
/*divisor_override=*/cstOne);
1797+
}
1798+
Value invP = rewriter.create<Torch::ConstantFloatOp>(
1799+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
1800+
rewriter.getF64FloatAttr(double{1.0 / p}));
1801+
rewriter.replaceOpWithNewOp<Torch::AtenPowTensorScalarOp>(
1802+
binder.op, resultType, avgPool, invP);
1803+
return success();
1804+
});
1805+
16901806
patterns.onOp(
16911807
"LayerNormalization", 17,
16921808
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,62 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.
274274

275275
// -----
276276

277+
// CHECK-LABEL: func.func @test_lppool_2d
278+
func.func @test_lppool_2d(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} {
279+
// CHECK: %[[I1:.*]] = torch.constant.int 1
280+
// CHECK: %[[I2:.*]] = torch.constant.int 2
281+
// CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int
282+
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
283+
// CHECK: %[[NE1:.*]] = torch.aten.mul %[[I2_1]], %[[NE]] : !torch.int, !torch.int -> !torch.int
284+
// CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
285+
// CHECK: %[[I0:.*]] = torch.constant.int 0
286+
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
287+
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
288+
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
289+
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
290+
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
291+
// CHECK: %[[I1_2:.*]] = torch.constant.int 1
292+
// CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
293+
// CHECK: %[[CEIL:.*]] = torch.constant.bool false
294+
// CHECK: %[[CIP:.*]] = torch.constant.bool true
295+
// CHECK: %[[P:.*]] = torch.constant.int 2
296+
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32,32],f32> -> !torch.vtensor<[1,3,32,32],f32>
297+
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32,32],f32>, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
298+
// CHECK: %[[AVG:.*]] = torch.aten.avg_pool2d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]], %[[I1]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,31,31],f32>
299+
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
300+
// CHECK: torch.aten.pow.Tensor_Scalar %[[AVG]], %[[INVP]] : !torch.vtensor<[1,3,31,31],f32>, !torch.float -> !torch.vtensor<[1,3,31,31],f32>
301+
%0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32>
302+
return %0 : !torch.vtensor<[1,3,31,31],f32>
303+
}
304+
305+
// -----
306+
307+
// CHECK-LABEL: func.func @test_lppool_1d
308+
func.func @test_lppool_1d(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} {
309+
// CHECK: %[[I1:.*]] = torch.constant.int 1
310+
// CHECK: %[[I2:.*]] = torch.constant.int 2
311+
// CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int
312+
// CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]] : (!torch.int) -> !torch.list<int>
313+
// CHECK: %[[I0:.*]] = torch.constant.int 0
314+
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
315+
// CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
316+
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
317+
// CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]] : (!torch.int) -> !torch.list<int>
318+
// CHECK: %[[CEIL:.*]] = torch.constant.bool false
319+
// CHECK: %[[CIP:.*]] = torch.constant.bool true
320+
// CHECK: %[[P:.*]] = torch.constant.int 2
321+
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32],f32> -> !torch.vtensor<[1,3,32],f32>
322+
// CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32],f32>, !torch.int -> !torch.vtensor<[1,3,32],f32>
323+
// CHECK: %[[AVG:.*]] = torch.aten.avg_pool1d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32>
324+
// CHECK: %[[POW_0:.*]] = torch.aten.mul.Scalar %[[AVG]], %[[NE]] : !torch.vtensor<[1,3,31],f32>, !torch.int -> !torch.vtensor<[1,3,31],f32>
325+
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
326+
// CHECK: torch.aten.pow.Tensor_Scalar %[[POW_0]], %[[INVP]] : !torch.vtensor<[1,3,31],f32>, !torch.float -> !torch.vtensor<[1,3,31],f32>
327+
%0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32>
328+
return %0 : !torch.vtensor<[1,3,31],f32>
329+
}
330+
331+
// -----
332+
277333
// CHECK-LABEL : func.func @test_layer_norm
278334
func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>)
279335
attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {

0 commit comments

Comments
 (0)