Skip to content

Commit c434736

Browse files
[MLIR][TORCH] Add support for conversion to int8 dtype
Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent 71ac62f commit c434736

File tree

6 files changed

+76
-10
lines changed

6 files changed

+76
-10
lines changed

e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@
288288

289289
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
290290
"AtenEmbeddingBagStaticModule_basic",
291+
292+
# Lowering not present for this case
293+
"ElementwiseToDtypeI64ToUI8Module_basic",
291294
}
292295

293296
if torch_version_for_comparison() < version.parse("2.1.0.dev"):

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
8787
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
8888
// should be converted builtin types.
8989
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
90-
std::optional<Type> srcOriginalDtype = std::nullopt);
90+
std::optional<Type> srcOriginalDtype = std::nullopt,
91+
std::optional<Type> dstOriginalDtype = std::nullopt);
9192

9293
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
9394
Value torchOptionalInt, Value builtinInt,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
988988
Type dtype = converter->convertType(atenToDtype.getType())
989989
.cast<RankedTensorType>()
990990
.getElementType();
991-
Value result = convertScalarToDtype(b, loc, input, dtype);
991+
Type resultElementType;
992+
int64_t dtypeInt;
993+
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
994+
atenToDtype.emitError("unimplemented: dtype must be a constant integer");
995+
return nullptr;
996+
}
997+
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
998+
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt,
999+
IntegerType::Signless);
1000+
if (failed(maybeResultElementType)) {
1001+
atenToDtype.emitError("unable to convert `dtypeInt` to builtin type");
1002+
return nullptr;
1003+
}
1004+
resultElementType = *maybeResultElementType;
1005+
Value result = convertScalarToDtype(b, loc, input, dtype,
1006+
/*srcOriginalDtype=*/std::nullopt,
1007+
/*dstOriginalDtype=*/resultElementType);
9921008
return result;
9931009
}
9941010
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {

lib/Conversion/Utils/Utils.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
249249
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
250250
// should be converted builtin types.
251251
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
252-
std::optional<Type> srcOriginalDtype) {
252+
std::optional<Type> srcOriginalDtype,
253+
std::optional<Type> dstOriginalDtype) {
253254
Type scalarType = scalar.getType();
254255
if (scalarType == dtype)
255256
return scalar;
@@ -261,14 +262,20 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
261262
return false;
262263
};
263264

264-
// We only support conversion from Byte or Char scalarType not to Byte or Char
265-
// dtype.
265+
// We don't support conversion to Byte dtype.
266266
if (isByteOrChar(dtype)) {
267-
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
268-
"convertScalarToDtype "
269-
<< scalarType << "(scalar type) -> " << dtype
270-
<< "(dtype)";
271-
return nullptr;
267+
if (!dstOriginalDtype.has_value()) {
268+
mlir::emitError(loc)
269+
<< "unimplemented: for conversion to byte or char type "
270+
"dstOriginalDtype has to be passed to convertScalarToDtype";
271+
return nullptr;
272+
}
273+
if (dstOriginalDtype->isUnsignedInteger()) {
274+
mlir::emitError(loc)
275+
<< "unsupported: conversion to byte type for convertScalarToDtype "
276+
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
277+
return nullptr;
278+
}
272279
}
273280

274281
// If the dtype is i1, i.e., a boolean type.

python/torch_mlir_e2e_test/test_suite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"NativeGroupNormBackwardModule_basic",
1515
"QuantizedMLP_basic",
1616
"ReduceMaxAlongDimUnsignedInt_basic",
17+
"ElementwiseToDtypeI64ToUI8Module_basic",
1718
}
1819

1920
# TODO: Delete once torch 2.1.0 is released

python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,44 @@ def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
16421642
# ==============================================================================
16431643

16441644

1645+
class ElementwiseToDtypeI64ToI8Module(torch.nn.Module):
1646+
1647+
def __init__(self):
1648+
super().__init__()
1649+
1650+
@export
1651+
@annotate_args([None, ([-1, -1], torch.int64, True)])
1652+
def forward(self, x):
1653+
return x.to(torch.int8)
1654+
1655+
1656+
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToI8Module())
1657+
def ElementwiseToDtypeI64ToI8Module_basic(module, tu: TestUtils):
1658+
module.forward(tu.randint(3, 4, low=-100, high=100))
1659+
1660+
1661+
# ==============================================================================
1662+
1663+
1664+
class ElementwiseToDtypeI64ToUI8Module(torch.nn.Module):
1665+
1666+
def __init__(self):
1667+
super().__init__()
1668+
1669+
@export
1670+
@annotate_args([None, ([-1, -1], torch.int64, True)])
1671+
def forward(self, x):
1672+
return x.to(torch.uint8)
1673+
1674+
1675+
@register_test_case(module_factory=lambda: ElementwiseToDtypeI64ToUI8Module())
1676+
def ElementwiseToDtypeI64ToUI8Module_basic(module, tu: TestUtils):
1677+
module.forward(tu.randint(3, 4, low=-100, high=100))
1678+
1679+
1680+
# ==============================================================================
1681+
1682+
16451683
class ElementwiseLog2Module(torch.nn.Module):
16461684

16471685
def __init__(self):

0 commit comments

Comments
 (0)