Skip to content

Commit bcd5dd0

Browse files
committed
[Linalg] Add torch.scatter.reduce to linalg lowering
1 parent 9cf12e1 commit bcd5dd0

File tree

2 files changed

+150
-2
lines changed

2 files changed

+150
-2
lines changed

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,6 +2124,152 @@ class ConvertAtenSliceScatterOp
21242124
};
21252125
} // namespace
21262126

2127+
namespace {
2128+
static Value
2129+
createLinalgPayloadForReduceScatterOp(OpBuilder &b, Location loc, Operation *op,
2130+
ValueRange payloadArgs, Value self,
2131+
int64_t dim, std::string reduceMode,
2132+
RankedTensorType resultType) {
2133+
Type resultElementType = resultType.getElementType();
2134+
Value index = castIntToIndex(b, loc, /*abstractindexElement*/ payloadArgs[0]);
2135+
// Get the element at self[index].
2136+
auto selfElement = b.create<tensor::ExtractOp>(loc, self, ValueRange{index});
2137+
// Get the element at src[index].
2138+
Value srcElement = convertScalarToDtype(
2139+
b, loc, /*abstractSrcElement*/ payloadArgs[1], resultElementType);
2140+
Value accumulatorElement;
2141+
// Reduce the elements based on different mode.
2142+
// TODO add more reduce mode here.
2143+
if (reduceMode == "sum") {
2144+
// accumulatorElement = selfElement + srcElement;
2145+
if (isa<mlir::FloatType>(resultElementType))
2146+
accumulatorElement =
2147+
b.create<arith::AddFOp>(loc, selfElement, srcElement);
2148+
else if (isa<mlir::IntegerType>(resultElementType))
2149+
accumulatorElement =
2150+
b.create<arith::AddIOp>(loc, selfElement, srcElement);
2151+
} else {
2152+
op->emitError("only sum lowering in createLinalgPayloadForReduceScatterOp");
2153+
return nullptr;
2154+
}
2155+
// Prepare source, indices, scatter_dims for scatter op.
2156+
Value accumulatorElementTensor =
2157+
b.create<tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
2158+
Value indexTensor = b.create<tensor::FromElementsOp>(loc, ValueRange{index});
2159+
ArrayRef<int64_t> dimArray{dim};
2160+
auto scatter = b.create<tensor::ScatterOp>(
2161+
loc, resultType, /*source*/ accumulatorElementTensor,
2162+
/*dest*/ self, /*indices*/ indexTensor,
2163+
/*scatter_dims*/ b.getDenseI64ArrayAttr(dimArray),
2164+
/*unique*/ b.getUnitAttr());
2165+
return scatter;
2166+
}
2167+
2168+
class ConvertAtenScatterReduceOp
2169+
: public OpConversionPattern<AtenScatterReduceOp> {
2170+
public:
2171+
using OpConversionPattern::OpConversionPattern;
2172+
LogicalResult
2173+
matchAndRewrite(AtenScatterReduceOp op, OpAdaptor adaptor,
2174+
ConversionPatternRewriter &rewriter) const override {
2175+
Location loc = op.getLoc();
2176+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
2177+
return failure();
2178+
2179+
// Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
2180+
std::string reduceMode;
2181+
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceMode)))
2182+
return rewriter.notifyMatchFailure(
2183+
op, "only support constant str reduce mode");
2184+
// TODO: add "prod", "mean", "amax", "amin" mode.
2185+
if (reduceMode != "sum")
2186+
return rewriter.notifyMatchFailure(
2187+
op, "Only support sum reduce mode for now");
2188+
2189+
// Get dim.
2190+
int64_t dim;
2191+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2192+
return rewriter.notifyMatchFailure(op, "dim must be constant");
2193+
2194+
// Prepare input.
2195+
auto self = adaptor.getSelf();
2196+
auto selfType = cast<RankedTensorType>(self.getType());
2197+
int64_t selfRank = selfType.getRank();
2198+
// TODO: add more input rank support.
2199+
if (selfRank > 1 || dim > selfRank - 1)
2200+
return rewriter.notifyMatchFailure(op,
2201+
"Only support self rank==1 for now");
2202+
2203+
// Prepare index.
2204+
Value index = adaptor.getIndex();
2205+
auto indexType = cast<RankedTensorType>(index.getType());
2206+
int64_t indexRank = indexType.getRank();
2207+
SmallVector<int64_t> indexAbstractSizes(indexRank, kUnknownSize);
2208+
auto abstractIndexType =
2209+
RankedTensorType::get(makeShapeLLVMCompatible(indexAbstractSizes),
2210+
indexType.getElementType());
2211+
Value abstractindex =
2212+
rewriter.create<tensor::CastOp>(loc, abstractIndexType, index);
2213+
2214+
// Prepare src.
2215+
Value src = adaptor.getSrc();
2216+
auto srcType = cast<RankedTensorType>(src.getType());
2217+
int64_t srcRank = srcType.getRank();
2218+
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
2219+
auto abstractSrcType = RankedTensorType::get(
2220+
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
2221+
Value abstractSrc =
2222+
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
2223+
2224+
// Prepare result type.
2225+
const TypeConverter *typeConverter = getTypeConverter();
2226+
RankedTensorType resultType = cast<RankedTensorType>(
2227+
typeConverter->convertType(op->getResult(0).getType()));
2228+
2229+
// Prepare indexingMaps and iteratorTypes.
2230+
SmallVector<AffineMap, 3> indexingMaps = {
2231+
rewriter.getMultiDimIdentityMap(indexRank),
2232+
rewriter.getMultiDimIdentityMap(srcRank),
2233+
rewriter.getMultiDimIdentityMap(selfRank),
2234+
};
2235+
// Prepare iteratorTypes.
2236+
SmallVector<utils::IteratorType> iteratorTypes{
2237+
1, utils::IteratorType::parallel};
2238+
2239+
// Implementation of scatter and reduce in linalg.generic.
2240+
bool err = false;
2241+
Value result =
2242+
rewriter
2243+
.create<linalg::GenericOp>(
2244+
loc, /*resultTensorTypes=*/self.getType(),
2245+
/*inputs=*/ValueRange({abstractindex, abstractSrc}),
2246+
/*outputs=*/self, indexingMaps, iteratorTypes,
2247+
[&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
2248+
// Scatter result after reduce accumulation.
2249+
Value scatter = createLinalgPayloadForReduceScatterOp(
2250+
builder, loc, op, payloadArgs, self, dim, reduceMode,
2251+
resultType);
2252+
// Return selfElements to itself, nothing change but a
2253+
// placeholder.
2254+
if (scatter) {
2255+
builder.create<linalg::YieldOp>(
2256+
loc, /*selfElement*/ payloadArgs[2]);
2257+
}
2258+
err = !scatter;
2259+
})
2260+
.getResult(0);
2261+
2262+
if (err)
2263+
return rewriter.notifyMatchFailure(
2264+
op,
2265+
"failed to create linalg.generic operation for reduce scatter op");
2266+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
2267+
2268+
return success();
2269+
}
2270+
};
2271+
} // namespace
2272+
21272273
namespace {
21282274
class ConvertAtenViewAsComplexOp
21292275
: public OpConversionPattern<AtenViewAsComplexOp> {
@@ -2664,6 +2810,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
26642810
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
26652811
target.addIllegalOp<AtenSliceScatterOp>();
26662812
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
2813+
target.addIllegalOp<AtenScatterReduceOp>();
2814+
patterns.add<ConvertAtenScatterReduceOp>(typeConverter, context);
26672815
target.addIllegalOp<AtenViewAsComplexOp>();
26682816
patterns.add<ConvertAtenViewAsComplexOp>(typeConverter, context);
26692817
target.addIllegalOp<AtenViewAsRealOp>();

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1
268268
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
269269
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
270270
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
271-
// CHECK: %[[STR:.*]] = torch.constant.str "add"
271+
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
272272
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
273273
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
274274
return %0 : !torch.vtensor<[1,5],f32>
@@ -301,7 +301,7 @@ func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],
301301
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
302302
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
303303
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
304-
// CHECK: %[[STR:.*]] = torch.constant.str "multiply"
304+
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
305305
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
306306
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
307307
return %0 : !torch.vtensor<[1,5],f32>

0 commit comments

Comments
 (0)