Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ Buffer::Buffer(int rank,
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);

// Common checks
EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment");
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
(num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
(low_latency_mode or num_rdma_bytes <= std::numeric_limits<int>::max()));
EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits<int>::max());
EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits<int>::max());
EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode));
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
if (num_rdma_bytes > 0)
Expand All @@ -57,6 +60,10 @@ Buffer::Buffer(int rank,
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
num_device_sms = device_prop.multiProcessorCount;

// Number of per-channel bytes cannot be large
EP_HOST_ASSERT(ceil_div<int64_t>(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits<int>::max());
EP_HOST_ASSERT(ceil_div<int64_t>(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits<int>::max());

if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handles
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes));
Expand Down
40 changes: 20 additions & 20 deletions csrc/kernels/buffer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ private:
uint8_t* ptr;

public:
int total_bytes;
int64_t total_bytes;

__device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}

__device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) {
total_bytes = num_elems * sizeof(dtype_t);
ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
ptr = static_cast<uint8_t*>(gbl_ptr) + offset * sizeof(dtype_t);
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
}

__device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}

Expand All @@ -35,30 +35,30 @@ template <typename dtype_t, int kNumRanks = 1>
struct AsymBuffer {
private:
uint8_t* ptrs[kNumRanks];
int num_bytes;
int64_t num_bytes;

public:
int total_bytes;
int64_t total_bytes;

__device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks == 1, "");
num_bytes = num_elems * sizeof(dtype_t);

int per_channel_bytes = num_bytes * num_ranks;
int64_t per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
ptrs[0] = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
ptrs[0] = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
}

__device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) {
EP_STATIC_ASSERT(kNumRanks > 1, "");
num_bytes = num_elems * sizeof(dtype_t);

int per_channel_bytes = num_bytes * num_ranks;
int64_t per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms;
for (int i = 0; i < kNumRanks; ++i) {
ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset;
gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
}
}

Expand All @@ -69,14 +69,14 @@ public:
}

__device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) {
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
return *this;
}

template <int kNumAlsoRanks>
__device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) {
for (int i = 0; i < kNumAlsoRanks; ++i)
gbl_ptrs[i] = reinterpret_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
gbl_ptrs[i] = static_cast<uint8_t*>(gbl_ptrs[i]) + total_bytes;
return *this;
}

Expand All @@ -97,19 +97,19 @@ private:
// NOTES: for non-decoupled case, `recv_ptr` is not used
uint8_t* send_ptr;
uint8_t* recv_ptr;
int num_bytes;
int64_t num_bytes;

public:
int total_bytes;
int64_t total_bytes;

__device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) {
num_bytes = num_elems * sizeof(dtype_t);

int per_channel_bytes = num_bytes * num_ranks;
int64_t per_channel_bytes = num_bytes * num_ranks;
total_bytes = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
send_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = reinterpret_cast<uint8_t*>(gbl_ptr) + total_bytes;
send_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * sm_id;
recv_ptr = static_cast<uint8_t*>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
gbl_ptr = static_cast<uint8_t*>(gbl_ptr) + total_bytes;
}

__device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {
Expand Down