Skip to content

Commit 1225073

Browse files
Integrate LLVM at e2402615a5a76d46a433dfcc1de10b38a1263c9d (#3982)
Update LLVM to llvm/llvm-project@e240261 Update StableHlo to openxla/stablehlo@8cd9444 Updates API calls from: 1. `applyPatternsAndFoldGreedily` -> `applyPatternsGreedily` 2. `applyOpPatternsAndFold` -> `applyOpPatternsGreedily` This commit also inlines the `BufferizeTypeConverter` in Torch-MLIR which has been removed from the LLVM project here: llvm/llvm-project@2ff2e87. This commit also updates the `AdjustCallingConventions` pass in order to align with the changes made for `TypeConverter` upstream. Some of the tests from the `adjust-calling-conventions.mlir` are disabled for the time being since they are not supported even after making changes in the pass. We will enable them once the `AdjustCallingConventions` pass is fully functional in a seperate PR. The fix will be tracked by #3983. TOSA Updates Summary: Update Torch to TOSA legalizations with TOSA 1.0 ops' forms from LLVM hash 64edde66. Changes include: TOSA Pad op's new shape requirement TOSA Convolution ops' new acc_type TOSA Tile with multiples as a !tosa.shape input --------- Signed-off-by: Vivek Khandelwal <[email protected]> Co-authored-by: Justin Ngo <[email protected]>
1 parent af8514c commit 1225073

29 files changed

+430
-178
lines changed

externals/llvm-project

Submodule llvm-project updated 18624 files

externals/stablehlo

Submodule stablehlo updated 143 files

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
121121
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
122122
TypeAttr &accType);
123123

124+
// Get accumulator type for TOSA convolution ops
125+
LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
126+
RankedTensorType inputTy,
127+
RankedTensorType weightTy,
128+
RankedTensorType outputTy, TypeAttr &accType);
129+
130+
// Temporary function to get TOSA const shape
131+
// TODO: Remove this function when getTosaConstShape is available in
132+
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
133+
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
134+
llvm::ArrayRef<int64_t> shape);
124135
} // namespace tosa
125136
} // namespace mlir
126137

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
549549
}
550550
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
551551
MLIRContext *context = op->getContext();
552-
Type floatDtype = mlir::FloatType::getF64(context);
552+
Type floatDtype = mlir::Float64Type::get(context);
553553
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
554554
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype);
555555
Value zero =
@@ -569,7 +569,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
569569
}
570570
if (isa<AtenLogicalNotOp>(op)) {
571571
MLIRContext *context = op->getContext();
572-
Type floatDtype = mlir::FloatType::getF64(context);
572+
Type floatDtype = mlir::Float64Type::get(context);
573573
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
574574
Value zero =
575575
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
@@ -1028,7 +1028,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10281028
Type powType = dtype;
10291029
if (payloadArgs[0].getType().isInteger() ||
10301030
payloadArgs[1].getType().isInteger())
1031-
powType = mlir::FloatType::getF64(op->getContext());
1031+
powType = mlir::Float64Type::get(op->getContext());
10321032
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType);
10331033
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType);
10341034
auto powOp = b.create<math::PowFOp>(loc, lhs, rhs);

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
1313
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1414
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
15+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1516
#include "mlir/IR/Matchers.h"
1617
#include "mlir/Transforms/DialectConversion.h"
1718
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
@@ -2252,6 +2253,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
22522253
return rewriter.notifyMatchFailure(op,
22532254
"non-const dilation list unsupported");
22542255

