Skip to content

Commit 9293326

Browse files
[MLIR][TORCH] Add support for bitwise_right_shit and bitwise_and.Scalar op
Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent c434736 commit 9293326

File tree

7 files changed

+299
-18
lines changed

7 files changed

+299
-18
lines changed

e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,4 +1421,6 @@
14211421
"UniformStaticShapeModule_basic",
14221422
"AtenEmbeddingBagStaticModule_basic",
14231423
"EmptyStridedModule_basic",
1424+
"ElementwiseBitwiseAndScalarInt64Module_basic",
1425+
"ElementwiseBitwiseAndScalarInt32Module_basic",
14241426
}

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2844,6 +2844,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
28442844
}];
28452845
}
28462846

2847+
def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [
2848+
AllowsTypeRefinement,
2849+
HasValueSemantics,
2850+
ReadOnly
2851+
]> {
2852+
let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`";
2853+
let arguments = (ins
2854+
AnyTorchTensorType:$self,
2855+
AnyTorchScalarType:$other
2856+
);
2857+
let results = (outs
2858+
AnyTorchTensorType:$result
2859+
);
2860+
let hasCustomAssemblyFormat = 1;
2861+
let extraClassDefinition = [{
2862+
ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) {
2863+
return parseDefaultTorchOp(parser, result, 2, 1);
2864+
}
2865+
void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) {
2866+
printDefaultTorchOp(printer, *this, 2, 1);
2867+
}
2868+
}];
2869+
}
2870+
2871+
def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [
2872+
IsTrailingUnderscoreInplaceVariant,
2873+
AllowsTypeRefinement
2874+
]> {
2875+
let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`";
2876+
let arguments = (ins
2877+
AnyTorchTensorType:$self,
2878+
AnyTorchScalarType:$other
2879+
);
2880+
let results = (outs
2881+
AnyTorchTensorType:$result
2882+
);
2883+
let hasCustomAssemblyFormat = 1;
2884+
let extraClassDefinition = [{
2885+
ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
2886+
return parseDefaultTorchOp(parser, result, 2, 1);
2887+
}
2888+
void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) {
2889+
printDefaultTorchOp(printer, *this, 2, 1);
2890+
}
2891+
}];
2892+
}
2893+
28472894
def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [
28482895
AllowsTypeRefinement,
28492896
HasValueSemantics,
@@ -2938,6 +2985,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
29382985
}];
29392986
}
29402987

