diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 3c32ed24..a29ed598 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -35,10 +35,13 @@ Buffer::Buffer(int rank, int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks + EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment"); EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits::max()); + EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits::max()); EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); if (num_rdma_bytes > 0) @@ -57,6 +60,10 @@ Buffer::Buffer(int rank, CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); num_device_sms = device_prop.multiProcessorCount; + // Number of per-channel bytes cannot be large + EP_HOST_ASSERT(ceil_div(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits::max()); + EP_HOST_ASSERT(ceil_div(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits::max()); + if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); diff --git a/csrc/kernels/buffer.cuh b/csrc/kernels/buffer.cuh index fc9af55b..222f42ac 100644 --- a/csrc/kernels/buffer.cuh +++ b/csrc/kernels/buffer.cuh @@ -11,18 +11,18 @@ private: uint8_t* ptr; public: - int total_bytes; + int64_t total_bytes; __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} __device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) { total_bytes = num_elems * sizeof(dtype_t); - ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + ptr = static_cast(gbl_ptr) + offset * sizeof(dtype_t); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) { - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; return *this; } @@ -35,30 +35,30 @@ template struct AsymBuffer { private: uint8_t* ptrs[kNumRanks]; - int num_bytes; + int64_t num_bytes; public: - int total_bytes; + int64_t total_bytes; __device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks == 1, ""); num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + int64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; - ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks > 1, ""); num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + int64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms; for (int i = 0; i < kNumRanks; ++i) { - ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; - gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; } } @@ -69,14 +69,14 @@ public: } __device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) { - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; return *this; } template __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { for (int i = 0; i < kNumAlsoRanks; ++i) - gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; return *this; } @@ -97,19 +97,19 @@ private: // NOTES: for non-decoupled case, `recv_ptr` is not used uint8_t* send_ptr; uint8_t* recv_ptr; - int num_bytes; + int64_t num_bytes; public: - int total_bytes; + int64_t total_bytes; __device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { num_bytes = num_elems * sizeof(dtype_t); - int per_channel_bytes = num_bytes * num_ranks; + int64_t per_channel_bytes = num_bytes * num_ranks; total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); - send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; - recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); - gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + send_ptr = static_cast(gbl_ptr) + per_channel_bytes * sm_id; + recv_ptr = static_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) {