Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,8 +1094,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
const std::optional<torch::Tensor>& x_global_scale,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
bool async, bool return_recv_hook) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
Expand Down Expand Up @@ -1140,8 +1142,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream);

// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / 2 : hidden},
x.options().dtype(use_nvfp4 ? torch::kUInt8 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
Expand All @@ -1151,6 +1153,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");

EP_HOST_ASSERT(not (use_fp8 and use_nvfp4));
if (use_fp8) {
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
Expand All @@ -1164,6 +1167,35 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}else if (use_nvfp4) {
constexpr int kNumPerChannels = 16;
constexpr int NUM_SF_ELEMS_PER_PACK = 4;
constexpr int mTileSize_dim_0 = 32;
constexpr int mTileSize_dim_1 = 4;
constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1;

assert(hidden % kNumPerChannels == 0);
auto l = num_local_experts;
auto m = num_ranks * num_max_dispatch_tokens_per_rank;
auto rm = (m + 127) / 128;
auto rk = (hidden + (kNumPerChannels * NUM_SF_ELEMS_PER_PACK) -1 ) / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
// The physical layout is (l, rm, rk, 32, 4, 4).
if (use_ue8m0_for_sf) {
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
torch::dtype(torch::kInt).device(torch::kCUDA));
} else {
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
}
// After permute, the logical shape is (32, 4, rm, 4, rk, l)
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});

// The physical layout is (l, m, k // 2).
// After permute, the logical shape is (m, k // 2, l).
packed_recv_x = packed_recv_x.permute({1, 2, 0});

packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
EP_HOST_ASSERT(packed_recv_x_scales_ptr != nullptr);
}

// Kernel launch
Expand All @@ -1174,13 +1206,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_count.data_ptr<int>(),
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
x_global_scale.has_value() ? x_global_scale->data_ptr<float>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0,
use_nvfp4, use_ue8m0_for_sf,
workspace, num_device_sms,
launch_stream, phases);
};
Expand Down
2 changes: 2 additions & 0 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ struct Buffer {
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
const std::optional<torch::Tensor>& x_global_scale,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
bool async, bool return_recv_hook);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
const float* x_global_scale,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);

Expand Down
Loading