Skip to content

Commit 1fcbfa8

Browse files
schnkmwtsachink
and
sachink
authored
Implement linalg lowering of diag_embed torch op (#2885)
This PR adds lowering of diag_embed to linalg dilect. Tracked in nod-ai/SHARK-ModelDev#288 --------- Co-authored-by: sachink <[email protected]>
1 parent 99b3a5f commit 1fcbfa8

File tree

7 files changed

+438
-0
lines changed

7 files changed

+438
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8429,6 +8429,32 @@ def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [
84298429
}];
84308430
}
84318431

8432+
def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [
8433+
AllowsTypeRefinement,
8434+
HasValueSemantics,
8435+
ReadOnly
8436+
]> {
8437+
let summary = "Generated op for `aten::diag_embed : (Tensor, int, int, int) -> (Tensor)`";
8438+
let arguments = (ins
8439+
AnyTorchTensorType:$self,
8440+
Torch_IntType:$offset,
8441+
Torch_IntType:$dim1,
8442+
Torch_IntType:$dim2
8443+
);
8444+
let results = (outs
8445+
AnyTorchTensorType:$result
8446+
);
8447+
let hasCustomAssemblyFormat = 1;
8448+
let extraClassDefinition = [{
8449+
ParseResult AtenDiagEmbedOp::parse(OpAsmParser &parser, OperationState &result) {
8450+
return parseDefaultTorchOp(parser, result, 4, 1);
8451+
}
8452+
void AtenDiagEmbedOp::print(OpAsmPrinter &printer) {
8453+
printDefaultTorchOp(printer, *this, 4, 1);
8454+
}
8455+
}];
8456+
}
8457+
84328458
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
84338459
AllowsTypeRefinement,
84348460
HasValueSemantics,

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Complex/IR/Complex.h"
2020
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2121
#include "mlir/Dialect/Linalg/IR/Linalg.h"
22+
#include "mlir/Dialect/Math/IR/Math.h"
2223
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
@@ -2094,6 +2095,159 @@ class ConvertAtenDiagonalOp : public OpConversionPattern<AtenDiagonalOp> {
20942095
};
20952096
} // namespace
20962097

