Skip to content

Commit

Permalink
[Arith] Fix solve inequality of unbound var ranges (apache#14582)
Browse files Browse the repository at this point in the history
Fix an issue in `FindBestRange` where if `IntSet::max()` if inf, it could not take part in general arithmetic computations.
  • Loading branch information
wrongtest-intellif authored Apr 11, 2023
1 parent 9fb9fd6 commit fb2ae1a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/arith/int_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,35 @@ Range IntGroupBounds::FindBestRange(const Map<Var, Range>& vranges_addl) const {

for (const PrimExpr& low : lowers) {
for (const PrimExpr& upp : uppers) {
PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3);
// Since diff may depend on some other variables, we compute its overapproximation
PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3);
Optional<PrimExpr> diff_over;
PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3);
IntSet diff_set1 = EvalSet(diff_1, var_intsets);
if (diff_set1.HasUpperBound()) {
diff_over = analyzer.Simplify(diff_set1.max(), 3);
}

// low is the lower bound for v*coef, but we need the lower bound for v.
// We use rounding-up division to compute it. Since we want to use a single formula
PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3);

// Compute another difference which may be more precise (or not).
PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3);
PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3);

PrimExpr diff_over =
analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1;
IntSet diff_set2 = EvalSet(diff_2, var_intsets);
if (diff_set2.HasUpperBound()) {
PrimExpr diff_over_2 = analyzer.Simplify(diff_set2.max(), 3);
diff_over = diff_over.defined() ? (analyzer.CanProve(diff_over_2 - diff_over.value() < 0)
? diff_over_2
: diff_over.value())
: diff_over_2;
}

// If it is provable that the new one is strictly better than the current best one,
// then replace it. Note that we are biased towards earlier pairs which should be simpler.
if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) {
if (diff_over.defined() && (!best_diff_over.defined() ||
analyzer.CanProve(diff_over.value() - best_diff_over < 0))) {
best_lower = low_divided;
best_diff_over = diff_over;
best_diff_over = diff_over.value();
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_arith_solve_linear_inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,20 @@ def test_no_solution():
assert not rel


def test_unbound_var_range():
x = te.var("x0")
free_var = te.var("fv")
vranges = {x: tvm.ir.Range.from_min_extent(0, tvm.tir.Cast("int32", 1 + tvm.tir.log(free_var)))}
problem = [x > 3]
solution = arith.solve_linear_inequalities(
problem,
[x],
vranges,
)
assert len(solution.variables) == 1
assert len(solution.ranges) == 0
assert len(solution.relations) == 3


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit fb2ae1a

Please sign in to comment.