2988+
def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
2989+
AllowsTypeRefinement,
2990+
HasValueSemantics,
2991+
ReadOnly
2992+
]> {
2993+
let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`";
2994+
let arguments = (ins
2995+
AnyTorchTensorType:$self,
2996+
AnyTorchTensorType:$other
2997+
);
2998+
let results = (outs
2999+
AnyTorchTensorType:$result
3000+
);
3001+
let hasCustomAssemblyFormat = 1;
3002+
let extraClassDefinition = [{
3003+
ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) {
3004+
return parseDefaultTorchOp(parser, result, 2, 1);
3005+
}
3006+
void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) {
3007+
printDefaultTorchOp(printer, *this, 2, 1);
3008+
}
3009+
}];
3010+
}
3011+
3012+
def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [
3013+
IsTrailingUnderscoreInplaceVariant,
3014+
AllowsTypeRefinement
3015+
]> {
3016+
let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`";
3017+
let arguments = (ins
3018+
AnyTorchTensorType:$self,
3019+
AnyTorchTensorType:$other
3020+
);
3021+
let results = (outs
3022+
AnyTorchTensorType:$result
3023+
);
3024+
let hasCustomAssemblyFormat = 1;
3025+
let extraClassDefinition = [{
3026+
ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
3027+
return parseDefaultTorchOp(parser, result, 2, 1);
3028+
}
3029+
void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) {
3030+
printDefaultTorchOp(printer, *this, 2, 1);
3031+
}
3032+
}];
3033+
}
3034+
29413035
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
29423036
AllowsTypeRefinement,
29433037
HasValueSemantics,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
300300
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
301301
return b.create<arith::AndIOp>(loc, lhs, rhs);
302302
}
303+
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
304+
Type dtype = converter->convertType(bitwiseAndScalar.getType())
305+
.cast<RankedTensorType>()
306+
.getElementType();
307+
if (!dtype.isa<mlir::IntegerType>()) {
308+
bitwiseAndScalar.emitError(
309+
"bitwise_and.Scalar does not support non-integer input dtype.");
310+
return nullptr;
311+
}
312+
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
313+
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
314+
return b.create<arith::AndIOp>(loc, self, other);
315+
}
303316
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
304317
if (bitwiseOrTensor.getType()
305318
.cast<ValueTensorType>()
@@ -332,6 +345,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
332345
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
333346
return b.create<arith::XOrIOp>(loc, lhs, rhs);
334347
}
348+
if (auto bitwiseRightShiftTensor =
349+
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
350+
Type dtype = converter->convertType(bitwiseRightShiftTensor.getType())
351+
.cast<RankedTensorType>()
352+
.getElementType();
353+
if (!dtype.isa<mlir::IntegerType>()) {
354+
bitwiseRightShiftTensor.emitError(
355+
"Bitwise_Right_Shift op does not support non-integer input dtype.");
356+
return nullptr;
357+
}
358+
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
359+
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
360+
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
361+
}
335362
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
336363
MLIRContext *context = op->getContext();
337364
Type floatDtype = mlir::FloatType::getF64(context);
@@ -571,7 +598,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
571598
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
572599
if (dtype.isa<mlir::FloatType>()) {
573600
return b.create<arith::MulFOp>(loc, lhs, rhs);
574-
} else if(dtype.isa<mlir::ComplexType>()) {
601+
} else if (dtype.isa<mlir::ComplexType>()) {
575602
return b.create<complex::MulOp>(loc, lhs, rhs);
576603
} else {
577604
return b.create<arith::MulIOp>(loc, lhs, rhs);
@@ -1066,7 +1093,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10661093
.getElementType();
10671094

10681095
Value self = payloadArgs[0];
1069-
Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
1096+
Value threshold =
1097+
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
10701098
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
10711099

10721100
Value predicate;
@@ -1088,7 +1116,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10881116

10891117
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
10901118
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
1091-
Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
1119+
Value threshold =
1120+
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
10921121
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
10931122

10941123
Value predicate;
@@ -1197,10 +1226,11 @@ class ConvertElementwiseOp : public ConversionPattern {
11971226
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
11981227
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
11991228
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
1200-
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
1201-
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
1202-
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
1203-
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
1229+
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp,
1230+
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
1231+
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
1232+
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
1233+
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
12041234
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
12051235
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
12061236
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
@@ -1699,7 +1729,8 @@ class ConvertAtenDetachOp : public OpConversionPattern<AtenDetachOp> {
16991729
return failure();
17001730

17011731
Type resultType = getTypeConverter()->convertType(op.getType());
1702-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, adaptor.getSelf());
1732+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
1733+
adaptor.getSelf());
17031734
return success();
17041735
}
17051736
};
@@ -1735,16 +1766,17 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
17351766
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
17361767
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
17371768
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
1738-
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
1739-
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
1740-
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
1741-
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
1742-
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
1743-
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
1744-
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
1745-
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
1746-
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
1747-
AtenRealOp, AtenImagOp>();
1769+
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
1770+
AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp,
1771+
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
1772+
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
1773+
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp,
1774+
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp,
1775+
AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
1776+
AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp,
1777+
AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp,
1778+
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp,
1779+
AtenImagOp>();
17481780
patterns.add<ConvertElementwiseOp>(typeConverter, context);
17491781
target.addIllegalOp<AtenNllLossForwardOp>();
17501782
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7410,10 +7410,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
74107410
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
74117411
" return %0 : !torch.list<int>\n"
74127412
" }\n"
7413+
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
7414+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7415+
" return %0 : !torch.list<int>\n"
7416+
" }\n"
74137417
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
74147418
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
74157419
" return %0 : !torch.list<int>\n"
74167420
" }\n"
7421+
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
7422+
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
7423+
" return %0 : !torch.list<int>\n"
7424+
" }\n"
74177425
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
74187426
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
74197427
" return %0 : !torch.list<int>\n"
@@ -9201,6 +9209,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
92019209
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
92029210
" return %4 : !torch.int\n"
92039211
" }\n"
9212+
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
9213+
" %none = torch.constant.none\n"
9214+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9215+
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
9216+
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
9217+
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9218+
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
9219+
" return %4 : !torch.int\n"
9220+
" }\n"
92049221
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
92059222
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
92069223
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@@ -9217,6 +9234,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
92179234
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
92189235
" return %4 : !torch.int\n"
92199236
" }\n"
9237+
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_right_shift.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
9238+
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9239+
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9240+
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
9241+
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9242+
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
9243+
" return %4 : !torch.int\n"
9244+
" }\n"
92209245
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
92219246
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
92229247
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,9 +796,15 @@ def aten〇bitwise_or〇Tensor〡shape(self: List[int], other: List[int]) -> Lis
796796
def aten〇bitwise_and〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
797797
return upstream_shape_functions.broadcast(self, other)
798798

799+
def aten〇bitwise_and〇Scalar〡shape(self: List[int], other: float) -> List[int]:
800+
return upstream_shape_functions.unary(self)
801+
799802
def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
800803
return upstream_shape_functions.broadcast(self, other)
801804

805+
def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
806+
return upstream_shape_functions.broadcast(self, other)
807+
802808
def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
803809
return upstream_shape_functions.unary(self)
804810

@@ -2265,6 +2271,14 @@ def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_
22652271
dtypes = [self_dtype, other_dtype]
22662272
return promote_dtypes(ranks, dtypes)
22672273

2274+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
2275+
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
2276+
def aten〇bitwise_and〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
2277+
self_rank, self_dtype = self_rank_dtype
2278+
ranks: List[Optional[int]] = [self_rank, None]
2279+
dtypes = [self_dtype, get_dtype_of_scalar(other)]
2280+
return promote_dtypes(ranks, dtypes)
2281+
22682282
@check_dtype_function(_check_two_tensor_op())
22692283
def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
22702284
other_rank, other_dtype = other_rank_dtype
@@ -2281,6 +2295,14 @@ def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_
22812295
dtypes = [self_dtype, other_dtype]
22822296
return promote_dtypes(ranks, dtypes)
22832297

2298+
@check_dtype_function(_check_two_tensor_op())
2299+
def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
2300+
other_rank, other_dtype = other_rank_dtype
2301+
self_rank, self_dtype = self_rank_dtype
2302+
ranks: List[Optional[int]] = [self_rank, other_rank]
2303+
dtypes = [self_dtype, other_dtype]
2304+
return promote_dtypes(ranks, dtypes)
2305+
22842306
@check_dtype_function(
22852307
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
22862308
# Different width

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,10 @@ def emit_with_mutating_variants(key, **kwargs):
301301
"aten::abs : (Tensor) -> (Tensor)",
302302
"aten::reciprocal : (Tensor) -> (Tensor)",
303303
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
304+
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
304305
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
305306
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
307+
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
306308
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
307309
"aten::square : (Tensor) -> (Tensor)",
308310
"aten::unsqueeze : (Tensor, int) -> (Tensor)",

0 commit comments

Comments
 (0)