From f622e7f180bdb95c9d3121bb1f78fe459a57812a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 11 Apr 2023 16:59:29 -0400 Subject: [PATCH] [ARITH][BUGFIX] Fix a bug of iter map floormod(x,2) simplify (#14571) This PR fixes a previous bug introduced in itermap detection. Specifically, y - (x % 2) were simplified to y + (x % 2) - 1. Which is wrong. The working rule is y + ((x + 1) % 2) - 1, but that rule will change the base iterator which is not desirable here. We also removed the rule that simplifies (x + 1) % 2 => 1 - x % 2 as benefit is minimal and it introduces extra negative co-efficients that hurts analysis in general (as negative co-efficients are harder in many cases). --- src/arith/iter_affine_map.cc | 5 -- src/arith/rewrite_simplify.cc | 17 +++- .../unittest/test_arith_canonical_simplify.py | 7 ++ .../unittest/test_arith_iter_affine_map.py | 10 +-- .../unittest/test_arith_rewrite_simplify.py | 22 +++-- ..._tir_transform_inject_software_pipeline.py | 90 ++++++++++--------- 6 files changed, 86 insertions(+), 65 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index e7fc4f2663f8..05af5b40702d 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -898,11 +898,6 @@ class IterMapRewriter : public ExprMutator { PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { - if (sign < 0 && is_const_int(rhs->extent, 2)) { - lhs->base -= rhs->scale; - sign = 1; - } - tir::ExprDeepEqual equal; for (size_t i = 0; i < lhs->args.size(); ++i) { IterSplitExpr lvalue = lhs->args[i]; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 528a272ef482..acd74b7031e7 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -306,6 +306,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2)); + // Simplify (x + 1) % 2 + x % 2 => 1 + // NOTE: we should avoid simplifying (x + 1) %2 => 1 - x % 2 though + // mainly because introducing extra negative signs to expression can harm itertaor + // analysis which usually relies on positive itertator co-efficients. + TVM_TRY_REWRITE_IF(floormod(x + c1, 2) + floormod(x, 2), OneWithTypeLike(x), + floormod(c1.Eval()->value, 2) == 1); + TVM_TRY_REWRITE_IF(floormod(x, 2) + floormod(x + c1, 2), OneWithTypeLike(x), + floormod(c1.Eval()->value, 2) == 1); + // canonicalization rule // will try rewrite again after canonicalization. @@ -1018,10 +1027,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), c2.Eval()->value > 0); - TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1, - floormod(c1.Eval()->value, 2) == 1); - TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + // (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x + TVM_TRY_REWRITE_IF( + floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2), + c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value || c1.Eval()->value < 0)); TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), c2.Eval()->value > 0); diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 1a4277d92453..c1d7587f430b 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -415,5 +415,12 @@ def test_proddiv_simplify(): ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z)) +def test_floormod_two(): + ck = CanonicalChecker() + flm = tvm.te.floormod + x, y = te.var("x"), te.var("y") + ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 0bb4c98b7b15..5ce729604504 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -199,14 +199,14 @@ def test_compound(): assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)])) -def test_compound_floormod_two(): +def test_compound_floormod_two_regression(): x = tvm.tir.Var("x", "int32") fld = tvm.tir.floordiv flm = tvm.tir.floormod - - # extent of 2 are normalized to positive scale - assert_iter_sum_pattern( - expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)}, + # regression + # extent of 2 of negative scale cannot be normalized + assert_iter_sum_failure( + [fld(x, 2) * 2 - flm(x, 2) + 1], dom_map=var_dom([(x, 8)]), ) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 7ecc34c385b6..46ac0f975157 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -392,8 +392,8 @@ class TestSubIndex(BaseCompare): TestCase(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)), TestCase(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1), TestCase(fld(y, 3) * 3 - y, 0 - flm(y, 3)), - TestCase(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6), - TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)), + TestCase(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6), + TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)), TestCase(y - fld(y + z, 5) * 5, flm(y + z, 5) - z), TestCase(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)), TestCase(y - fld(y - z, 5) * 5, flm(y - z, 5) + z), @@ -554,13 +554,15 @@ class TestFloormodIndex(BaseCompare): TestCase(flm(x + 10, 2), flm(x, 2)), TestCase(flm(x + y * 10, 2), flm(x, 2)), TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)), - TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1), TestCase(flm(x * (-10), 2), 0), TestCase(flm(x * (-10) + y, 2), flm(y, 2)), TestCase(flm(x + (-10), 2), flm(x, 2)), TestCase(flm(x + y * (-10), 2), flm(x, 2)), TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]), TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]), + # NOTE: the followng case is covered by canonical simplify + # long range simplifcation in general can be covered by canonical simplify + # TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1), ) @@ -574,13 +576,14 @@ class TestFloorModTwo(BaseCompare): require identifying more related terms in order to apply. (x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2) + + We should not introduce extra negative coeficient to iterators + however during simplification """ x, y, z = te.var("x"), te.var("y"), te.var("z") test_case = tvm.testing.parameter( # Removing offsets from floormod - TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1), - TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1), TestCase(flm(x, 2) + flm(x + 1, 2), 1), TestCase(flm(x + 1, 2) + flm(x, 2), 1), # Difference of floordiv yields floormod @@ -592,8 +595,13 @@ class TestFloorModTwo(BaseCompare): # Sum of floordiv and floormod to yield floordiv TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)), TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)), - # Removal of floormod where possible - TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]), + # regression: although we can rewrite (x + 1) %2 => 1 - x%2 + # doing so would introduce negative co-efficient to iterators + # which makes later iter map detection harder, in principle we + # should not introduce additional negative signs of iterator in rewriting + TestCase(flm(x + 1, 2), flm(x + 1, 2)), + TestCase(flm(x + 5, 2), flm(x + 1, 2)), + TestCase(flm(x + 1, 2) * 8192, flm(x + 1, 2) * 8192, [x >= 0, x < 2]), ) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 7e59172bdd83..b9f35ed553e1 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -139,8 +139,8 @@ def transformed_simple_compute( for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1]]) - T.writes([B[1 - i % 2, tx, 0]]) - B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + T.writes([B[(i + 1) % 2, tx, 0]]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) with T.block(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) @@ -202,8 +202,8 @@ def transformed_simple_compute_with_other_annotation( ): with T.block(): T.reads([A[tx, i + 1]]) - T.writes([B[1 - i % 2, tx, 0]]) - B[1 - i % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + T.writes([B[(i + 1) % 2, tx, 0]]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) with T.block(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) @@ -266,7 +266,7 @@ def transformed_three_stage_compute( T.where(i == 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) @@ -278,7 +278,7 @@ def transformed_three_stage_compute( with T.block(): T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[1 - i % 2, tx, 0] = B[1 - i % 2, tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i]) @@ -291,7 +291,7 @@ def transformed_three_stage_compute( T.where(i < 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[1 - i, tx, 0] = B[1 - i, tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i + 14]) @@ -391,12 +391,12 @@ def transformed_dag_interleaving( BS[tx, 0] = B[tx, i + 1] + T.float32(2) with T.block(): T.reads(AS[tx, 0]) - T.writes(AL[1 - i % 2, 0, 0]) - AL[1 - i % 2, 0, 0] = AS[tx, 0] + T.writes(AL[(i + 1) % 2, 0, 0]) + AL[(i + 1) % 2, 0, 0] = AS[tx, 0] with T.block(): T.reads(BS[tx, 0]) - T.writes(BL[1 - i % 2, 0, 0]) - BL[1 - i % 2, 0, 0] = BS[tx, 0] + T.writes(BL[(i + 1) % 2, 0, 0]) + BL[(i + 1) % 2, 0, 0] = BS[tx, 0] with T.block(): T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0]) T.writes(C[tx, i]) @@ -475,12 +475,12 @@ def transformed_nested_pipeline_simple( for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1, 0:16]]) - T.writes([A_shared[1 - i % 2, tx, 0, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A[tx, i + 1, j]]) - T.writes([A_shared[1 - i % 2, tx, 0, j]]) - A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j] + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] with T.block(): T.reads([A_shared[i % 2, tx, i, 0]]) T.writes([B[0, tx, i, 0]]) @@ -491,10 +491,10 @@ def transformed_nested_pipeline_simple( for j in T.serial(0, 15): with T.block(): T.reads([A_shared[i % 2, tx, i, j + 1]]) - T.writes([B[1 - j % 2, tx, i, 0]]) - B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32( - 2 - ) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) @@ -516,8 +516,8 @@ def transformed_nested_pipeline_simple( for j in T.serial(0, 15): with T.block(): T.reads([A_shared[1, tx, 15, j + 1]]) - T.writes([B[1 - j % 2, tx, 15, 0]]) - B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) @@ -603,30 +603,30 @@ def transformed_nested_pipeline_prefetch_inner( for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1, 0:16]]) - T.writes([A_shared[1 - i % 2, tx, 0, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A[tx, i + 1, j]]) - T.writes([A_shared[1 - i % 2, tx, 0, j]]) - A_shared[1 - i % 2, tx, 0, j] = A[tx, i + 1, j] + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] with T.block(): T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) for j in T.serial(0, 15): with T.block(): T.reads([A_shared[i % 2, tx, i, j + 1]]) - T.writes([B[1 - j % 2, tx, i, 0]]) - B[1 - j % 2, tx, i, 0] = A_shared[i % 2, tx, 0, j + 1] * T.float32( - 2 - ) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) with T.block(): - T.reads([A_shared[1 - i % 2, tx, i + 1, 0]]) + T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) - B[0, tx, i + 1, 0] = A_shared[1 - i % 2, tx, 0, 0] * T.float32(2) + B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2) with T.block(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) @@ -640,8 +640,8 @@ def transformed_nested_pipeline_prefetch_inner( for j in T.serial(0, 15): with T.block(): T.reads([A_shared[1, tx, 15, j + 1]]) - T.writes([B[1 - j % 2, tx, 15, 0]]) - B[1 - j % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) @@ -768,8 +768,8 @@ def transformed_nested_pipeline_interleaving( for j in T.serial(0, 15): with T.block(): T.reads([A_local[tx, i, j + 1]]) - T.writes([B[1 - j % 2, tx, i, 0]]) - B[1 - j % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) @@ -799,8 +799,8 @@ def transformed_nested_pipeline_interleaving( for j in T.serial(0, 15): with T.block(): T.reads([A_local[tx, 15, j + 1]]) - T.writes([B[1 - j % 2, tx, 15, 0]]) - B[1 - j % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]]) @@ -929,25 +929,27 @@ def transformed_nested_pipeline_double_buffer( for j in T.serial(0, 15): with T.block(): T.reads([A_local[i % 2, tx, i, j + 1]]) - T.writes([B[1 - j % 2, tx, i, 0]]) - B[1 - j % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32(2) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32( + 2 + ) with T.block(): T.reads([B[j % 2, tx, i, 0]]) T.writes([C[tx, i, j]]) C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) with T.block(): T.reads([A_shared[tx, 0, 0:16]]) - T.writes([A_local[1 - i % 2, 0, 0, 0:16]]) + T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) for j in T.serial(0, 16): with T.block(): T.reads([A_shared[tx, 0, j]]) - T.writes([A_local[1 - i % 2, 0, 0, j]]) + T.writes([A_local[(i + 1) % 2, 0, 0, j]]) T.block_attr({"double_buffer_scope": 0}) - A_local[1 - i % 2, 0, 0, j] = A_shared[tx, i + 1, j] + A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] with T.block(): - T.reads([A_local[1 - i % 2, tx, i + 1, 0]]) + T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]]) T.writes([B[0, tx, i + 1, 0]]) - B[0, tx, i + 1, 0] = A_local[1 - i % 2, 0, 0, 0] * T.float32(2) + B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2) with T.block(): T.reads([B[1, tx, i, 0]]) T.writes([C[tx, i, 15]]) @@ -961,8 +963,8 @@ def transformed_nested_pipeline_double_buffer( for j in T.serial(0, 15): with T.block(): T.reads([A_local[1, tx, 15, j + 1]]) - T.writes([B[1 - j % 2, tx, 15, 0]]) - B[1 - j % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) with T.block(): T.reads([B[j % 2, tx, 15, 0]]) T.writes([C[tx, 15, j]])