@@ -2124,152 +2124,6 @@ class ConvertAtenSliceScatterOp
2124
2124
};
2125
2125
} // namespace
2126
2126
2127
- namespace {
2128
- static Value
2129
- createLinalgPayloadForReduceScatterOp (OpBuilder &b, Location loc, Operation *op,
2130
- ValueRange payloadArgs, Value self,
2131
- int64_t dim, std::string reduceMode,
2132
- RankedTensorType resultType) {
2133
- Type resultElementType = resultType.getElementType ();
2134
- Value index = castIntToIndex (b, loc, /* abstractindexElement*/ payloadArgs[0 ]);
2135
- // Get the element at self[index].
2136
- auto selfElement = b.create <tensor::ExtractOp>(loc, self, ValueRange{index });
2137
- // Get the element at src[index].
2138
- Value srcElement = convertScalarToDtype (
2139
- b, loc, /* abstractSrcElement*/ payloadArgs[1 ], resultElementType);
2140
- Value accumulatorElement;
2141
- // Reduce the elements based on different mode.
2142
- // TODO add more reduce mode here.
2143
- if (reduceMode == " sum" ) {
2144
- // accumulatorElement = selfElement + srcElement;
2145
- if (isa<mlir::FloatType>(resultElementType))
2146
- accumulatorElement =
2147
- b.create <arith::AddFOp>(loc, selfElement, srcElement);
2148
- else if (isa<mlir::IntegerType>(resultElementType))
2149
- accumulatorElement =
2150
- b.create <arith::AddIOp>(loc, selfElement, srcElement);
2151
- } else {
2152
- op->emitError (" only sum lowering in createLinalgPayloadForReduceScatterOp" );
2153
- return nullptr ;
2154
- }
2155
- // Prepare source, indices, scatter_dims for scatter op.
2156
- Value accumulatorElementTensor =
2157
- b.create <tensor::FromElementsOp>(loc, ValueRange{accumulatorElement});
2158
- Value indexTensor = b.create <tensor::FromElementsOp>(loc, ValueRange{index });
2159
- ArrayRef<int64_t > dimArray{dim};
2160
- auto scatter = b.create <tensor::ScatterOp>(
2161
- loc, resultType, /* source*/ accumulatorElementTensor,
2162
- /* dest*/ self, /* indices*/ indexTensor,
2163
- /* scatter_dims*/ b.getDenseI64ArrayAttr (dimArray),
2164
- /* unique*/ b.getUnitAttr ());
2165
- return scatter;
2166
- }
2167
-
2168
- class ConvertAtenScatterReduceOp
2169
- : public OpConversionPattern<AtenScatterReduceOp> {
2170
- public:
2171
- using OpConversionPattern::OpConversionPattern;
2172
- LogicalResult
2173
- matchAndRewrite (AtenScatterReduceOp op, OpAdaptor adaptor,
2174
- ConversionPatternRewriter &rewriter) const override {
2175
- Location loc = op.getLoc ();
2176
- if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
2177
- return failure ();
2178
-
2179
- // Get reduce mode, it could be "sum", "prod", "mean", "amax", "amin".
2180
- std::string reduceMode;
2181
- if (!matchPattern (op.getReduce (), m_TorchConstantStr (reduceMode)))
2182
- return rewriter.notifyMatchFailure (
2183
- op, " only support constant str reduce mode" );
2184
- // TODO: add "prod", "mean", "amax", "amin" mode.
2185
- if (reduceMode != " sum" )
2186
- return rewriter.notifyMatchFailure (
2187
- op, " Only support sum reduce mode for now" );
2188
-
2189
- // Get dim.
2190
- int64_t dim;
2191
- if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
2192
- return rewriter.notifyMatchFailure (op, " dim must be constant" );
2193
-
2194
- // Prepare input.
2195
- auto self = adaptor.getSelf ();
2196
- auto selfType = cast<RankedTensorType>(self.getType ());
2197
- int64_t selfRank = selfType.getRank ();
2198
- // TODO: add more input rank support.
2199
- if (selfRank > 1 || dim > selfRank - 1 )
2200
- return rewriter.notifyMatchFailure (op,
2201
- " Only support self rank==1 for now" );
2202
-
2203
- // Prepare index.
2204
- Value index = adaptor.getIndex ();
2205
- auto indexType = cast<RankedTensorType>(index .getType ());
2206
- int64_t indexRank = indexType.getRank ();
2207
- SmallVector<int64_t > indexAbstractSizes (indexRank, kUnknownSize );
2208
- auto abstractIndexType =
2209
- RankedTensorType::get (makeShapeLLVMCompatible (indexAbstractSizes),
2210
- indexType.getElementType ());
2211
- Value abstractindex =
2212
- rewriter.create <tensor::CastOp>(loc, abstractIndexType, index );
2213
-
2214
- // Prepare src.
2215
- Value src = adaptor.getSrc ();
2216
- auto srcType = cast<RankedTensorType>(src.getType ());
2217
- int64_t srcRank = srcType.getRank ();
2218
- SmallVector<int64_t > srcAbstractSizes (srcRank, kUnknownSize );
2219
- auto abstractSrcType = RankedTensorType::get (
2220
- makeShapeLLVMCompatible (srcAbstractSizes), srcType.getElementType ());
2221
- Value abstractSrc =
2222
- rewriter.create <tensor::CastOp>(loc, abstractSrcType, src);
2223
-
2224
- // Prepare result type.
2225
- const TypeConverter *typeConverter = getTypeConverter ();
2226
- RankedTensorType resultType = cast<RankedTensorType>(
2227
- typeConverter->convertType (op->getResult (0 ).getType ()));
2228
-
2229
- // Prepare indexingMaps and iteratorTypes.
2230
- SmallVector<AffineMap, 3 > indexingMaps = {
2231
- rewriter.getMultiDimIdentityMap (indexRank),
2232
- rewriter.getMultiDimIdentityMap (srcRank),
2233
- rewriter.getMultiDimIdentityMap (selfRank),
2234
- };
2235
- // Prepare iteratorTypes.
2236
- SmallVector<utils::IteratorType> iteratorTypes{
2237
- 1 , utils::IteratorType::parallel};
2238
-
2239
- // Implementation of scatter and reduce in linalg.generic.
2240
- bool err = false ;
2241
- Value result =
2242
- rewriter
2243
- .create <linalg::GenericOp>(
2244
- loc, /* resultTensorTypes=*/ self.getType (),
2245
- /* inputs=*/ ValueRange ({abstractindex, abstractSrc}),
2246
- /* outputs=*/ self, indexingMaps, iteratorTypes,
2247
- [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
2248
- // Scatter result after reduce accumulation.
2249
- Value scatter = createLinalgPayloadForReduceScatterOp (
2250
- builder, loc, op, payloadArgs, self, dim, reduceMode,
2251
- resultType);
2252
- // Return selfElements to itself, nothing change but a
2253
- // placeholder.
2254
- if (scatter) {
2255
- builder.create <linalg::YieldOp>(
2256
- loc, /* selfElement*/ payloadArgs[2 ]);
2257
- }
2258
- err = !scatter;
2259
- })
2260
- .getResult (0 );
2261
-
2262
- if (err)
2263
- return rewriter.notifyMatchFailure (
2264
- op,
2265
- " failed to create linalg.generic operation for reduce scatter op" );
2266
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
2267
-
2268
- return success ();
2269
- }
2270
- };
2271
- } // namespace
2272
-
2273
2127
namespace {
2274
2128
class ConvertAtenViewAsComplexOp
2275
2129
: public OpConversionPattern<AtenViewAsComplexOp> {
@@ -2810,8 +2664,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
2810
2664
patterns.add <ConvertAtenCopyOp>(typeConverter, context);
2811
2665
target.addIllegalOp <AtenSliceScatterOp>();
2812
2666
patterns.add <ConvertAtenSliceScatterOp>(typeConverter, context);
2813
- target.addIllegalOp <AtenScatterReduceOp>();
2814
- patterns.add <ConvertAtenScatterReduceOp>(typeConverter, context);
2815
2667
target.addIllegalOp <AtenViewAsComplexOp>();
2816
2668
patterns.add <ConvertAtenViewAsComplexOp>(typeConverter, context);
2817
2669
target.addIllegalOp <AtenViewAsRealOp>();
0 commit comments