diff --git a/.gitmodules b/.gitmodules index 8b46098d9615..d3b5f516f86b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "externals/llvm-project"] path = externals/llvm-project - url = https://github.com/llvm/llvm-project.git + url = https://github.com/iree-org/llvm-project.git [submodule "externals/stablehlo"] path = externals/stablehlo url = https://github.com/openxla/stablehlo.git diff --git a/externals/llvm-project b/externals/llvm-project index 21473462f762..8007c56fa699 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 21473462f762a9a2d3140eb8ecaea034f83d9a7c +Subproject commit 8007c56fa699040c6c921906a5fbfcd5c9bb0953 diff --git a/externals/stablehlo b/externals/stablehlo index 9018c682b99e..4c0d4841519a 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 9018c682b99eb20d5874a4e38271ce63d7393879 +Subproject commit 4c0d4841519aed22e3689c30b72a0e4228051249 diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index be1ea0c3221a..14df4928681a 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -11,6 +11,7 @@ #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -26,8 +27,8 @@ namespace tosa { // rounding mode Value buildRescale(PatternRewriter &rewriter, Operation *op, ShapedType output_type, Value input_val, double scale, - int64_t input_zp, int64_t output_zp, StringRef rounding_mode, - bool scale32); + int64_t input_zp, int64_t output_zp, + tosa::RoundingMode rounding_mode, bool scale32); // Creates TOSA rescale op with int32 output Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 50056bba5b96..52a8ce73cee1 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -138,7 +138,9 @@ class ConvertAtenBinaryOp : public OpConversionPattern { // tosa.minimum binaryOp = rewriter.create( op->getLoc(), outTy, lhs, rhs, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { binaryOp = tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); @@ -907,7 +909,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp rewriter.replaceOpWithNewOp( op, outTy, self, minFloatAttr, maxFloatAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); return success(); } @@ -1237,7 +1241,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(outputReduceTy), input, reduceDimAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) .getResult(); }; @@ -3925,7 +3931,9 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), selfElemType), - self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + self, dimAttr, /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { reduceOp = rewriter.create( op->getLoc(), @@ -3946,14 +3954,18 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - negateOp, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + negateOp, dimAttr, /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax argMaxOp = rewriter.create( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), - self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + self, dimAttr, /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } if (argMaxOp.getType() != indicesType) { @@ -5202,7 +5214,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outType, adaptor.getSelf(), minIntAttr, maxIntAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); } else { FloatAttr minFloatAttr, maxFloatAttr; if (outElemTy.isF16()) { @@ -5231,7 +5245,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outType, adaptor.getSelf(), minFloatAttr, maxFloatAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); } return success(); @@ -5340,13 +5356,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum auto minThresholdCheck = rewriter.create( op->getLoc(), resultType, self, min, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); // yi = min(max(xi, min_valuei), max_valuei) // Use default NaN Propagation mode "PROPAGATE" for tosa.minimum auto result = rewriter.create( op->getLoc(), resultType, minThresholdCheck, max, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get(rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)); rewriter.replaceOp(op, result); return success(); @@ -5934,7 +5954,10 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { pooledOutput = rewriter .create( op->getLoc(), outputTy, input, kernel, stride, pad, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), + tosa::NanPropagationMode::PROPAGATE)) .getResult(); } else if constexpr (std::is_same::value) { TypeAttr accType; @@ -6830,11 +6853,11 @@ ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only nearest and bilinear interpolation modes supported"); - std::string mode; + tosa::ResizeMode mode; if (pyMode == "bilinear") { - mode = "BILINEAR"; + mode = tosa::ResizeMode::BILINEAR; } else { - mode = "NEAREST_NEIGHBOR"; + mode = tosa::ResizeMode::NEAREST_NEIGHBOR; } bool alignCorners; @@ -6896,7 +6919,7 @@ ConvertAtenOp::matchAndRewrite( offset = 0; // If nearest neighbours we need to guarantee we round up. - if (mode == "NEAREST_NEIGHBOR" && alignCorners) { + if (mode == tosa::ResizeMode::NEAREST_NEIGHBOR && alignCorners) { offset += n / 2; } @@ -6916,7 +6939,8 @@ ConvertAtenOp::matchAndRewrite( tosa::getTosaConstShape(rewriter, op->getLoc(), {offset_y, offset_x}); auto border = tosa::getTosaConstShape(rewriter, op->getLoc(), {border_y, border_x}); - StringAttr modeAttr = rewriter.getStringAttr(mode); + + auto modeAttr = tosa::ResizeModeAttr::get(rewriter.getContext(), mode); auto resizeOpResult = rewriter @@ -8610,11 +8634,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Clamp input to [eps, 1 - eps] when eps is not None // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp if (!isEpsNone) { - zi = rewriter - .create( - op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) - .getResult(); + zi = + rewriter + .create( + op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr, + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) + .getResult(); } auto one = diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 02d1390ed148..036f0f2e5110 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -764,7 +765,9 @@ std::optional convertReduceOpCommon( // and tosa.reduce_max reduce_op = CreateOpAndInfer( rewriter, op->getLoc(), reduce_type, val, axis_attr, - /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, val, axis_attr); @@ -777,7 +780,7 @@ std::optional convertReduceOpCommon( RankedTensorType output_rescale_type = RankedTensorType::get(shape_vec, output_type.getElementType()); val = buildRescale(rewriter, op, output_rescale_type, val, output_scale, - 0, output_zp, "SINGLE_ROUND", true); + 0, output_zp, tosa::RoundingMode::SINGLE_ROUND, true); } // Optionally squeeze out the reduced axes. diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index bd902d8e2575..3fc11f4fa13f 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -35,8 +35,8 @@ Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter, // rounding mode Value buildRescale(PatternRewriter &rewriter, Operation *op, ShapedType output_type, Value input_val, double scale, - int64_t input_zp, int64_t output_zp, StringRef rounding_mode, - bool scale32) { + int64_t input_zp, int64_t output_zp, + tosa::RoundingMode rounding_mode, bool scale32) { int32_t multiplier; int32_t shift; @@ -70,7 +70,8 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, input_val, multiplier_val, shift_val, input_zp_val.value(), output_zp_val.value(), - rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode), + rewriter.getBoolAttr(scale32), + tosa::RoundingModeAttr::get(rewriter.getContext(), rounding_mode), rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned), rewriter.getBoolAttr(output_unsigned)); @@ -87,7 +88,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, auto output_type = input_type.clone(rewriter.getI32Type()); return buildRescale(rewriter, op, output_type, input_val, input_scale, - input_zp, 0, "SINGLE_ROUND", true); + input_zp, 0, tosa::RoundingMode::SINGLE_ROUND, true); } // Creates a TOSA rescale op based on conv2d parameters. @@ -146,7 +147,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, multiplier_val, shift_val, input_zp_val.value(), output_zp_val.value(), - rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"), + rewriter.getBoolAttr(scale32), + tosa::RoundingModeAttr::get(rewriter.getContext(), + tosa::RoundingMode::DOUBLE_ROUND), rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned), rewriter.getBoolAttr(output_unsigned)); @@ -188,7 +191,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, multiplier_val, shift_val, input_zp_val.value(), output_zp_val.value(), - rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"), + rewriter.getBoolAttr(scale32), + tosa::RoundingModeAttr::get(rewriter.getContext(), + tosa::RoundingMode::DOUBLE_ROUND), rewriter.getBoolAttr(true), rewriter.getBoolAttr(input_unsigned), rewriter.getBoolAttr(output_unsigned)); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a4888a218fae..7c1767b723bf 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -373,21 +373,21 @@ LogicalResult ClassTypeOp::verify() { // PrimLoopOp //===----------------------------------------------------------------------===// -OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getRegion()); +OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getRegion()); return getIterArgsInit(); } void PrimLoopOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { Region ®ion = getRegion(); - if (!point.getRegionOrNull()) { + if (!point.getTerminatorPredecessorOrNull()) { regions.emplace_back(®ion, region.getArguments().slice(1)); return; } - assert(point == region); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == ®ion); regions.emplace_back(®ion, region.getArguments().slice(1)); - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } bool PrimLoopOp::isForLike() { @@ -400,7 +400,7 @@ bool PrimLoopOp::isForLike() { //===----------------------------------------------------------------------===// MutableOperandRange -PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { +PrimLoopConditionOp::getMutableSuccessorOperands(RegionSuccessor successor) { // Pass all operands except the condition to the successor which is the // parent loop op. return getIterArgsMutable(); @@ -452,8 +452,8 @@ void PrimIfOp::print(OpAsmPrinter &p) { void PrimIfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. - if (point.getRegionOrNull()) { - regions.push_back(RegionSuccessor(getResults())); + if (point.getTerminatorPredecessorOrNull()) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -5321,17 +5321,18 @@ template static void getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point, SmallVectorImpl ®ions) { - if (!point.getRegionOrNull()) { + if (!point.getTerminatorPredecessorOrNull()) { // First thing the op does is branch into the calculation. regions.emplace_back(&op.getCalculation()); return; } - if (point == op.getBody()) { + Region *region = point.getTerminatorPredecessorOrNull()->getParentRegion(); + if (region == &op.getBody()) { // Body returns control to the outer op, passing through results. - regions.emplace_back(op.getResults()); + regions.emplace_back(op.getOperation(), op.getResults()); return; } - assert(point == op.getCalculation()); + assert(region == &op.getCalculation()); // Calculation branches to the body. regions.emplace_back(&op.getBody()); } @@ -5355,7 +5356,7 @@ void DtypeCalculateOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { // The shape operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. @@ -5846,7 +5847,7 @@ LogicalResult AtenKthvalueOp::verify() { //===----------------------------------------------------------------------===// MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { // The dtype operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 040b5969ce1d..a3d75a928ded 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -361,7 +361,8 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase { auto func = getOperation(); auto *context = &getContext(); RewritePatternSet patterns(context); - populateExpandTanhPattern(patterns); + math::populateExpansionPatterns(patterns, + {math::TanhOp::getOperationName()}); patterns.add(patterns.getContext()); ConversionTarget target(*context); target.addLegalDialect(); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 4eee8987f089..12f971cc9767 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1557,7 +1557,7 @@ func.func @torch.aten.isclose$basic(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !to // CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "BILINEAR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = BILINEAR} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32> // CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> // CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32> @@ -1586,7 +1586,7 @@ func.func @torch.aten.__interpolate.size_list_scale_list.bilinear(%arg0: !torch. // CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[4, 2, 4, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = "NEAREST_NEIGHBOR"} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_11:.*]] = tosa.resize %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {mode = NEAREST_NEIGHBOR} : (tensor<1x135x240x16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x270x480x16xf32> // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_11]] {perms = array} : (tensor<1x270x480x16xf32>) -> tensor<1x16x270x480xf32> // CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> // CHECK: return %[[VAL_13]] : !torch.vtensor<[1,16,270,480],f32>