Skip to content

Commit faec62d

Browse files
committed
[ONNX] Change AtenScatterReduce to AtenScatterReduceTwoOp for onnx.ScatterElements
This will enable the AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext Remove the wrong AtenScatterReduce to linalg pass.
1 parent e632755 commit faec62d

File tree

3 files changed

+24
-169
lines changed

3 files changed

+24
-169
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -645,10 +645,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
645645

646646
Value cstStrReduction =
647647
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
648-
649-
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
648+
Value cstTrue =
649+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
650+
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
650651
binder.op, resultType, data, constAxis, indices, updates,
651-
cstStrReduction);
652+
cstStrReduction, cstTrue);
652653
return success();
653654
});
654655
patterns.onOp(

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 0 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,152 +2124,6 @@ class ConvertAtenSliceScatterOp
21242124
};
21252125
} // namespace
21262126

2127-
namespace {
2128-
static Value
2129-
createLinalgPayloadForReduceScatterOp(OpBuilder &b, Location loc, Operation *op,
2130-
ValueRange payloadArgs, Value self,
2131-
int64_t dim, std::string reduceMode,
2132-
RankedTensorType resultType) {
2133-
Type resultElementType = resultType.getElementType();
2134-
Value index = castIntToIndex(b, loc, /*abstractindexElement*/ payloadArgs[0]);
2135-
// Get the element at self[index].
2136-
auto selfElement = b.create<tensor::ExtractOp>(loc, self, ValueRange{index});
2137-
// Get the element at src[index].
2138-
Value srcElement = convertScalarToDtype(
2139-
b, loc, /*abstractSrcElement*/ payloadArgs[1], resultElementType);
2140-
Value accumulatorElement;
2141-
// Reduce the elements based on different mode.
2142-
// TODO add more reduce mode here.
2143-
if (reduceMode == "sum") {
2144-
// accumulatorElement = selfElement + srcElement;
2145-
if (isa<mlir::FloatType>(resultElementType))
2146-
accumulatorElement =
2147-
b.create<arith::AddFOp>(loc, selfElement, srcElement);
2148-
else if (isa<mlir::IntegerType>(resultElementType))
2149-
accumulatorElement =
2150-
b.create<arith::AddIOp>(loc, selfElement, srcElement);
2151-
} else {
2152-
op->emitError("only sum lowering in createLinalgPayloadForReduceScatterOp");
2153-
return nullptr;
2154-
}
2155-
// Prepare source, indices, scatter_dims for scatter op.
2156-
Value accumulatorElementTensor =
2157-
b.create<tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
2158-
Value indexTensor = b.create<tensor::FromElementsOp>(loc, ValueRange{index});
2159-
ArrayRef<int64_t> dimArray{dim};
2160-
auto scatter = b.create<tensor::ScatterOp>(
2161-
loc, resultType, /*source*/ accumulatorElementTensor,
2162-
/*dest*/ self, /*indices*/ indexTensor,
2163-
/*scatter_dims*/ b.getDenseI64ArrayAttr(dimArray),
2164-
/*unique*/ b.getUnitAttr());
2165-
return scatter;
2166-
}
2167-
2168-
class ConvertAtenScatterReduceOp
2169-
: public OpConversionPattern<AtenScatterReduceOp> {
2170-
public:
2171-
using OpConversionPattern::OpConversionPattern;
2172-
LogicalResult
2173-
matchAndRewrite(AtenScatterReduceOp op, OpAdaptor adaptor,
2174-
ConversionPatternRewriter &rewriter) const override {
2175-
Location loc = op.getLoc();
2176-
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
2177-
return failure();
2178-
2179-
// Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
2180-
std::string reduceMode;
2181-
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceMode)))
2182-
return rewriter.notifyMatchFailure(
2183-
op, "only support constant str reduce mode");
2184-
// TODO: add "prod", "mean", "amax", "amin" mode.
2185-
if (reduceMode != "sum")
2186-
return rewriter.notifyMatchFailure(
2187-
op, "Only support sum reduce mode for now");
2188-
2189-
// Get dim.
2190-
int64_t dim;
2191-
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2192-
return rewriter.notifyMatchFailure(op, "dim must be constant");
2193-
2194-
// Prepare input.
2195-
auto self = adaptor.getSelf();
2196-
auto selfType = cast<RankedTensorType>(self.getType());
2197-
int64_t selfRank = selfType.getRank();
2198-
// TODO: add more input rank support.
2199-
if (selfRank > 1 || dim > selfRank - 1)
2200-
return rewriter.notifyMatchFailure(op,
2201-
"Only support self rank==1 for now");
2202-
2203-
// Prepare index.
2204-
Value index = adaptor.getIndex();
2205-
auto indexType = cast<RankedTensorType>(index.getType());
2206-
int64_t indexRank = indexType.getRank();
2207-
SmallVector<int64_t> indexAbstractSizes(indexRank, kUnknownSize);
2208-
auto abstractIndexType =
2209-
RankedTensorType::get(makeShapeLLVMCompatible(indexAbstractSizes),
2210-
indexType.getElementType());
2211-
Value abstractindex =
2212-
rewriter.create<tensor::CastOp>(loc, abstractIndexType, index);
2213-
2214-
// Prepare src.
2215-
Value src = adaptor.getSrc();
2216-
auto srcType = cast<RankedTensorType>(src.getType());
2217-
int64_t srcRank = srcType.getRank();
2218-
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
2219-
auto abstractSrcType = RankedTensorType::get(
2220-
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
2221-
Value abstractSrc =
2222-
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
2223-
2224-
// Prepare result type.
2225-
const TypeConverter *typeConverter = getTypeConverter();
2226-
RankedTensorType resultType = cast<RankedTensorType>(
2227-
typeConverter->convertType(op->getResult(0).getType()));
2228-
2229-
// Prepare indexingMaps and iteratorTypes.
2230-
SmallVector<AffineMap, 3> indexingMaps = {
2231-
rewriter.getMultiDimIdentityMap(indexRank),
2232-
rewriter.getMultiDimIdentityMap(srcRank),
2233-
rewriter.getMultiDimIdentityMap(selfRank),
2234-
};
2235-
// Prepare iteratorTypes.
2236-
SmallVector<utils::IteratorType> iteratorTypes{
2237-
1, utils::IteratorType::parallel};
2238-
2239-
// Implementation of scatter and reduce in linalg.generic.
2240-
bool err = false;
2241-
Value result =
2242-
rewriter
2243-
.create<linalg::GenericOp>(
2244-
loc, /*resultTensorTypes=*/self.getType(),
2245-
/*inputs=*/ValueRange({abstractindex, abstractSrc}),
2246-
/*outputs=*/self, indexingMaps, iteratorTypes,
2247-
[&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
2248-
// Scatter result after reduce accumulation.
2249-
Value scatter = createLinalgPayloadForReduceScatterOp(
2250-
builder, loc, op, payloadArgs, self, dim, reduceMode,
2251-
resultType);
2252-
// Return selfElements to itself, nothing change but a
2253-
// placeholder.
2254-
if (scatter) {
2255-
builder.create<linalg::YieldOp>(
2256-
loc, /*selfElement*/ payloadArgs[2]);
2257-
}
2258-
err = !scatter;
2259-
})
2260-
.getResult(0);
2261-
2262-
if (err)
2263-
return rewriter.notifyMatchFailure(
2264-
op,
2265-
"failed to create linalg.generic operation for reduce scatter op");
2266-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
2267-
2268-
return success();
2269-
}
2270-
};
2271-
} // namespace
2272-
22732127
namespace {
22742128
class ConvertAtenViewAsComplexOp
22752129
: public OpConversionPattern<AtenViewAsComplexOp> {
@@ -2810,8 +2664,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
28102664
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
28112665
target.addIllegalOp<AtenSliceScatterOp>();
28122666
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
2813-
target.addIllegalOp<AtenScatterReduceOp>();
2814-
patterns.add<ConvertAtenScatterReduceOp>(typeConverter, context);
28152667
target.addIllegalOp<AtenViewAsComplexOp>();
28162668
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
28172669
target.addIllegalOp<AtenViewAsRealOp>();

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
261261

262262
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
263263
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
264-
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
265-
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
266-
// CHECK: %[[ONE:.+]] = torch.constant.int 1
267-
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
268-
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
269-
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
270-
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
271-
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
272-
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
264+
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
265+
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
266+
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
267+
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
268+
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
269+
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
270+
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
271+
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
272+
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
273+
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
273274
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
274275
return %0 : !torch.vtensor<[1,5],f32>
275276
}
@@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
294295

295296
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
296297
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
297-
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
298-
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
299-
// CHECK: %[[ONE:.+]] = torch.constant.int 1
300-
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
301-
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
302-
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
303-
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
304-
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
305-
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
298+
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
299+
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
300+
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
301+
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
302+
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
303+
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
304+
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
305+
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
306+
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
307+
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
306308
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
307309
return %0 : !torch.vtensor<[1,5],f32>
308310
}

0 commit comments

Comments
 (0)