2098+
namespace {
2099+
class ConvertAtenDiagEmbedOp : public OpConversionPattern<AtenDiagEmbedOp> {
2100+
2101+
static SmallVector<Value>
2102+
getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor,
2103+
int64_t offset, int64_t dim1, int64_t dim2) {
2104+
auto inputType = tensor.getType().cast<RankedTensorType>();
2105+
auto inputRank = inputType.getRank();
2106+
2107+
// output tensor always has 1 extra dimension
2108+
auto resultRank = inputRank + 1;
2109+
2110+
// regardless of offset sign, output tensor is same
2111+
Value constOffset = b.create<arith::ConstantIndexOp>(loc, offset);
2112+
Value absOffset = b.create<math::AbsIOp>(loc, constOffset);
2113+
2114+
// diagonal size is determined by last input dimension
2115+
auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1);
2116+
Value diagDim = b.create<arith::AddIOp>(loc, lastInputDim, absOffset);
2117+
2118+
// output shape has same dimensions as input
2119+
// except for the diagonal dimensions
2120+
int input_dim_idx = 0;
2121+
SmallVector<Value> resultShape;
2122+
for (unsigned int i = 0; i < resultRank; i++) {
2123+
if (i == dim1 || i == dim2)
2124+
resultShape.push_back(diagDim);
2125+
else
2126+
resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++));
2127+
}
2128+
2129+
return resultShape;
2130+
}
2131+
2132+
public:
2133+
using OpConversionPattern::OpConversionPattern;
2134+
LogicalResult
2135+
matchAndRewrite(AtenDiagEmbedOp op, OpAdaptor adaptor,
2136+
ConversionPatternRewriter &rewriter) const override {
2137+
2138+
Location loc = op->getLoc();
2139+
2140+
Value input = adaptor.getSelf();
2141+
auto inputType = input.getType().cast<RankedTensorType>();
2142+
auto inputRank = inputType.getRank();
2143+
auto resultRank = inputRank + 1;
2144+
2145+
int64_t offset;
2146+
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
2147+
return rewriter.notifyMatchFailure(op, "offset is not constant");
2148+
2149+
int64_t dim1;
2150+
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
2151+
return rewriter.notifyMatchFailure(op, "dim1 is not constant");
2152+
dim1 = toPositiveDim(dim1, resultRank);
2153+
if (!isValidDim(dim1, resultRank))
2154+
return rewriter.notifyMatchFailure(
2155+
op, "dim1 can only be in closed range [" +
2156+
std::to_string(-resultRank) + "," +
2157+
std::to_string(resultRank - 1) + "]");
2158+
2159+
int64_t dim2;
2160+
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
2161+
return rewriter.notifyMatchFailure(op, "dim2 is not constant");
2162+
dim2 = toPositiveDim(dim2, resultRank);
2163+
if (!isValidDim(dim2, resultRank))
2164+
return rewriter.notifyMatchFailure(
2165+
op, "dim2 can only be in closed range [" +
2166+
std::to_string(-resultRank) + "," +
2167+
std::to_string(resultRank - 1) + "]");
2168+
2169+
if (dim1 == dim2)
2170+
return rewriter.notifyMatchFailure(op, "dim1 and dim2 can not be equal");
2171+
2172+
// add linalg.fill
2173+
Type resultElemType = inputType.getElementType();
2174+
auto resultShape =
2175+
getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2);
2176+
Value zeroTensor =
2177+
createZeroInitTensor(rewriter, loc, resultShape, resultElemType);
2178+
2179+
// add linalg.generic with diagonal access pattern affine indexing maps
2180+
SmallVector<AffineMap> indexingMaps = {
2181+
rewriter.getMultiDimIdentityMap(resultRank),
2182+
};
2183+
SmallVector<utils::IteratorType> iteratorTypes(
2184+
resultRank, utils::IteratorType::parallel);
2185+
Value resultTensor =
2186+
rewriter
2187+
.create<linalg::GenericOp>(
2188+
loc, zeroTensor.getType(), ValueRange{}, zeroTensor,
2189+
/*indexingMaps=*/indexingMaps,
2190+
/*iteratorTypes=*/iteratorTypes,
2191+
[&](OpBuilder &b, Location loc, ValueRange args) {
2192+
Value dim1Index = b.create<linalg::IndexOp>(loc, dim1);
2193+
Value dim2Index = b.create<linalg::IndexOp>(loc, dim2);
2194+
2195+
// to pick right element from input, first add all dimensions
2196+
// except last one, then last will be either dim1 or dim2
2197+
// depending upon lower or upper diagonal defined by offset
2198+
// sign
2199+
SmallVector<Value> inputIndices;
2200+
for (unsigned int i = 0; i < resultRank; i++) {
2201+
if (i != dim1 && i != dim2) {
2202+
inputIndices.push_back(b.create<linalg::IndexOp>(loc, i));
2203+
}
2204+
}
2205+
2206+
// adjust output diagonal indices and last input Index based
2207+
// on offset
2208+
Value dim1IdxAdjusted;
2209+
Value dim2IdxAdjusted;
2210+
if (offset < 0) {
2211+
Value absOffset =
2212+
b.create<arith::ConstantIndexOp>(loc, -offset);
2213+
dim1IdxAdjusted = dim1Index;
2214+
dim2IdxAdjusted =
2215+
b.create<arith::AddIOp>(loc, dim2Index, absOffset);
2216+
inputIndices.push_back(
2217+
b.create<linalg::IndexOp>(loc, dim2));
2218+
} else {
2219+
Value constOffset =
2220+
b.create<arith::ConstantIndexOp>(loc, offset);
2221+
dim1IdxAdjusted =
2222+
b.create<arith::AddIOp>(loc, dim1Index, constOffset);
2223+
dim2IdxAdjusted = dim2Index;
2224+
inputIndices.push_back(
2225+
b.create<linalg::IndexOp>(loc, dim1));
2226+
}
2227+
2228+
Value isDiagonal =
2229+
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2230+
dim1IdxAdjusted, dim2IdxAdjusted);
2231+
2232+
Value inputElem = b.create<tensor::ExtractOp>(
2233+
loc, resultElemType, input, inputIndices);
2234+
2235+
Value result = rewriter.create<arith::SelectOp>(
2236+
loc, isDiagonal, inputElem, args[0]);
2237+
b.create<linalg::YieldOp>(loc, result);
2238+
})
2239+
.getResult(0);
2240+
2241+
RankedTensorType resultType = getTypeConverter()
2242+
->convertType(op->getResult(0).getType())
2243+
.cast<RankedTensorType>();
2244+
2245+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
2246+
return success();
2247+
}
2248+
};
2249+
} // namespace
2250+
20972251
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
20982252
TypeConverter &typeConverter, RewritePatternSet &patterns,
20992253
ConversionTarget &target) {
@@ -2136,4 +2290,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
21362290
patterns.add<ConvertAtenViewAsRealOp>(typeConverter, context);
21372291
target.addIllegalOp<AtenDiagonalOp>();
21382292
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
2293+
target.addIllegalOp<AtenDiagEmbedOp>();
2294+
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
21392295
}

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8253,6 +8253,91 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
82538253
" func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
82548254
" return %arg1 : !torch.list<int>\n"
82558255
" }\n"
8256+
" func.func @\"__torch_mlir_shape_fn.aten.diag_embed\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
8257+
" %0 = call @__torch__._diag_embed_shape_helper(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
8258+
" return %0 : !torch.list<int>\n"
8259+
" }\n"
8260+
" func.func @__torch__._diag_embed_shape_helper(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
8261+
" %int-1 = torch.constant.int -1\n"
8262+
" %true = torch.constant.bool true\n"
8263+
" %none = torch.constant.none\n"
8264+
" %str = torch.constant.str \"AssertionError: \"\n"
8265+
" %int1 = torch.constant.int 1\n"
8266+
" %int0 = torch.constant.int 0\n"
8267+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8268+
" %1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n"
8269+
" %2 = torch.aten.ne.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.bool\n"
8270+
" torch.prim.If %2 -> () {\n"
8271+
" torch.prim.If.yield\n"
8272+
" } else {\n"
8273+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8274+
" torch.prim.If.yield\n"
8275+
" }\n"
8276+
" %3 = torch.aten.lt.int %arg2, %1 : !torch.int, !torch.int -> !torch.bool\n"
8277+
" torch.prim.If %3 -> () {\n"
8278+
" torch.prim.If.yield\n"
8279+
" } else {\n"
8280+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8281+
" torch.prim.If.yield\n"
8282+
" }\n"
8283+
" %4 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n"
8284+
" %5 = torch.aten.ge.int %arg2, %4 : !torch.int, !torch.int -> !torch.bool\n"
8285+
" torch.prim.If %5 -> () {\n"
8286+
" torch.prim.If.yield\n"
8287+
" } else {\n"
8288+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8289+
" torch.prim.If.yield\n"
8290+
" }\n"
8291+
" %6 = torch.aten.lt.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n"
8292+
" torch.prim.If %6 -> () {\n"
8293+
" torch.prim.If.yield\n"
8294+
" } else {\n"
8295+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8296+
" torch.prim.If.yield\n"
8297+
" }\n"
8298+
" %7 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n"
8299+
" %8 = torch.aten.ge.int %arg3, %7 : !torch.int, !torch.int -> !torch.bool\n"
8300+
" torch.prim.If %8 -> () {\n"
8301+
" torch.prim.If.yield\n"
8302+
" } else {\n"
8303+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8304+
" torch.prim.If.yield\n"
8305+
" }\n"
8306+
" %9 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
8307+
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
8308+
" %15 = torch.aten.add.int %1, %arg2 : !torch.int, !torch.int -> !torch.int\n"
8309+
" torch.prim.If.yield %15 : !torch.int\n"
8310+
" } else {\n"
8311+
" torch.prim.If.yield %arg2 : !torch.int\n"
8312+
" }\n"
8313+
" %11 = torch.aten.lt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
8314+
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
8315+
" %15 = torch.aten.add.int %1, %arg3 : !torch.int, !torch.int -> !torch.int\n"
8316+
" torch.prim.If.yield %15 : !torch.int\n"
8317+
" } else {\n"
8318+
" torch.prim.If.yield %arg3 : !torch.int\n"
8319+
" }\n"
8320+
" %13 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
8321+
" %14 = torch.prim.Loop %1, %true, init(%int0) {\n"
8322+
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
8323+
" %15 = torch.prim.ListConstruct %10, %12 : (!torch.int, !torch.int) -> !torch.list<int>\n"
8324+
" %16 = torch.aten.__contains__.int_list %15, %arg4 : !torch.list<int>, !torch.int -> !torch.bool\n"
8325+
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
8326+
" %18 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
8327+
" %19 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int \n"
8328+
" %20 = torch.aten.add.int %18, %19 : !torch.int, !torch.int -> !torch.int\n"
8329+
" %21 = torch.aten.append.t %13, %20 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
8330+
" torch.prim.If.yield %arg5 : !torch.int\n"
8331+
" } else {\n"
8332+
" %18 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int\n"
8333+
" %19 = torch.aten.append.t %13, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
8334+
" %20 = torch.aten.add.int %arg5, %int1 : !torch.int, !torch.int -> !torch.int\n"
8335+
" torch.prim.If.yield %20 : !torch.int\n"
8336+
" }\n"
8337+
" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n"
8338+
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
8339+
" return %13 : !torch.list<int>\n"
8340+
" }\n"
82568341
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
82578342
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
82588343
" return %0 : !torch.list<int>\n"
@@ -12516,6 +12601,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1251612601
" }\n"
1251712602
" return %2 : !torch.int\n"
1251812603
" }\n"
12604+
" func.func @\"__torch_mlir_dtype_fn.aten.diag_embed\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
12605+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12606+
" return %0#1 : !torch.int\n"
12607+
" }\n"
1251912608
" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.int {\n"
1252012609
" %none = torch.constant.none\n"
1252112610
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,12 @@
18781878
"DiagonalModule_with_dims_and_offset",
18791879
"DiagonalModule_with_negative_dims",
18801880
"DiagonalModule_with_offset",
1881+
"AtenDiagEmbedDefaultDiag_basic",
1882+
"AtenDiagEmbedDimDiag_basic",
1883+
"AtenDiagEmbedOffsetDiag_basic",
1884+
"AtenDiagEmbedRevDimDiag_basic",
1885+
"AtenDiagEmbedNegOffsetDiag_basic",
1886+
"AtenDiagEmbedNonDefault4DDiag_basic",
18811887
"ScatterReduceFloatMaxModuleIncludeSelf",
18821888
"ScatterReduceFloatMinModuleIncludeSelf",
18831889
"ScatterReduceFloatProdModuleIncludeSelf",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,32 @@ def _embedding_bag_helper(weight: List[int], indices: List[int],
5353

5454
return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape
5555

56+
def _diag_embed_shape_helper(self: List[int], offset: int, dim1: int, dim2: int):
57+
self_rank = len(self)
58+
result_rank = self_rank + 1
59+
60+
assert dim1 != dim2
61+
assert dim1 < result_rank
62+
assert dim1 >= -(result_rank)
63+
assert dim2 < result_rank
64+
assert dim2 >= -(result_rank)
65+
66+
if dim1 < 0:
67+
dim1 = result_rank + dim1
68+
if dim2 < 0:
69+
dim2 = result_rank + dim2
70+
71+
result_shape: List[int] = []
72+
input_dim_idx = 0
73+
for i in range(result_rank):
74+
if i in (dim1, dim2):
75+
result_shape.append(self[-1] + abs(offset))
76+
else:
77+
result_shape.append(self[input_dim_idx])
78+
input_dim_idx += 1
79+
80+
return result_shape
81+
5682
def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
5783
return upstream_shape_functions.unary(self)
5884

@@ -1057,6 +1083,20 @@ def aten〇new_empty〡shape(self: List[int], size: List[int], dtype: Optional[i
10571083
def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
10581084
return size
10591085

1086+
@check_shape_function([
1087+
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
1088+
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=3), # Test explicit dim1 and dim2.
1089+
Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=1, dim2=3), # Positive offset.
1090+
Invocation(TensorOfShape(2, 3, 4), offset=1, dim1=3, dim2=1), # Reverse dim1 and dim2
1091+
Invocation(TensorOfShape(2, 3, 4), offset=-1, dim1=1, dim2=3), # Negative offset
1092+
Invocation(TensorOfShape(2, 3, 4), offset=3), # large `offset`.
1093+
Invocation(TensorOfShape(2)), # Input one-dimensional.
1094+
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal.
1095+
ErrorInvocation(TensorOfShape(2, 3, 4), dim1=4, dim2=1), # `dim1` out of bounds.
1096+
])
1097+
def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]:
1098+
return _diag_embed_shape_helper(self, offset, dim1, dim2)
1099+
10601100
def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
10611101
return upstream_shape_functions.unary(self)
10621102

@@ -4200,6 +4240,11 @@ def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: Lis
42004240
self_rank, self_dtype = self_rank_dtype
42014241
return self_dtype if dtype is None else dtype
42024242

4243+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
4244+
def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> int:
4245+
self_rank, self_dtype = self_rank_dtype
4246+
return self_dtype
4247+
42034248
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) +
42044249
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) +
42054250
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) +

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def emit_with_mutating_variants(key, **kwargs):
561561
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
562562
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
563563
emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)")
564+
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
564565

565566
# Misc tensor ops.
566567
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

0 commit comments

Comments
 (0)