2256+
TypeAttr accType;
2257+
if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy,
2258+
accType)))
2259+
return rewriter.notifyMatchFailure(
2260+
op, "failed to get accumulator type for convolution ops");
2261+
22552262
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
22562263
// Perform the necessary transformations.
22572264
std::optional<Value> nchwToNhwcTransposeConst =
@@ -2365,12 +2372,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23652372
// full convolution
23662373
convOpResult =
23672374
rewriter
2368-
.create<tosa::Conv2DOp>(op->getLoc(),
2369-
getTypeConverter()->convertType(convOpTy),
2370-
transposedInput, transformedWeight, bias,
2371-
rewriter.getDenseI64ArrayAttr(padding),
2372-
rewriter.getDenseI64ArrayAttr(stride),
2373-
rewriter.getDenseI64ArrayAttr(dilation))
2375+
.create<tosa::Conv2DOp>(
2376+
op->getLoc(), getTypeConverter()->convertType(convOpTy),
2377+
transposedInput, transformedWeight, bias,
2378+
rewriter.getDenseI64ArrayAttr(padding),
2379+
rewriter.getDenseI64ArrayAttr(stride),
2380+
rewriter.getDenseI64ArrayAttr(dilation), accType)
23742381
.getResult();
23752382
} else if (weightShape[1] == 1) {
23762383
// depthwise convolution
@@ -2381,7 +2388,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23812388
transposedInput, transformedWeight, bias,
23822389
rewriter.getDenseI64ArrayAttr(padding),
23832390
rewriter.getDenseI64ArrayAttr(stride),
2384-
rewriter.getDenseI64ArrayAttr(dilation))
2391+
rewriter.getDenseI64ArrayAttr(dilation), accType)
23852392
.getResult();
23862393
} else {
23872394
llvm_unreachable("Unhandled convolution type");
@@ -3909,9 +3916,11 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
39093916
}
39103917
}
39113918

3912-
auto result = rewriter.create<tosa::TileOp>(
3913-
op->getLoc(), resultType, reshapedInput,
3914-
rewriter.getDenseI64ArrayAttr(tileOpShape));
3919+
auto tileOpMultiples =
3920+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShape);
3921+
3922+
auto result = rewriter.create<tosa::TileOp>(op->getLoc(), resultType,
3923+
reshapedInput, tileOpMultiples);
39153924

39163925
rewriter.replaceOp(op, {result.getResult()});
39173926
}
@@ -4104,9 +4113,11 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
41044113
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
41054114
rewriter.getIntegerType(32));
41064115

4116+
auto tileOpMultiples =
4117+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape);
4118+
41074119
auto expandedIndices = rewriter.create<tosa::TileOp>(
4108-
op->getLoc(), tileType, reshapedIndices.getResult(),
4109-
rewriter.getDenseI64ArrayAttr(tileShape));
4120+
op->getLoc(), tileType, reshapedIndices.getResult(), tileOpMultiples);
41104121

41114122
// convert torch style index and dim into tf style indices
41124123
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
@@ -4445,17 +4456,23 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
44454456
if (needsTiling) {
44464457
auto idxType =
44474458
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());
4459+
44484460
// indicesTfConcatTensors has a trailing [1] dim for the final concat.
44494461
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
44504462
maxRankMaxDimShapeTf.push_back(1);
4463+
44514464
auto tileOpShapeTf(tileOpShape);
44524465
tileOpShapeTf.push_back(1);
4466+
44534467
auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
44544468
idxType.getElementType());
44554469
auto reshapedIdxTensor = indicesTfConcatTensors[i];
4470+
4471+
auto tileOpMultiples =
4472+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShapeTf);
4473+
44564474
indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
4457-
op->getLoc(), tileOutputTy, reshapedIdxTensor,
4458-
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
4475+
op->getLoc(), tileOutputTy, reshapedIdxTensor, tileOpMultiples);
44594476
}
44604477

44614478
// Every index tensor now has the same rank and shape
@@ -6023,12 +6040,14 @@ class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
60236040
op->getLoc(), fillValueMatchedInputRankType, fillValue,
60246041
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
60256042

