Skip to content

Commit 94c8b6a

Browse files
committed
[onnx] Lowerings from onnx.selu
Started work on the `selu` lowerings for ONNX to Torch.
1 parent 099e1f4 commit 94c8b6a

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,25 @@ struct OpBinder {
113113
return failure();
114114
}
115115

116+
ParseResult f32FloatAttr(float &value, StringRef nameSuffix,
117+
float defaultValue = 0.0f) {
118+
SmallString<64> name("torch.onnx.");
119+
name.append(nameSuffix);
120+
auto attr = op->getAttr(name);
121+
if (!attr) {
122+
value = defaultValue;
123+
return success();
124+
}
125+
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
126+
FloatType t = cast<FloatType>(floatAttr.getType());
127+
if (t.getWidth() != 32)
128+
return failure();
129+
value = floatAttr.getValueAsDouble();
130+
return success();
131+
}
132+
return failure();
133+
}
134+
116135
ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
117136
std::string defaultValue = "") {
118137
SmallString<64> name("torch.onnx.");

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,33 @@ using namespace mlir::torch::onnx_c;
2626
// results in a lot of ONNX test cases that all reduce to the exact same
2727
// thing here, so we simplify.
2828
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29-
OnnxCustomOpConversionPattern &patterns) {}
29+
OnnxCustomOpConversionPattern &patterns) {
30+
31+
patterns.onOp(
32+
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
33+
Torch::ValueTensorType resultType;
34+
float alpha, gamma;
35+
Value operand;
36+
if (binder.tensorOperand(operand) ||
37+
binder.f32FloatAttr(alpha, "alpha") ||
38+
binder.f32FloatAttr(gamma, "gamma") ||
39+
binder.tensorResultType(resultType))
40+
return failure();
41+
42+
Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
43+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
44+
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
45+
46+
Value vScale = rewriter.create<Torch::ConstantFloatOp>(
47+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
48+
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));
49+
50+
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
51+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
52+
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));
53+
54+
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
55+
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
56+
return success();
57+
});
58+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
2+
// Generally, the test cases accumulated here come from running the importer
3+
// over all included backend tests that involve simple ops with no model
4+
// level constants. This is a pragmatic choice which lets us have a lot
5+
// of tests in this file, whereas the others tend to be more bespoke.
6+
7+
8+
// CHECK-LABEL: func.func @test_selu
9+
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
10+
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
11+
// CHECK-DAG: %[[F2:.+]] = torch.constant.float 2
12+
// CHECK-DAG: %[[F3:.+]] = torch.constant.float 3
13+
// CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]]
14+
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
15+
return %0 : !torch.vtensor<[3,4,5],f32>
16+
}

0 commit comments

Comments
 (0)