Skip to content

Commit e1164e9

Browse files
committed
[onnx] Lowerings from onnx.relu and onnx.selu
Started work on the `relu` and `selu` lowerings for ONNX to Torch.
1 parent 099e1f4 commit e1164e9

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,25 @@ struct OpBinder {
113113
return failure();
114114
}
115115

116+
ParseResult f32FloatAttr(float &value, StringRef nameSuffix,
117+
float defaultValue = 0.0f) {
118+
SmallString<64> name("torch.onnx.");
119+
name.append(nameSuffix);
120+
auto attr = op->getAttr(name);
121+
if (!attr) {
122+
value = defaultValue;
123+
return success();
124+
}
125+
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
126+
FloatType t = cast<FloatType>(floatAttr.getType());
127+
if (t.getWidth() != 32)
128+
return failure();
129+
value = floatAttr.getValueAsDouble();
130+
return success();
131+
}
132+
return failure();
133+
}
134+
116135
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
117136
std::string defaultValue = "") {
118137
SmallString<64> name("torch.onnx.");

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,44 @@ using namespace mlir::torch::onnx_c;
2626
// results in a lot of ONNX test cases that all reduce to the exact same
2727
// thing here, so we simplify.
2828
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29-
OnnxCustomOpConversionPattern &patterns) {}
29+
OnnxCustomOpConversionPattern &patterns) {
30+
patterns.onOp("Relu", 6,
31+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
32+
Torch::ValueTensorType resultType;
33+
Value operand;
34+
if (binder.tensorOperand(operand) ||
35+
binder.tensorResultType(resultType))
36+
return failure();
37+
rewriter.replaceOpWithNewOp<Torch::AtenReluOp>(
38+
binder.op, resultType, operand);
39+
return success();
40+
});
41+
42+
patterns.onOp(
43+
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
44+
Torch::ValueTensorType resultType;
45+
float alpha, gamma;
46+
Value operand;
47+
if (binder.tensorOperand(operand) ||
48+
binder.f32FloatAttr(alpha, "alpha") ||
49+
binder.f32FloatAttr(gamma, "gamma") ||
50+
binder.tensorResultType(resultType))
51+
return failure();
52+
53+
Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
54+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
55+
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
56+
57+
Value vScale = rewriter.create<Torch::ConstantFloatOp>(
58+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
59+
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));
60+
61+
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
62+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
63+
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));
64+
65+
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
66+
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
67+
return success();
68+
});
69+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
2+
// Generally, the test cases accumulated here come from running the importer
3+
// over all included backend tests that involve simple ops with no model
4+
// level constants. This is a pragmatic choice which lets us have a lot
5+
// of tests in this file, whereas the others tend to be more bespoke.
6+
7+
8+
// CHECK-LABEL: func.func @test_relu
9+
func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 14 : si64} {
10+
// CHECK: torch.aten.relu %arg0
11+
%0 = torch.operator "onnx.Relu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
12+
return %0 : !torch.vtensor<[3,4,5],f32>
13+
}
14+
15+
// -----
16+
17+
// CHECK-LABEL: func.func @test_selu
18+
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
19+
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
20+
// CHECK-DAG: %[[F2:.+]] = torch.constant.float 2
21+
// CHECK-DAG: %[[F3:.+]] = torch.constant.float 3
22+
// CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]]
23+
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
24+
return %0 : !torch.vtensor<[3,4,5],f32>
25+
}

0 commit comments

Comments
 (0)