From 52a23231ed1bead691a974ea16e09b457ff26ced Mon Sep 17 00:00:00 2001 From: sachink Date: Thu, 21 Dec 2023 19:58:38 -0800 Subject: [PATCH 01/10] Added op to registry and placeholder funcs --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 12 +++++++++ .../build_tools/abstract_interp_lib_gen.py | 12 +++++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/constant_alloc.py | 21 +++++++++++++++ 5 files changed, 72 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6013f6da3cfc..0baf4f1dfd89 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7819,6 +7819,32 @@ def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [ }]; } +def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::diag_embed : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$offset, + Torch_IntType:$dim1, + Torch_IntType:$dim2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDiagEmbedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenDiagEmbedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 1031f4aa7e53..cfabf62923d6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7552,6 +7552,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " return %arg1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.diag_embed\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__._diag_embed_shape_helper(%arg0, %arg2, %arg3) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__._diag_embed_shape_helper(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11197,6 +11205,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.diag_embed\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 338f5e97e100..bc1df6f1c1a5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -53,6 +53,10 @@ def _embedding_bag_helper(weight: List[int], indices: List[int], return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape +# TODO: upstream this +def _diag_embed_shape_helper(self: List[int], dim1: int, dim2: int): + return upstream_shape_functions.unary(self) + def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) @@ -807,6 +811,9 @@ def aten〇new_empty〡shape(self: List[int], size: List[int], dtype: Optional[i def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size +def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]: + return _diag_embed_shape_helper(self, dim1, dim2) + def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -3609,6 +3616,11 @@ def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype if dtype is None else dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index efee6c852eb4..d37feb8f1579 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -537,6 +537,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)") + emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index eb0143b9d06b..c3aa2145613a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1873,3 +1873,24 @@ def forward(self, a): @register_test_case(module_factory=lambda: EmptyStridedSizeIntStrideModule()) def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + + +class AtenDiagEmbedDefaultDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedDefaultDiag()) + def AtenDiagEmbedDefaultDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) \ No newline at end of file From 9cc13939c5e34954949f5310b861d7aa5f3455fe Mon Sep 17 00:00:00 2001 From: sachink Date: Wed, 7 Feb 2024 14:04:23 -0800 Subject: [PATCH 02/10] Completed implementation of diag_embed with unit tests --- include/torch-mlir/Conversion/Utils/Utils.h | 2 + lib/Conversion/TorchToLinalg/DataMovement.cpp | 104 ++++++++++++++++++ lib/Conversion/Utils/Utils.cpp | 27 +++++ .../Transforms/AbstractInterpLibrary.cpp | 85 +++++++++++++- .../build_tools/abstract_interp_lib_gen.py | 29 ++++- .../test_suite/constant_alloc.py | 96 +++++++++++++++- 6 files changed, 335 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 516954b88fbc..4ca0ce4d2065 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -53,6 +53,8 @@ SmallVector castIndexVectorToInt64Vector(OpBuilder &b, Location loc, SmallVectorImpl &indexValues); +SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2); + Value getDimOp(OpBuilder &b, Location loc, Value v, int dim); SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index dae387422b52..08b6c4b6d72f 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1409,6 +1409,108 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenDiagEmbedOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenDiagEmbedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + + Value input = adaptor.getSelf(); + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + auto resultRank = inputRank+1; + + int64_t offset; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + return rewriter.notifyMatchFailure(op, "offset is not constant"); + + int64_t dim1; + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return rewriter.notifyMatchFailure(op, "dim1 is not constant"); + dim1 = toPositiveDim(dim1, resultRank); + if (!isValidDim(dim1, resultRank)) + return rewriter.notifyMatchFailure(op, "dim1 can only be between [" + std::to_string(-resultRank) + "," + std::to_string(resultRank-1) + "]"); + + int64_t dim2; + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) + return rewriter.notifyMatchFailure(op, "dim2 is not constant"); + dim2 = toPositiveDim(dim2, resultRank); + if (!isValidDim(dim2, resultRank)) + return rewriter.notifyMatchFailure(op, "dim2 can only be between [" + std::to_string(-resultRank) + "," + std::to_string(resultRank-1) + "]"); + + if(dim1 == dim2) + return rewriter.notifyMatchFailure(op, "dim1 and dim2 can not be equal"); + + // add linalg.fill + Type resultElemType = inputType.getElementType(); + auto resultShape = getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2); + Value zeroTensor = + createZeroInitTensor(rewriter, loc, resultShape, resultElemType); + + // add linalg.generic with diagonal access pattern affine indexing maps + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(resultRank), + }; + SmallVector iteratorTypes( + resultRank, utils::IteratorType::parallel); + Value resultTensor = + rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{}, zeroTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dim1Index = b.create(loc, dim1); + Value dim2Index = b.create(loc, dim2); + + // to pick right element from input, first add all dimensions except last one, then last will be either dim1 or dim2 depending upon lower or upper diagonal defined by offset sign + SmallVector inputIndices; + for(unsigned int i=0; i < resultRank; i++) { + if (i != dim1 && i != dim2) { + inputIndices.push_back(b.create(loc, i)); + } + } + + // adjust output diagonal indices and last input Index based on offset + Value dim1IdxAdjusted; + Value dim2IdxAdjusted; + if (offset < 0) { + Value absOffset = b.create(loc, -offset); + dim1IdxAdjusted = dim1Index; + dim2IdxAdjusted = b.create(loc, dim2Index, absOffset); + inputIndices.push_back(b.create(loc, dim2)); + } + else { + Value constOffset = b.create(loc, offset); + dim1IdxAdjusted = b.create(loc, dim1Index, constOffset); + dim2IdxAdjusted = dim2Index; + inputIndices.push_back(b.create(loc, dim1)); + } + + Value isDiagonal = b.create(loc, arith::CmpIPredicate::eq, dim1IdxAdjusted, dim2IdxAdjusted); + + Value inputElem = b.create(loc, resultElemType, input, inputIndices); + + Value result = rewriter.create(loc, isDiagonal, inputElem, args[0]); + b.create(loc, result); + }) + .getResult(0); + + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, resultTensor); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1443,4 +1545,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 3df9da94b735..bfa5332bbdc1 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -166,6 +166,33 @@ castIndexVectorToInt64Vector(OpBuilder &b, Location loc, return intValues; } +SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2) { + auto inputType = tensor.getType().cast(); + auto inputRank = inputType.getRank(); + auto resultRank = inputRank + 1; + + SmallVector resultShape; + Value constZero = b.create(loc, 0); + Value constNegOne = b.create(loc, -1); + Value constOffset = b.create(loc, offset); + Value isNegOffset = b.create(loc, arith::CmpIPredicate::slt, constOffset, constZero); + Value mulOffsetNegOne = b.create(loc, constOffset, constNegOne); + Value absOffset = b.create(loc, isNegOffset, mulOffsetNegOne, constOffset); + + auto lastInputDim = getDimOp(b, loc, tensor, inputRank-1); + Value diagDim = b.create(loc, lastInputDim, absOffset); + + int input_dim_idx = 0; + for (unsigned int i = 0; i < resultRank; i++) { + if (i == dim1 || i == dim2) + resultShape.push_back(diagDim); + else + resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++)); + } + + return resultShape; +} + Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { return b.createOrFold(loc, v, dim); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index cfabf62923d6..376844b339da 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7553,12 +7553,89 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %arg1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.diag_embed\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__._diag_embed_shape_helper(%arg0, %arg2, %arg3) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" +" %0 = call @__torch__._diag_embed_shape_helper(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__._diag_embed_shape_helper(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" +" func.func @__torch__._diag_embed_shape_helper(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %2 = torch.aten.ne.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.lt.int %arg2, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n" +" %5 = torch.aten.ge.int %arg2, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.lt.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n" +" %8 = torch.aten.ge.int %arg3, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %15 = torch.aten.add.int %1, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" %11 = torch.aten.lt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" %15 = torch.aten.add.int %1, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg3 : !torch.int\n" +" }\n" +" %13 = torch.prim.ListConstruct : () -> !torch.list\n" +" %14 = torch.prim.Loop %1, %true, init(%int0) {\n" +" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n" +" %15 = torch.prim.ListConstruct %10, %12 : (!torch.int, !torch.int) -> !torch.list\n" +" %16 = torch.aten.__contains__.int_list %15, %arg4 : !torch.list, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %18 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int\n" +" %20 = torch.aten.add.int %18, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %13, %20 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %arg5 : !torch.int\n" +" } else {\n" +" %18 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.append.t %13, %18 : !torch.list, !torch.int -> !torch.list\n" +" %20 = torch.aten.add.int %arg5, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" return %13 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bc1df6f1c1a5..df32436df722 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -54,8 +54,31 @@ def _embedding_bag_helper(weight: List[int], indices: List[int], return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape # TODO: upstream this -def _diag_embed_shape_helper(self: List[int], dim1: int, dim2: int): - return upstream_shape_functions.unary(self) +def _diag_embed_shape_helper(self: List[int], offset: int, dim1: int, dim2: int): + self_rank = len(self) + result_rank = self_rank + 1 + + assert dim1 != dim2 + assert dim1 < result_rank + assert dim1 >= -(result_rank) + assert dim2 < result_rank + assert dim2 >= -(result_rank) + + if dim1 < 0: + dim1 = result_rank + dim1 + if dim2 < 0: + dim2 = result_rank + dim2 + + result_shape: List[int] = [] + input_dim_idx = 0 + for i in range(result_rank): + if i in (dim1, dim2): + result_shape.append(self[-1] + abs(offset)) + else: + result_shape.append(self[input_dim_idx]) + input_dim_idx += 1 + + return result_shape def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) @@ -812,7 +835,7 @@ def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: L return size def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]: - return _diag_embed_shape_helper(self, dim1, dim2) + return _diag_embed_shape_helper(self, offset, dim1, dim2) def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index c3aa2145613a..4fc9bbcf61cc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1893,4 +1893,98 @@ def forward(self, a): @register_test_case(module_factory=lambda: AtenDiagEmbedDefaultDiag()) def AtenDiagEmbedDefaultDiag_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) \ No newline at end of file + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedDimDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=0, dim1=1, dim2=3) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedDimDiag()) + def AtenDiagEmbedDimDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedOffsetDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=1, dim1=1, dim2=3) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedOffsetDiag()) + def AtenDiagEmbedOffsetDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedRevDimDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=1, dim1=3, dim2=1) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedRevDimDiag()) + def AtenDiagEmbedRevDimDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +class AtenDiagEmbedNegOffsetDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=-1, dim1=1, dim2=3) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedNegOffsetDiag()) + def AtenDiagEmbedNegOffsetDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +class AtenDiagEmbedNonDefault4DDiag(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diag_embed(a, offset=-2, dim1=2, dim2=-2) + + + @register_test_case(module_factory=lambda: AtenDiagEmbedNonDefault4DDiag()) + def AtenDiagEmbedNonDefault4DDiag_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 5)) \ No newline at end of file From 0071e8b05ead2784353942acf6678c76ed139920 Mon Sep 17 00:00:00 2001 From: sachink Date: Wed, 7 Feb 2024 17:13:59 -0800 Subject: [PATCH 03/10] Linter fixes --- include/torch-mlir/Conversion/Utils/Utils.h | 4 +- lib/Conversion/TorchToLinalg/DataMovement.cpp | 121 ++++++++++-------- lib/Conversion/Utils/Utils.cpp | 19 ++- .../test_suite/constant_alloc.py | 2 +- 4 files changed, 85 insertions(+), 61 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 4ca0ce4d2065..a46290d2d5bf 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -53,7 +53,9 @@ SmallVector castIndexVectorToInt64Vector(OpBuilder &b, Location loc, SmallVectorImpl &indexValues); -SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2); +SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, + Value tensor, int64_t offset, + int64_t dim1, int64_t dim2); Value getDimOp(OpBuilder &b, Location loc, Value v, int dim); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 2ab4efa320b5..513712f2cad7 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2001,20 +2001,19 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { } // namespace namespace { -class ConvertAtenDiagEmbedOp - : public OpConversionPattern { +class ConvertAtenDiagEmbedOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDiagEmbedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - + Location loc = op->getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); - auto resultRank = inputRank+1; + auto resultRank = inputRank + 1; int64_t offset; if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) @@ -2025,23 +2024,28 @@ class ConvertAtenDiagEmbedOp return rewriter.notifyMatchFailure(op, "dim1 is not constant"); dim1 = toPositiveDim(dim1, resultRank); if (!isValidDim(dim1, resultRank)) - return rewriter.notifyMatchFailure(op, "dim1 can only be between [" + std::to_string(-resultRank) + "," + std::to_string(resultRank-1) + "]"); + return rewriter.notifyMatchFailure( + op, "dim1 can only be between [" + std::to_string(-resultRank) + "," + + std::to_string(resultRank - 1) + "]"); int64_t dim2; if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) return rewriter.notifyMatchFailure(op, "dim2 is not constant"); dim2 = toPositiveDim(dim2, resultRank); if (!isValidDim(dim2, resultRank)) - return rewriter.notifyMatchFailure(op, "dim2 can only be between [" + std::to_string(-resultRank) + "," + std::to_string(resultRank-1) + "]"); + return rewriter.notifyMatchFailure( + op, "dim2 can only be between [" + std::to_string(-resultRank) + "," + + std::to_string(resultRank - 1) + "]"); - if(dim1 == dim2) + if (dim1 == dim2) return rewriter.notifyMatchFailure(op, "dim1 and dim2 can not be equal"); // add linalg.fill Type resultElemType = inputType.getElementType(); - auto resultShape = getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2); - Value zeroTensor = - createZeroInitTensor(rewriter, loc, resultShape, resultElemType); + auto resultShape = + getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2); + Value zeroTensor = + createZeroInitTensor(rewriter, loc, resultShape, resultElemType); // add linalg.generic with diagonal access pattern affine indexing maps SmallVector indexingMaps = { @@ -2050,52 +2054,65 @@ class ConvertAtenDiagEmbedOp SmallVector iteratorTypes( resultRank, utils::IteratorType::parallel); Value resultTensor = - rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{}, zeroTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value dim1Index = b.create(loc, dim1); - Value dim2Index = b.create(loc, dim2); - - // to pick right element from input, first add all dimensions except last one, then last will be either dim1 or dim2 depending upon lower or upper diagonal defined by offset sign - SmallVector inputIndices; - for(unsigned int i=0; i < resultRank; i++) { - if (i != dim1 && i != dim2) { - inputIndices.push_back(b.create(loc, i)); - } - } - - // adjust output diagonal indices and last input Index based on offset - Value dim1IdxAdjusted; - Value dim2IdxAdjusted; - if (offset < 0) { - Value absOffset = b.create(loc, -offset); - dim1IdxAdjusted = dim1Index; - dim2IdxAdjusted = b.create(loc, dim2Index, absOffset); - inputIndices.push_back(b.create(loc, dim2)); - } - else { - Value constOffset = b.create(loc, offset); - dim1IdxAdjusted = b.create(loc, dim1Index, constOffset); - dim2IdxAdjusted = dim2Index; - inputIndices.push_back(b.create(loc, dim1)); - } - - Value isDiagonal = b.create(loc, arith::CmpIPredicate::eq, dim1IdxAdjusted, dim2IdxAdjusted); - - Value inputElem = b.create(loc, resultElemType, input, inputIndices); - - Value result = rewriter.create(loc, isDiagonal, inputElem, args[0]); - b.create(loc, result); - }) - .getResult(0); + rewriter + .create( + loc, zeroTensor.getType(), ValueRange{}, zeroTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dim1Index = b.create(loc, dim1); + Value dim2Index = b.create(loc, dim2); + + // to pick right element from input, first add all dimensions + // except last one, then last will be either dim1 or dim2 + // depending upon lower or upper diagonal defined by offset + // sign + SmallVector inputIndices; + for (unsigned int i = 0; i < resultRank; i++) { + if (i != dim1 && i != dim2) { + inputIndices.push_back(b.create(loc, i)); + } + } + + // adjust output diagonal indices and last input Index based + // on offset + Value dim1IdxAdjusted; + Value dim2IdxAdjusted; + if (offset < 0) { + Value absOffset = + b.create(loc, -offset); + dim1IdxAdjusted = dim1Index; + dim2IdxAdjusted = + b.create(loc, dim2Index, absOffset); + inputIndices.push_back( + b.create(loc, dim2)); + } else { + Value constOffset = + b.create(loc, offset); + dim1IdxAdjusted = + b.create(loc, dim1Index, constOffset); + dim2IdxAdjusted = dim2Index; + inputIndices.push_back( + b.create(loc, dim1)); + } + + Value isDiagonal = + b.create(loc, arith::CmpIPredicate::eq, + dim1IdxAdjusted, dim2IdxAdjusted); + + Value inputElem = b.create( + loc, resultElemType, input, inputIndices); + + Value result = rewriter.create( + loc, isDiagonal, inputElem, args[0]); + b.create(loc, result); + }) + .getResult(0); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); + rewriter.replaceOpWithNewOp(op, resultType, resultTensor); return success(); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index bfa5332bbdc1..3f8779642b6b 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -166,20 +166,25 @@ castIndexVectorToInt64Vector(OpBuilder &b, Location loc, return intValues; } -SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2) { +SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, + Value tensor, int64_t offset, + int64_t dim1, int64_t dim2) { auto inputType = tensor.getType().cast(); auto inputRank = inputType.getRank(); auto resultRank = inputRank + 1; - + SmallVector resultShape; Value constZero = b.create(loc, 0); Value constNegOne = b.create(loc, -1); Value constOffset = b.create(loc, offset); - Value isNegOffset = b.create(loc, arith::CmpIPredicate::slt, constOffset, constZero); - Value mulOffsetNegOne = b.create(loc, constOffset, constNegOne); - Value absOffset = b.create(loc, isNegOffset, mulOffsetNegOne, constOffset); - - auto lastInputDim = getDimOp(b, loc, tensor, inputRank-1); + Value isNegOffset = b.create(loc, arith::CmpIPredicate::slt, + constOffset, constZero); + Value mulOffsetNegOne = + b.create(loc, constOffset, constNegOne); + Value absOffset = + b.create(loc, isNegOffset, mulOffsetNegOne, constOffset); + + auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1); Value diagDim = b.create(loc, lastInputDim, absOffset); int input_dim_idx = 0; diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 4fc9bbcf61cc..540fa2d2204d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1982,7 +1982,7 @@ def __init__(self): ([-1, -1, -1, -1], torch.float32, True), ]) def forward(self, a): - return torch.ops.aten.diag_embed(a, offset=-2, dim1=2, dim2=-2) + return torch.ops.aten.diag_embed(a, offset=-2, dim1=1, dim2=-3) @register_test_case(module_factory=lambda: AtenDiagEmbedNonDefault4DDiag()) From 1bac2e8e27658afe2c447c076a6ff36edd0ffa5f Mon Sep 17 00:00:00 2001 From: sachink Date: Wed, 7 Feb 2024 21:23:55 -0800 Subject: [PATCH 04/10] Review fixes --- include/torch-mlir/Conversion/Utils/Utils.h | 4 --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 33 +++++++++++++++++++ lib/Conversion/Utils/Utils.cpp | 32 ------------------ .../build_tools/abstract_interp_lib_gen.py | 12 ++++++- 4 files changed, 44 insertions(+), 37 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index a46290d2d5bf..516954b88fbc 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -53,10 +53,6 @@ SmallVector castIndexVectorToInt64Vector(OpBuilder &b, Location loc, SmallVectorImpl &indexValues); -SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, - Value tensor, int64_t offset, - int64_t dim1, int64_t dim2); - Value getDimOp(OpBuilder &b, Location loc, Value v, int dim); SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 513712f2cad7..0a775b4644d4 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" @@ -2002,6 +2003,38 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { namespace { class ConvertAtenDiagEmbedOp : public OpConversionPattern { + + static SmallVector + getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, + int64_t offset, int64_t dim1, int64_t dim2) { + auto inputType = tensor.getType().cast(); + auto inputRank = inputType.getRank(); + + // output tensor always has 1 extra dimension + auto resultRank = inputRank + 1; + + // regardless of offset sign, output tensor is same + Value constOffset = b.create(loc, offset); + Value absOffset = b.create(loc, constOffset); + + // diagonal size is determined by last input dimension + auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1); + Value diagDim = b.create(loc, lastInputDim, absOffset); + + // output shape has same dimensions as input + // except for the diagonal dimensions + int input_dim_idx = 0; + SmallVector resultShape; + for (unsigned int i = 0; i < resultRank; i++) { + if (i == dim1 || i == dim2) + resultShape.push_back(diagDim); + else + resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++)); + } + + return resultShape; + } + public: using OpConversionPattern::OpConversionPattern; LogicalResult diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 3f8779642b6b..3df9da94b735 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -166,38 +166,6 @@ castIndexVectorToInt64Vector(OpBuilder &b, Location loc, return intValues; } -SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, - Value tensor, int64_t offset, - int64_t dim1, int64_t dim2) { - auto inputType = tensor.getType().cast(); - auto inputRank = inputType.getRank(); - auto resultRank = inputRank + 1; - - SmallVector resultShape; - Value constZero = b.create(loc, 0); - Value constNegOne = b.create(loc, -1); - Value constOffset = b.create(loc, offset); - Value isNegOffset = b.create(loc, arith::CmpIPredicate::slt, - constOffset, constZero); - Value mulOffsetNegOne = - b.create(loc, constOffset, constNegOne); - Value absOffset = - b.create(loc, isNegOffset, mulOffsetNegOne, constOffset); - - auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1); - Value diagDim = b.create(loc, lastInputDim, absOffset); - - int input_dim_idx = 0; - for (unsigned int i = 0; i < resultRank; i++) { - if (i == dim1 || i == dim2) - resultShape.push_back(diagDim); - else - resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++)); - } - - return resultShape; -} - Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { return b.createOrFold(loc, v, dim); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 88bb33290bf6..d622b74ed786 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -53,7 +53,6 @@ def _embedding_bag_helper(weight: List[int], indices: List[int], return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape -# TODO: upstream this def _diag_embed_shape_helper(self: List[int], offset: int, dim1: int, dim2: int): self_rank = len(self) result_rank = self_rank + 1 @@ -1033,6 +1032,17 @@ def aten〇new_empty〡shape(self: List[int], size: List[int], dtype: Optional[i def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=3), # Test explicit dim1 and dim2. + Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=1, dim2=3), # Positive offset. + Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=3, dim2=1), # Reverse dim1 and dim2 + Invocation(TensorOfShape(2, 3, 4), offset=-1, dim1=1, dim2=3), # Negative offset + Invocation(TensorOfShape(2, 3, 4), offset=3), # large `offset`. + ErrorInvocation(TensorOfShape(2)), # Input one-dimensional. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=4, dim2=1), # `dim1` out of bounds. +]) def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]: return _diag_embed_shape_helper(self, offset, dim1, dim2) From 36023d6ebc892dec9f62fc249b6bea7336873acd Mon Sep 17 00:00:00 2001 From: sachink Date: Wed, 7 Feb 2024 23:55:16 -0800 Subject: [PATCH 05/10] more linter fixes --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 0a775b4644d4..1c8f85899b72 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2077,7 +2077,7 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { Type resultElemType = inputType.getElementType(); auto resultShape = getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2); - Value zeroTensor = + Value zeroTensor = createZeroInitTensor(rewriter, loc, resultShape, resultElemType); // add linalg.generic with diagonal access pattern affine indexing maps @@ -2087,34 +2087,34 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { SmallVector iteratorTypes( resultRank, utils::IteratorType::parallel); Value resultTensor = - rewriter - .create( + rewriter + .create( loc, zeroTensor.getType(), ValueRange{}, zeroTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value dim1Index = b.create(loc, dim1); - Value dim2Index = b.create(loc, dim2); - + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dim1Index = b.create(loc, dim1); + Value dim2Index = b.create(loc, dim2); + // to pick right element from input, first add all dimensions // except last one, then last will be either dim1 or dim2 // depending upon lower or upper diagonal defined by offset // sign - SmallVector inputIndices; + SmallVector inputIndices; for (unsigned int i = 0; i < resultRank; i++) { if (i != dim1 && i != dim2) { - inputIndices.push_back(b.create(loc, i)); - } - } - + inputIndices.push_back(b.create(loc, i)); + } + } + // adjust output diagonal indices and last input Index based // on offset - Value dim1IdxAdjusted; - Value dim2IdxAdjusted; - if (offset < 0) { + Value dim1IdxAdjusted; + Value dim2IdxAdjusted; + if (offset < 0) { Value absOffset = b.create(loc, -offset); - dim1IdxAdjusted = dim1Index; + dim1IdxAdjusted = dim1Index; dim2IdxAdjusted = b.create(loc, dim2Index, absOffset); inputIndices.push_back( @@ -2124,23 +2124,23 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { b.create(loc, offset); dim1IdxAdjusted = b.create(loc, dim1Index, constOffset); - dim2IdxAdjusted = dim2Index; + dim2IdxAdjusted = dim2Index; inputIndices.push_back( b.create(loc, dim1)); - } - + } + Value isDiagonal = b.create(loc, arith::CmpIPredicate::eq, dim1IdxAdjusted, dim2IdxAdjusted); - + Value inputElem = b.create( loc, resultElemType, input, inputIndices); Value result = rewriter.create( loc, isDiagonal, inputElem, args[0]); - b.create(loc, result); - }) - .getResult(0); + b.create(loc, result); + }) + .getResult(0); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) From d72c1daacb5bd3fc98d08050c188c2047cc3e23a Mon Sep 17 00:00:00 2001 From: sachink Date: Thu, 21 Mar 2024 14:52:47 -0700 Subject: [PATCH 06/10] linter fixes: tool worked this time! --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 2 +- lib/Conversion/TorchToLinalg/Linear.cpp | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 1c8f85899b72..698d4c9c3c93 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2135,7 +2135,7 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { Value inputElem = b.create( loc, resultElemType, input, inputIndices); - + Value result = rewriter.create( loc, isDiagonal, inputElem, args[0]); b.create(loc, result); diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 3557b27a2eb2..1d6e0a7f36a3 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -882,14 +882,14 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (bias.getType().isa()) { Value c0; if (resultDTy.isa()) { - c0 = rewriter.create( - loc, FloatAttr::get(resultDTy, 0.0)); + c0 = rewriter.create(loc, + FloatAttr::get(resultDTy, 0.0)); } else if (resultDTy.isa()) { - c0 = rewriter.create( - loc, IntegerAttr::get(resultDTy, 0)); + c0 = rewriter.create(loc, + IntegerAttr::get(resultDTy, 0)); } - outputTensor = rewriter.create(loc, c0, initTensor) - .getResult(0); + outputTensor = + rewriter.create(loc, c0, initTensor).getResult(0); } else { auto biasType = bias.getType().cast(); @@ -1058,11 +1058,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { loc, collapsedType, weight, collapsedDims); conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); From 4fd096326a958c4b0713c5cf30e285cb0bca64ed Mon Sep 17 00:00:00 2001 From: sachink Date: Thu, 21 Mar 2024 14:59:22 -0700 Subject: [PATCH 07/10] Review fix: Updated message --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 02c3533cbce2..61607333797c 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2152,7 +2152,8 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { dim1 = toPositiveDim(dim1, resultRank); if (!isValidDim(dim1, resultRank)) return rewriter.notifyMatchFailure( - op, "dim1 can only be between [" + std::to_string(-resultRank) + "," + + op, "dim1 can only be in closed range [" + + std::to_string(-resultRank) + "," + std::to_string(resultRank - 1) + "]"); int64_t dim2; From 21296b8de45ecc936c0edc60a477c4fb8ddfbdc9 Mon Sep 17 00:00:00 2001 From: sachink Date: Thu, 21 Mar 2024 19:05:32 -0700 Subject: [PATCH 08/10] Added tests to ONNX e2e test exceptions --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 3 ++- projects/pt1/e2e_testing/xfail_sets.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 61607333797c..d2953ec22870 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2162,7 +2162,8 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { dim2 = toPositiveDim(dim2, resultRank); if (!isValidDim(dim2, resultRank)) return rewriter.notifyMatchFailure( - op, "dim2 can only be between [" + std::to_string(-resultRank) + "," + + op, "dim2 can only be in closed range [" + + std::to_string(-resultRank) + "," + std::to_string(resultRank - 1) + "]"); if (dim1 == dim2) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7f8b54dae147..aa9b742168e4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1862,6 +1862,12 @@ "DiagonalModule_with_dims_and_offset", "DiagonalModule_with_negative_dims", "DiagonalModule_with_offset", + "AtenDiagEmbedDefaultDiag_basic", + "AtenDiagEmbedDimDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", "ScatterReduceFloatProdModuleIncludeSelf", From 4a929b64653b57097783bcf82e8e716887b4d864 Mon Sep 17 00:00:00 2001 From: sachink Date: Fri, 22 Mar 2024 10:50:24 -0700 Subject: [PATCH 09/10] Fix: 1D tensor is a valid test case --- .../jit_ir_importer/build_tools/abstract_interp_lib_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d9de402b27df..60a12b8fb396 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1081,7 +1081,7 @@ def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: L Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=3, dim2=1), # Reverse dim1 and dim2 Invocation(TensorOfShape(2, 3, 4), offset=-1, dim1=1, dim2=3), # Negative offset Invocation(TensorOfShape(2, 3, 4), offset=3), # large `offset`. - ErrorInvocation(TensorOfShape(2)), # Input one-dimensional. + Invocation(TensorOfShape(2)), # Input one-dimensional. ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal. ErrorInvocation(TensorOfShape(2, 3, 4), dim1=4, dim2=1), # `dim1` out of bounds. ]) From 568fbb0c54791e0d2ab5baf06c5637500e37ad96 Mon Sep 17 00:00:00 2001 From: sachink Date: Fri, 22 Mar 2024 14:56:55 -0700 Subject: [PATCH 10/10] Ran update_abstract_interp_lib.sh --- lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 68cf5a05467e..920383ba5324 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8305,7 +8305,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %16 = torch.aten.__contains__.int_list %15, %arg4 : !torch.list, !torch.int -> !torch.bool\n" " %17 = torch.prim.If %16 -> (!torch.int) {\n" " %18 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %19 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int\n" +" %19 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int \n" " %20 = torch.aten.add.int %18, %19 : !torch.int, !torch.int -> !torch.int\n" " %21 = torch.aten.append.t %13, %20 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield %arg5 : !torch.int\n"