From 851669a2c2b8977f97d705c52c9a40531062d82a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 23 Dec 2024 20:54:55 -0500 Subject: [PATCH 1/8] Split Hopper MMA by warp-tile before instruction tile Fixes #3636 --- csrc/scheduler/hopper_multi_matmul.cpp | 25 ++++++--- tests/cpp/test_matmul.cpp | 73 ++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 8 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 2fa0a40ab75..e1095c4289a 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -38,16 +38,25 @@ void HopperMultipleMatmulScheduler::transformLikeMmaOutput( return (is_mma_result) ? idx - 1 : idx; }; - // Original: [..., Mo, No, Mi, Ni] + // The input is originally block tiled so that the inner dims are the CTA tile + // size + // Original: [..., M, N(, K)] + // We split this into warp tiles then instruction tiles + tv->split(apply_k_dim_offset(-2), params_->tile_sizes.warp_tile.m); tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro)); + tv->split(apply_k_dim_offset(-1), params_->tile_sizes.warp_tile.n); tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro)); - // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] - tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}}); - // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] - tv->merge(apply_k_dim_offset(-4)); - // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] - tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] + // After Split: [..., Mo, Mw, Mi, No, Nw, Nwi] + tv->reorder({ + {apply_k_dim_offset(-3), apply_k_dim_offset(-5)}, + {apply_k_dim_offset(-2), apply_k_dim_offset(-3)}, + }); + // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] + + tv->merge(apply_k_dim_offset(-6)); + // After Merge: [..., Mo * No, Mio, Nio, Mii, Nii] + tv->axis(apply_k_dim_offset(-5))->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] } MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 9e9395c5e18..99ada8776e3 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4243,4 +4243,77 @@ TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) { EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); } +// This tests that we can use a small instruction tile with a medium size +// warpgroup tile and a large CTA tile. +TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {0}); + + // Reorder the accumulator as [M, N, K] + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({K, M, 1}, options); + auto b_ref = at::randn({K, 1, N}, options); + auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + // Regardless of the instruction, this should result in 2 warp groups i.e. 256 + // threads + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_8_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + EXPECT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + auto cg_outputs = ke.run(inputs); + + // Check number of launched threads matches what we expect + EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128); + EXPECT_EQ(ke.lastLaunchParams().bdimy(), 2) + << " expected 2 warp groups (BIDy == 2) but found BIDy==" + << ke.lastLaunchParams().bdimy(); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser From 8b42cd6b6988611b5dcf3acb17052e1e2d8ee4e4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 31 Dec 2024 08:46:16 -0500 Subject: [PATCH 2/8] Use 4 warpgroups, disable smem epilogue --- tests/cpp/test_matmul.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 99ada8776e3..657b6cc9e95 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4276,7 +4276,7 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { // Regardless of the instruction, this should result in 2 warp groups i.e. 256 // threads gemm_tile.cta_tile = GemmTile(128, 256, 16); - gemm_tile.warp_tile = GemmTile(64, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 128, 16); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4289,9 +4289,12 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { mparams.circular_buffer_options.smem_circular_buffer_stage = 4; mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; mparams.splitk_factor = 1; - mparams.use_smem_epilogue = true; + // NOTE: disabling smem use for this test since we currrently hit a bank + // conflict. + // TODO: enable smem epilogue once stmatrix is updated + mparams.use_smem_epilogue = false; mparams.cluster_dims = {2, 1, 1}; - mparams.promote_prologue_smem_reuse = true; + mparams.promote_prologue_smem_reuse = false; SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) ->schedule(&fusion, &mparams); @@ -4308,8 +4311,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { // Check number of launched threads matches what we expect EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128); - EXPECT_EQ(ke.lastLaunchParams().bdimy(), 2) - << " expected 2 warp groups (BIDy == 2) but found BIDy==" + EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4) + << " expected 4 warp groups (BIDy==4) but found BIDy==" << ke.lastLaunchParams().bdimy(); // Relax tolerance for larger sum due to large K From 521d5ccad1ec1fa10f607d87887d52a1b6147407 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 31 Dec 2024 11:22:05 -0500 Subject: [PATCH 3/8] Use warp_tile for tma_m and tma_n --- csrc/scheduler/hopper_multi_matmul.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index ba8b35b7dc2..aeb0fbc0325 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -499,8 +499,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // tile is a multiple of the macro size because stmatrix stores results from // wgmma to shared memory. For maximum inlining and to reduce shared memory // usage, the tma tile is mma_macro size. - const int64_t tma_m = getM(params_->mma_macro); - const int64_t tma_n = getN(params_->mma_macro); + const int64_t tma_m = params_->tile_sizes.warp_tile.m; + const int64_t tma_n = params_->tile_sizes.warp_tile.n; fusion_->manage("st_matrix_m_tile", stmatrix_tile_m); fusion_->manage("st_matrix_n_tile", stmatrix_tile_n); From dce16ad9da1dfd566131cbfeab947e54da815ca0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 2 Jan 2025 09:56:52 -0500 Subject: [PATCH 4/8] Two warp tiles per CTA in each dim, increase instr to 64_64_16 --- tests/cpp/test_matmul.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 657b6cc9e95..521913772f1 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4275,18 +4275,18 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { MatMulTileOptions gemm_tile; // Regardless of the instruction, this should result in 2 warp groups i.e. 256 // threads - gemm_tile.cta_tile = GemmTile(128, 256, 16); - gemm_tile.warp_tile = GemmTile(64, 128, 16); + gemm_tile.cta_tile = GemmTile(256, 256, 16); + gemm_tile.warp_tile = GemmTile(128, 128, 16); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; - mparams.mma_macro = MmaMacro::Hopper_64_8_16; + mparams.mma_macro = MmaMacro::Hopper_64_64_16; mparams.tile_sizes = gemm_tile; mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; mparams.async_gmem_load_operands = true; mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = false; - mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_stage = 2; mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; mparams.splitk_factor = 1; // NOTE: disabling smem use for this test since we currrently hit a bank From f5e084c8ae62227b9f54876ff525a107ff2f64de Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 2 Jan 2025 10:32:45 -0500 Subject: [PATCH 5/8] Also split by K I think this covers the motivation for #3616 --- csrc/scheduler/hopper_multi_matmul.cpp | 62 ++++++++++++++++++-------- tests/cpp/test_matmul.cpp | 6 +-- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index aeb0fbc0325..022148d4a84 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -34,29 +34,53 @@ void HopperMultipleMatmulScheduler::transformLikeMmaOutput( bool is_mma_result) { // TODO Add constraints - auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr { - return (is_mma_result) ? idx - 1 : idx; - }; - // The input is originally block tiled so that the inner dims are the CTA tile // size // Original: [..., M, N(, K)] // We split this into warp tiles then instruction tiles - tv->split(apply_k_dim_offset(-2), params_->tile_sizes.warp_tile.m); - tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro)); - tv->split(apply_k_dim_offset(-1), params_->tile_sizes.warp_tile.n); - tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro)); - // After Split: [..., Mo, Mw, Mi, No, Nw, Nwi] - tv->reorder({ - {apply_k_dim_offset(-3), apply_k_dim_offset(-5)}, - {apply_k_dim_offset(-2), apply_k_dim_offset(-3)}, - }); - // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] - - tv->merge(apply_k_dim_offset(-6)); - // After Merge: [..., Mo * No, Mio, Nio, Mii, Nii] - tv->axis(apply_k_dim_offset(-5))->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] + if (is_mma_result) { + // Original: [..., M, N, K] + tv->split(-3, params_->tile_sizes.warp_tile.m); + tv->split(-3, getM(params_->mma_macro)); + tv->split(-2, params_->tile_sizes.warp_tile.n); + tv->split(-2, getN(params_->mma_macro)); + // K dimension is present for mma_result + tv->split(-1, params_->tile_sizes.warp_tile.k); + tv->split(-1, getK(params_->mma_macro)); + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Ko, Kw, Ki] + tv->reorder({ + {-9, -9}, // Mo + {-8, -6}, // Mw + {-7, -3}, // Mi + {-6, -8}, // No + {-5, -5}, // Nw + {-4, -2}, // Ni + {-3, -7}, // Ko + {-2, -4}, // Kw + {-1, -1}, // Ki + }); + // After Reorder: [..., Mo, No, Ko, Mw, Nw, Kw, Mi, Ni, Ki] + tv->merge(-9); + // After Merge: [..., Mo * No, Ko, Mw, Nw, Kw, Mi, Ni] + tv->axis(-8)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Ko, Mw, Nw, Kw, Mi, Ni, Ki] + } else { + // Original: [..., M, N] + tv->split(-2, params_->tile_sizes.warp_tile.m); + tv->split(-2, getM(params_->mma_macro)); + tv->split(-1, params_->tile_sizes.warp_tile.n); + tv->split(-1, getN(params_->mma_macro)); + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni] + tv->reorder({ + {-3, -5}, + {-2, -3}, + }); + // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] + tv->merge(-6); + // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni] + tv->axis(-5)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] + } } MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 521913772f1..2378eed5599 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4275,8 +4275,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { MatMulTileOptions gemm_tile; // Regardless of the instruction, this should result in 2 warp groups i.e. 256 // threads - gemm_tile.cta_tile = GemmTile(256, 256, 16); - gemm_tile.warp_tile = GemmTile(128, 128, 16); + gemm_tile.cta_tile = GemmTile(256, 256, 32); + gemm_tile.warp_tile = GemmTile(128, 128, 32); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4286,7 +4286,7 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { mparams.async_gmem_load_operands = true; mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = false; - mparams.circular_buffer_options.smem_circular_buffer_stage = 2; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; mparams.splitk_factor = 1; // NOTE: disabling smem use for this test since we currrently hit a bank From be705bf24f24ee6b8a9e6f5f1b90c9af5e560220 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 7 Jan 2025 10:34:30 -0500 Subject: [PATCH 6/8] Add ScheduleWithTranslation test (failing) --- tests/cpp/test_matmul.cpp | 60 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 2378eed5599..b1aaf475059 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4319,4 +4319,64 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); } +TEST_F(HopperMatmulTest, ScheduleWithTranslation) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K + auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N + // Note tv1 has allocation domain + // tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1); + + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({M, K}, options); + // auto b_ref = at::randn({N, K}, options).t(); + auto b_ref = at::randn({K, N}, options); + auto out_ref = at::matmul(a_ref, b_ref); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 64, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_64_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 3; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {1, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + EXPECT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + auto cg_outputs = ke.run(inputs); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser From 5246fb3b119722af338c16fcbddd450a479f7122 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 28 Jan 2025 08:19:54 -0500 Subject: [PATCH 7/8] Update to fix compilation --- tests/cpp/test_matmul.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index d021b2831a1..d8bb60f8421 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4516,9 +4516,10 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { KernelExecutor ke; ke.compile(&fusion, inputs); - EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); - EXPECT_FALSE( - PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + kir::Kernel* kernel = ke.compiledKernel()->kernel(); + ASSERT_TRUE(kernel != nullptr); + EXPECT_TRUE(getBankConflictInfo(kernel).empty()); + EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel)); auto cg_outputs = ke.run(inputs); @@ -4582,9 +4583,10 @@ TEST_F(HopperMatmulTest, ScheduleWithTranslation) { KernelExecutor ke; ke.compile(&fusion, inputs); - EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); - EXPECT_FALSE( - PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + kir::Kernel* kernel = ke.compiledKernel()->kernel(); + ASSERT_TRUE(kernel != nullptr); + EXPECT_TRUE(getBankConflictInfo(kernel).empty()); + EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel)); auto cg_outputs = ke.run(inputs); From 1dccf222fde821e1e8872a8c98244fb5383784e0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 28 Jan 2025 16:21:49 -0500 Subject: [PATCH 8/8] Don't do K split. Fix TMA offset --- csrc/scheduler/hopper_multi_matmul.cpp | 105 +++++++++++++------------ csrc/scheduler/hopper_multi_matmul.h | 7 +- csrc/scheduler/mma_utils.cpp | 4 +- 3 files changed, 62 insertions(+), 54 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 022148d4a84..078d9380bc6 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -29,58 +29,60 @@ namespace nvfuser { -void HopperMultipleMatmulScheduler::transformLikeMmaOutput( - TensorView* tv, - bool is_mma_result) { +void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK( + TensorView* tv) { + // The input is originally block tiled so that the inner dims are the CTA tile + // size + // + // We split this into warp tiles then instruction tiles + // Original: [..., M, N, K] + tv->split(-3, params_->tile_sizes.warp_tile.m); + tv->split(-3, getM(params_->mma_macro)); + tv->split(-2, params_->tile_sizes.warp_tile.n); + tv->split(-2, getN(params_->mma_macro)); + // K dimension is present for mma_result + // We don't need to split by warp_tile.k, since we always have cta_tile.k==warp_tile.k + tv->split(-1, getK(params_->mma_macro)); + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki] + tv->reorder({ + {-8, -8}, // Mo + {-7, -6}, // Mw + {-6, -3}, // Mi + {-5, -7}, // No + {-4, -5}, // Nw + {-3, -2}, // Ni + {-2, -4}, // Kw + {-1, -1}, // Ki + }); + // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki] + tv->merge(-8); + // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni] + tv->axis(-7)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki] +} + +void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK( + TensorView* tv) { // TODO Add constraints // The input is originally block tiled so that the inner dims are the CTA tile // size - // Original: [..., M, N(, K)] + // Original: [..., M, N] // We split this into warp tiles then instruction tiles - if (is_mma_result) { - // Original: [..., M, N, K] - tv->split(-3, params_->tile_sizes.warp_tile.m); - tv->split(-3, getM(params_->mma_macro)); - tv->split(-2, params_->tile_sizes.warp_tile.n); - tv->split(-2, getN(params_->mma_macro)); - // K dimension is present for mma_result - tv->split(-1, params_->tile_sizes.warp_tile.k); - tv->split(-1, getK(params_->mma_macro)); - // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Ko, Kw, Ki] - tv->reorder({ - {-9, -9}, // Mo - {-8, -6}, // Mw - {-7, -3}, // Mi - {-6, -8}, // No - {-5, -5}, // Nw - {-4, -2}, // Ni - {-3, -7}, // Ko - {-2, -4}, // Kw - {-1, -1}, // Ki - }); - // After Reorder: [..., Mo, No, Ko, Mw, Nw, Kw, Mi, Ni, Ki] - tv->merge(-9); - // After Merge: [..., Mo * No, Ko, Mw, Nw, Kw, Mi, Ni] - tv->axis(-8)->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo * No (TIDy), Ko, Mw, Nw, Kw, Mi, Ni, Ki] - } else { - // Original: [..., M, N] - tv->split(-2, params_->tile_sizes.warp_tile.m); - tv->split(-2, getM(params_->mma_macro)); - tv->split(-1, params_->tile_sizes.warp_tile.n); - tv->split(-1, getN(params_->mma_macro)); - // After Split: [..., Mo, Mw, Mi, No, Nw, Ni] - tv->reorder({ - {-3, -5}, - {-2, -3}, - }); - // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] - tv->merge(-6); - // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni] - tv->axis(-5)->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] - } + tv->split(-2, params_->tile_sizes.warp_tile.m); + tv->split(-2, getM(params_->mma_macro)); + tv->split(-1, params_->tile_sizes.warp_tile.n); + tv->split(-1, getN(params_->mma_macro)); + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni] + tv->reorder({ + {-3, -5}, + {-2, -3}, + }); + // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] + tv->merge(-6); + // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni] + tv->axis(-5)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] } MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { @@ -370,6 +372,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() { const std::vector& smem_operands, MmaOperand operand_type) { blockTileTensors(smem_operands); + parallelizeBlocks(smem_operands); for (TensorView* tv : smem_operands) { if (params_->promote_prologue_smem_reuse) { tv->promoteReuse(); @@ -457,7 +460,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { splitk_sums_.push_back(splitk_sum); } - transformLikeMmaOutput(mma_result, /*is_mma_result=*/true); + transformLikeMmaOutputWithK(mma_result); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( mma_result->getLoopDomain()); mma_result->setAllocationDomain(s.as(), true); @@ -492,7 +495,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // op. blockTileTensors({d}); parallelizeBlocks({d}); - transformLikeMmaOutput(d, /*is_mma_result=*/false); + transformLikeMmaOutputWithoutK(d); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( d->getLoopDomain()); @@ -572,7 +575,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { blockTileTensors(tvs_to_schedule); parallelizeBlocks(tvs_to_schedule); for (auto tv : tvs_to_schedule) { - transformLikeMmaOutput(tv, /*is_mma_result=*/false); + transformLikeMmaOutputWithoutK(tv); } // Should not propagate if the dc is a mma output as the mma output has @@ -623,7 +626,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() { for (TensorView* splitk_sum : splitk_sums_) { // Always use serial grid reduction for split-K sum splitk_sum->definition()->as()->requestSerialGridReduction(); - transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false); + transformLikeMmaOutputWithoutK(splitk_sum); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( splitk_sum->getLoopDomain()); splitk_sum->setLoopDomain(s.as()); diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 295b55ee96e..8763789c38c 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -187,7 +187,12 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { // Schedule a block-tiled TensorView like mma output. // Why? WGMMA has a unique output format. TensorViews after the mma-result in // registers must respect this format for correctness. - void transformLikeMmaOutput(TensorView* tv, bool is_mma_result); + // This version is meant to be used on the mma_result, which has a Reduction + // K axis. + void transformLikeMmaOutputWithK(TensorView* tv); + + // This is like the above method, but tv should not have any K dimension + void transformLikeMmaOutputWithoutK(TensorView* tv); private: std::vector canonical_dim_ordering_; diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index b7b53738075..78e8d89aae4 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1321,7 +1321,7 @@ void scheduleStMatrixForMmaOutput( // [128(TIDx), 4(n), 2, 2] -> [128(TIDx), 2(no), 2(ni), 2, 2] tv->split(-3, 2); // [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 2(ni), 2, 2] - tv->reorder({{-4, 0}}); + tv->reorder({{-4, -5}}); // [128(TIDx), 2(no), 2(ni), 2, 2] -> [2(no), 128(TIDx), 8 (vectorize)] tv->merge(-3); tv->merge(-2); @@ -1329,7 +1329,7 @@ void scheduleStMatrixForMmaOutput( // Let [M, N] be [64, 16] // After scheduleMmaOutputAllocation: [128(TIDx), 2, 2, 2] // [128(TIDx), 2, 2, 2] -> [2, 128(TIDx), 2, 2] - tv->reorder({{-3, 0}}); + tv->reorder({{-3, -4}}); // [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)] tv->merge(-2); }