Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions csrc/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct Config {
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
}

size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks, bool return_recv_hook) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
Expand All @@ -51,7 +51,7 @@ struct Config {
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
const int num_channels = return_recv_hook ? num_sms : num_sms / 2; // one SM per channel for hook mode

size_t num_bytes = 0;
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
Expand All @@ -66,7 +66,7 @@ struct Config {
return num_bytes;
}

size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
size_t get_rdma_buffer_size_hint(int num_max_dispatch_tokens_per_rank, int64_t hidden_bytes, int num_ranks, bool decoupled_mode, bool return_recv_hook) const {
#ifndef DISABLE_NVSHMEM
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
Expand All @@ -79,16 +79,17 @@ struct Config {
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
const int num_channels = return_recv_hook ? num_sms : num_sms / 2; // one SM per channel for hook mode
int num_slots_per_rdma_chunk = decoupled_mode ? (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels : num_max_rdma_chunked_recv_tokens;

size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;
num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_slots_per_rdma_chunk * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
Expand Down Expand Up @@ -187,4 +188,29 @@ size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
}

uint64_t get_normal_hook_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_nodes, int num_sms, bool return_recv_hook) {
if (num_nodes <= 1)
return 0;

// Below are some assumptions
// TODO: add assertions
int hidden_bytes = hidden * sizeof(nv_bfloat16);
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_channels = return_recv_hook ? num_sms : num_sms / 2; // one SM per channel for hook mode
uint64_t num_slots_per_rdma_chunk = (num_max_dispatch_tokens_per_rank + num_channels - 1) / num_channels;

uint64_t num_bytes = 0;
num_bytes += num_channels * num_nodes * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * hidden_bytes * 2;
num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(int64_t) * 2;
num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_nodes * num_slots_per_rdma_chunk * sizeof(int4) * 2;
num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
}

} // namespace deep_ep
Loading