Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 132 additions & 8 deletions csrc/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ dtype_t align_down(dtype_t a, dtype_t b) {
return a / b * b;
}

template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
out_ptr_t advance_ptr(in_ptr_t &ptr, size_t count) {
out_ptr_t saved = reinterpret_cast<out_ptr_t>(ptr);
ptr = reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
return saved;
}

struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
Expand Down Expand Up @@ -47,22 +54,31 @@ struct Config {
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
}

size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks, bool zero_copy = false) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or zero_copy or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;
const int sms_per_channel = zero_copy ? 1 : 2;
const int num_channels = num_sms / sms_per_channel;

size_t num_bytes = 0;
if (zero_copy) {
num_bytes += ZCOPY_NOTIFY_NVL_METADATA_OFFSET_INTS * sizeof(int);
num_bytes += std::max(
num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 2) * sizeof(int), // Dispatch
num_channels * num_nvl_ranks * 1 * sizeof(int) // Combine
);
return num_bytes;
}
num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes;
#ifndef DISABLE_NVSHMEM
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes();
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(zero_copy);
#endif
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t);
num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float);
Expand All @@ -71,7 +87,7 @@ struct Config {
return num_bytes;
}

size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks, bool zero_copy = false) const {
#ifndef DISABLE_NVSHMEM
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
Expand All @@ -82,14 +98,16 @@ struct Config {
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
EP_HOST_ASSERT(zero_copy or num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_channels = num_sms / 2;
const int sms_per_channel = zero_copy ? 1 : 2;
const int num_channels = num_sms / sms_per_channel;

const int rdma_buffer_factor = zero_copy ? 1 : 2; // The zero-copy kernels only have recv buffer
size_t num_bytes = 0;
num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int);
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes(zero_copy) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
Expand Down Expand Up @@ -192,4 +210,110 @@ size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES;
}

struct InternodeDispatchBuffer {
void* x = nullptr;
float* x_scales = nullptr;
topk_idx_t* topk_idx = nullptr;
float* topk_weights = nullptr;
void* source_meta = nullptr;
};

struct InternodeDispatchLayout {
size_t total_bytes = 0;
std::vector<InternodeDispatchBuffer> buffers;

InternodeDispatchLayout(void* rdma_fused_buffer, int num_tokens, int hidden, int num_topk, int num_ranks, bool with_topk, bool use_fp8, int num_buffers) {
size_t num_bytes_x = 0;
size_t num_bytes_x_scales = 0;
size_t num_bytes_topk_idx = 0;
size_t num_bytes_topk_weights = 0;
size_t num_bytes_source_meta = 0;

if (use_fp8){
num_bytes_x = num_tokens * hidden * elementSize(torch::kFloat8_e4m3fn);
total_bytes += num_bytes_x;

int num_scales = hidden / 128;
num_bytes_x_scales = num_tokens * num_scales * sizeof(float);
total_bytes += num_bytes_x_scales;
}
else {
num_bytes_x = num_tokens * hidden * elementSize(torch::kBFloat16);
total_bytes += num_bytes_x;
}

if (with_topk) {
num_bytes_topk_idx = num_tokens * num_topk * sizeof(topk_idx_t);
total_bytes += num_bytes_topk_idx;

num_bytes_topk_weights = num_tokens * num_topk * sizeof(float);
total_bytes += num_bytes_topk_weights;
}

num_bytes_source_meta = (num_ranks / NUM_MAX_NVL_PEERS) * num_tokens * internode::get_source_meta_bytes();
total_bytes += num_bytes_source_meta;
EP_HOST_ASSERT(total_bytes % 16 == 0);

EP_HOST_ASSERT(total_bytes <= NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER);

for (int i=0; i<num_buffers; ++i) {
auto tmp = rdma_fused_buffer;
buffers.push_back({
advance_ptr(tmp, num_bytes_x), // x
use_fp8 ? advance_ptr<float*>(tmp, num_bytes_x_scales) : nullptr, // x_scales
with_topk ? advance_ptr<topk_idx_t*>(tmp, num_bytes_topk_idx) : nullptr, // topk_idx
with_topk ? advance_ptr<float*>(tmp, num_bytes_topk_weights) : nullptr, // topk_weights
advance_ptr(tmp, num_bytes_source_meta), // source_meta
});
advance_ptr(rdma_fused_buffer, NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER);
}

total_bytes = NUM_DISPATCH_INPUT_BYTES_PER_ZCOPY_BUFFER * num_buffers;
}
};


struct InternodeCombineBuffer {
void* x = nullptr;
float* topk_weights = nullptr;
};

struct InternodeCombineLayout {
size_t total_bytes = 0;
std::vector<InternodeCombineBuffer> buffers;

template <typename out_ptr_t = void*, typename count_ptr_t = uint8_t*, typename in_ptr_t = void*>
out_ptr_t advance(const in_ptr_t& ptr, size_t count) {
return reinterpret_cast<out_ptr_t>(reinterpret_cast<count_ptr_t>(ptr) + count);
}

InternodeCombineLayout(void* nvl_buffer, int num_tokens, int hidden, int num_topk, bool with_topk, int num_buffers) {
size_t num_bytes_x = 0;
size_t num_bytes_topk_weights = 0;

num_bytes_x = num_tokens * hidden * elementSize(torch::kBFloat16);
total_bytes += num_bytes_x;

if (with_topk) {
num_bytes_topk_weights = num_tokens * num_topk * sizeof(float);
total_bytes += num_bytes_topk_weights;
}

EP_HOST_ASSERT(total_bytes <= NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER);

EP_HOST_ASSERT(total_bytes % 16 == 0);

for (int i=0; i<num_buffers; ++i) {
auto tmp = nvl_buffer;
buffers.push_back({
tmp,
with_topk ? advance<float*>(tmp, num_bytes_x) : nullptr
});
advance_ptr(nvl_buffer, NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER);
}

total_bytes = NUM_COMBINE_INPUT_BYTES_PER_ZCOPY_BUFFER * num_buffers;
}
};

} // namespace deep_ep
Loading