Skip to content

Commit 1690320

Browse files
Update llvm-project to 5d6d982df61d16b6d498e6d59dd91c059679d3d8 Update stablehlo to b62dc66da9946b4c400c0d99c9d5bb8e04edaee6 Co-authored-by: Justin Ngo <[email protected]> --------- Signed-off-by: Justin Ngo <[email protected]> Signed-off-by: Praveen G <[email protected]> Co-authored-by: Justin Ngo <[email protected]>
1 parent 4b9b972 commit 1690320

File tree

8 files changed

+938
-540
lines changed

8 files changed

+938
-540
lines changed

externals/llvm-project

Submodule llvm-project updated 4066 files

externals/stablehlo

Submodule stablehlo updated 52 files

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,6 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
127127
RankedTensorType weightTy,
128128
RankedTensorType outputTy, TypeAttr &accType);
129129

130-
// Temporary function to get TOSA const shape
131-
// TODO: Remove this function when getTosaConstShape is available in
132-
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
133-
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
134-
llvm::ArrayRef<int64_t> shape);
135130
} // namespace tosa
136131
} // namespace mlir
137132

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 395 additions & 123 deletions
Large diffs are not rendered by default.

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
351351

352352
// %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) ->
353353
// tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix.
354-
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
354+
Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
355355
rewriter, op->getLoc(),
356356
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
357357
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
@@ -378,13 +378,18 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
378378
if (!flattenedCoeffValue)
379379
return std::nullopt;
380380

381+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp,
382+
flattenedCoeffValue.value())
383+
.failed())
384+
return std::nullopt;
385+
381386
// Multiply the coefficients by the coordinates
382387
// %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>,
383388
// tensor<3xi32>) -> tensor<8x3xi32>
384389
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
385390
rewriter, op->getLoc(),
386391
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
387-
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
392+
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);
388393

389394
// Sum up the products of the coefficients and coordinates
390395
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
@@ -616,7 +621,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
616621
// [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]]
617622
// %11 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>)
618623
// -> tensor<3x2xi32>
619-
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
624+
Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
620625
rewriter, op->getLoc(),
621626
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
622627
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
@@ -643,14 +648,19 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
643648
if (!flattenedCoeffValue)
644649
return std::nullopt;
645650

651+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp,
652+
flattenedCoeffValue.value())
653+
.failed())
654+
return std::nullopt;
655+
646656
// Multiply the coefficients by the coordinates.
647657
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
648658
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
649659
// tensor<2xi32>) -> tensor<3x2xi32>
650660
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
651661
rewriter, op->getLoc(),
652662
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
653-
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
663+
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);
654664

655665
// Sum up the products of the coefficients and coordinates
656666
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
@@ -734,10 +744,20 @@ std::optional<Value> convertReduceOpCommon(
734744
RankedTensorType reduce_type =
735745
RankedTensorType::get(shape_vec, reduce_element_type);
736746

737-
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
738-
val, axis_attr);
747+
Value reduce_op;
748+
if constexpr (std::is_same<T, tosa::ReduceMinOp>() ||
749+
std::is_same<T, tosa::ReduceMaxOp>()) {
750+
// Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min
751+
// and tosa.reduce_max
752+
reduce_op = CreateOpAndInfer<T>(
753+
rewriter, op->getLoc(), reduce_type, val, axis_attr,
754+
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
755+
} else {
756+
reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
757+
val, axis_attr);
758+
}
739759

740-
val = reduce_op.getResult();
760+
val = reduce_op;
741761
}
742762

743763
if (is_quantized) {
@@ -973,6 +993,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
973993

974994
if (!input_is_qtype) {
975995
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
996+
997+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), val.value(),
998+
div_const)
999+
.failed())
1000+
return std::nullopt;
1001+
9761002
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
9771003
val.value(), div_const, 0)
9781004
.getResult();
@@ -1021,6 +1047,11 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
10211047
return std::nullopt;
10221048
}
10231049

1050+
Value ordValRank0 = ordVal;
1051+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input_value, ordVal)
1052+
.failed())
1053+
return std::nullopt;
1054+
10241055
if (fabs(ordLiteralFloat) < epsilon ||
10251056
fabs(static_cast<double>(ordLiteralInt)) < epsilon) {
10261057
op->emitOpError("unimplemented: L0 norm");
@@ -1049,9 +1080,17 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
10491080
rewriter, op, output_type, powVal, axes_elems, keep_dims);
10501081
if (!result)
10511082
return std::nullopt;
1052-
auto reciprocalVal = CreateOpAndInfer<tosa::ReciprocalOp>(
1053-
rewriter, op->getLoc(), ordVal.getType(), ordVal)
1054-
.getResult();
1083+
1084+
Value reciprocalVal =
1085+
CreateOpAndInfer<tosa::ReciprocalOp>(rewriter, op->getLoc(),
1086+
ordValRank0.getType(), ordValRank0)
1087+
.getResult();
1088+
1089+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), result.value(),
1090+
reciprocalVal)
1091+
.failed())
1092+
return std::nullopt;
1093+
10551094
return CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(), output_type,
10561095
result.value(), reciprocalVal)
10571096
.getResult();

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
11-
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
11+
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
12+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1213
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
1314

