diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index e05cf92c..42070213 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -37,7 +37,7 @@ static void unload_library(const LibraryHandle& library) { static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + const dim3& grid_dim, const dim3& block_dim, const dim3& cluster_dim) { if (smem_size > 0) DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -51,12 +51,10 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, // NOTES: must use `static` or the `attr` will be deconstructed static LaunchAttrHandle attr; - if (cluster_dim > 1) { - attr.id = cudaLaunchAttributeClusterDimension; - attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; - config.attrs = &attr; - config.numAttrs = 1; - } + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {cluster_dim.x, cluster_dim.y, cluster_dim.z}; + config.attrs = &attr; + config.numAttrs = 1; return config; } @@ -95,7 +93,7 @@ static void unload_library(const LibraryHandle& library) { static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + const dim3& grid_dim, const dim3& block_dim, const dim3& cluster_dim) { if (smem_size > 0) DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); @@ -113,14 +111,12 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, // NOTES: must use `static` or the `attr` will be deconstructed static LaunchAttrHandle attr; - if (cluster_dim > 1) { - attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - attr.value.clusterDim.x = cluster_dim; - attr.value.clusterDim.y = 1; - attr.value.clusterDim.z = 1; - config.attrs = &attr; - config.numAttrs = 1; - } + attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attr.value.clusterDim.x = cluster_dim.x; + attr.value.clusterDim.y = cluster_dim.y; + attr.value.clusterDim.z = cluster_dim.z; + config.attrs = &attr; + config.numAttrs = 1; return config; } diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp index ba66eeb8..4a5b543a 100644 --- a/csrc/jit/kernel_runtime.hpp +++ b/csrc/jit/kernel_runtime.hpp @@ -12,13 +12,13 @@ struct LaunchArgs { std::pair grid_dim; int num_threads; int smem_size; - int cluster_dim; + std::pair cluster_dim; // (tma_multicast_dim, split_k_dim) - LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim_x = 1, const int& cluster_dim_y = 1): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim({cluster_dim_x, cluster_dim_y}) {} - LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim_x = 1, const int& cluster_dim_y = 1): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim({cluster_dim_x, cluster_dim_y}) {} }; class KernelRuntime final { @@ -102,14 +102,17 @@ class LaunchRuntime { static_cast(launch_args.grid_dim.second), 1}; const dim3& block_dim = {static_cast(launch_args.num_threads), 1, 1}; + const dim3& cluster_dim = {static_cast(launch_args.cluster_dim.first), + static_cast(launch_args.cluster_dim.second), + 1}; auto config = construct_launch_config(kernel, stream, launch_args.smem_size, - grid_dim, block_dim, launch_args.cluster_dim); + grid_dim, block_dim, cluster_dim); // Launch in the derived class if (get_env("DG_JIT_DEBUG")) { - printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n", + printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: {%d, %d}, stream: %ld\n", launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads, - launch_args.smem_size, launch_args.cluster_dim, stream.id()); + launch_args.smem_size, launch_args.cluster_dim.first, launch_args.cluster_dim.second, stream.id()); } Derived::launch_impl(kernel, config, args); } diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 455223bc..0c5a181b 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -65,6 +65,8 @@ struct GemmConfig { int block_m, block_n, block_k; int num_stages, num_last_stages; + int k_slices; + // Templated device configs int num_sms; int tc_util; @@ -99,7 +101,8 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne const int& block_m, const int& block_n, const int& block_k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const int& num_stages, const MulticastConfig& multicast_config) { + const int& num_stages, const MulticastConfig& multicast_config, + const bool split_k) { const int& ab_elem_size = static_cast(c10::elementSize(ab_dtype)); const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); @@ -110,7 +113,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0; // Different archs have different epilogue pipelines - const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); + const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, split_k ? torch::kFloat : cd_dtype); // A/B shared memory const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size; @@ -122,7 +125,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k); // M-barriers and tensor memory pointers - const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages); + const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages + (split_k ? 1 : 0)); const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size(); const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type); @@ -146,6 +149,25 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne }; } +template +static void get_k_slices_options(const GemmType& gemm_type, const int k, const int block_k, std::vector& k_slices_options) { + const auto aligned_k = align(k, block_k); + int max_k_slices = 1; + if (ArchSpec::support_split_k() and (gemm_type != GemmType::KGroupedContiguous)) { + // max potential cluster size for sm90 is 8, queried by cudaOccupancyMaxPotentialClusterSize() + max_k_slices = 8; + } + for (int k_slices = 1; k_slices <= max_k_slices; k_slices *= 2) { + if (aligned_k % k_slices != 0) { + continue; + } + if ((aligned_k / k_slices) % block_k != 0) { + continue; + } + k_slices_options.push_back(k_slices); + } +} + template static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type, const int& m, const int& n, const int& k, const int& num_groups, @@ -170,51 +192,71 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const auto& get_num_blocks = [=](const int& block_m, const int& block_n) { return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups; }; - const auto& get_num_waves = [=](const int& block_m, const int& block_n) { - return ceil_div(get_num_blocks(block_m, block_n), num_sms); + const auto& get_num_waves = [=](const int& block_m, const int& block_n, const int& k_slices, const int& num_sms_available) { + return ceil_div(k_slices * get_num_blocks(block_m, block_n), num_sms_available); }; - const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) { - const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms; - return num_last_blocks == 0 ? num_sms : num_last_blocks; + // block padding should not be taken into account for last_wave_util + const auto& get_block_padding_cost = [=](const int& block_m, const int& block_n, const int& k_slices) { + return ceil_div((block_m - m % block_m) % block_m * k_slices * num_groups, block_m) * ceil_div(n, block_n) + + ceil_div((block_n - n % block_n) % block_n * k_slices * num_groups, block_n) * ceil_div(m, block_m); }; + const auto& get_last_wave_util = [=](const int& block_m, const int& block_n, const int& k_slices, const int& num_sms_available) { + const auto& num_last_blocks = (k_slices * get_num_blocks(block_m, block_n)) % num_sms_available; + return (num_last_blocks == 0 ? num_sms_available : num_last_blocks) - get_block_padding_cost(block_m, block_n, k_slices); + }; + + std::vector k_slices_options; + get_k_slices_options(gemm_type, k, block_k, k_slices_options); + int best_k_slices = 1; // Decide block sizes by waves int best_block_m = 0, best_block_n = 0; int best_num_waves = 0, best_last_util = 0; for (const auto& block_m: block_ms) { for (const auto& block_n: block_ns) { - const int& num_waves = get_num_waves(block_m, block_n); - const auto& last_util = get_last_wave_util(block_m, block_n); - if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k)) - continue; - - bool success = false; - if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) { - success = true; - } else if (num_waves == best_num_waves) { - // Check last wave utilization - success = last_util > best_last_util; - if (last_util == best_last_util) { - // Case 1: same `block_m`, smaller `block_n` (wasted) - success |= block_m == best_block_m and block_n < best_block_n; - // Case 2: same `block_n`, smaller `block_m` (wasted) - success |= block_n == best_block_n and block_m < best_block_m; - // Case 3: different for both `block_m` and `block_n`, larger `block_n` is better - // NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case - success |= block_m != best_block_m and block_n > best_block_n - and block_n <= n and block_m <= m; + for (const auto& k_slices: k_slices_options) { + // Number of concurrently available SMs is limited by cluster size. + int num_sms_available = ArchSpec::get_num_sms_available_by_cluster_size(k_slices, num_sms); + const int& num_waves = get_num_waves(block_m, block_n, k_slices, num_sms_available); + const auto& last_util = get_last_wave_util(block_m, block_n, k_slices, num_sms_available); + if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k)) + continue; + + bool success = false; + if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) { + success = true; + } else if (num_waves == best_num_waves) { + // Check last wave utilization + success = last_util > best_last_util; + if (last_util == best_last_util) { + // smaller k_slices is better + success = k_slices < best_k_slices; + if (k_slices == best_k_slices) { + // Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n; + // Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m; + // Case 3: different for both `block_m` and `block_n`, larger `block_n` is better + // NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case + success |= block_m != best_block_m and block_n > best_block_n + and block_n <= n and block_m <= m; + } + } } - } - // Replace with the new config if successful - if (success) { - best_block_m = block_m, best_block_n = block_n; - best_num_waves = num_waves, best_last_util = last_util; + // Replace with the new config if successful + if (success) { + best_block_m = block_m, best_block_n = block_n; + best_num_waves = num_waves, best_last_util = last_util; + best_k_slices = k_slices; + } } } } DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); + int k_partitioned = k / best_k_slices; + // Decide the number of TMA multicasts and whether broadcast on A MulticastConfig best_multicast_config = {1, false}; const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( @@ -223,10 +265,14 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k bool order[2] = {false, true}; if (best_block_m > best_block_n) std::swap(order[0], order[1]); - for (const bool& is_multicast_on_a: order) { - if (m >= 512 and is_legal[static_cast(is_multicast_on_a)]) { - best_multicast_config = {2, is_multicast_on_a}; - break; + // limited by cluster size + if (best_k_slices <= 1 or + best_k_slices * get_num_blocks(best_block_m, best_block_n) <= ArchSpec::get_num_sms_available_by_cluster_size(best_k_slices * 2, num_sms)) { + for (const bool& is_multicast_on_a: order) { + if (m >= 512 and is_legal[static_cast(is_multicast_on_a)]) { + best_multicast_config = {2, is_multicast_on_a}; + break; + } } } @@ -234,16 +280,18 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k constexpr int smem_capacity = ArchSpec::smem_capacity; int best_num_stages = 0; SharedMemoryConfig best_smem_config; + for (int num_stages = 12; num_stages > 0; -- num_stages) { if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) continue; best_smem_config = get_smem_config(gemm_type, kernel_type, - m, n, k, + m, n, k_partitioned, best_block_m, best_block_n, block_k, major_a, major_b, ab_dtype, cd_dtype, - num_stages, best_multicast_config); + num_stages, best_multicast_config, + best_k_slices > 1); if (best_smem_config.smem_size <= smem_capacity) { best_num_stages = num_stages; break; @@ -251,13 +299,19 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k } DG_HOST_ASSERT(best_num_stages != 0); + int num_min_sms = ArchSpec::get_num_sms_available_by_cluster_size(best_k_slices * best_multicast_config.num_multicast, num_sms); // Recompute the minimal number of SMs required // NOTES: less L2 cache usage and less GPU frequency drop - int num_min_sms = num_sms; if (ArchSpec::should_minimize_num_sms()) { - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); - num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); - DG_HOST_ASSERT(num_min_sms <= num_sms); + if (best_k_slices > 1) { + int num_partitioned_blocks = best_k_slices * get_num_blocks(best_block_m, best_block_n); + DG_HOST_ASSERT(num_partitioned_blocks <= num_min_sms); + num_min_sms = num_partitioned_blocks; + } else { + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); + num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); + DG_HOST_ASSERT(num_min_sms <= num_sms); + } } const auto& config = GemmConfig { @@ -272,7 +326,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .block_n = best_block_n, .block_k = block_k, .num_stages = best_num_stages, - .num_last_stages = ceil_div(k, block_k) % best_num_stages, + .num_last_stages = ceil_div(k_partitioned, block_k) % best_num_stages, + .k_slices = best_k_slices, .num_sms = num_min_sms, .tc_util = device_runtime->get_tc_util(), .multicast_config = best_multicast_config, @@ -291,12 +346,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k ab_dtype, cd_dtype, with_accumulation, num_sms); static std::set printed; if (printed.count(key) == 0) { - printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " + printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, k_slices: %d, groups: %d, " "A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, " "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " "swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n", - static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, + static_cast(gemm_type), static_cast(kernel_type), m, n, k, best_k_slices, num_groups, static_cast(major_a), static_cast(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype), static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index e62a13cc..704a155b 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -99,6 +99,15 @@ struct SM100ArchSpec { return false; } + // Since split-k optimization is not yet supported on sm100, we directly return device_num_sms here to make it compatible. + static int get_num_sms_available_by_cluster_size(int cluster_size, int device_num_sms) { + return device_num_sms; + } + + static bool support_split_k() { + return false; + } + static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, const int& m, const int& n, const int& block_m, const int& block_n, const int& num_sms) { diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 133e2da0..f7431352 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -5,11 +5,13 @@ #include #include "common.hpp" +#include "../../jit/device_runtime.hpp" namespace deep_gemm { struct SM90ArchSpec { static constexpr int smem_capacity = 232448; + static constexpr int num_sms_available_by_cluster_size_h20[8] = {78, 78, 54, 60, 50, 48, 49, 56}; static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype) { // Avoid bank conflicts for FP32 output @@ -84,6 +86,30 @@ struct SM90ArchSpec { return true; } + // Assuming one CTA per SM, max concurrently available number of SMs is limited by cluster size. + // To support other sm90 GPUs than H20, a num_sms_available_by_cluster_size array should be added. + // Otherwise, only cluster_size <= 2 will be selected to make good use of all the SMs. + static int get_num_sms_available_by_cluster_size(int cluster_size, int device_num_sms) { + if (cluster_size > 8 or cluster_size <= 1) { + return device_num_sms; + } else if (cluster_size == 2) { + return (device_num_sms / cluster_size) * cluster_size; + } + int result = (device_num_sms / cluster_size) * cluster_size; + int num_all_sms = device_runtime->get_num_sms(); + if (num_all_sms == 78) { // H20 + result = std::min(result, num_sms_available_by_cluster_size_h20[cluster_size - 1]); + } else { + // For other cases, we return max_num_sms as 1 to avoid using large cluster size. + result = 1; + } + return result; + } + + static bool support_split_k() { + return true; + } + static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, const int& m, const int& n, const int& block_m, const int& block_n, const int& num_sms) { diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index 7b4c4f65..7f9e1870 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -38,6 +38,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {}, {}, @@ -49,6 +50,7 @@ static void __instantiate_kernel() {{ get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.num_groups, args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.k_slices, args.gemm_config.smem_config.swizzle_cd_mode, args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, @@ -107,9 +109,9 @@ static void sm90_bf16_gemm(const torch::Tensor& a, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + .launch_args = LaunchArgs({config.num_sms / config.k_slices, config.k_slices}, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, config.k_slices), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, @@ -160,9 +162,9 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, .num_groups = num_groups, .compiled_dims = compiled_dims, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + .launch_args = LaunchArgs({config.num_sms / config.k_slices, config.k_slices}, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, config.k_slices), .grouped_layout = m_indices.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, @@ -214,9 +216,9 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, .num_groups = num_groups, .compiled_dims = compiled_dims, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + .launch_args = LaunchArgs({config.num_sms / config.k_slices, config.k_slices}, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, config.k_slices), .grouped_layout = masked_m.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp index 2f54a353..215e523c 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -44,6 +44,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {}, @@ -54,6 +55,7 @@ static void __instantiate_kernel() {{ get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.num_groups, args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.k_slices, args.gemm_config.num_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, @@ -117,9 +119,9 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + .launch_args = LaunchArgs({config.num_sms / config.k_slices, config.k_slices}, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, config.k_slices), .gmem_a_ptr = nullptr, .gmem_b_ptr = nullptr, .grouped_layout = nullptr, diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index ac878600..83df0cdc 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -43,6 +43,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {}, {}, @@ -54,6 +55,7 @@ static void __instantiate_kernel() {{ get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.num_groups, args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.k_slices, args.gemm_config.smem_config.swizzle_cd_mode, args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, @@ -118,9 +120,9 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, .compiled_dims = compiled_dims, .epilogue_type = epilogue_type, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + .launch_args = LaunchArgs({config.num_sms / config.k_slices, config.k_slices}, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, config.k_slices), .sfb = sfb.data_ptr(), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, @@ -178,9 +180,9 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + .launch_args = LaunchArgs({config.num_sms / config.k_slices, config.k_slices}, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, config.k_slices), .sfb = sfb.data_ptr(), .grouped_layout = m_indices.data_ptr(), .tensor_map_a = tensor_map_a, @@ -239,9 +241,9 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + .launch_args = LaunchArgs({config.num_sms / config.k_slices, config.k_slices}, config.thread_config.num_threads, config.smem_config.smem_size, - config.multicast_config.num_multicast), + config.multicast_config.num_multicast, config.k_slices), .sfb = sfb.data_ptr(), .grouped_layout = masked_m.data_ptr(), .tensor_map_a = tensor_map_a, diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index 237f688c..f8f75d99 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -11,14 +11,15 @@ enum class KGroupedIndexType { SF_K, }; -template + +template static constexpr uint32_t get_num_1d_blocks_per_group() { // Select the best from candidates uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); for (const auto& candidate: {8u, 16u}) { const auto& usage = kIsMulticastOnA ? - candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N - candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs / kSplitKSlices, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs / kSplitKSlices, candidate) * BLOCK_N; // Grouping on M if (usage < min_usage) min_usage = usage, num_best_blocks = candidate; } @@ -32,8 +33,9 @@ template ()> + uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group()> struct Scheduler { int current_iter = -1; @@ -149,7 +151,7 @@ struct Scheduler { } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { - const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + const auto next_block_idx = (++ current_iter) * (kNumSMs / kSplitKSlices) + blockIdx.x; if constexpr (kGemmType == GemmType::MGroupedMasked) { while (true) { diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index d910a2df..d35fd0c6 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -238,12 +238,12 @@ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast = 1) { + const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& tma_multicast_cta_mask = 0) { constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - if (num_tma_multicast == 1) { + if (tma_multicast_cta_mask == 0) { cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); - } else if (cute::block_rank_in_cluster() == 0) { - cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); + } else if (cute::block_id_in_cluster().x == 0) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, tma_multicast_cta_mask, cache_hint, smem_ptr, crd_0, crd_1); } } diff --git a/deep_gemm/include/deep_gemm/common/split_k_reduce.cuh b/deep_gemm/include/deep_gemm/common/split_k_reduce.cuh new file mode 100644 index 00000000..8a090c65 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/split_k_reduce.cuh @@ -0,0 +1,65 @@ +#pragma once +#include +#include + +namespace deep_gemm { + +template +static constexpr uint32_t get_reduce_iterations() { + uint32_t num_iterations = 0; + for (int i = 1; i < kSplitKSlices; i <<= 1) { + num_iterations++; + } + return num_iterations; +} + +// Reduce `accums` on registers in place across `kSplitKSlices` CTAs within a cluster, +// using distributed shared memory and memory barriers for communication. +// The reduction is done by a shuffle-down manner. +// Before reduction, each thread should hold `kNumAccum` float values in `accums`. +// After reduction, only the leader CTA holds the final results and will write them back to shared memory and then global memory. +template +__device__ __forceinline__ void +split_k_reduce(float* smem_d, float* accums, Barrier* split_k_reduce_empty_barrier, Barrier* split_k_reduce_full_barrier, uint32_t current_iter){ + constexpr uint32_t reduce_iterations = get_reduce_iterations(); + // kSplitKSlices CTAs with same cute::block_id_in_cluster().x and different cute::block_id_in_cluster().y within a cluster make a reduction group. + DG_TRAP_ONLY_DEVICE_ASSERT(cute::cluster_shape().y == kSplitKSlices); + auto k_partition_id = cute::block_id_in_cluster().y; + auto smem_d_u32 = cute::cast_smem_ptr_to_uint(smem_d); + uint32_t reduce_full_barrier_u32 = cute::cast_smem_ptr_to_uint(split_k_reduce_full_barrier); + #pragma unroll + for (uint32_t i = 1, mask = 1; i <= reduce_iterations; i ++, mask <<= 1) { + int peer_rank = (k_partition_id ^ mask) * cute::cluster_shape().x + cute::block_id_in_cluster().x; + + // notify all CTAs whthin the reduction group that I'm ready for current reduce iteration + if (threadIdx.x < kSplitKSlices) { + split_k_reduce_empty_barrier->arrive(threadIdx.x * cute::cluster_shape().x + cute::block_id_in_cluster().x); + } + // wait all CTAs whthin the reduction group to be ready + split_k_reduce_empty_barrier->wait((current_iter * reduce_iterations + i + 1) & 1); + if ((k_partition_id & mask) != 0) { // send + #pragma unroll + for (uint32_t j = 0, smem_offset = threadIdx.x; j < kNumAccum; ++j, smem_offset += kNumMathThreads) { + uint32_t smem_d_shifted_u32 = smem_d_u32 + smem_offset * sizeof(float); + // store to peer smem + cute::store_shared_remote(*reinterpret_cast(&accums[j]), smem_d_shifted_u32, reduce_full_barrier_u32, peer_rank); + } + if (threadIdx.x == 0) { + split_k_reduce_full_barrier->arrive_and_expect_tx(kNumAccum * kNumMathThreads * sizeof(float), peer_rank); + // arrive on local full_barrier, to make barrier phases aligned + split_k_reduce_full_barrier->arrive(); + } + // to make barrier phases aligned + split_k_reduce_full_barrier->wait((current_iter * reduce_iterations + i + 1) & 1); + } else { // receive + split_k_reduce_full_barrier->wait((current_iter * reduce_iterations + i + 1) & 1); + #pragma unroll + for (uint32_t j = 0, smem_offset = threadIdx.x; j < kNumAccum; ++j, smem_offset += kNumMathThreads) { + accums[j] += smem_d[smem_offset]; + } + } + } +} + + +}; // namespace deep_gemm \ No newline at end of file diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 9186e683..0111f386 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -13,6 +13,7 @@ #include #include #include +#include namespace deep_gemm { @@ -21,6 +22,7 @@ using namespace deep_gemm::sm90; template 0, "Invalid split K slices"); + const uint32_t shape_k_partitioned = ceil_div(shape_k, kSplitKSlices); + // Shared memory - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(cd_dtype_t); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * (kSplitKSlices > 1 ? sizeof(float) : sizeof(cd_dtype_t)); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); // Configs constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t num_iterations = ceil_div(shape_k_partitioned, kFullKOfAllStages); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t lane_idx = get_lane_idx(); @@ -90,6 +95,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, full_barriers[i] = barrier_start_ptr + i; empty_barriers[i] = barrier_start_ptr + kNumStages + i; } + auto split_k_reduce_full_barrier = barrier_start_ptr + 2 * kNumStages; + auto split_k_reduce_empty_barrier = split_k_reduce_full_barrier + 1; // Initialize barriers if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { @@ -102,9 +109,17 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); } + if constexpr (kSplitKSlices > 1) { + if (threadIdx.x == 0) { + split_k_reduce_full_barrier->init(1); + split_k_reduce_empty_barrier->init(kSplitKSlices); + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + } // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + (kNumTMAMulticast > 1 or kSplitKSlices > 1) ? cute::cluster_sync() : __syncthreads(); struct DivisibleK {}; struct NotDivisibleK {}; @@ -125,7 +140,10 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + const uint32_t k_idx_offset = kSplitKSlices > 1 ? blockIdx.y * shape_k_partitioned : 0; + const uint32_t rank_in_cluster_offset = kSplitKSlices > 1 ? cute::block_id_in_cluster().y * kNumTMAMulticast : 0; if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data @@ -142,9 +160,12 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Assign TMA multicast number into A and B // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + uint32_t tma_multicast_cta_mask_a = (kNumTMAMulticast > 1 and kIsTMAMulticastOnA and is_tma_multicast_valid) ? + ((1 << kNumTMAMulticast) - 1) << (kSplitKSlices > 1 ? rank_in_cluster_offset : 0): 0; + uint32_t tma_multicast_cta_mask_b = (kNumTMAMulticast > 1 and not kIsTMAMulticastOnA and is_tma_multicast_valid) ? + ((1 << kNumTMAMulticast) - 1) << (kSplitKSlices > 1 ? rank_in_cluster_offset : 0): 0; + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant @@ -155,14 +176,13 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[s]; - uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - + const uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K + (kSplitKSlices > 1 ? k_idx_offset : 0); tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); + tma_multicast_cta_mask_a); tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); + tma_multicast_cta_mask_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } @@ -185,6 +205,9 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Math warp-groups for WGMMA cutlass::arch::warpgroup_reg_alloc(); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); @@ -198,7 +221,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, if constexpr (kNumTMAMulticast == 1) { lane_idx == 0 ? empty_barriers[s]->arrive() : void(); } else { - auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx + rank_in_cluster_offset : cute::block_rank_in_cluster(); lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); } }; @@ -261,10 +284,22 @@ sm90_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); // Wait last TMA store to be finished - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + if (cute::block_id_in_cluster().y == 0 and threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) cute::tma_store_wait<0>(); cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + // Reduce d across CTAs within a cluster + if constexpr (kSplitKSlices > 1) { + split_k_reduce( + reinterpret_cast(smem_d), accum, split_k_reduce_empty_barrier, split_k_reduce_full_barrier, scheduler.current_iter + ); + // Only the leader cta writes d to global memory + if (cute::block_id_in_cluster().y != 0) { + continue; + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + } + if constexpr (std::is_same_v) { // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index 4c57cbe0..37065863 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -13,6 +13,7 @@ #include #include #include +#include namespace deep_gemm { @@ -21,6 +22,7 @@ using namespace deep_gemm::sm90; template 0, "Invalid split K slices"); + DG_STATIC_ASSERT(not (kSplitKSlices > 1 and kGemmType == GemmType::KGroupedContiguous), "GemmType::KGroupedContiguous is uncompatable with split K"); + const uint32_t shape_k_partitioned = ceil_div(shape_k, kSplitKSlices); + // Shared memory static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); @@ -115,6 +121,8 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); }); + auto split_k_reduce_full_barrier = reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + 2 * kNumStages * static_cast(sizeof(Barrier)))); + auto split_k_reduce_empty_barrier = split_k_reduce_full_barrier + 1; if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { // Load tensormap A/B to shared memory @@ -138,8 +146,17 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cutlass::arch::fence_barrier_init(); } + if constexpr (kSplitKSlices > 1) { + if (threadIdx.x == 0) { + split_k_reduce_full_barrier->init(1); + split_k_reduce_empty_barrier->init(kSplitKSlices); + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + } + // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + (kNumTMAMulticast > 1 or kSplitKSlices > 1) ? cute::cluster_sync() : __syncthreads(); // Pipeline unroll control constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages); @@ -150,7 +167,10 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + const uint32_t k_block_idx_offset = kSplitKSlices > 1 ? blockIdx.y * ceil_div(shape_k_partitioned, BLOCK_K) : 0; + const uint32_t rank_in_cluster_offset = kSplitKSlices > 1 ? cute::block_id_in_cluster().y * kNumTMAMulticast : 0; // TMA and MMA pipeline const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { @@ -173,11 +193,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Assign TMA multicast number into A and B // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + uint32_t tma_multicast_cta_mask_a = (kNumTMAMulticast > 1 and kIsTMAMulticastOnA and is_tma_multicast_valid) ? + ((1 << kNumTMAMulticast) - 1) << (kSplitKSlices > 1 ? rank_in_cluster_offset : 0): 0; + uint32_t tma_multicast_cta_mask_b = (kNumTMAMulticast > 1 and not kIsTMAMulticastOnA and is_tma_multicast_valid) ? + ((1 << kNumTMAMulticast) - 1) << (kSplitKSlices > 1 ? rank_in_cluster_offset : 0): 0; + const uint32_t& num_k_blocks = ceil_div(kSplitKSlices > 1 ? shape_k_partitioned : scheduler.current_shape_k, BLOCK_K); const uint32_t& m_idx = m_block_idx * BLOCK_M; const uint32_t& n_idx = n_block_idx * BLOCK_N; @@ -215,12 +236,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Issue TMA auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& k_idx = k_block_idx * BLOCK_K; - const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; - tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); - tma_copy(&tensor_map_sfb, reinterpret_cast(&full_barrier), smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); - tma_copy(current_tensor_map_a, reinterpret_cast(&full_barrier), smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); - tma_copy(current_tensor_map_b, reinterpret_cast(&full_barrier), smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + const uint32_t& k_idx = (k_block_idx + k_block_idx_offset) * BLOCK_K; + const uint32_t& sf_k_idx = (kSplitKSlices > 1 ? k_block_idx_offset : scheduler.current_sf_k_cumsum) + k_block_idx; + tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), smem_sfa[stage_idx], m_idx, sf_k_idx, tma_multicast_cta_mask_a); + tma_copy(&tensor_map_sfb, reinterpret_cast(&full_barrier), smem_sfb[stage_idx], n_idx, sf_k_idx, tma_multicast_cta_mask_b); + tma_copy(current_tensor_map_a, reinterpret_cast(&full_barrier), smem_a[stage_idx], k_idx, m_idx, tma_multicast_cta_mask_a); + tma_copy(current_tensor_map_b, reinterpret_cast(&full_barrier), smem_b[stage_idx], k_idx, n_idx, tma_multicast_cta_mask_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); } } @@ -249,7 +270,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); - const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + const uint32_t& num_k_blocks = ceil_div(kSplitKSlices > 1 ? shape_k_partitioned : scheduler.current_shape_k, BLOCK_K); float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; float2 scales_b[WGMMA::kNumAccum / 4]; @@ -258,7 +279,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, if constexpr (kNumTMAMulticast == 1) { lane_idx == 0 ? empty_barriers[s]->arrive() : void(); } else { - auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx + rank_in_cluster_offset : cute::block_rank_in_cluster(); lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); } }; @@ -312,10 +333,22 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, } // Flush previous stores - if (warp_idx % 4 == 0 and cute::elect_one_sync()) + if (cute::block_id_in_cluster().y == 0 and warp_idx % 4 == 0 and cute::elect_one_sync()) cute::tma_store_wait<0>(); cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + // Reduce d across CTAs within a cluster + if constexpr (kSplitKSlices > 1) { + split_k_reduce( + reinterpret_cast(smem_d), final_accum, split_k_reduce_empty_barrier, split_k_reduce_full_barrier, scheduler.current_iter + ); + // Only the leader cta writes d to global memory + if (cute::block_id_in_cluster().y != 0) { + continue; + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + } + // Store to D shared memory const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 5a92d7d4..5425fadd 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -14,6 +14,7 @@ #include #include #include +#include namespace deep_gemm { @@ -33,6 +34,7 @@ __device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_ template 0, "Invalid split K slices"); + const uint32_t shape_k_partitioned = ceil_div(shape_k, kSplitKSlices); + // Shared memory static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * (kSplitKSlices > 1 ? sizeof(float) : sizeof(__nv_bfloat16)); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + const uint32_t& shape_k_scales_partitioned = ceil_div(shape_k_partitioned, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales_partitioned * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); // Configs - const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t num_total_k_blocks = ceil_div(shape_k_partitioned, BLOCK_K); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t lane_idx = get_lane_idx(); @@ -106,6 +112,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + auto split_k_reduce_full_barrier = barrier_start_ptr + 2 * kNumStages; + auto split_k_reduce_empty_barrier = split_k_reduce_full_barrier + 1; // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); @@ -121,9 +129,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); } + + if constexpr (kSplitKSlices > 1) { + if (threadIdx.x == 0) { + split_k_reduce_full_barrier->init(1); + split_k_reduce_empty_barrier->init(kSplitKSlices); + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + } // Synchronize all threads to make barrier visible in normal memory model - (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + (kNumTMAMulticast > 1 or kSplitKSlices > 1) ? cute::cluster_sync() : __syncthreads(); // Register reconfigurations constexpr uint32_t kNumTMARegisters = 40; @@ -131,7 +148,10 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + const uint32_t k_block_idx_offset = kSplitKSlices > 1 ? blockIdx.y * num_total_k_blocks : 0; + const uint32_t rank_in_cluster_offset = kSplitKSlices > 1 ? cute::block_id_in_cluster().y * kNumTMAMulticast : 0; // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -154,10 +174,11 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Assign TMA multicast number into A and B // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - + uint32_t tma_multicast_cta_mask_a = (kNumTMAMulticast > 1 and kIsTMAMulticastOnA and is_tma_multicast_valid) ? + ((1 << kNumTMAMulticast) - 1) << (kSplitKSlices > 1 ? rank_in_cluster_offset : 0): 0; + uint32_t tma_multicast_cta_mask_b = (kNumTMAMulticast > 1 and not kIsTMAMulticastOnA and is_tma_multicast_valid) ? + ((1 << kNumTMAMulticast) - 1) << (kSplitKSlices > 1 ? rank_in_cluster_offset : 0): 0; for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); @@ -165,18 +186,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Issue TMA A constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t k_idx = k_block_idx * BLOCK_K; + const uint32_t k_idx = (k_block_idx + k_block_idx_offset) * BLOCK_K; tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); + tma_multicast_cta_mask_a); tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), - smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx(shape_k_scales, 1, k_block_idx), - num_tma_multicast_a); + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx(shape_k_scales, 1, k_block_idx + k_block_idx_offset), + tma_multicast_cta_mask_a); // Issue TMA B tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); + tma_multicast_cta_mask_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); } } @@ -209,16 +230,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; } - uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + uint32_t num_sfb = shape_k_scales_partitioned * (num_former_iters >= num_full_iters ? 1 : 2); // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { auto num_previous_lines = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); - auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; + auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales + + (kSplitKSlices > 1 ? k_block_idx_offset : 0); + #pragma unroll - for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) - st_shared(smem_sfb + i, __ldg(local_sfb + i)); + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) { + if constexpr (kSplitKSlices > 1) { + st_shared(smem_sfb + i, __ldg(local_sfb + i + (i >= shape_k_scales_partitioned ? shape_k_scales - shape_k_scales_partitioned : 0))); + } else { + st_shared(smem_sfb + i, __ldg(local_sfb + i)); + } + } } cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); @@ -232,7 +260,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, if constexpr (kNumTMAMulticast == 1) { lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); } else { - auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx + rank_in_cluster_offset : cute::block_rank_in_cluster(); lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); } }; @@ -255,7 +283,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales_partitioned); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -329,10 +357,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); // Wait last TMA store to be finished - if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + if (cute::block_id_in_cluster().y == 0 and threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) cute::tma_store_wait<0>(); cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + // Reduce d across CTAs within a cluster + if constexpr (kSplitKSlices > 1) { + split_k_reduce( + reinterpret_cast(smem_d), final_accum, split_k_reduce_empty_barrier, split_k_reduce_full_barrier, scheduler.current_iter + ); + // Only the leader cta writes d to global memory + if (cute::block_id_in_cluster().y != 0) { + continue; + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + } + // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll