@@ -2124,6 +2124,152 @@ class ConvertAtenSliceScatterOp
2124
2124
};
2125
2125
} // namespace
2126
2126
2127
+ namespace {
2128
+ static Value
2129
+ createLinalgPayloadForReduceScatterOp (OpBuilder &b, Location loc, Operation *op,
2130
+ ValueRange payloadArgs, Value self,
2131
+ int64_t dim, std::string reduceMode,
2132
+ RankedTensorType resultType) {
2133
+ Type resultElementType = resultType.getElementType ();
2134
+ Value index = castIntToIndex (b, loc, /* abstractindexElement*/ payloadArgs[0 ]);
2135
+ // Get the element at self[index].
2136
+ auto selfElement = b.create <tensor::ExtractOp>(loc, self, ValueRange{index });
2137
+ // Get the element at src[index].
2138
+ Value srcElement = convertScalarToDtype (
2139
+ b, loc, /* abstractSrcElement*/ payloadArgs[1 ], resultElementType);
2140
+ Value accumulatorElement;
2141
+ // Reduce the elements based on different mode.
2142
+ // TODO add more reduce mode here.
2143
+ if (reduceMode == " sum" ) {
2144
+ // accumulatorElement = selfElement + srcElement;
2145
+ if (isa<mlir::FloatType>(resultElementType))
2146
+ accumulatorElement =
2147
+ b.create <arith::AddFOp>(loc, selfElement, srcElement);
2148
+ else if (isa<mlir::IntegerType>(resultElementType))
2149
+ accumulatorElement =
2150
+ b.create <arith::AddIOp>(loc, selfElement, srcElement);
2151
+ } else {
2152
+ op->emitError (" only sum lowering in createLinalgPayloadForReduceScatterOp" );
2153
+ return nullptr ;
2154
+ }
2155
+ // Prepare source, indices, scatter_dims for scatter op.
2156
+ Value accumulatorElementTensor =
2157
+ b.create <tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
2158
+ Value indexTensor = b.create <tensor::FromElementsOp>(loc, ValueRange{index });
2159
+ ArrayRef<int64_t > dimArray{dim};
2160
+ auto scatter = b.create <tensor::ScatterOp>(
2161
+ loc, resultType, /* source*/ accumulatorElementTensor,
2162
+ /* dest*/ self, /* indices*/ indexTensor,
2163
+ /* scatter_dims*/ b.getDenseI64ArrayAttr (dimArray),
2164
+ /* unique*/ b.getUnitAttr ());
2165
+ return scatter;
2166
+ }
2167
+
2168
+ class ConvertAtenScatterReduceOp
2169
+ : public OpConversionPattern<AtenScatterReduceOp> {
2170
+ public:
2171
+ using OpConversionPattern::OpConversionPattern;
2172
+ LogicalResult
2173
+ matchAndRewrite (AtenScatterReduceOp op, OpAdaptor adaptor,
2174
+ ConversionPatternRewriter &rewriter) const override {
2175
+ Location loc = op.getLoc ();
2176
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
2177
+ return failure ();
2178
+
2179
+ // Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
2180
+ std::string reduceMode;
2181
+ if (!matchPattern (op.getReduce (), m_TorchConstantStr (reduceMode)))
2182
+ return rewriter.notifyMatchFailure (
2183
+ op, " only support constant str reduce mode" );
2184
+ // TODO: add "prod", "mean", "amax", "amin" mode.
2185
+ if (reduceMode != " sum" )
2186
+ return rewriter.notifyMatchFailure (
2187
+ op, " Only support sum reduce mode for now" );
2188
+
2189
+ // Get dim.
2190
+ int64_t dim;
2191
+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
2192
+ return rewriter.notifyMatchFailure (op, " dim must be constant" );
2193
+
2194
+ // Prepare input.
2195
+ auto self = adaptor.getSelf ();
2196
+ auto selfType = cast<RankedTensorType>(self.getType ());
2197
+ int64_t selfRank = selfType.getRank ();
2198
+ // TODO: add more input rank support.
2199
+ if (selfRank > 1 || dim > selfRank - 1 )
2200
+ return rewriter.notifyMatchFailure (op,
2201
+ " Only support self rank==1 for now" );
2202
+
2203
+ // Prepare index.
2204
+ Value index = adaptor.getIndex ();
2205
+ auto indexType = cast<RankedTensorType>(index .getType ());
2206
+ int64_t indexRank = indexType.getRank ();
2207
+ SmallVector<int64_t > indexAbstractSizes (indexRank, kUnknownSize );
2208
+ auto abstractIndexType =
2209
+ RankedTensorType::get (makeShapeLLVMCompatible (indexAbstractSizes),
2210
+ indexType.getElementType ());
2211
+ Value abstractindex =
2212
+ rewriter.create <tensor::CastOp>(loc, abstractIndexType, index );
2213
+
2214
+ // Prepare src.
2215
+ Value src = adaptor.getSrc ();
2216
+ auto srcType = cast<RankedTensorType>(src.getType ());
2217
+ int64_t srcRank = srcType.getRank ();
2218
+ SmallVector<int64_t > srcAbstractSizes (srcRank, kUnknownSize );
2219
+ auto abstractSrcType = RankedTensorType::get (
2220
+ makeShapeLLVMCompatible (srcAbstractSizes), srcType.getElementType ());
2221
+ Value abstractSrc =
2222
+ rewriter.create <tensor::CastOp>(loc, abstractSrcType, src);
2223
+
2224
+ // Prepare result type.
2225
+ const TypeConverter *typeConverter = getTypeConverter ();
2226
+ RankedTensorType resultType = cast<RankedTensorType>(
2227
+ typeConverter->convertType (op->getResult (0 ).getType ()));
2228
+
2229
+ // Prepare indexingMaps and iteratorTypes.
2230
+ SmallVector<AffineMap, 3 > indexingMaps = {
2231
+ rewriter.getMultiDimIdentityMap (indexRank),
2232
+ rewriter.getMultiDimIdentityMap (srcRank),
2233
+ rewriter.getMultiDimIdentityMap (selfRank),
2234
+ };
2235
+ // Prepare iteratorTypes.
2236
+ SmallVector<utils::IteratorType> iteratorTypes{
2237
+ 1 , utils::IteratorType::parallel};
2238
+
2239
+ // Implementation of scatter and reduce in linalg.generic.
2240
+ bool err = false ;
2241
+ Value result =
2242
+ rewriter
2243
+ .create <linalg::GenericOp>(
2244
+ loc, /* resultTensorTypes=*/ self.getType (),
2245
+ /* inputs=*/ ValueRange ({abstractindex, abstractSrc}),
2246
+ /* outputs=*/ self, indexingMaps, iteratorTypes,
2247
+ [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
2248
+ // Scatter result after reduce accumulation.
2249
+ Value scatter = createLinalgPayloadForReduceScatterOp (
2250
+ builder, loc, op, payloadArgs, self, dim, reduceMode,
2251
+ resultType);
2252
+ // Return selfElements to itself, nothing change but a
2253
+ // placeholder.
2254
+ if (scatter) {
2255
+ builder.create <linalg::YieldOp>(
2256
+ loc, /* selfElement*/ payloadArgs[2 ]);
2257
+ }
2258
+ err = !scatter;
2259
+ })
2260
+ .getResult (0 );
2261
+
2262
+ if (err)
2263
+ return rewriter.notifyMatchFailure (
2264
+ op,
2265
+ " failed to create linalg.generic operation for reduce scatter op" );
2266
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
2267
+
2268
+ return success ();
2269
+ }
2270
+ };
2271
+ } // namespace
2272
+
2127
2273
namespace {
2128
2274
class ConvertAtenViewAsComplexOp
2129
2275
: public OpConversionPattern<AtenViewAsComplexOp> {
@@ -2664,6 +2810,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
2664
2810
patterns.add <ConvertAtenCopyOp>(typeConverter, context);
2665
2811
target.addIllegalOp <AtenSliceScatterOp>();
2666
2812
patterns.add <ConvertAtenSliceScatterOp>(typeConverter, context);
2813
+ target.addIllegalOp <AtenScatterReduceOp>();
2814
+ patterns.add <ConvertAtenScatterReduceOp>(typeConverter, context);
2667
2815
target.addIllegalOp <AtenViewAsComplexOp>();
2668
2816
patterns.add <ConvertAtenViewAsComplexOp>(typeConverter, context);
2669
2817
target.addIllegalOp <AtenViewAsRealOp>();
0 commit comments