Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -929,8 +929,8 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array<StmtSRef>& loop_srefs,
for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
substitute_value.Set(i, is_one(loops[i]->extent)
? 0
: floordiv(floormod(fused_var, lower * loops[i]->extent), lower));
lower = lower * loops[i]->extent;
: floordiv(floormod(fused_var, loops[i]->extent * lower), lower));
lower = loops[i]->extent * lower;
}
substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower));
Stmt new_stmt = loops.back()->body;
Expand All @@ -947,7 +947,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array<StmtSRef>& loop_srefs,
SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt));
// Step 3. Generate a loop to replace the original loops
PrimExpr fused_extent = 1;
for (int i = 0; i < n; i++) {
for (int i = n - 1; i >= 0; --i) {
fused_extent *= loops[i]->extent;
}
fused_extent = analyzer.Simplify(fused_extent);
Expand Down
33 changes: 33 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,5 +822,38 @@ def before(a: T.handle):
assert warning_msg in captured



def test_fused_symbolic_2D_tiling():
@T.prim_func
def before(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None:
A = T.match_buffer(a, (M, N))
B = T.match_buffer(b, (M, N))
for i, j in T.grid(M, N):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

@T.prim_func
def expected(a: T.handle, b: T.handle, M: T.int32, N: T.int32):
A = T.match_buffer(a, (M, N))
B = T.match_buffer(b, (M, N))
for i_0_j_0_fused, i_1, j_1 in T.grid((N + 15) // 16 * ((M + 63) // 64), 64, 16):
with T.block("B"):
vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1)
vj = T.axis.spatial(N, i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1)
T.where(i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1 < M and i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1 < N)
B[vi, vj] = A[vi, vj] * T.float32(2.0)

sch = tir.Schedule(before, debug_mask="all")
block_b = sch.get_block("B")
i, j = sch.get_loops(block_b)
i0, i1 = sch.split(i, factors=[None, 64])
j0, j1 = sch.split(j, factors=[None, 16])
sch.reorder(i0, j0, i1, j1)
sch.fuse(i0, j0)
assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=before)


if __name__ == "__main__":
tvm.testing.main()
Loading