-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
A small test:
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
@tvm.register_global_func("tir_print")
def tir_print(A):
print(A)
@I.ir_module
class TestCeilDiv:
@T.prim_func
def main(n: T.int32) -> None:
T.func_attr({"target": T.target("llvm")})
res : T.int32 = T.ceildiv(n, T.int32(10))
T.call_packed("tir_print", res)
mod = TestCeilDiv
exe = tvm.build(mod)
exe["main"](24)Expected behavior
Expected result of ceildiv(24,10) is 3.
Actual behavior
The result of this test is 4.
debug
The direct cause is incorrect optimization of TIR by the LowerIntrin pass.
after LowerIntrin pass, tir became:
T.Cast("int64", T.Div(n - 2147483631, 10) - -214748364)
This is obviously wrong.
In src/tir/transforms/lower_intrin.cc:156
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
In my test case, min is -2147483648+9.
b-1 is 9.
-min + (op->b - 1) = 2147483648, after const fold in src/arith/const_fold.h:80, the return value is -2147483648, because 2147483648 exceeds the upper limit of int32.
This resulted in a non-equivalent transformation.
Possible fix method
To avoid overflow of '(op->b - 1) - min' , we can assert (b-1-a_min) < MAX_VALUE+1.
That is (b_max-a_min) < MAX_VALUE + 2.
We can avoid this issue by changing the condition of the if statement in L122.
if (const_int_bound_a->min_value < 0 &&
const_int_bound_a->min_value >
(const_int_bound_b->max_value -
(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value) - 2)) {Of course, this may not be the best fix.
Because the floormod node uses similar optimization,it also has the same problem, in src/tir/transforms/lower_intrin.cc:217.
- needs-triage