From 49445f9dd4a8d856a2fdd4505f7152b94e8a044f Mon Sep 17 00:00:00 2001 From: ayrnb <641876696@qq.com> Date: Thu, 10 Jul 2025 19:40:41 +0800 Subject: [PATCH 1/4] ll dispatch tma --- csrc/kernels/internode_ll.cu | 212 +++++++++++++++++++---------------- 1 file changed, 114 insertions(+), 98 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index dc03c65a..bfed9009 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -7,6 +7,7 @@ namespace deep_ep { namespace internode_ll { + template __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1) { @@ -36,7 +37,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template +template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -82,6 +83,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, constexpr int kNumMaxWarpGroups = 32; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + + // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; @@ -91,7 +94,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // 2. The last warp for reading `topk_idx` and count for per-expert information if (warp_id < num_warps - 1) { constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); - EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden"); + EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); const auto num_threads = (num_warps - 1) * 32; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; @@ -125,7 +128,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Reduce amax and scale EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); - amax = warp_reduce_max<16>(amax); + amax = half_warp_reduce_max(amax); calculate_fp8_scales(amax, scale, scale_inv, round_scale); if (lane_id == 0 or lane_id == 16) rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; @@ -165,6 +168,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto* src_int4_ptr = reinterpret_cast(src_ptr); const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + // if (lane_id == 0) { + // tma_load_1d(tma_buffer, src_int4_ptr, tma_mbarrier, num_int4_per_msg * sizeof(int4)); + // mbarrier_arrive_and_expect_tx(tma_mbarrier, num_int4_per_msg * sizeof(int4)); + // mbarrier_wait(tma_mbarrier, tma_phase); + // tma_store_1d(dst_int4_ptr, tma_buffer, num_int4_per_msg * sizeof(int4)); + // tma_store_wait(); + // } + // __syncwarp(); } // Increase counter after finishing @@ -285,6 +296,21 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + // TMA shared memory and barrier initialization + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto quarter_hidden_int4 = hidden_int4 / 4; + auto quarter_hidden_bytes = quarter_hidden_int4 * static_cast(sizeof(int4)); + auto tma_buffer_for_warp = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer_for_warp + quarter_hidden_bytes); + uint32_t tma_phase = 0; + if (lane_id == 0) { + mbarrier_init(tma_mbarrier, 1); + fence_view_async_shared(); + fence_barrier_init(); + EP_DEVICE_ASSERT(quarter_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp); + } + __syncwarp(); + // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { @@ -293,18 +319,45 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, if (lane_id == 0) recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); __syncwarp(); - + // Copy data // NOTES: only 2 load iterations for 7K hidden with 7 unrolls const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); - // Copy scales + // if (lane_id == 0) { + // printf("Before copy - Iteration %d, src_data[0]: (%d, %d, %d, %d)\n", i, src_data[0].x, src_data[0].y, src_data[0].z, src_data[0].w); + // printf("Before copy - Iteration %d, dst_data[0]: (%d, %d, %d, %d)\n", i, dst_data[0].x, dst_data[0].y, dst_data[0].z, dst_data[0].w); + // } + // UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + // __syncwarp(); + #pragma unroll + for (int j = 0; j < 4; ++j) { + if (lane_id == 0) { + tma_load_1d(tma_buffer_for_warp, src_data + j * quarter_hidden_int4, tma_mbarrier, quarter_hidden_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, quarter_hidden_bytes); + } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); + + if (lane_id == 0) { + tma_store_1d( tma_buffer_for_warp,dst_data + j * quarter_hidden_int4, quarter_hidden_bytes, false); + tma_store_wait(); + } + __syncwarp(); + + } + // if (lane_id == 0) { + // printf("After copy - Iteration %d, src_data[0]: (%d, %d, %d, %d)\n", i, src_data[0].x, src_data[0].y, src_data[0].z, src_data[0].w); + // printf("After copy - Iteration %d, dst_data[0]: (%d, %d, %d, %d)\n", i, dst_data[0].x, dst_data[0].y, dst_data[0].z, dst_data[0].w); + // } + // __syncwarp(); + + if constexpr (kUseFP8) { - // Equivalent CuTe layout: - // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) + // 3. Scatter scales from shared memory to global memory const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + // const auto smem_scales = reinterpret_cast(static_cast(tma_buffer_for_warp) + data_bytes); const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); const auto token_idx = recv_token_begin_idx + i; const auto token_stride = num_elems_per_pack; @@ -313,14 +366,17 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto pack_idx = lane_id / num_elems_per_pack; const auto elem_idx = lane_id % num_elems_per_pack; auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); + // auto scale = extract_required_scale_format(smem_scales[lane_id]); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } if (lane_id + 32 < num_scales) { const auto pack_idx = (lane_id + 32) / num_elems_per_pack; const auto elem_idx = (lane_id + 32) % num_elems_per_pack; auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); + // auto scale = extract_required_scale_format(smem_scales[lane_id + 32]); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } + } } } @@ -347,6 +403,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_sms = ceil_div(num_experts, num_warp_groups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + constexpr int kNumTMABytesPerWarp = 4096; // 16KB per warp + const int smem_size = kNumTMABytesPerWarp * num_warps; // Workspace checks auto atomic_counter_per_expert = static_cast(workspace); @@ -358,11 +416,12 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = dispatch