6043+
auto tileOpMultiples =
6044+
tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape());
6045+
60266046
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
60276047
op->getLoc(),
60286048
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
60296049
fillValueElemTy),
6030-
fillValueMatchedInputRankTensor.getResult(),
6031-
makeShapeTorchCompatible(outType.getShape()));
6050+
fillValueMatchedInputRankTensor.getResult(), tileOpMultiples);
60326051
} else {
60336052
if (failed(torchScalarToTosaTensor(
60346053
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
@@ -6179,7 +6198,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
61796198
}
61806199

61816200
DenseElementsAttr paddingAttr = DenseIntElementsAttr::get(
6182-
RankedTensorType::get({rank, 2}, rewriter.getI64Type()),
6201+
RankedTensorType::get({2 * rank}, rewriter.getI64Type()),
61836202
translatePadsList);
61846203

61856204
Value padsList1 = rewriter.create<mlir::tosa::ConstOp>(
@@ -7836,9 +7855,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
78367855
resultType.getElementType()),
78377856
self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
78387857

7858+
auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(),
7859+
resultShapeIndex0Replaced);
7860+
78397861
auto selfTiled = rewriter.create<tosa::TileOp>(
7840-
op->getLoc(), resultType, selfReshaped.getResult(),
7841-
rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
7862+
op->getLoc(), resultType, selfReshaped.getResult(), selfTileOpMultiples);
78427863

78437864
// Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]}
78447865
auto vec2Reshaped = rewriter.create<tosa::ReshapeOp>(
@@ -7847,9 +7868,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
78477868
resultType.getElementType()),
78487869
vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
78497870

7871+
auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(),
7872+
resultShapeIndex1Replaced);
7873+
78507874
auto vec2Tiled = rewriter.create<tosa::TileOp>(
7851-
op->getLoc(), resultType, vec2Reshaped.getResult(),
7852-
rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
7875+
op->getLoc(), resultType, vec2Reshaped.getResult(), vec2TileOpMultiples);
78537876

78547877
auto result =
78557878
tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(),

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
11+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1112
#include "torch-mlir/Conversion/Utils/Utils.h"
1213
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1314

@@ -566,11 +567,12 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
566567

567568
// [0] -> [0,0,0]
568569
SmallVector<int64_t, 1> tileShape({W}); // {3}
570+
auto tileOpMultiples =
571+
tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape);
569572
auto tosaFillValuesTileOp = tosa::CreateOpAndInfer<tosa::TileOp>(
570573
rewriter, op->getLoc(),
571574
GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()),
572-
tosaFillValuesOneReshapeOp.getResult(),
573-
rewriter.getDenseI64ArrayAttr(tileShape));
575+
tosaFillValuesOneReshapeOp.getResult(), tileOpMultiples);
574576

575577
// [0,0,0] -> [[0,0,0]]
576578
SmallVector<int64_t, 2> newTosaFillValuesShape({N, W}); // {1,3}

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,5 +454,63 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
454454
return success();
455455
}
456456

457+
// Get accumulator type for TOSA convolution ops
458+
LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
459+
RankedTensorType inputTy,
460+
RankedTensorType weightTy,
461+
RankedTensorType outputTy, TypeAttr &accType) {
462+
auto inputElemTy = inputTy.getElementType();
463+
auto weightElemTy = weightTy.getElementType();
464+
auto outputElemTy = outputTy.getElementType();
465+
466+
auto quantTy = dyn_cast<quant::QuantizedType>(inputElemTy);
467+
if (quantTy)
468+
inputElemTy = quantTy.getStorageType();
469+
470+
// Get TOSA conv ops acc type based on input, weight, and output types
471+
// according to the spec:
472+
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
473+
// https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
474+
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d
475+
//
476+
// For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the
477+
// output type but does not offer any guarantee on the numerical precision
478+
// since such cases will fail TOSA validation.
479+
if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) ||
480+
(inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) ||
481+
(inputElemTy.isBF16() && weightElemTy.isBF16() &&
482+
outputElemTy.isBF16())) {
483+
accType = mlir::TypeAttr::get(rewriter.getF32Type());
484+
} else if (inputElemTy.isInteger(8) &&
485+
(weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) &&
486+
outputElemTy.isInteger(32)) {
487+
accType = mlir::TypeAttr::get(rewriter.getIntegerType(32));
488+
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
489+
outputElemTy.isInteger(48)) {
490+
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
491+
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
492+
outputElemTy.isF16()) ||
493+
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
494+
outputElemTy.isF16())) {
495+
accType = mlir::TypeAttr::get(rewriter.getF16Type());
496+
} else {
497+
accType = mlir::TypeAttr::get(outputElemTy);
498+
}
499+
500+
return success();
501+
}
502+
503+
// Temporary function to get TOSA const shape
504+
// TODO: Remove this function when getTosaConstShape is available in
505+
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
506+
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
507+
llvm::ArrayRef<int64_t> shape) {
508+
auto attr = rewriter.getIndexTensorAttr(shape);
509+
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
510+
mlir::Operation *mlir_op =
511+
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
512+
return mlir_op->getResult(0);
513+
}
514+
457515
} // namespace tosa
458516
} // namespace mlir

