Skip to content

Implement linalg lowering of diag_embed torch op #2885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8379,6 +8379,32 @@ def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [
}];
}

def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::diag_embed : (Tensor, int, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$offset,
Torch_IntType:$dim1,
Torch_IntType:$dim2
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiagEmbedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenDiagEmbedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
155 changes: 155 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
Expand Down Expand Up @@ -2094,6 +2095,158 @@ class ConvertAtenDiagonalOp : public OpConversionPattern<AtenDiagonalOp> {
};
} // namespace

namespace {
class ConvertAtenDiagEmbedOp : public OpConversionPattern<AtenDiagEmbedOp> {

static SmallVector<Value>
getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor,
int64_t offset, int64_t dim1, int64_t dim2) {
auto inputType = tensor.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();

// output tensor always has 1 extra dimension
auto resultRank = inputRank + 1;

// regardless of offset sign, output tensor is same
Value constOffset = b.create<arith::ConstantIndexOp>(loc, offset);
Value absOffset = b.create<math::AbsIOp>(loc, constOffset);

// diagonal size is determined by last input dimension
auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1);
Value diagDim = b.create<arith::AddIOp>(loc, lastInputDim, absOffset);

// output shape has same dimensions as input
// except for the diagonal dimensions
int input_dim_idx = 0;
SmallVector<Value> resultShape;
for (unsigned int i = 0; i < resultRank; i++) {
if (i == dim1 || i == dim2)
resultShape.push_back(diagDim);
else
resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++));
}

return resultShape;
}

public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenDiagEmbedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();

Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
auto resultRank = inputRank + 1;

int64_t offset;
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
return rewriter.notifyMatchFailure(op, "offset is not constant");

int64_t dim1;
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op, "dim1 is not constant");
dim1 = toPositiveDim(dim1, resultRank);
if (!isValidDim(dim1, resultRank))
return rewriter.notifyMatchFailure(
op, "dim1 can only be in closed range [" +
std::to_string(-resultRank) + "," +
std::to_string(resultRank - 1) + "]");

int64_t dim2;
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
return rewriter.notifyMatchFailure(op, "dim2 is not constant");
dim2 = toPositiveDim(dim2, resultRank);
if (!isValidDim(dim2, resultRank))
return rewriter.notifyMatchFailure(
op, "dim2 can only be between [" + std::to_string(-resultRank) + "," +
std::to_string(resultRank - 1) + "]");

if (dim1 == dim2)
return rewriter.notifyMatchFailure(op, "dim1 and dim2 can not be equal");

// add linalg.fill
Type resultElemType = inputType.getElementType();
auto resultShape =
getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, resultElemType);

// add linalg.generic with diagonal access pattern affine indexing maps
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(resultRank),
};
SmallVector<utils::IteratorType> iteratorTypes(
resultRank, utils::IteratorType::parallel);
Value resultTensor =
rewriter
.create<linalg::GenericOp>(
loc, zeroTensor.getType(), ValueRange{}, zeroTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim1Index = b.create<linalg::IndexOp>(loc, dim1);
Value dim2Index = b.create<linalg::IndexOp>(loc, dim2);

// to pick right element from input, first add all dimensions
// except last one, then last will be either dim1 or dim2
// depending upon lower or upper diagonal defined by offset
// sign
SmallVector<Value> inputIndices;
for (unsigned int i = 0; i < resultRank; i++) {
if (i != dim1 && i != dim2) {
inputIndices.push_back(b.create<linalg::IndexOp>(loc, i));
}
}

// adjust output diagonal indices and last input Index based
// on offset
Value dim1IdxAdjusted;
Value dim2IdxAdjusted;
if (offset < 0) {
Value absOffset =
b.create<arith::ConstantIndexOp>(loc, -offset);
dim1IdxAdjusted = dim1Index;
dim2IdxAdjusted =
b.create<arith::AddIOp>(loc, dim2Index, absOffset);
inputIndices.push_back(
b.create<linalg::IndexOp>(loc, dim2));
} else {
Value constOffset =
b.create<arith::ConstantIndexOp>(loc, offset);
dim1IdxAdjusted =
b.create<arith::AddIOp>(loc, dim1Index, constOffset);
dim2IdxAdjusted = dim2Index;
inputIndices.push_back(
b.create<linalg::IndexOp>(loc, dim1));
}

Value isDiagonal =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
dim1IdxAdjusted, dim2IdxAdjusted);

Value inputElem = b.create<tensor::ExtractOp>(
loc, resultElemType, input, inputIndices);

Value result = rewriter.create<arith::SelectOp>(
loc, isDiagonal, inputElem, args[0]);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);

RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -2136,4 +2289,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenViewAsRealOp>(typeConverter, context);
target.addIllegalOp<AtenDiagonalOp>();
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
target.addIllegalOp<AtenDiagEmbedOp>();
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
}
89 changes: 89 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8234,6 +8234,91 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.diag_embed\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__._diag_embed_shape_helper(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @__torch__._diag_embed_shape_helper(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %2 = torch.aten.ne.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.lt.int %arg2, %1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n"
" %5 = torch.aten.ge.int %arg2, %4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.lt.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n"
" %8 = torch.aten.ge.int %arg3, %7 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
" %15 = torch.aten.add.int %1, %arg2 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %15 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg2 : !torch.int\n"
" }\n"
" %11 = torch.aten.lt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
" %15 = torch.aten.add.int %1, %arg3 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %15 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.int\n"
" }\n"
" %13 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %14 = torch.prim.Loop %1, %true, init(%int0) {\n"
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
" %15 = torch.prim.ListConstruct %10, %12 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %16 = torch.aten.__contains__.int_list %15, %arg4 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
" %18 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int\n"
" %20 = torch.aten.add.int %18, %19 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.aten.append.t %13, %20 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %arg5 : !torch.int\n"
" } else {\n"
" %18 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.aten.append.t %13, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %20 = torch.aten.add.int %arg5, %int1 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %20 : !torch.int\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" return %13 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12481,6 +12566,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.diag_embed\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ def _embedding_bag_helper(weight: List[int], indices: List[int],

return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape

def _diag_embed_shape_helper(self: List[int], offset: int, dim1: int, dim2: int):
self_rank = len(self)
result_rank = self_rank + 1

assert dim1 != dim2
assert dim1 < result_rank
assert dim1 >= -(result_rank)
assert dim2 < result_rank
assert dim2 >= -(result_rank)

if dim1 < 0:
dim1 = result_rank + dim1
if dim2 < 0:
dim2 = result_rank + dim2

result_shape: List[int] = []
input_dim_idx = 0
for i in range(result_rank):
if i in (dim1, dim2):
result_shape.append(self[-1] + abs(offset))
else:
result_shape.append(self[input_dim_idx])
input_dim_idx += 1

return result_shape

def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -1048,6 +1074,20 @@ def aten〇new_empty〡shape(self: List[int], size: List[int], dtype: Optional[i
def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size

@check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=3), # Test explicit dim1 and dim2.
Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=1, dim2=3), # Positive offset.
Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=3, dim2=1), # Reverse dim1 and dim2
Invocation(TensorOfShape(2, 3, 4), offset=-1, dim1=1, dim2=3), # Negative offset
Invocation(TensorOfShape(2, 3, 4), offset=3), # large `offset`.
ErrorInvocation(TensorOfShape(2)), # Input one-dimensional.
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal.
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=4, dim2=1), # `dim1` out of bounds.
])
def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]:
return _diag_embed_shape_helper(self, offset, dim1, dim2)

def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -4176,6 +4216,11 @@ def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: Lis
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)")
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")

# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
Expand Down
Loading