1415
namespace mlir {
@@ -301,31 +302,31 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
301302
(src.isF32() && dest.isInteger(8)) ||
302303
(src.isF32() && dest.isBF16()) ||
303304
(src.isF32() && dest.isF16()) ||
304-
(src.isF32() && dest.isFloat8E4M3()) ||
305-
(src.isF32() && dest.isFloat8E5M2()) ||
305+
(src.isF32() && isa<Float8E4M3Type>(dest)) ||
306+
(src.isF32() && isa<Float8E5M2Type>(dest)) ||
306307
// f16 -> *
307308
(src.isF16() && dest.isInteger(32)) ||
308309
(src.isF16() && dest.isInteger(16)) ||
309310
(src.isF16() && dest.isInteger(8)) ||
310311
(src.isF16() && dest.isBF16()) ||
311312
(src.isF16() && dest.isF32()) ||
312-
(src.isF16() && dest.isFloat8E4M3()) ||
313-
(src.isF16() && dest.isFloat8E5M2()) ||
313+
(src.isF16() && isa<Float8E4M3Type>(dest)) ||
314+
(src.isF16() && isa<Float8E5M2Type>(dest)) ||
314315
// bf16 -> *
315316
(src.isBF16() && dest.isInteger(32)) ||
316317
(src.isBF16() && dest.isInteger(16)) ||
317318
(src.isBF16() && dest.isInteger(8)) ||
318319
(src.isBF16() && dest.isF32()) ||
319-
(src.isBF16() && dest.isFloat8E4M3()) ||
320-
(src.isBF16() && dest.isFloat8E5M2()) ||
320+
(src.isBF16() && isa<Float8E4M3Type>(dest)) ||
321+
(src.isBF16() && isa<Float8E5M2Type>(dest)) ||
321322
// fp8e4m3 -> *
322-
(src.isFloat8E4M3() && dest.isBF16()) ||
323-
(src.isFloat8E4M3() && dest.isF32()) ||
324-
(src.isFloat8E4M3() && dest.isF16()) ||
323+
(isa<Float8E4M3Type>(src) && dest.isBF16()) ||
324+
(isa<Float8E4M3Type>(src) && dest.isF32()) ||
325+
(isa<Float8E4M3Type>(src) && dest.isF16()) ||
325326
// fp8e5m2 -> *
326-
(src.isFloat8E5M2() && dest.isBF16()) ||
327-
(src.isFloat8E5M2() && dest.isF32()) ||
328-
(src.isFloat8E5M2() && dest.isF16())) {
327+
(isa<Float8E5M2Type>(src) && dest.isBF16()) ||
328+
(isa<Float8E5M2Type>(src) && dest.isF32()) ||
329+
(isa<Float8E5M2Type>(src) && dest.isF16())) {
329330
return success();
330331
}
331332
// clang-format on
@@ -393,6 +394,11 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
393394
auto zeroValue =
394395
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();
395396

397+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
398+
.failed())
399+
return rewriter.notifyMatchFailure(
400+
op, "Failed to equalize ranks among operands and result");
401+
396402
auto boolType = srcType.clone(rewriter.getIntegerType(1));
397403
auto isNegative = tosa::CreateOpAndInfer<tosa::GreaterOp>(
398404
rewriter, op->getLoc(), boolType, zeroValue, src);
@@ -488,10 +494,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
488494
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
489495
outputElemTy.isInteger(48)) {
490496
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
491-
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
492-
outputElemTy.isF16()) ||
493-
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
494-
outputElemTy.isF16())) {
497+
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
498+
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
499+
(isa<Float8E5M2Type>(inputElemTy) &&
500+
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
495501
accType = mlir::TypeAttr::get(rewriter.getF16Type());
496502
} else {
497503
accType = mlir::TypeAttr::get(outputElemTy);
@@ -500,17 +506,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
500506
return success();
501507
}
502508

503-
// Temporary function to get TOSA const shape
504-
// TODO: Remove this function when getTosaConstShape is available in
505-
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
506-
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
507-
llvm::ArrayRef<int64_t> shape) {
508-
auto attr = rewriter.getIndexTensorAttr(shape);
509-
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
510-
mlir::Operation *mlir_op =
511-
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
512-
return mlir_op->getResult(0);
513-
}
514-
515509
} // namespace tosa
516510
} // namespace mlir

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
652652
return rewriter.getF32Type();
653653
if (isa<Float64Type>(inputType))
654654
return rewriter.getF64Type();
655-
if (inputType.isFloat8E5M2())
655+
if (isa<Float8E5M2Type>(inputType))
656656
return rewriter.getF32Type();
657-
if (inputType.isFloat8E4M3FN())
657+
if (isa<Float8E4M3FNType>(inputType))
658658
return rewriter.getF32Type();
659-
if (inputType.isFloat8E5M2FNUZ())
659+
if (isa<Float8E5M2FNUZType>(inputType))
660660
return rewriter.getF32Type();
661-
if (inputType.isFloat8E4M3FNUZ())
661+
if (isa<Float8E4M3FNUZType>(inputType))
662662
return rewriter.getF32Type();
663663
if (inputType.isInteger(8))
664664
// this is an intentional deviation from CUDA (which accumulates i8 to i64)

0 commit comments

Comments
 (0)