lib/Dialect/TMTensor/Transforms/Bufferize.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern<TMTensorOp> {
121121
};
122122

123123
namespace {
124+
125+
static Value materializeToTensor(OpBuilder &builder, TensorType type,
126+
ValueRange inputs, Location loc) {
127+
assert(inputs.size() == 1);
128+
assert(isa<BaseMemRefType>(inputs[0].getType()));
129+
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
130+
}
131+
124132
/// Converts TMTensor operations that work on tensor-type operands or results to
125133
/// work on buffers.
126134
struct TMTensorBufferizePass
@@ -133,7 +141,47 @@ struct TMTensorBufferizePass
133141
void runOnOperation() override {
134142
MLIRContext &context = getContext();
135143
ConversionTarget target(context);
136-
bufferization::BufferizeTypeConverter typeConverter;
144+
// Since the `BufferizeTypeConverter` has been removed here
145+
// https://github.com/llvm/llvm-project/commit/2ff2e871f5e632ea493efaf4f2192f8b18a54ab1,
146+
// hence we have inlined the converter here.
147+
TypeConverter typeConverter;
148+
typeConverter.addConversion([](Type type) { return type; });
149+
// Convert RankedTensorType to MemRefType.
150+
typeConverter.addConversion([](RankedTensorType type) -> Type {
151+
return MemRefType::get(type.getShape(), type.getElementType());
152+
});
153+
// Convert UnrankedTensorType to UnrankedMemRefType.
154+
typeConverter.addConversion([](UnrankedTensorType type) -> Type {
155+
return UnrankedMemRefType::get(type.getElementType(), 0);
156+
});
157+
typeConverter.addArgumentMaterialization(materializeToTensor);
158+
typeConverter.addSourceMaterialization(materializeToTensor);
159+
typeConverter.addTargetMaterialization([](OpBuilder &builder,
160+
BaseMemRefType type,
161+
ValueRange inputs,
162+
Location loc) -> Value {
163+
assert(inputs.size() == 1 && "expected exactly one input");
164+
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
165+
// MemRef to MemRef cast.
166+
assert(inputType != type && "expected different types");
167+
// Ranked to unranked casts must be explicit.
168+
auto rankedDestType = dyn_cast<MemRefType>(type);
169+
if (!rankedDestType)
170+
return nullptr;
171+
bufferization::BufferizationOptions options;
172+
options.bufferAlignment = 0;
173+
FailureOr<Value> replacement = castOrReallocMemRefValue(
174+
builder, inputs[0], rankedDestType, options);
175+
if (failed(replacement))
176+
return nullptr;
177+
return *replacement;
178+
}
179+
if (isa<TensorType>(inputs[0].getType())) {
180+
// Tensor to MemRef cast.
181+
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
182+
}
183+
llvm_unreachable("only tensor/memref input types supported");
184+
});
137185

138186
// Mark all Standard operations legal.
139187
target.addLegalDialect<arith::ArithDialect, func::FuncDialect,

lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
110110

111111
RewritePatternSet patterns(context);
112112
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
113-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
114-
std::move(patterns)))) {
113+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
115114
return signalPassFailure();
116115
}
117116
}

0 commit comments

Comments
 (0)