Skip to content

Commit 2787ea0

Browse files
APInt fix (#4070)
`ConvertAtenOp<ValueTensorLiteralOp>` triggers newly added assertions around `APInt` construction (see [llvm/llvm-project#106524](llvm/llvm-project#106524)) when `isSigned == false` and we pass the result of `v.getSExtValue()` that doesn't fit within the specified bit-width. For example, if we build a 32-bit `APInt` using a value of `-1` (returned by `v.getSExtValue()`), that negative number is stored in a `uint64-t` representation as `0xFF FF FF FF FF FF FF FF`. Since we are treating it as unsigned, it obviously doesn't fit into 32-bits, causing the new assertion to fail.
1 parent 32feea5 commit 2787ea0

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2913,7 +2913,7 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
29132913
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
29142914
DenseElementsAttr valueAttr =
29152915
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
2916-
return APInt(bitWidth, v.getSExtValue());
2916+
return APInt(bitWidth, v.getSExtValue(), /*isSigned=*/true);
29172917
});
29182918
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, valueAttr);
29192919
return success();

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,18 @@ func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> {
10251025
}
10261026

10271027
// -----
1028+
1029+
// CHECK-LABEL: @torch.vtensor.literal_si32$basic(
1030+
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1x512xi32>}> : () -> tensor<1x512xi32>
1031+
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<1x512xi32> -> !torch.vtensor<[1,512],si32>
1032+
// CHECK: return %[[VAL_1]] : !torch.vtensor<[1,512],si32>
1033+
func.func @torch.vtensor.literal_si32$basic() -> !torch.vtensor<[1,512],si32> {
1034+
%0 = torch.vtensor.literal(dense<-1> : tensor<1x512xsi32>) : !torch.vtensor<[1,512],si32>
1035+
return %0 : !torch.vtensor<[1,512],si32>
1036+
}
1037+
1038+
// -----
1039+
10281040
// CHECK-LABEL: func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> {
10291041
// CHECK: %[[VAL_0:.*]] = torch.constant.none
10301042
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0

0 commit comments

Comments
 (0)