Skip to content

Commit 2ea2bc3

Browse files
[ONNX] Add OnnxToTorch Lowering for GroupNormalization op (#3458)
Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 04c6479 commit 2ea2bc3

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
18181818
binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) ||
18191819
binder.s64IntegerAttr(stashType, "stash_type", 1))
18201820
return failure();
1821+
1822+
// Since the support for `stash_type` arg does not exist in
1823+
// the torch op so we just check for the stash_type to be same
1824+
// as the input dtype since that won't require us to do any
1825+
// input type conversion and hence can be supported.
1826+
auto xType = cast<Torch::ValueTensorType>(x.getType());
1827+
std::optional<int64_t> stashTypeIntTorch =
1828+
onnxDtypeIntToTorchDtypeInt(stashType);
1829+
if (!stashTypeIntTorch.has_value())
1830+
return rewriter.notifyMatchFailure(
1831+
binder.op, "unimplemented support for the given stash_type");
1832+
1833+
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
1834+
binder.op->getContext(),
1835+
(torch_upstream::ScalarType)stashTypeIntTorch.value());
1836+
if (failed(stashDtype))
1837+
return failure();
1838+
if (*stashDtype != xType.getOptionalDtype())
1839+
return rewriter.notifyMatchFailure(
1840+
binder.op, "unimplemented: stash_type should be same "
1841+
"as the input dtype");
1842+
18211843
Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>(
18221844
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
18231845
rewriter.getF64FloatAttr(epsilon));
@@ -1826,7 +1848,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
18261848
rank = *maybeRank;
18271849
SmallVector<Value> normalized;
18281850
axis = Torch::toPositiveDim(axis, rank);
1829-
auto xType = cast<Torch::ValueTensorType>(x.getType());
18301851
if (!xType.hasSizes()) {
18311852
return rewriter.notifyMatchFailure(
18321853
binder.op, "Expected input (X) to have sizes");
@@ -2444,4 +2465,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
24442465
paddingList);
24452466
return success();
24462467
});
2468+
patterns.onOp(
2469+
"GroupNormalization", 18,
2470+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
2471+
Torch::ValueTensorType resultType;
2472+
Value input, scale, bias;
2473+
int64_t numGroups, stashType;
2474+
float epsilon;
2475+
if (binder.tensorOperandAtIndex(input, 0) ||
2476+
binder.tensorOperandAtIndex(scale, 1) ||
2477+
binder.tensorOperandAtIndex(bias, 2) ||
2478+
binder.tensorResultType(resultType) ||
2479+
binder.s64IntegerAttr(numGroups, "num_groups") ||
2480+
binder.f32FloatAttr(epsilon, "epsilon", 1e-5) ||
2481+
binder.s64IntegerAttr(stashType, "stash_type", 1))
2482+
return failure();
2483+
2484+
// Since the support for `stash_type` arg does not exist in
2485+
// the torch op so we just check for the stash_type to be same
2486+
// as the input dtype since that won't require us to do any
2487+
// input type conversion and hence can be supported.
2488+
std::optional<int64_t> stashTypeIntTorch =
2489+
onnxDtypeIntToTorchDtypeInt(stashType);
2490+
if (!stashTypeIntTorch.has_value())
2491+
return rewriter.notifyMatchFailure(
2492+
binder.op, "unimplemented support for the given stash_type");
2493+
2494+
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
2495+
binder.op->getContext(),
2496+
(torch_upstream::ScalarType)stashTypeIntTorch.value());
2497+
if (failed(stashDtype))
2498+
return failure();
2499+
auto inputDtype =
2500+
cast<Torch::ValueTensorType>(input.getType()).getOptionalDtype();
2501+
if (*stashDtype != inputDtype)
2502+
return rewriter.notifyMatchFailure(
2503+
binder.op, "unimplemented: stash_type != input dtype");
2504+
2505+
Value cstEpsilon = rewriter.create<Torch::ConstantFloatOp>(
2506+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
2507+
rewriter.getF64FloatAttr((double)epsilon));
2508+
Value cstNumGroups = rewriter.create<Torch::ConstantIntOp>(
2509+
binder.getLoc(), rewriter.getI64IntegerAttr(numGroups));
2510+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
2511+
binder.getLoc(), rewriter.getBoolAttr(false));
2512+
rewriter.replaceOpWithNewOp<Torch::AtenGroupNormOp>(
2513+
binder.op, resultType, input, cstNumGroups, scale, bias, cstEpsilon,
2514+
/*cudnn_enabled=*/cstFalse);
2515+
return success();
2516+
});
24472517
}

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,3 +1292,28 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
12921292
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
12931293
return %0 : !torch.vtensor<[1,1,4,4,4],f32>
12941294
}
1295+
1296+
// -----
1297+
1298+
// CHECK-LABEL: func.func @test_group_normalization
1299+
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1300+
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
1301+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1302+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
1303+
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
1304+
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
1305+
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
1306+
return %0 : !torch.vtensor<[3,4,2,2],f32>
1307+
}
1308+
1309+
// -----
1310+
1311+
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1312+
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
1313+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1314+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
1315+
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
1316+
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
1317+
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
1318+
return %0 : !torch.vtensor<[3,4,2,2],f32>
1319+
}

0 commit comments

Comments
 (0)