diff --git a/csrc/config.hpp b/csrc/config.hpp index 6bef6255..525e9188 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -20,6 +20,13 @@ dtype_t align_down(dtype_t a, dtype_t b) { return a / b * b; } +template +out_ptr_t advance_ptr(in_ptr_t &ptr, size_t count) { + out_ptr_t saved = reinterpret_cast(ptr); + ptr = reinterpret_cast(reinterpret_cast(ptr) + count); + return saved; +} + struct Config { int num_sms; int num_max_nvl_chunked_send_tokens; @@ -47,22 +54,31 @@ struct Config { EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); } - size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { + size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks, bool zero_copy = false) const { // Below are some assumptions // TODO: add assertions constexpr int kNumMaxTopK = 128; constexpr int kNumMaxScales = 128; EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); - EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or zero_copy or num_sms % 2 == 0); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); - const int num_channels = num_sms / 2; + const int sms_per_channel = zero_copy ? 1 : 2; + const int num_channels = num_sms / sms_per_channel; size_t num_bytes = 0; + if (zero_copy) { + num_bytes += ZCOPY_NOTIFY_NVL_METADATA_OFFSET_INTS * sizeof(int); + num_bytes += std::max( + num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 2) * sizeof(int), // Dispatch + num_channels * num_nvl_ranks * 1 * sizeof(int) // Combine + ); + return num_bytes; + } num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; #ifndef DISABLE_NVSHMEM - num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(zero_copy); #endif num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t); num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); @@ -71,7 +87,7 @@ struct Config { return num_bytes; } - size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { + size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks, bool zero_copy = false) const { #ifndef DISABLE_NVSHMEM // Legacy mode if (num_ranks <= NUM_MAX_NVL_PEERS) @@ -82,14 +98,16 @@ struct Config { constexpr int kNumMaxTopK = 128; constexpr int kNumMaxScales = 128; EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); - EP_HOST_ASSERT(num_sms % 2 == 0); + EP_HOST_ASSERT(zero_copy or num_sms % 2 == 0); const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; - const int num_channels = num_sms / 2; + const int sms_per_channel = zero_copy ? 1 : 2; + const int num_channels = num_sms / sms_per_channel; + const int rdma_buffer_factor = zero_copy ? 1 : 2; // The zero-copy kernels only have recv buffer size_t num_bytes = 0; num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; - num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes(zero_copy) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; @@ -192,4 +210,110 @@ size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } +struct InternodeDispatchBuffer { + void* x = nullptr; + float* x_scales = nullptr; + topk_idx_t* topk_idx = nullptr; + float* topk_weights = nullptr; + void* source_meta = nullptr; +}; + +struct InternodeDispatchLayout { + size_t total_bytes = 0; + std::vector buffers; + + InternodeDispatchLayout(void* rdma_fused_buffer, int num_tokens, int hidden, int num_topk, int num_ranks, bool with_topk, bool use_fp8, int num_buffers) { + size_t num_bytes_x = 0; + size_t num_bytes_x_scales = 0; + size_t num_bytes_topk_idx = 0; + size_t num_bytes_topk_weights = 0; + size_t num_bytes_source_meta = 0; + + if (use_fp8){ + num_bytes_x = num_tokens * hidden * elementSize(torch::kFloat8_e4m3fn); + total_bytes += num_bytes_x; + + int num_scales = hidden / 128; + num_bytes_x_scales = num_tokens * num_scales * sizeof(float); + total_bytes += num_bytes_x_scales; + } + else { + num_bytes_x = num_tokens * hidden * elementSize(torch::kBFloat16); + total_bytes += num_bytes_x; + } + + if (with_topk) { + num_bytes_topk_idx = num_tokens * num_topk * sizeof(topk_idx_t); + total_bytes += num_bytes_topk_idx; + + num_bytes_topk_weights = num_tokens * num_topk * sizeof(float); + total_bytes += num_bytes_topk_weights; + } + + num_bytes_source_meta = (num_ranks / NUM_MAX_NVL_PEERS) * num_tokens * internode::get_source_meta_bytes(); + total_bytes += num_bytes_source_meta; + EP_HOST_ASSERT(total_bytes % 16 == 0); + + EP_HOST_ASSERT(total_bytes <= NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER); + + for (int i=0; i(tmp, num_bytes_x_scales) : nullptr, // x_scales + with_topk ? advance_ptr(tmp, num_bytes_topk_idx) : nullptr, // topk_idx + with_topk ? advance_ptr(tmp, num_bytes_topk_weights) : nullptr, // topk_weights + advance_ptr(tmp, num_bytes_source_meta), // source_meta + }); + advance_ptr(rdma_fused_buffer, NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER); + } + + total_bytes = NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER * num_buffers; + } +}; + + +struct InternodeCombineBuffer { + void* x = nullptr; + float* topk_weights = nullptr; +}; + +struct InternodeCombineLayout { + size_t total_bytes = 0; + std::vector buffers; + + template + out_ptr_t advance(const in_ptr_t& ptr, size_t count) { + return reinterpret_cast(reinterpret_cast(ptr) + count); + } + + InternodeCombineLayout(void* nvl_buffer, int num_tokens, int hidden, int num_topk, bool with_topk, int num_buffers) { + size_t num_bytes_x = 0; + size_t num_bytes_topk_weights = 0; + + num_bytes_x = num_tokens * hidden * elementSize(torch::kBFloat16); + total_bytes += num_bytes_x; + + if (with_topk) { + num_bytes_topk_weights = num_tokens * num_topk * sizeof(float); + total_bytes += num_bytes_topk_weights; + } + + EP_HOST_ASSERT(total_bytes <= NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER); + + EP_HOST_ASSERT(total_bytes % 16 == 0); + + for (int i=0; i(tmp, num_bytes_x) : nullptr + }); + advance_ptr(nvl_buffer, NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER); + } + + total_bytes = NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER * num_buffers; + } +}; + } // namespace deep_ep diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 06bbb830..133d3065 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -13,16 +13,26 @@ namespace deep_ep { Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, - bool enable_shrink): + bool enable_shrink, + int num_zcopy_buffers): rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), + num_zcopy_buffers(num_zcopy_buffers), enable_shrink(enable_shrink), low_latency_mode(low_latency_mode), explicitly_destroy(explicitly_destroy), comm_stream(at::cuda::getStreamFromPool(true)) { + const bool support_zero_copy = num_zcopy_buffers > 0; + EP_HOST_ASSERT(num_zcopy_buffers >= 0); + // Metadata memory int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + // Fused (zero-copy) NVL buffer layout is as follows (case of 2 sub-buffers): + // Sub-buffer 0 Sub-buffer 1 Sub-buffer 0 Sub-buffer 1 + // | Input of Combine 0 | Input of Combine 1 | Output of Dispatch 0 | Output of Dispatch 1 | + buffer_fused_bytes = 1l * num_zcopy_buffers * (NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER + NUM_DISPATCH_OUTPUT_BYTES_PER_ZCOPY_BUFFER); + int64_t buffer_fused_ptr_bytes = 1l * support_zero_copy * NUM_MAX_NVL_PEERS * sizeof(void*); int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks @@ -48,13 +58,22 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); + // Zero-copy: Guarantee buffer_fused_ptr is aligned by placing it at buffer_ptr + num_nvl_bytes (both are aligned) + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + buffer_fused_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes + buffer_fused_ptr_bytes)); CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); - buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); + buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + buffer_fused_bytes + barrier_signal_bytes); // Set barrier signals - barrier_signal_ptrs[nvl_rank] = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); - barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); + barrier_signal_ptrs[nvl_rank] = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + buffer_fused_bytes); + barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + buffer_fused_bytes + barrier_signal_bytes + buffer_ptr_bytes); + + // Set NVL fused (zero-copy send/recv) buffers + buffer_fused_ptrs[nvl_rank] = support_zero_copy ? + reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes) : + nullptr; + buffer_fused_ptrs_gpu = support_zero_copy ? + reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + buffer_fused_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes) : + nullptr; // No need to synchronize, will do a full device sync during `sync` CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); @@ -193,6 +212,10 @@ void Buffer::sync(const std::vector &device_ids, const std::optional& root_unique_id_opt) { EP_HOST_ASSERT(not is_available()); + int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); + + const bool support_zero_copy = num_zcopy_buffers > 0; + // Sync IPC handles if (num_nvl_bytes > 0) { EP_HOST_ASSERT(num_ranks == device_ids.size()); @@ -204,7 +227,8 @@ void Buffer::sync(const std::vector &device_ids, if (offset + i != rank) { std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); - barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); + barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes + buffer_fused_bytes); + buffer_fused_ptrs[i] = support_zero_copy ? reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes) : nullptr; } else { EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); } @@ -213,6 +237,9 @@ void Buffer::sync(const std::vector &device_ids, // Copy all buffer and barrier signal pointers to GPU CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + if (support_zero_copy) { + CUDA_CHECK(cudaMemcpy(buffer_fused_ptrs_gpu, buffer_fused_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + } CUDA_CHECK(cudaDeviceSynchronize()); } @@ -231,6 +258,7 @@ void Buffer::sync(const std::vector &device_ids, // Allocate rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); + rdma_fused_buffer_ptr = support_zero_copy ? internode::alloc(NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER * num_zcopy_buffers, NUM_BUFFER_ALIGNMENT_BYTES) : nullptr; // Clean buffer (mainly for low-latency mode) CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); @@ -666,7 +694,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> +std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional, std::optional> Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, @@ -674,15 +702,19 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + const std::optional& cached_recv_gbl_rank_prefix_sum_fwd, + int expert_alignment, const Config& config, std::optional& previous_event, + bool zero_copy, int zcopy_buffer_id, + bool async, bool allocate_on_comm_stream) { #ifndef DISABLE_NVSHMEM // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL // unless we release GIL here. pybind11::gil_scoped_release release; - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); + const int sms_per_channel = zero_copy ? 1 : 2; + const int num_channels = config.num_sms / sms_per_channel; + EP_HOST_ASSERT(config.num_sms % sms_per_channel == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); @@ -691,6 +723,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalscalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); + if (zero_copy) { + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum_fwd->scalar_type() == torch::kInt32); + } } else { EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); @@ -721,6 +759,10 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalsize(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); + if (zero_copy) { + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum_fwd->dim() == 2 and cached_recv_gbl_rank_prefix_sum_fwd->is_contiguous()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum_fwd->size(0) == NUM_MAX_NVL_PEERS and cached_recv_gbl_rank_prefix_sum_fwd->size(1) == num_ranks); + } } else { EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); @@ -747,8 +789,6 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalsize(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); - topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks @@ -759,12 +799,39 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalscalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + EP_HOST_ASSERT(not zero_copy or x_scales->is_contiguous()); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = static_cast(x_scales->data_ptr()); scale_token_stride = static_cast(x_scales->stride(0)); scale_hidden_stride = static_cast(x_scales->stride(1)); } + void *rdma_x_ptr; + size_t rdma_workspace_size = config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks, zero_copy); + size_t rdma_buffer_size_needed = rdma_workspace_size; + bool with_topk = topk_idx.has_value(); + bool with_scales = x_scales.has_value(); + if (zero_copy) { + InternodeDispatchLayout layout(rdma_fused_buffer_ptr, num_tokens, hidden, num_topk, num_ranks, with_topk, with_scales, num_zcopy_buffers); + rdma_x_ptr = layout.buffers.at(zcopy_buffer_id).x; + if (with_topk) { + topk_idx_ptr = layout.buffers.at(zcopy_buffer_id).topk_idx; + topk_weights_ptr = layout.buffers.at(zcopy_buffer_id).topk_weights; + } + if (with_scales) { + x_scales_ptr = layout.buffers.at(zcopy_buffer_id).x_scales; + } + } else { + rdma_x_ptr = x.data_ptr(); + if (with_topk) { + topk_idx_ptr = topk_idx->data_ptr(); + topk_weights_ptr = topk_weights->data_ptr(); + } + if (with_scales) { + x_scales_ptr = static_cast(x_scales->data_ptr()); + } + } + EP_HOST_ASSERT(rdma_buffer_size_needed <= num_rdma_bytes); + // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! auto compute_stream = at::cuda::getCurrentCUDAStream(); @@ -787,6 +854,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional num_recv_tokens_per_expert_list; + std::optional recv_gbl_rank_prefix_sum_fwd; // Barrier or send sizes if (cached_mode) { @@ -796,6 +864,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), + zero_copy ? recv_gbl_rank_prefix_sum_fwd->data_ptr() : nullptr, rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - barrier_signal_ptrs_gpu, rank, comm_stream, - config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + barrier_signal_ptrs_gpu, rank, zero_copy, comm_stream, + rdma_workspace_size, num_nvl_bytes, low_latency_mode); // Synchronize total received tokens and tokens per expert @@ -855,16 +930,20 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } - // Allocate new tensors - auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); - auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); - auto recv_src_meta = std::optional(); - auto recv_rdma_channel_prefix_matrix = std::optional(); - auto recv_gbl_channel_prefix_matrix = std::optional(); - auto send_rdma_head = std::optional(); - auto send_nvl_head = std::optional(); + // Prepare output tensors and assign pointers + at::Tensor recv_x; + std::optional recv_topk_idx; + std::optional recv_topk_weights; + std::optional recv_x_scales; + std::optional recv_src_meta; + std::optional recv_rdma_channel_prefix_matrix; + std::optional recv_gbl_channel_prefix_matrix; + std::optional send_rdma_head; + std::optional send_nvl_head; + void* recv_x_ptr = nullptr; + void *recv_src_meta_ptr = nullptr; + if (not cached_mode) { - recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); @@ -875,35 +954,91 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionaloptions()); - recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); - recv_topk_weights_ptr = recv_topk_weights->data_ptr(); - } if (x_scales.has_value()) { recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } + if (zero_copy) { + const size_t x_bytes_per_token = hidden * x.element_size(); + const size_t topk_idx_bytes_per_token = num_topk * sizeof(topk_idx_t); + const size_t topk_weights_bytes_per_token = num_topk * sizeof(float); + const size_t x_scales_bytes_per_token = num_scales * sizeof(float); + const size_t src_meta_bytes_per_token = internode_zcopy::get_source_meta_bytes(); + + // TMA alignment requirements + EP_HOST_ASSERT(x_bytes_per_token % 16 == 0); + EP_HOST_ASSERT(topk_idx_bytes_per_token % 16 == 0); + EP_HOST_ASSERT(topk_weights_bytes_per_token % 16 == 0); + EP_HOST_ASSERT(x_scales_bytes_per_token % 16 == 0); + EP_HOST_ASSERT(src_meta_bytes_per_token % 16 == 0); + + const size_t recv_x_bytes = num_recv_tokens * x_bytes_per_token; + const size_t recv_topk_idx_bytes = num_recv_tokens * topk_idx_bytes_per_token; + const size_t recv_topk_weights_bytes = num_recv_tokens * topk_weights_bytes_per_token; + const size_t recv_x_scales_bytes = num_recv_tokens * x_scales_bytes_per_token; + const size_t recv_src_meta_bytes = num_recv_tokens * src_meta_bytes_per_token; + + void* recv_x_ptr = reinterpret_cast(buffer_fused_ptrs[nvl_rank]) + NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER * num_zcopy_buffers + NUM_DISPATCH_OUTPUT_BYTES_PER_ZCOPY_BUFFER * zcopy_buffer_id; + recv_x = torch::from_blob(recv_x_ptr, {num_recv_tokens, hidden}, torch::TensorOptions().dtype(x.scalar_type()).device(torch::kCUDA)); + if (not cached_mode) { + recv_src_meta_ptr = reinterpret_cast(reinterpret_cast(recv_x_ptr) + recv_x_bytes + recv_topk_idx_bytes + recv_topk_weights_bytes + recv_x_scales_bytes); + recv_src_meta = std::make_optional(torch::from_blob(recv_src_meta_ptr, {num_recv_tokens, internode_zcopy::get_source_meta_bytes()}, torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA))); + } + + if (topk_idx.has_value()) { + recv_topk_idx_ptr = reinterpret_cast(reinterpret_cast(recv_x_ptr) + recv_x_bytes); + recv_topk_weights_ptr = reinterpret_cast(reinterpret_cast(recv_x_ptr) + recv_x_bytes + recv_topk_idx_bytes); + recv_topk_idx = std::make_optional(torch::from_blob(recv_topk_idx_ptr, {num_recv_tokens, num_topk}, torch::TensorOptions().dtype(topk_idx->scalar_type()).device(torch::kCUDA))); + recv_topk_weights = std::make_optional(torch::from_blob(recv_topk_weights_ptr, {num_recv_tokens, num_topk}, torch::TensorOptions().dtype(topk_weights->scalar_type()).device(torch::kCUDA))); + } + if (x_scales.has_value()) { + recv_x_scales_ptr = reinterpret_cast(reinterpret_cast(recv_x_ptr) + recv_x_bytes + recv_topk_idx_bytes + recv_topk_weights_bytes); + recv_x_scales = x_scales->dim() == 1 ? + std::make_optional(torch::from_blob(recv_x_scales_ptr, {num_recv_tokens}, torch::TensorOptions().dtype(x_scales->scalar_type()).device(torch::kCUDA))) : + std::make_optional(torch::from_blob(recv_x_scales_ptr, {num_recv_tokens, num_scales}, torch::TensorOptions().dtype(x_scales->scalar_type()).device(torch::kCUDA))); + } + } else { + recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + recv_x_ptr = recv_x.data_ptr(); + if (not cached_mode) { + recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); + recv_src_meta_ptr = recv_src_meta->data_ptr(); + } + + if (topk_idx.has_value()) { + recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->dim() == 1 ? + torch::empty({num_recv_tokens}, x_scales->options()) : + torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); + } + } // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file - internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, - cached_mode ? nullptr : recv_src_meta->data_ptr(), - x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, + internode::dispatch(recv_x_ptr, recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, + cached_mode ? nullptr : recv_src_meta_ptr, + rdma_x_ptr, x_scales_ptr, topk_idx_ptr, topk_weights_ptr, cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), + recv_gbl_rank_prefix_sum_fwd.has_value() ? recv_gbl_rank_prefix_sum_fwd->data_ptr() : nullptr, is_token_in_rank.data_ptr(), num_tokens, hidden_int4, num_scales, num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, + buffer_fused_ptrs_gpu, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, + zero_copy, num_zcopy_buffers, zcopy_buffer_id, comm_stream, num_channels, low_latency_mode); // Wait streams @@ -920,8 +1055,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalrecord_stream(comm_stream) : void(); if (allocate_on_comm_stream) @@ -940,6 +1076,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional(recv_gbl_rank_prefix_sum_fwd) : std::nullopt, recv_src_meta, send_rdma_head, send_nvl_head, event}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); @@ -952,11 +1089,15 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const std::optional& gbl_rank_prefix_sum_fwd, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + const Config& config, std::optional& previous_event, + bool zero_copy, int zcopy_buffer_id, + bool async, bool allocate_on_comm_stream) { #ifndef DISABLE_NVSHMEM - const int num_channels = config.num_sms / 2; - EP_HOST_ASSERT(config.num_sms % 2 == 0); + const int sms_per_channel = zero_copy ? 1 : 2; + const int num_channels = config.num_sms / sms_per_channel; + EP_HOST_ASSERT(config.num_sms % sms_per_channel == 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); @@ -965,19 +1106,26 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optionaldim() == 2 and gbl_rank_prefix_sum_fwd->is_contiguous() and gbl_rank_prefix_sum_fwd->scalar_type() == torch::kInt32); + } EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); - EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); + EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes(zero_copy)); EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); + if (zero_copy) { + EP_HOST_ASSERT(gbl_rank_prefix_sum_fwd->size(0) == NUM_MAX_NVL_PEERS and gbl_rank_prefix_sum_fwd->size(1) == num_ranks); + } // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! @@ -1020,8 +1168,8 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - barrier_signal_ptrs_gpu, rank, comm_stream, - config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + barrier_signal_ptrs_gpu, rank, zero_copy, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks, zero_copy), num_nvl_bytes, false, low_latency_mode); // Assign bias pointers @@ -1035,18 +1183,29 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(), - x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], + x_ptr, topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), + gbl_rank_prefix_sum_fwd.has_value() ? gbl_rank_prefix_sum_fwd->data_ptr() : nullptr, num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, - rank, num_ranks, comm_stream, num_channels, low_latency_mode); + buffer_fused_ptrs_gpu, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, + rank, num_ranks, zero_copy, zcopy_buffer_id, comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; @@ -1348,6 +1507,46 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank #endif } +std::tuple, std::optional, std::optional> +Buffer::get_internode_dispatch_buffer(int num_tokens, int hidden, int num_topk, bool with_topk, bool use_fp8, int buffer_id) const { +#ifndef DISABLE_NVSHMEM + InternodeDispatchLayout layout(rdma_fused_buffer_ptr, num_tokens, hidden, num_topk, num_ranks, with_topk, use_fp8, num_zcopy_buffers); + EP_HOST_ASSERT(0 <= buffer_id and buffer_id < num_zcopy_buffers); + auto buffer = layout.buffers[buffer_id]; + + EP_HOST_ASSERT(layout.total_bytes <= NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER * num_zcopy_buffers); + + auto x = use_fp8 ? torch::from_blob(buffer.x, {num_tokens, hidden}, torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA)) : torch::from_blob(buffer.x, {num_tokens, hidden}, torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA)); + auto x_scales = use_fp8 ? std::make_optional(torch::from_blob(buffer.x_scales, {num_tokens, hidden / 128}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA))) : std::optional(); + auto topk_idx = with_topk ? std::make_optional(torch::from_blob(buffer.topk_idx, {num_tokens, num_topk}, torch::TensorOptions().dtype(c10::CppTypeToScalarType::value).device(torch::kCUDA))) : std::optional(); + auto topk_weights = with_topk ? std::make_optional(torch::from_blob(buffer.topk_weights, {num_tokens, num_topk}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA))) : std::optional(); + + return {x, x_scales, topk_idx, topk_weights}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} + +std::tuple> +Buffer::get_internode_combine_buffer(int num_tokens, int hidden, int num_topk, bool with_topk, int buffer_id) const { +#ifndef DISABLE_NVSHMEM + InternodeCombineLayout layout(buffer_fused_ptrs[nvl_rank], num_tokens, hidden, num_topk, with_topk, num_zcopy_buffers); + EP_HOST_ASSERT(0 <= buffer_id and buffer_id < num_zcopy_buffers); + auto buffer = layout.buffers[buffer_id]; + + EP_HOST_ASSERT(layout.total_bytes <= NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER * num_zcopy_buffers); + + auto x = torch::from_blob(buffer.x, {num_tokens, hidden}, torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA)); + auto topk_weights = with_topk ? std::make_optional(torch::from_blob(buffer.topk_weights, {num_tokens, num_topk}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA))) : std::optional(); + + return {x, topk_weights}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} + bool is_sm90_compiled() { #ifndef DISABLE_SM90_FEATURES return true; @@ -1386,8 +1585,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("num_sms") = 20, py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) - .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) - .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); + .def("get_nvl_buffer_size_hint", + &deep_ep::Config::get_nvl_buffer_size_hint, + py::arg("hidden_bytes"), + py::arg("num_ranks"), + py::arg("zero_copy") = false) + .def("get_rdma_buffer_size_hint", + &deep_ep::Config::get_rdma_buffer_size_hint, + py::arg("hidden_bytes"), + py::arg("num_ranks"), + py::arg("zero_copy") = false); m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); pybind11::class_(m, "EventHandle") @@ -1395,7 +1602,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); pybind11::class_(m, "Buffer") - .def(pybind11::init()) + .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) @@ -1412,6 +1619,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("intranode_combine", &deep_ep::Buffer::intranode_combine) .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch) .def("internode_combine", &deep_ep::Buffer::internode_combine) + .def("get_internode_dispatch_buffer", &deep_ep::Buffer::get_internode_dispatch_buffer) + .def("get_internode_combine_buffer", &deep_ep::Buffer::get_internode_combine_buffer) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index e4981ced..6eb3985a 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -35,9 +35,18 @@ struct Buffer { void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** buffer_ptrs_gpu = nullptr; + // For zero-copy NVLink send/recv + int64_t buffer_fused_bytes; + void* buffer_fused_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + void** buffer_fused_ptrs_gpu = nullptr; + // NVSHMEM Buffer int64_t num_rdma_bytes; void* rdma_buffer_ptr = nullptr; + void* rdma_fused_buffer_ptr = nullptr; + + // Splitting zero-copy buffers for multi-batch overlapping + int num_zcopy_buffers; // Shrink mode buffer bool enable_shrink = false; @@ -82,7 +91,7 @@ struct Buffer { int* moe_recv_rdma_counter_mapped = nullptr; public: - Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, bool enable_shrink); + Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, bool enable_shrink, int num_zcopy_buffers); ~Buffer() noexcept(false); @@ -128,7 +137,7 @@ struct Buffer { const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); - std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> + std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional, std::optional> internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, @@ -136,15 +145,17 @@ struct Buffer { int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + const std::optional& cached_recv_gbl_rank_prefix_sum_fwd, + int expert_alignment, const Config& config, std::optional& previous_event, bool zero_copy, int zcopy_buffer_id, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> internode_combine(const torch::Tensor& x, const std::optional& topk_weights, const std::optional& bias_0, const std::optional& bias_1, const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const std::optional& gbl_rank_prefix_sum_fwd, const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + const Config& config, std::optional& previous_event, bool zero_copy, int zcopy_buffer_id, bool async, bool allocate_on_comm_stream); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); @@ -172,6 +183,12 @@ struct Buffer { void low_latency_query_mask_buffer(const torch::Tensor& mask_status); void low_latency_clean_mask_buffer(); + + std::tuple, std::optional, std::optional> + get_internode_dispatch_buffer(int num_tokens, int hidden, int num_topk, bool with_topk, bool use_fp8, int buffer_id) const; + + std::tuple> + get_internode_combine_buffer(int num_tokens, int hidden, int num_topk, bool with_topk, int buffer_id) const; }; } // namespace deep_ep diff --git a/csrc/kernels/CMakeLists.txt b/csrc/kernels/CMakeLists.txt index 22e34a38..ee3fec72 100644 --- a/csrc/kernels/CMakeLists.txt +++ b/csrc/kernels/CMakeLists.txt @@ -15,7 +15,8 @@ add_deep_ep_library(runtime_cuda runtime.cu) add_deep_ep_library(layout_cuda layout.cu) add_deep_ep_library(intranode_cuda intranode.cu) add_deep_ep_library(internode_cuda internode.cu) +add_deep_ep_library(internode_zcopy_cuda internode_zcopy.cu) add_deep_ep_library(internode_ll_cuda internode_ll.cu) # Later, we should link all libraries in `EP_CUDA_LIBRARIES` -set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE) +set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_zcopy_cuda internode_ll_cuda PARENT_SCOPE) diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 0084cd4f..d565835d 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -84,6 +84,8 @@ namespace internode { int get_source_meta_bytes(); +int get_source_meta_bytes(bool is_zero_copy); + void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, @@ -91,9 +93,10 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe int hidden_int4, int num_scales, int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + int* recv_gbl_rank_prefix_sum_fwd, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** barrier_signal_ptrs, int rank, + int** barrier_signal_ptrs, int rank, bool zero_copy, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode); @@ -103,12 +106,14 @@ void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, flo int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const int* recv_gbl_rank_prefix_sum_fwd, const bool* is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + void** buffer_fused_ptrs, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, bool is_cached_dispatch, + bool zero_copy, int num_zcopy_buffers, int zcopy_buffer_id, cudaStream_t stream, int num_channels, bool low_latency_mode); void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, @@ -116,7 +121,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** barrier_signal_ptrs, int rank, cudaStream_t stream, + int** barrier_signal_ptrs, int rank, bool zero_copy, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode); @@ -127,13 +132,57 @@ void combine(cudaDataType_t type, const void* bias_0, const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum_fwd, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode); + void** buffer_fused_ptrs, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool zero_copy, int zcopy_buffer_id, cudaStream_t stream, int num_channels, bool low_latency_mode); } // namespace internode +// Internode zero-copy kernels + +namespace internode_zcopy { + +static constexpr int get_source_meta_bytes() { + return 16; // There is an assertion in the .cu file to make sure this is accurate +} + +// TODO: Zero-copy: Eliminate these duplicate definitions +__host__ __device__ +std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, + int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, + int num_channels, bool is_dispatch); + +void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, + const void* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const int* recv_gbl_rank_prefix_sum_fwd, + const bool* is_token_in_rank, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_fused_ptrs, void** buffer_ptrs, int num_zcopy_buffers, int buffer_id, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool is_cached_dispatch, + cudaStream_t stream, int num_channels, bool low_latency_mode); + +void combine(cudaDataType_t type, + void* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, + const int* combined_rdma_head, const int* combined_nvl_head, + const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum_fwd, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_fused_ptrs, void** buffer_ptrs, int buffer_id, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode); + +} // namespace internode_zcopy + // Internode low-latency kernels namespace internode_ll { diff --git a/csrc/kernels/buffer.cuh b/csrc/kernels/buffer.cuh index 7c243d3c..017314be 100644 --- a/csrc/kernels/buffer.cuh +++ b/csrc/kernels/buffer.cuh @@ -11,11 +11,11 @@ private: uint8_t* ptr; public: - int total_bytes; + uint32_t total_bytes; __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} - __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { + __device__ __forceinline__ Buffer(void* &gbl_ptr, uint32_t num_elems, uint32_t offset = 0) { total_bytes = num_elems * sizeof(dtype_t); ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; @@ -30,47 +30,47 @@ public: return reinterpret_cast(ptr); } - __device__ __forceinline__ dtype_t& operator[](int idx) { + __device__ __forceinline__ dtype_t& operator[](uint32_t idx) { return buffer()[idx]; } }; -template +template struct AsymBuffer { private: uint8_t* ptrs[kNumRanks]; - int num_bytes; + uint32_t num_bytes; public: - int total_bytes; + uint32_t total_bytes; - __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1, int offset = 0) { + __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, uint32_t num_elems, uint32_t num_ranks, + uint32_t sm_id = 0, uint32_t num_sms = 1, uint32_t offset = 0) { EP_STATIC_ASSERT(kNumRanks == 1, ""); num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + uint32_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; } - __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1, int offset = 0) { + __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, uint32_t num_elems, uint32_t num_ranks, + uint32_t sm_id = 0, uint32_t num_sms = 1, uint32_t offset = 0) { EP_STATIC_ASSERT(kNumRanks > 1, ""); num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + uint32_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; - for (int i = 0; i < kNumRanks; ++ i) { + for (uint32_t i = 0; i < kNumRanks; ++ i) { ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; } } - __device__ __forceinline__ void advance(int shift) { + __device__ __forceinline__ void advance(uint32_t shift) { #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) + for (uint32_t i = 0; i < kNumRanks; ++ i) ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); } @@ -79,19 +79,19 @@ public: return *this; } - template + template __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { - for (int i = 0; i < kNumAlsoRanks; ++ i) + for (uint32_t i = 0; i < kNumAlsoRanks; ++ i) gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; return *this; } - __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + __device__ __forceinline__ dtype_t* buffer(uint32_t idx = 0) { EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[0] + num_bytes * idx); } - __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { + __device__ __forceinline__ dtype_t* buffer_by(uint32_t rank_idx, uint32_t idx = 0) { EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); } @@ -103,33 +103,33 @@ private: // NOTES: for non-decoupled case, `recv_ptr` is not used uint8_t* send_ptr; uint8_t* recv_ptr; - int num_bytes; + uint32_t num_bytes; public: - int total_bytes; + uint32_t total_bytes; - __device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1) { + __device__ __forceinline__ SymBuffer(void* &gbl_ptr, uint32_t num_elems, uint32_t num_ranks, + uint32_t sm_id = 0, uint32_t num_sms = 1) { num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + uint32_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; } - __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { + __device__ __forceinline__ dtype_t* send_buffer(uint32_t idx = 0) { EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); return reinterpret_cast(send_ptr + num_bytes * idx); } - __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { + __device__ __forceinline__ dtype_t* recv_buffer(uint32_t idx = 0) { EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); return reinterpret_cast(recv_ptr + num_bytes * idx); } - __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + __device__ __forceinline__ dtype_t* buffer(uint32_t idx = 0) { EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); return reinterpret_cast(send_ptr + num_bytes * idx); } diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index 4df478fe..087142fb 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -5,6 +5,22 @@ #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) #define NUM_MAX_LOCAL_EXPERTS 1024 #define NUM_BUFFER_ALIGNMENT_BYTES 128 +// TODO: Zero-copy: Make these runtime configurable options +#define NUM_MAX_SGE_PER_WQE 60 // The DS field of WQE is only 6 bits; 2^6 - 1 = 63 +// In the zero-copy variants of internode kernels, the kernel itself uses very little NVL buffer. +// So we reserve this (heuristically large-enough) length of memory for the dispatch/combine +// kernels and use the following space for metadata sync in notify. +#define ZCOPY_NOTIFY_NVL_METADATA_OFFSET_INTS 65536 +#define ZCOPY_TMA_SMEM_ALIGNMENT 1024 +#define NUM_MAX_ZCOPY_DISPATCH_TOKENS 4096 +// Zero-copy: for the 8-of-256 experts case, *8 is no smaller than the average recv count. +#define NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER ((unsigned long)NUM_MAX_ZCOPY_DISPATCH_TOKENS * 16384) +#define NUM_DISPATCH_OUTPUT_BYTES_PER_ZCOPY_BUFFER ((unsigned long)NUM_MAX_ZCOPY_DISPATCH_TOKENS * 16384 * 8) +#define NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER ((unsigned long)NUM_MAX_ZCOPY_DISPATCH_TOKENS * 16384 * 8) +static_assert(NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + NUM_DISPATCH_OUTPUT_BYTES_PER_ZCOPY_BUFFER % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER % NUM_BUFFER_ALIGNMENT_BYTES == 0, + "Zero-copy buffer sizes are not properly aligned"); #define FINISHED_SUM_TAG 1024 #define NUM_WAIT_NANOSECONDS 500 diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 7fab70c1..41999fa2 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -68,6 +68,12 @@ typedef struct { uint64_t reserved; } __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t; +typedef struct { + uint64_t addr; + uint32_t length; + uint32_t lkey; +} __attribute__((__packed__)) ibgda_sge_t; + __device__ static __forceinline__ nvshmemi_ibgda_device_state_t* ibgda_get_state() { return &nvshmemi_ibgda_device_state_d; @@ -251,6 +257,43 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, *out_rkey = device_key.key; } +__device__ static __forceinline__ +uint64_t ibgda_get_my_lkey(uint64_t laddr, __be32 *lkey) { + auto state = ibgda_get_state(); + auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); + auto log2_cumem_granularity = state->log2_cumem_granularity; + + // Local key + uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity; + auto device_key = state->constmem.lkeys[idx]; + auto lchunk_size = device_key.next_addr - laddr; + *lkey = device_key.key; + + return lchunk_size; +} + +__device__ static __forceinline__ +uint64_t ibgda_get_my_rkey(uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) { + auto state = ibgda_get_state(); + auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); + auto log2_cumem_granularity = state->log2_cumem_granularity; + + // Remote key + uint64_t roffset = raddr - heap_start; + uint64_t idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe; + nvshmemi_ibgda_device_key_t device_key; + if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) { + device_key = state->constmem.rkeys[idx]; + } else { + device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; + } + *out_raddr = reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; + *out_rkey = device_key.key; + + auto rchunk_size = device_key.next_addr - roffset; + return rchunk_size; +} + __device__ static __forceinline__ uint64_t ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) { auto mvars = &qp->mvars; @@ -320,6 +363,67 @@ ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be3 st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } +__device__ static __forceinline__ void +ibgda_write_rdma_write_wqe_warp_multi_sge(nvshmemi_ibgda_device_qp_t *qp, ibgda_sge_t *sge_list, int num_sge, + uint64_t raddr, __be32 rkey, uint16_t wqe_idx, int lane_id) { + void *wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); + + auto *ctrl_seg_ptr = reinterpret_cast(wqe_ptr); + void *av_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + auto raddr_seg_ptr = reinterpret_cast(reinterpret_cast(av_seg_ptr)); + + if (lane_id == 0) { + // wqe_idx: WQEBB number of the first block of this WQE. + // ds: WQE size in octowords (16-byte units). + ibgda_ctrl_seg_t ctrl_seg; + ctrl_seg = {0}; + const auto ds = 2 + num_sge; + EP_DEVICE_ASSERT(ds < (2<<6)); + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | ds); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE); + + struct mlx5_wqe_raddr_seg raddr_seg; + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); + EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); + st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + } + + const ptrdiff_t wq_size_bytes = qp->tx_wq.nwqes << MLX5_SEND_WQE_SHIFT; + const auto wq_start_addr = reinterpret_cast(ibgda_get_wqe_ptr(qp, 0)); + const ptrdiff_t first_data_seg_offset = reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr) - wq_start_addr; + + for (int i = lane_id; i < num_sge; i += 32) { + __be32 sge_lkey = 0; + uint64_t sge_laddr = sge_list[i].addr; + uint64_t my_chunk_size = ibgda_get_my_lkey(sge_laddr, &sge_lkey); + uint32_t sge_bytes = sge_list[i].length; + EP_DEVICE_ASSERT(sge_bytes <= my_chunk_size); + + struct mlx5_wqe_data_seg *data_seg_ptr; + EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16"); + + ptrdiff_t data_seg_offset = first_data_seg_offset + i * sizeof(*data_seg_ptr); + if (data_seg_offset >= wq_size_bytes) { + data_seg_offset -= wq_size_bytes; + } + data_seg_ptr = reinterpret_cast(wq_start_addr + data_seg_offset); + + struct mlx5_wqe_data_seg data_seg = { + .byte_count = HtoBE32(sge_bytes), + .lkey = sge_lkey, + .addr = HtoBE64(sge_laddr) + }; + st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); + } + __syncwarp(); +} + __device__ static __forceinline__ void ibgda_write_empty_recv_wqe(void *out_wqe) { auto *data_seg_ptr = reinterpret_cast(out_wqe); @@ -389,6 +493,35 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, __syncwarp(); } +template +__device__ static __forceinline__ void +nvshmemi_ibgda_put_nbi_warp_multi_sge_parallel(ibgda_sge_t *sge_list, int num_sge, uint64_t req_rptr, size_t bytes, int dst_pe, int qp_id, int lane_id) { + // Get rkey, store them into lanes + __be32 my_rkey = 0; + uint64_t my_raddr = 0; + uint64_t my_chunk_size = 0; + + // Decide how many messages + my_chunk_size = ibgda_get_my_rkey(req_rptr, dst_pe, &my_raddr, &my_rkey); + EP_DEVICE_ASSERT(bytes <= my_chunk_size); + + // Process WQE + uint32_t num_wqebbs = ((num_sge + 2) + 3) / 4; // +3 for Round up + auto qp = ibgda_get_rc(dst_pe, qp_id); + uint64_t base_wqe_idx = 0; + if (lane_id == 0) { + base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqebbs); + } + base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); + ibgda_write_rdma_write_wqe_warp_multi_sge(qp, sge_list, num_sge, my_raddr, my_rkey, base_wqe_idx, lane_id); + __syncwarp(); + + // Submit + if (lane_id == 0) + ibgda_submit_requests(qp, base_wqe_idx, num_wqebbs); + __syncwarp(); +} + __device__ static __forceinline__ void ibgda_write_amo_add_wqe( nvshmemi_ibgda_device_qp_t *qp, const int &value, uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 9aaa7ce2..0e13bf54 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -2,6 +2,7 @@ #include #include "configs.cuh" +#include "api.cuh" #include "buffer.cuh" #include "exception.cuh" #include "launch.cuh" @@ -37,20 +38,28 @@ struct SourceMeta { EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); +int get_source_meta_bytes(bool is_zero_copy) { + return is_zero_copy ? internode_zcopy::get_source_meta_bytes() : internode::get_source_meta_bytes(); +} + int get_source_meta_bytes() { return sizeof(SourceMeta); } +template __host__ __device__ __forceinline__ int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { - return static_cast(align_up(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); + return static_cast(align_up(hidden_int4 * sizeof(int4) + SourceMetaBytes + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); } __host__ __device__ __forceinline__ std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, - int num_channels) { + int num_channels, bool zero_copy, bool is_dispatch) { // Return `int32_t` offset and count to clean + if (zero_copy) { + return internode_zcopy::get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_rdma_recv_buffer_tokens, num_channels, is_dispatch); + } return { (get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int), (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels @@ -60,10 +69,24 @@ std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num __host__ __device__ __forceinline__ std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, - int num_nvl_recv_buffer_tokens, int num_channels, bool is_dispatch) { + int num_nvl_recv_buffer_tokens, int num_channels, bool zero_copy, bool is_dispatch) { // Return `int32_t` offset and to clean EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + if (zero_copy) { + if (is_dispatch) { + return { + // The actual recv data is in the fused buffer + ZCOPY_NOTIFY_NVL_METADATA_OFFSET_INTS, + num_nvl_ranks * (2 * num_rdma_ranks + 1) * num_channels, + }; + } + return { + ZCOPY_NOTIFY_NVL_METADATA_OFFSET_INTS, + num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, + }; + } + return { (num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks * num_channels) / sizeof(int), num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, @@ -90,6 +113,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in const int nvl_clean_offset, const int nvl_num_int_clean, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + int* recv_gbl_rank_prefix_sum_fwd, void* rdma_buffer_ptr, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, const nvshmem_team_t rdma_team) { @@ -177,11 +201,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + // Only used for the zero-copy case + auto nvl_send_num_tokens_per_rank_fwd = AsymBuffer(nvl_send_buffer, kNumRDMARanks * NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank_fwd = AsymBuffer(nvl_recv_buffer, kNumRDMARanks * NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS); // Clean up for later data dispatch auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + - nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int)); + nvl_send_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank_fwd.total_bytes <= nvl_clean_offset * sizeof(int)); #pragma unroll for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; @@ -245,6 +272,23 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in moe_recv_expert_counter_mapped[thread_id] = sum; } + if (recv_gbl_rank_prefix_sum_fwd) { + if (thread_id < NUM_MAX_NVL_PEERS) { + #pragma unroll + for (int i = 0; i < num_ranks; i++) { + nvl_send_num_tokens_per_rank_fwd.buffer(nvl_rank)[i] = recv_gbl_rank_prefix_sum[i]; + } + } + barrier_block(barrier_signal_ptrs, nvl_rank); + + if (thread_id < NUM_MAX_NVL_PEERS) { + #pragma unroll + for (int i = 0; i < num_ranks; i++) { + recv_gbl_rank_prefix_sum_fwd[thread_id * num_ranks + i] = nvl_recv_num_tokens_per_rank_fwd.buffer(thread_id)[i]; + } + } + } + // Finally barrier if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); @@ -309,9 +353,11 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe int hidden_int4, int num_scales, int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + int* recv_gbl_rank_prefix_sum_fwd, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, int** barrier_signal_ptrs, int rank, + bool zero_copy, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode) { #define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ @@ -326,6 +372,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe nvl_clean_meta.first, nvl_clean_meta.second, \ rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ + recv_gbl_rank_prefix_sum_fwd, \ rdma_buffer_ptr, \ buffer_ptrs, barrier_signal_ptrs, rank, \ cpu_rdma_team); } break @@ -334,13 +381,15 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; // Get clean meta - auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); - auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, true); + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels, zero_copy, true); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, zero_copy, true); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(not zero_copy or recv_gbl_rank_prefix_sum_fwd); + // Launch kernel SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); @@ -997,13 +1046,35 @@ void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, flo int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const int* recv_gbl_rank_prefix_sum_fwd, const bool* is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + void** buffer_fused_ptrs, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, bool is_cached_dispatch, + bool zero_copy, int num_zcopy_buffers, int zcopy_buffer_id, cudaStream_t stream, int num_channels, bool low_latency_mode) { + if (zero_copy) { + // TODO: Zero-copy: Support ue8m0 scenario + EP_HOST_ASSERT((scale_token_stride == num_scales and scale_hidden_stride == 1) or (scale_token_stride == 0 and scale_hidden_stride == 0)); + return internode_zcopy::dispatch( + recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, recv_src_meta, + x, x_scales, topk_idx, topk_weights, + send_rdma_head, send_nvl_head, + recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, + recv_gbl_rank_prefix_sum_fwd, + is_token_in_rank, + num_tokens, hidden_int4, num_scales, num_topk, num_experts, + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, + buffer_fused_ptrs, buffer_ptrs, num_zcopy_buffers, zcopy_buffer_id, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, + rank, num_ranks, is_cached_dispatch, + stream, num_channels, low_latency_mode + ); + } + constexpr int kNumDispatchRDMASenderWarps = 7; constexpr int kNumTMABytesPerWarp = 16384; constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; @@ -1040,7 +1111,7 @@ void dispatch(void* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, flo #undef DISPATCH_LAUNCH_CASE } -template +template __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset, const int nvl_num_int_clean, int* combined_rdma_head, int num_combined_tokens, int num_channels, @@ -1141,7 +1212,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in } __syncwarp(); - for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) { + for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * kNumBlocksPerChannel - 2) { // Iterate in reverse order int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; @@ -1188,7 +1259,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** barrier_signal_ptrs, int rank, cudaStream_t stream, + int** barrier_signal_ptrs, int rank, bool zero_copy, cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool is_cached_dispatch, bool low_latency_mode) { const int num_threads = std::max(128, 32 * num_channels); @@ -1196,19 +1267,22 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; const int kNumTMABytesPerWarp = 8192; const int smem_size = kNumTMABytesPerWarp * num_warps; + const int num_blocks_per_channel = zero_copy ? 1 : 2; // Get clean meta - auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); - auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, is_cached_dispatch); + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels, zero_copy, is_cached_dispatch); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels, zero_copy, is_cached_dispatch); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); - EP_HOST_ASSERT(num_channels * 2 > 3); + EP_HOST_ASSERT(num_channels * num_blocks_per_channel > 3); // Launch kernel - auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; - SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); + auto cached_notify_func = + zero_copy ? (low_latency_mode ? cached_notify : cached_notify) : + (low_latency_mode ? cached_notify : cached_notify); + SETUP_LAUNCH_CONFIG(num_channels * num_blocks_per_channel, num_threads, stream); SET_SHARED_MEMORY_FOR_TMA(cached_notify_func); LAUNCH_KERNEL(&cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second, @@ -1828,10 +1902,27 @@ void combine(cudaDataType_t type, const void* bias_0, const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum_fwd, int num_tokens, int num_combined_tokens, int hidden, int num_topk, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) { + void** buffer_fused_ptrs, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool zero_copy, int zcopy_buffer_id, cudaStream_t stream, int num_channels, bool low_latency_mode) { + if (zero_copy) { + return internode_zcopy::combine( + type, combined_x, combined_topk_weights, + is_combined_token_in_rank, + x, topk_weights, + bias_0, bias_1, + combined_rdma_head, combined_nvl_head, + src_meta, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum_fwd, + num_tokens, num_combined_tokens, hidden, num_topk, + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, + buffer_fused_ptrs, buffer_ptrs, zcopy_buffer_id, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, + rank, num_ranks, stream, num_channels, low_latency_mode + ); + } + constexpr int kNumCombineForwarderWarps = 24; constexpr int kNumTMABytesPerSenderWarp = 16384; constexpr int kNumTMABytesPerForwarderWarp = 9248; diff --git a/csrc/kernels/internode_zcopy.cu b/csrc/kernels/internode_zcopy.cu new file mode 100644 index 00000000..8c87a18d --- /dev/null +++ b/csrc/kernels/internode_zcopy.cu @@ -0,0 +1,1209 @@ +#include "configs.cuh" +#include "api.cuh" +#include "buffer.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" +#include "ibgda_device.cuh" + +namespace deep_ep { + +namespace internode { + +extern nvshmem_team_t cpu_rdma_team; + +} + +namespace internode_zcopy { + +// TODO: Zero-copy: Eliminate duplicate definitions +struct SourceMeta { + int src_rdma_rank, is_token_in_nvl_rank_bits; + int token_idx; // Used in local token dispatch as an indirect reference to the original token + int dummy; + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); + + __forceinline__ SourceMeta() = default; + + // TODO: faster encoding + __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks, int token_idx = 0) { + src_rdma_rank = rdma_rank; + is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; + this->token_idx = token_idx; + #pragma unroll + for (int i = 1; i < NUM_MAX_NVL_PEERS; ++ i) + is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; + } + + __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { + return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; + } +}; + +EP_STATIC_ASSERT(sizeof(SourceMeta) == get_source_meta_bytes(), "Invalid size of `SourceMeta`"); + +template +__host__ __device__ __forceinline__ +int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { + return static_cast(align_up(hidden_int4 * sizeof(int4) + SourceMetaBytes + num_scales * sizeof(float) + num_topk_idx * sizeof(topk_idx_t) + num_topk_weights * sizeof(float), sizeof(int4))); +} + +__host__ __device__ +std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, + int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, + int num_channels, bool is_dispatch) { + if (is_dispatch) { + return { + get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 1 * num_channels / sizeof(int) + // recv buffer + get_source_meta_bytes() * ceil_div(NUM_MAX_ZCOPY_DISPATCH_TOKENS, num_channels) * num_rdma_ranks * 1 * num_channels / sizeof(int), // send buffer for SourceMeta + (NUM_MAX_NVL_PEERS * 2 + 2) * num_rdma_ranks * 2 * num_channels + // meta + 3 * num_rdma_ranks * 1 * num_channels * sizeof(uint64_t) / sizeof(int) // head & tail & recv finish signal + }; + } + return { + get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels / sizeof(int), // recv buffer + 2 * num_rdma_ranks * 1 * num_channels * sizeof(uint64_t) / sizeof(int) // head & tail + }; +} + +template +__forceinline__ __device__ static int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) { + return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank; +} + +// At most 8 RDMA ranks to be sent +constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { + return num_rdma_ranks < 8 ? num_rdma_ranks : 8; +} + +template +__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 2 + NUM_MAX_NVL_PEERS) * 32), 1) // 2 + 1 + 1 + 8 = 12 +dispatch(int4* recv_x, float* recv_x_scales, topk_idx_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta, + const int4* x, const float* x_scales, const topk_idx_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const int* recv_gbl_rank_prefix_sum_fwd, + const bool* is_token_in_rank, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_fused_ptrs, void** buffer_ptrs, int num_zcopy_buffers, int buffer_id, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks) { + enum class WarpRole { + kRDMASender, + kRDMAAndNVLForwarder, + kForwarderCoordinator, + kNVLReceivers + }; + + const auto sm_id = static_cast(blockIdx.x); + const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_channels = static_cast(gridDim.x), channel_id = sm_id; + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + + const auto role_meta = [=]() -> std::pair { + if (warp_id < kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASender, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps) { + return {WarpRole::kNVLReceivers, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps + 1) { + return {WarpRole::kForwarderCoordinator, -1}; + } else { + return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id - kNumDispatchRDMASenderWarps - 2) % NUM_MAX_NVL_PEERS}; + } + }(); + auto warp_role = role_meta.first; + auto target_rank = role_meta.second; // Not applicable for RDMA senders and NVL receivers + EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 2 + NUM_MAX_NVL_PEERS); + + // Data checks + EP_DEVICE_ASSERT(num_topk <= 32); + + // RDMA symmetric layout + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto num_bytes_per_rdma_token = get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk); + + auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); + EP_DEVICE_ASSERT(num_tokens <= NUM_MAX_ZCOPY_DISPATCH_TOKENS); + auto rdma_channel_src_meta = SymBuffer(rdma_buffer_ptr, ceil_div(NUM_MAX_ZCOPY_DISPATCH_TOKENS, num_channels), kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + // Each element indicates the RDMA rank has finished receiving (1) or not (0) + // TODO: Zero-copy: Replace this with polling CQ + auto rdma_channel_recv_finish_signal = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + // NVL buffer layouts + // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" + void *ws_rr_buffer_ptr = nullptr; + void *ws_rr_fused_buffer_ptr = nullptr; + if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + ws_rr_fused_buffer_ptr = shift_ptr(buffer_fused_ptrs[target_rank], NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER * num_zcopy_buffers + NUM_DISPATCH_OUTPUT_BYTES_PER_ZCOPY_BUFFER * buffer_id); + ws_rr_buffer_ptr = buffer_ptrs[target_rank]; + } + if (warp_role == WarpRole::kNVLReceivers) { + ws_rr_buffer_ptr = buffer_ptrs[nvl_rank]; + } + + ws_rr_buffer_ptr = reinterpret_cast(reinterpret_cast(ws_rr_buffer_ptr) + ZCOPY_NOTIFY_NVL_METADATA_OFFSET_INTS); + auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels); + auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels); + auto nvl_channel_finish_signal = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels); + + // RDMA sender warp synchronization + __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; + auto sync_rdma_sender_smem = []() { asm volatile("barrier.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps) * 32)); }; + + // Forward warp synchronization + __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; + __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; + auto sync_forwarder_smem = []() { asm volatile("barrier.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; + + // RDMA multi-sge list + __shared__ __align__(16) ibgda_sge_t sge_list_buf[kNumDispatchRDMASenderWarps * NUM_MAX_SGE_PER_WQE]; + ibgda_sge_t *sge_list = sge_list_buf + warp_id * NUM_MAX_SGE_PER_WQE; + + const size_t scale_bytes = num_scales * sizeof(float); + + if (warp_role == WarpRole::kRDMASender) { + // NOTES: in case of splitting the issued put at the end of the buffer + EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + + // Get tasks + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0; + + // Send number of tokens in this channel by `-value - 1` + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); + + for (int i = warp_id; i < kNumRDMARanks; i += kNumDispatchRDMASenderWarps) { + int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; + + auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); + if (lane_id < NUM_MAX_NVL_PEERS) { + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; + } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + // Issue RDMA for non-local ranks + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, lane_id, 0); + } + } + sync_rdma_sender_smem(); + + // Iterate over tokens and copy into buffer + int64_t token_idx; + int num_sge_per_rdma_token = 2; + if (num_scales > 0) num_sge_per_rdma_token ++; + if (num_topk > 0) num_sge_per_rdma_token += 2; + + for (int i = warp_id; i < kNumRDMARanks; i += kNumDispatchRDMASenderWarps) { + int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; + + int last_issued_tail = 0; + int num_tokens_to_send = 0; + num_tokens_to_send = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; + if (channel_id > 0) { + num_tokens_to_send -= rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + } + + int issue_token_idx = 0; + int num_tokens_to_issue = min(num_tokens_to_send, num_max_rdma_chunked_send_tokens); + int num_sge_per_msg = num_sge_per_rdma_token * num_tokens_to_issue; + EP_DEVICE_ASSERT(num_sge_per_msg <= NUM_MAX_SGE_PER_WQE); + for (token_idx = token_start_idx; token_idx < token_end_idx; token_idx ++) { + // Read RDMA rank existence + uint64_t is_token_in_rank_uint64 = 0; + int cached_rdma_channel_head = 0, rdma_tail_idx = -1; + if (lane_id == 0) { + is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); + // Acquire next tail + if (is_token_in_rank_uint64 != 0) { + rdma_tail_idx = rdma_send_channel_next_tail[dst_rdma_rank] ++; + // Since in the zcopy case we use the SGE buffer exclusively, we do not + // wait for the head here and directly proceed to fill the SGE instead, + // without risk of overwriting the src buffer of an ongoing RDMA WRITE. + } + + // Store RDMA head for combine + if (not kCachedMode) + send_rdma_head[token_idx * kNumRDMARanks + dst_rdma_rank] = rdma_tail_idx; + } + __syncwarp(); + + auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, 0); + if (recv_is_token_in_rank_uint64 == 0) { + continue; + } + + auto token_x = x + token_idx * hidden_int4; + auto token_x_scales = x_scales + token_idx * num_scales; + auto token_topk_idx = topk_idx + token_idx * num_topk; + auto token_topk_weights = topk_weights + token_idx * num_topk; + auto token_source_meta_rdma_buf = rdma_channel_src_meta.buffer(dst_rdma_rank) + rdma_tail_idx; + + SourceMeta src_meta; + // Construct source meta + if (lane_id == 0) { + auto recv_is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); + src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values, token_idx); + st_na_global(token_source_meta_rdma_buf, src_meta); + } + __syncwarp(); + + // Prepare SGE + if (dst_rdma_rank != rdma_rank) { + int sge_idx = 0; + if (lane_id == 0) { + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].addr = reinterpret_cast(token_x); + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].length = hidden_bytes; + sge_idx++; + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].addr = reinterpret_cast(token_source_meta_rdma_buf); + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].length = sizeof(SourceMeta); + sge_idx++; + if (num_scales > 0) { + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].addr = reinterpret_cast(token_x_scales); + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].length = scale_bytes; + sge_idx++; + } + if (num_topk > 0) { + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].addr = reinterpret_cast(token_topk_idx); + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].length = num_topk * sizeof(topk_idx_t); + sge_idx++; + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].addr = reinterpret_cast(token_topk_weights); + sge_list[issue_token_idx * num_sge_per_rdma_token + sge_idx].length = num_topk * sizeof(float); + sge_idx++; + } + } + __syncwarp(); + } + + issue_token_idx ++; + if (issue_token_idx < num_tokens_to_issue) { + continue; + } + + // Actually wait for the tail now + auto start_time = clock64(); + while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { + cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank))); + + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d, num_max_rdma_chunked_recv_tokens: %d \n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx, num_max_rdma_chunked_recv_tokens); + trap(); + } + } + + // Issue the WRITE operation + if (dst_rdma_rank != rdma_rank) { + size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; + int dst_slot_idx = last_issued_tail % num_max_rdma_chunked_recv_tokens; + auto dst_ptr = reinterpret_cast(rdma_channel_data.buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); + nvshmemi_ibgda_put_nbi_warp_multi_sge_parallel(sge_list, num_sge_per_msg, dst_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id); + } else { + // The local forwarder will fetch the SourceMeta from the SourceMeta RDMA + // source buffer. Simply updating the tail suffices. + memory_fence(); + } + + last_issued_tail += num_tokens_to_issue; + __syncwarp(); + + issue_token_idx = 0; + num_tokens_to_send -= num_tokens_to_issue; + if (lane_id == 0) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); + } + __syncwarp(); + num_tokens_to_issue = min(num_tokens_to_send, num_max_rdma_chunked_send_tokens); + num_sge_per_msg = num_sge_per_rdma_token * num_tokens_to_issue; + } + } + // Wait for all RDMA ranks to finish receiving + if (warp_id == 0 and lane_id < kNumRDMARanks) { + auto start_time = clock64(); + constexpr uint64_t expected = 1; + while (true) { + auto signal = ld_volatile_global(rdma_channel_recv_finish_signal.buffer()); + if (signal == expected) { + break; + } + if (clock64() - start_time > NUM_TIMEOUT_CYCLES or signal > expected) { + printf("DeepEP zcopy dispatch recv completion signal corruption or timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, signal: %#lx, expected: %#lx\n", + channel_id, rdma_rank, nvl_rank, lane_id, signal, expected); + trap(); + } + } + } + } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + // RDMA consumers and NVL producers + const auto dst_nvl_rank = target_rank; + const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; + const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); + const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); + + // Dynamic TMA shared memory layout + const size_t kNumHiddenTMABytesPerWarp = 16384; + extern __shared__ int4 tma_smem[]; + char *tma_smem_aligned = reinterpret_cast(align_up(reinterpret_cast(tma_smem), ZCOPY_TMA_SMEM_ALIGNMENT)); + __shared__ uint64_t tma_mbarrier[NUM_MAX_NVL_PEERS]; + + // Dedicated TMA shared memory for scales + const size_t kNumScalesTMABytesPerWarp = 512; + __shared__ __align__(kNumScalesTMABytesPerWarp) char tma_smem_scales[NUM_MAX_NVL_PEERS][kNumScalesTMABytesPerWarp]; + __shared__ uint64_t tma_mbarrier_scales[NUM_MAX_NVL_PEERS]; + + EP_DEVICE_ASSERT(hidden_bytes <= kNumHiddenTMABytesPerWarp); + EP_DEVICE_ASSERT(scale_bytes <= kNumScalesTMABytesPerWarp); + + char *smem_ptrs[NUM_MAX_NVL_PEERS]; + #pragma unroll + for (size_t i = 0; i < NUM_MAX_NVL_PEERS; ++ i) { + smem_ptrs[i] = tma_smem_aligned + kNumHiddenTMABytesPerWarp * i; + } + EP_DEVICE_ASSERT( + reinterpret_cast(smem_ptrs[NUM_MAX_NVL_PEERS - 1]) + kNumHiddenTMABytesPerWarp <= + reinterpret_cast(tma_smem) + kNvlFwdTMASMemLen); + + // Wait counters to arrive + int num_tokens_to_recv_from_rdma = 0, num_tokens_to_recv_from_rdma_saved = 0, src_rdma_channel_prefix = 0; + bool finish_signaled = false; + auto rdma_signal_finish = + [&rdma_channel_recv_finish_signal, &finish_signaled, rdma_rank, nvl_rank, channel_id](int src_rdma_rank, int qp_id) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_recv_finish_signal.buffer(rdma_rank), 1, + translate_dst_rdma_rank(src_rdma_rank, nvl_rank), qp_id, src_rdma_rank == rdma_rank); + finish_signaled = true; + }; + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + auto start_time = clock64(); + int start_sum = -1, end_sum = -1; + if (lane_id < kNumRDMARanks) { + while (true) { + auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); + auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); + auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); + auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); + if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { + // Notify NVL ranks + start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; + EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); + st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + nvl_rank * kNumRDMARanks + lane_id, -start_sum - 1); + st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + nvl_rank * kNumRDMARanks + lane_id, -end_sum - 1); + + // Save RDMA channel received token count + src_rdma_channel_prefix = -meta_2 - 1; + auto src_rdma_channel_prefix_1 = -meta_3 - 1; + num_tokens_to_recv_from_rdma = num_tokens_to_recv_from_rdma_saved = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; + if (not kCachedMode) + recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; + src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; + EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); + // TODO: Zero-copy: We can check if the recv count is 0 on the sender side instead + if (num_tokens_to_recv_from_rdma == 0 and dst_nvl_rank == lane_id % NUM_MAX_NVL_PEERS) { + // Need to immediately send finish signal here, as we won't receive any + // tokens and thus will not trigger the logic upon token reception. + // As of the second condition, see comments on the other atomic add + rdma_signal_finish(lane_id, channel_id); + } + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP zcopy dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); + trap(); + } + } + } + __syncwarp(); + + // Shift cached head + send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; + + // Wait shared memory to be cleaned + sync_forwarder_smem(); + + // Forward tokens from RDMA buffer + // NOTES: always start from the local rank + int src_rdma_rank = sm_id % kNumRDMARanks; + int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; + const uint64_t output_buffer_size = NUM_DISPATCH_OUTPUT_BYTES_PER_ZCOPY_BUFFER / (hidden_bytes + num_topk * sizeof(int64_t) + num_topk * sizeof(float) + scale_bytes); + while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { + // Find next source RDMA rank (round-robin) + start_time = clock64(); + while (true) { + src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; + if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { + if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) + cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); + if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP zcopy dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); + trap(); + } + } + auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); + auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); + + int total_offset = 0; + if (src_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank > 0) { + total_offset = recv_gbl_rank_prefix_sum_fwd[dst_nvl_rank * num_ranks + src_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank - 1]; + } + + int num_recv_tokens = recv_gbl_rank_prefix_sum_fwd[dst_nvl_rank * num_ranks + num_ranks - 1]; + auto target_nvl_rank_x_bytes = num_recv_tokens * hidden_bytes; + auto target_nvl_rank_topk_idx_bytes = num_recv_tokens * num_topk * sizeof(topk_idx_t); + auto target_nvl_rank_topk_weights_bytes = num_recv_tokens * num_topk * sizeof(float); + auto target_nvl_rank_x_scales_bytes = num_recv_tokens * scale_bytes; + + if (num_recv_tokens > output_buffer_size) { + if (lane_id == 0) { + printf("DeepEP dispatch output buffer overflow: RDMA %d, dst_nvl_rank %d, %d > %lu\n", + rdma_rank, dst_nvl_rank, num_recv_tokens, output_buffer_size); + } + trap(); + } + + auto target_nvl_rank_x = ws_rr_fused_buffer_ptr; + auto target_nvl_rank_topk_idx = reinterpret_cast(reinterpret_cast(ws_rr_fused_buffer_ptr) + target_nvl_rank_x_bytes); + auto target_nvl_rank_topk_weights = reinterpret_cast(reinterpret_cast(ws_rr_fused_buffer_ptr) + target_nvl_rank_x_bytes + target_nvl_rank_topk_idx_bytes); + auto target_nvl_rank_x_scales = reinterpret_cast(reinterpret_cast(ws_rr_fused_buffer_ptr) + target_nvl_rank_x_bytes + target_nvl_rank_topk_idx_bytes + target_nvl_rank_topk_weights_bytes); + auto target_nvl_rank_src_meta = reinterpret_cast(reinterpret_cast(ws_rr_fused_buffer_ptr) + target_nvl_rank_x_bytes + target_nvl_rank_topk_idx_bytes + target_nvl_rank_topk_weights_bytes + target_nvl_rank_x_scales_bytes); + + int start_offset = __shfl_sync(0xffffffff, start_sum, src_rdma_rank); + int end_offset = __shfl_sync(0xffffffff, end_sum, src_rdma_rank); + total_offset += start_offset; + + if (lane_id == src_rdma_rank and not finish_signaled) { + // There will be NUM_MAX_NVL_PEERS threads in this that learn + // about the first condition, each belonging to a different forwarder warp. The + // second condition guarantees only 1 warp (and thus 1 thread) will issue the + // atomic add. + if (src_rdma_tail == num_tokens_to_recv_from_rdma_saved and dst_nvl_rank == src_rdma_rank % NUM_MAX_NVL_PEERS) { + rdma_signal_finish(lane_id, channel_id); + } + } + + // Iterate over every token from the RDMA buffer + for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) { + // Wait for previous TMA transfers to finish, if any + tma_store_wait<0>(); + __syncwarp(); + + const bool is_local_token = src_rdma_rank == rdma_rank; + void *shifted; + if (is_local_token) { + shifted = rdma_channel_src_meta.buffer(src_rdma_rank) + i; + } else { + auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; + shifted = rdma_channel_data.buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; + } + auto src_meta = ld_nc_global(reinterpret_cast(reinterpret_cast(shifted) + + (is_local_token ? 0 : hidden_bytes))); + __syncwarp(); + bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); + if (lane_id == src_rdma_rank) { + --num_tokens_to_recv_from_rdma; + auto cached_head = is_in_dst_nvl_rank ? total_offset : -1; + if (not kCachedMode) { + send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; + } + if (src_meta.src_rdma_rank != src_rdma_rank or src_meta.is_token_in_nvl_rank_bits == 0 or src_meta.is_token_in_nvl_rank_bits >= (1<