@@ -1818,6 +1818,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1818
1818
binder.f32FloatAttr (epsilon, " epsilon" , 0 .00001f ) ||
1819
1819
binder.s64IntegerAttr (stashType, " stash_type" , 1 ))
1820
1820
return failure ();
1821
+
1822
+ // Since the support for `stash_type` arg does not exist in
1823
+ // the torch op so we just check for the stash_type to be same
1824
+ // as the input dtype since that won't require us to do any
1825
+ // input type conversion and hence can be supported.
1826
+ auto xType = cast<Torch::ValueTensorType>(x.getType ());
1827
+ std::optional<int64_t > stashTypeIntTorch =
1828
+ onnxDtypeIntToTorchDtypeInt (stashType);
1829
+ if (!stashTypeIntTorch.has_value ())
1830
+ return rewriter.notifyMatchFailure (
1831
+ binder.op , " unimplemented support for the given stash_type" );
1832
+
1833
+ FailureOr<Type> stashDtype = Torch::getTypeForScalarType (
1834
+ binder.op ->getContext (),
1835
+ (torch_upstream::ScalarType)stashTypeIntTorch.value ());
1836
+ if (failed (stashDtype))
1837
+ return failure ();
1838
+ if (*stashDtype != xType.getOptionalDtype ())
1839
+ return rewriter.notifyMatchFailure (
1840
+ binder.op , " unimplemented: stash_type should be same "
1841
+ " as the input dtype" );
1842
+
1821
1843
Value constEpsilon = rewriter.create <Torch::ConstantFloatOp>(
1822
1844
binder.getLoc (), rewriter.getType <Torch::FloatType>(),
1823
1845
rewriter.getF64FloatAttr (epsilon));
@@ -1826,7 +1848,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1826
1848
rank = *maybeRank;
1827
1849
SmallVector<Value> normalized;
1828
1850
axis = Torch::toPositiveDim (axis, rank);
1829
- auto xType = cast<Torch::ValueTensorType>(x.getType ());
1830
1851
if (!xType.hasSizes ()) {
1831
1852
return rewriter.notifyMatchFailure (
1832
1853
binder.op , " Expected input (X) to have sizes" );
@@ -2444,4 +2465,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
2444
2465
paddingList);
2445
2466
return success ();
2446
2467
});
2468
+ patterns.onOp (
2469
+ " GroupNormalization" , 18 ,
2470
+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
2471
+ Torch::ValueTensorType resultType;
2472
+ Value input, scale, bias;
2473
+ int64_t numGroups, stashType;
2474
+ float epsilon;
2475
+ if (binder.tensorOperandAtIndex (input, 0 ) ||
2476
+ binder.tensorOperandAtIndex (scale, 1 ) ||
2477
+ binder.tensorOperandAtIndex (bias, 2 ) ||
2478
+ binder.tensorResultType (resultType) ||
2479
+ binder.s64IntegerAttr (numGroups, " num_groups" ) ||
2480
+ binder.f32FloatAttr (epsilon, " epsilon" , 1e-5 ) ||
2481
+ binder.s64IntegerAttr (stashType, " stash_type" , 1 ))
2482
+ return failure ();
2483
+
2484
+ // Since the support for `stash_type` arg does not exist in
2485
+ // the torch op so we just check for the stash_type to be same
2486
+ // as the input dtype since that won't require us to do any
2487
+ // input type conversion and hence can be supported.
2488
+ std::optional<int64_t > stashTypeIntTorch =
2489
+ onnxDtypeIntToTorchDtypeInt (stashType);
2490
+ if (!stashTypeIntTorch.has_value ())
2491
+ return rewriter.notifyMatchFailure (
2492
+ binder.op , " unimplemented support for the given stash_type" );
2493
+
2494
+ FailureOr<Type> stashDtype = Torch::getTypeForScalarType (
2495
+ binder.op ->getContext (),
2496
+ (torch_upstream::ScalarType)stashTypeIntTorch.value ());
2497
+ if (failed (stashDtype))
2498
+ return failure ();
2499
+ auto inputDtype =
2500
+ cast<Torch::ValueTensorType>(input.getType ()).getOptionalDtype ();
2501
+ if (*stashDtype != inputDtype)
2502
+ return rewriter.notifyMatchFailure (
2503
+ binder.op , " unimplemented: stash_type != input dtype" );
2504
+
2505
+ Value cstEpsilon = rewriter.create <Torch::ConstantFloatOp>(
2506
+ binder.getLoc (), rewriter.getType <Torch::FloatType>(),
2507
+ rewriter.getF64FloatAttr ((double )epsilon));
2508
+ Value cstNumGroups = rewriter.create <Torch::ConstantIntOp>(
2509
+ binder.getLoc (), rewriter.getI64IntegerAttr (numGroups));
2510
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(
2511
+ binder.getLoc (), rewriter.getBoolAttr (false ));
2512
+ rewriter.replaceOpWithNewOp <Torch::AtenGroupNormOp>(
2513
+ binder.op , resultType, input, cstNumGroups, scale, bias, cstEpsilon,
2514
+ /* cudnn_enabled=*/ cstFalse);
2515
+ return success ();
2516
+ });
2447
2517
}
0 commit comments