diff --git a/Hybrid-EP_Intra-node_Implementation.md b/Hybrid-EP_Intra-node_Implementation.md index 65b3199b..2f7734bf 100644 --- a/Hybrid-EP_Intra-node_Implementation.md +++ b/Hybrid-EP_Intra-node_Implementation.md @@ -152,7 +152,7 @@ Refer to `tests/test_hybrid_ep.py` for comprehensive usage examples including: - Performance benchmarking setups ### Important Configuration Note -Here are important parameter settings in `csrc/hybrid_ep/config.cuh`. You can modify these parameters via `HybridEpBuffer.init_config()` or by setting proper environment variables (see `deep_ep/hybrid_ep_buffer.py`) to achieve better performance/usability: +Here are important parameter settings in `csrc/hybrid_ep/config.cuh`. You can modify these parameters via `HybridEPBuffer.init_config()` or by setting proper environment variables (see `deep_ep/hybrid_ep_buffer.py`) to achieve better performance/usability: - HIDDEN_DIM Hidden size (must match model hidden dimension). diff --git a/csrc/hybrid_ep/config.cuh b/csrc/hybrid_ep/config.cuh index 101d6f2a..eb8bf5c8 100644 --- a/csrc/hybrid_ep/config.cuh +++ b/csrc/hybrid_ep/config.cuh @@ -9,6 +9,31 @@ // This will be used to initialize the template param_t for communication kernel. #define MAX_NUM_OF_RANKS_PER_NODE 72 +// Config used for buffer allocation. +struct BufferConfig { + int hidden_dim; + int max_num_of_tokens_per_rank; + int num_of_experts_per_rank; + int num_of_ranks_per_node; + int num_of_nodes; + TOKEN_DATA_TYPE token_data_type; + int num_of_blocks_preprocessing_api; + int num_of_tokens_per_chunk_dispatch_api; + int num_of_tokens_per_chunk_combine_api; + + /* + * Validation check + */ + bool is_valid(){ + bool valid = true; + valid &= (hidden_dim % 512 == 0); + valid &= ((num_of_experts_per_rank * num_of_ranks_per_node) % 4 == 0); + valid &= (num_of_ranks_per_node % 2 == 0); + return valid; + } +}; + +// Config used for hybrid-ep kernel. struct HybridEpConfigInstance { /* * Hybrid-ep Config diff --git a/csrc/hybrid_ep/executor/executor.cu b/csrc/hybrid_ep/executor/executor.cu index a1c23992..80177751 100644 --- a/csrc/hybrid_ep/executor/executor.cu +++ b/csrc/hybrid_ep/executor/executor.cu @@ -7,8 +7,8 @@ Executor::Executor(int local_rank, int node_rank) : local_rank(local_rank), node std::tuple Executor::metadata_preprocess_core( - hybrid_ep::tmp_state_t *preprocessing_tmp, HybridEpConfigInstance config, + hybrid_ep::tmp_state_t *preprocessing_tmp, torch::Tensor global_routing_map, int num_of_tokens_per_rank ) { @@ -16,6 +16,9 @@ Executor::metadata_preprocess_core( // padding for the routing map const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + auto num_of_expert = global_routing_map.size(-1); + assert(num_of_expert == config.num_of_experts_per_rank * config.num_of_ranks_per_node * config.num_of_nodes); + // Construt the output tensor of the metadata preprocessing kernel. auto sparse_to_dense_map = torch::empty({num_of_tokens_per_rank * config.num_of_nodes, @@ -112,7 +115,16 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d torch::Tensor row_id_map, tokens_per_expert; if(args.num_dispatched_tokens == 0 ) { - // Fast return if there are no tokens to dispatch + // Fast return empty tensors if there are no tokens to dispatch + dispatched_tokens = torch::empty({0, config.hidden_dim}, torch::dtype(args.hidden.dtype()).device(torch::kCUDA)); + if(config.forward_dispatch_api) { + dispatched_probs = torch::empty({0}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } + if(config.token_data_type == TOKEN_DATA_TYPE::UINT8) { + dispatched_scaling_factor = torch::empty({0, config.hidden_dim / 128}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } + row_id_map = torch::empty({0, config.num_of_experts_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + tokens_per_expert = torch::full({config.num_of_experts_per_rank}, 0, torch::dtype(torch::kInt32).device(torch::kCUDA)); return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor, row_id_map, tokens_per_expert); } @@ -123,12 +135,7 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d int num_dispatched_tokens = args.num_dispatched_tokens; int num_permuted_tokens = args.num_permuted_tokens; torch::Tensor num_dispatched_tokens_tensor = args.num_dispatched_tokens_tensor.value(); - // If args.num_dispatched_tokens >= 0, which means that the sync-free model is used. - // Otherwise, we will use the values in args.num_dispatched_tokens_tensor. - if (num_dispatched_tokens < 0) { - num_dispatched_tokens = num_dispatched_tokens_tensor.item(); - } - + if (args.row_id_map.has_value()) { // The row_id_map is valid, which means that the cached model is used. // Then we will use the values in args directly. @@ -179,7 +186,7 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d size_t sizeof_token_data_type = get_token_data_type_size(dispatch_buffers.data_type); dispatched_tokens = torch::empty({args.num_dispatched_tokens, config.hidden_dim}, torch::dtype(args.hidden.dtype()).device(torch::kCUDA)); auto res_sz = args.num_dispatched_tokens * config.hidden_dim * sizeof_token_data_type; - CUDA_CHECK(cudaMemcpyAsync(dispatched_tokens.data_ptr(), dispatch_buffers.expert_output_token,res_sz, cudaMemcpyDeviceToDevice, args.stream)); + CUDA_CHECK(cudaMemcpyAsync(dispatched_tokens.data_ptr(), dispatch_buffers.expert_output_token, res_sz, cudaMemcpyDeviceToDevice, args.stream)); if(config.forward_dispatch_api) { dispatched_probs = torch::empty({args.num_dispatched_tokens, diff --git a/csrc/hybrid_ep/executor/executor.cuh b/csrc/hybrid_ep/executor/executor.cuh index b19e407c..755714e6 100644 --- a/csrc/hybrid_ep/executor/executor.cuh +++ b/csrc/hybrid_ep/executor/executor.cuh @@ -66,8 +66,8 @@ public: std::tuple metadata_preprocess_core( - hybrid_ep::tmp_state_t *preprocessing_tmp, HybridEpConfigInstance config, + hybrid_ep::tmp_state_t *preprocessing_tmp, torch::Tensor global_routing_map, int num_of_tokens_per_rank ); diff --git a/csrc/hybrid_ep/hybrid_ep.cu b/csrc/hybrid_ep/hybrid_ep.cu index 8659c3f8..77b1562b 100644 --- a/csrc/hybrid_ep/hybrid_ep.cu +++ b/csrc/hybrid_ep/hybrid_ep.cu @@ -2,11 +2,10 @@ // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved #include "hybrid_ep.cuh" -HybridEpBuffer::HybridEpBuffer(HybridEpConfigInstance config, int local_rank, int node_rank, int group_size, int num_of_ranks_per_node, bool use_fp8_dispatch) - : config(config), local_rank(local_rank), node_rank(node_rank), group_size(group_size), - num_of_ranks_per_node(num_of_ranks_per_node), executor(local_rank, node_rank), use_fp8_dispatch(use_fp8_dispatch) { - - if(group_size <= config.num_of_ranks_per_node) { +HybridEPBuffer::HybridEPBuffer(BufferConfig config, int local_rank, int node_rank, int group_size) + : buffer_config(config), local_rank(local_rank), node_rank(node_rank), group_size(group_size), + executor(local_rank, node_rank) { + if(group_size <= buffer_config.num_of_ranks_per_node) { // If used on only intra-node communication, the dispatch/combine can share same buffers. use_shared_buffer = true; }else{ @@ -18,7 +17,14 @@ HybridEpBuffer::HybridEpBuffer(HybridEpConfigInstance config, int local_rank, in allocate_buffer(); } -HybridEpBuffer::~HybridEpBuffer() { +HybridEPBuffer::~HybridEPBuffer() { + release_buffer(); +} + +void HybridEPBuffer::release_buffer() { + // Synchronize the device to ensure all operations are completed. + CUDA_CHECK(cudaDeviceSynchronize()); + auto free_buffer = [this](void *ptr, bool remote_memory) { if (ptr != nullptr) { if (remote_memory) { @@ -38,10 +44,8 @@ HybridEpBuffer::~HybridEpBuffer() { free_buffer(dispatch_buffers.expert_output_token, true); free_buffer(dispatch_buffers.expert_output_prob, true); } - if (use_fp8_dispatch) { - free_buffer(dispatch_buffers.expert_output_scaling_factor, true); - free_buffer(dispatch_buffers.rdma_inter_node_group_scaling_factor, false); - } + free_buffer(dispatch_buffers.expert_output_scaling_factor, true); + free_buffer(dispatch_buffers.rdma_inter_node_group_scaling_factor, false); free_buffer(dispatch_buffers.rdma_inter_node_group_token,false); free_buffer(dispatch_buffers.rdma_inter_node_group_prob, false); free_buffer(dispatch_buffers.rdma_inter_node_group_flags, false); @@ -52,13 +56,11 @@ HybridEpBuffer::~HybridEpBuffer() { }else{ remote_allocator.close_handle(dispatch_buffers.intra_node_write_completion_flags); } - for (int i = 0; i < config.num_of_ranks_per_node; i++) { + for (int i = 0; i < buffer_config.num_of_ranks_per_node; i++) { if (i != local_rank) { remote_allocator.close_handle(dispatch_buffers.expert_output_token_all_ranks[i]); remote_allocator.close_handle(dispatch_buffers.expert_output_prob_all_ranks[i]); - if (use_fp8_dispatch) { - remote_allocator.close_handle(dispatch_buffers.expert_output_scaling_factor_all_ranks[i]); - } + remote_allocator.close_handle(dispatch_buffers.expert_output_scaling_factor_all_ranks[i]); } } delete[] dispatch_buffers.expert_output_token_all_ranks; @@ -80,7 +82,7 @@ HybridEpBuffer::~HybridEpBuffer() { }else{ remote_allocator.close_handle(combine_buffers.intra_node_write_completion_flags); } - for (int i = 0; i < config.num_of_ranks_per_node; i++) { + for (int i = 0; i < buffer_config.num_of_ranks_per_node; i++) { if (i != local_rank) { remote_allocator.close_handle(combine_buffers.expert_input_token_all_ranks[i]); remote_allocator.close_handle(combine_buffers.expert_input_prob_all_ranks[i]); @@ -90,33 +92,33 @@ HybridEpBuffer::~HybridEpBuffer() { delete[] combine_buffers.expert_input_prob_all_ranks; } -void HybridEpBuffer::allocate_buffer_for_preprocessing() { +void HybridEPBuffer::allocate_buffer_for_preprocessing() { auto preprocessing_tmp_elts = - config.num_of_blocks_preprocessing_api * config.num_of_ranks_per_node; + buffer_config.num_of_blocks_preprocessing_api * buffer_config.num_of_ranks_per_node; CUDA_CHECK( cudaMalloc((void **)&this->preprocessing_tmp, preprocessing_tmp_elts * sizeof(hybrid_ep::tmp_state_t))); } -void HybridEpBuffer::allocate_buffer_for_dispatch() { - dispatch_buffers.data_type = config.token_data_type; +void HybridEPBuffer::allocate_buffer_for_dispatch() { + dispatch_buffers.data_type = buffer_config.token_data_type; size_t sizeof_token_data_type = get_token_data_type_size(dispatch_buffers.data_type); // Calculate buffer sizes - auto expert_output_token_elts = max_num_of_tokens_for_experts * config.hidden_dim; + auto expert_output_token_elts = max_num_of_tokens_for_experts * buffer_config.hidden_dim; auto expert_output_prob_elts = max_num_of_tokens_for_experts * - (config.num_of_experts_per_rank * config.num_of_ranks_per_node); - auto expert_output_scaling_factor_elts = max_num_of_tokens_for_experts * (config.hidden_dim / 128); + (buffer_config.num_of_experts_per_rank * buffer_config.num_of_ranks_per_node); + auto expert_output_scaling_factor_elts = max_num_of_tokens_for_experts * (buffer_config.hidden_dim / 128); // Calculate local temp buffer sizes - auto rdma_inter_node_group_token_elts = config.max_num_of_tokens_per_rank * - (config.num_of_nodes - 1) * config.hidden_dim; - auto rdma_inter_node_group_prob_elts = config.max_num_of_tokens_per_rank * (config.num_of_nodes - 1) * - (config.num_of_experts_per_rank * config.num_of_ranks_per_node); - auto rdma_inter_node_group_scaling_factor_elts = config.max_num_of_tokens_per_rank * - (config.num_of_nodes - 1) * (config.hidden_dim / 128); - auto rdma_inter_node_group_flags_elts = (config.max_num_of_tokens_per_rank / - config.num_of_tokens_per_chunk_dispatch_api) * - (config.num_of_nodes - 1); + auto rdma_inter_node_group_token_elts = buffer_config.max_num_of_tokens_per_rank * + (buffer_config.num_of_nodes - 1) * buffer_config.hidden_dim; + auto rdma_inter_node_group_prob_elts = buffer_config.max_num_of_tokens_per_rank * (buffer_config.num_of_nodes - 1) * + (buffer_config.num_of_experts_per_rank * buffer_config.num_of_ranks_per_node); + auto rdma_inter_node_group_scaling_factor_elts = buffer_config.max_num_of_tokens_per_rank * + (buffer_config.num_of_nodes - 1) * (buffer_config.hidden_dim / 128); + auto rdma_inter_node_group_flags_elts = (buffer_config.max_num_of_tokens_per_rank / + buffer_config.num_of_tokens_per_chunk_dispatch_api) * + (buffer_config.num_of_nodes - 1); // Allocate main buffers if (use_shared_buffer) { @@ -124,24 +126,19 @@ void HybridEpBuffer::allocate_buffer_for_dispatch() { assert(combine_buffers.expert_input_prob != nullptr); dispatch_buffers.expert_output_token = combine_buffers.expert_input_token; dispatch_buffers.expert_output_prob = combine_buffers.expert_input_prob; - } - else { + } else { remote_allocator.allocate((void**)&dispatch_buffers.expert_output_token, expert_output_token_elts * sizeof_token_data_type); remote_allocator.allocate((void**)&dispatch_buffers.expert_output_prob, expert_output_prob_elts * sizeof(float)); } - if (use_fp8_dispatch) { - remote_allocator.allocate((void**)&dispatch_buffers.expert_output_scaling_factor, expert_output_scaling_factor_elts * sizeof(float)); - } + remote_allocator.allocate((void**)&dispatch_buffers.expert_output_scaling_factor, expert_output_scaling_factor_elts * sizeof(float)); // Allocate RDMA buffers CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_token, rdma_inter_node_group_token_elts * sizeof_token_data_type)); CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_prob, rdma_inter_node_group_prob_elts * sizeof(float))); - if (use_fp8_dispatch) { - CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_scaling_factor, - rdma_inter_node_group_scaling_factor_elts * sizeof(float))); - } + CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_scaling_factor, + rdma_inter_node_group_scaling_factor_elts * sizeof(float))); CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_flags, rdma_inter_node_group_flags_elts * sizeof(uint64_t))); @@ -159,9 +156,7 @@ void HybridEpBuffer::allocate_buffer_for_dispatch() { MemHandle handles[4]; remote_allocator.get_handle(&handles[0], dispatch_buffers.expert_output_token); remote_allocator.get_handle(&handles[1], dispatch_buffers.expert_output_prob); - if (use_fp8_dispatch) { - remote_allocator.get_handle(&handles[2], dispatch_buffers.expert_output_scaling_factor); - } + remote_allocator.get_handle(&handles[2], dispatch_buffers.expert_output_scaling_factor); if (local_rank == 0) { remote_allocator.get_handle(&handles[3], dispatch_buffers.intra_node_write_completion_flags); } @@ -175,23 +170,23 @@ void HybridEpBuffer::allocate_buffer_for_dispatch() { CUDA_CHECK(cudaGetLastError()); } -void HybridEpBuffer::allocate_buffer_for_combine() { +void HybridEPBuffer::allocate_buffer_for_combine() { // Calculate buffer sizes - auto expert_input_token_elts = max_num_of_tokens_for_experts * config.hidden_dim; + auto expert_input_token_elts = max_num_of_tokens_for_experts * buffer_config.hidden_dim; auto expert_input_prob_elts = max_num_of_tokens_for_experts * - (config.num_of_experts_per_rank * config.num_of_ranks_per_node); + (buffer_config.num_of_experts_per_rank * buffer_config.num_of_ranks_per_node); // Calculate local temp buffer sizes - auto rdma_intra_node_red_token_elts = config.max_num_of_tokens_per_rank * - (config.num_of_nodes - 1) * config.hidden_dim; - auto rdma_intra_node_red_prob_elts = config.max_num_of_tokens_per_rank * (config.num_of_nodes - 1) * - (config.num_of_experts_per_rank * config.num_of_ranks_per_node); - auto rdma_inter_node_group_token_elts = config.max_num_of_tokens_per_rank * - (config.num_of_nodes - 1) * config.hidden_dim; - auto rdma_inter_node_group_prob_elts = config.max_num_of_tokens_per_rank * (config.num_of_nodes - 1) * - (config.num_of_experts_per_rank * config.num_of_ranks_per_node); - auto rdma_inter_node_group_flags_elts = (config.max_num_of_tokens_per_rank / - config.num_of_tokens_per_chunk_combine_api) * - (config.num_of_nodes - 1); + auto rdma_intra_node_red_token_elts = buffer_config.max_num_of_tokens_per_rank * + (buffer_config.num_of_nodes - 1) * buffer_config.hidden_dim; + auto rdma_intra_node_red_prob_elts = buffer_config.max_num_of_tokens_per_rank * (buffer_config.num_of_nodes - 1) * + (buffer_config.num_of_experts_per_rank * buffer_config.num_of_ranks_per_node); + auto rdma_inter_node_group_token_elts = buffer_config.max_num_of_tokens_per_rank * + (buffer_config.num_of_nodes - 1) * buffer_config.hidden_dim; + auto rdma_inter_node_group_prob_elts = buffer_config.max_num_of_tokens_per_rank * (buffer_config.num_of_nodes - 1) * + (buffer_config.num_of_experts_per_rank * buffer_config.num_of_ranks_per_node); + auto rdma_inter_node_group_flags_elts = (buffer_config.max_num_of_tokens_per_rank / + buffer_config.num_of_tokens_per_chunk_combine_api) * + (buffer_config.num_of_nodes - 1); // Allocate main buffers remote_allocator.allocate((void**)&combine_buffers.expert_input_token, expert_input_token_elts * sizeof(uint16_t)); @@ -236,11 +231,11 @@ void HybridEpBuffer::allocate_buffer_for_combine() { CUDA_CHECK(cudaGetLastError()); } -void HybridEpBuffer::allocate_buffer() { +void HybridEPBuffer::allocate_buffer() { // Token number at the worst case, all tokens are routed to the same expert. - this->max_num_of_tokens_for_experts = config.max_num_of_tokens_per_rank * - config.num_of_ranks_per_node * - config.num_of_nodes; + this->max_num_of_tokens_for_experts = buffer_config.max_num_of_tokens_per_rank * + buffer_config.num_of_ranks_per_node * + buffer_config.num_of_nodes; assert(this->max_num_of_tokens_for_experts % 4 == 0); // The number of tokens for experts should be divisible by 4, this // is required by the permute make_row_id_map kernel @@ -249,7 +244,7 @@ void HybridEpBuffer::allocate_buffer() { allocate_buffer_for_dispatch(); } -void HybridEpBuffer::exchange_ipc_address(py::object process_group) { +void HybridEPBuffer::exchange_ipc_address(py::object process_group) { try { // Use Python's torch.distributed APIs through py::object auto torch_distributed = py::module_::import("torch.distributed"); @@ -294,20 +289,19 @@ void HybridEpBuffer::exchange_ipc_address(py::object process_group) { } -void HybridEpBuffer::open_handles_from_other_ranks( +void HybridEPBuffer::open_handles_from_other_ranks( std::vector dispatch_handles, std::vector combine_handles) { - // Malloc the pointer arrays used in the dispatch kernel. dispatch_buffers.expert_output_token_all_ranks = - (void **)malloc(config.num_of_ranks_per_node * sizeof(void *)); + (void **)malloc(buffer_config.num_of_ranks_per_node * sizeof(void *)); dispatch_buffers.expert_output_prob_all_ranks = - (float **)malloc(config.num_of_ranks_per_node * sizeof(float *)); + (float **)malloc(buffer_config.num_of_ranks_per_node * sizeof(float *)); dispatch_buffers.expert_output_scaling_factor_all_ranks = - (float **)malloc(config.num_of_ranks_per_node * sizeof(float *)); + (float **)malloc(buffer_config.num_of_ranks_per_node * sizeof(float *)); // Global offset means the position in the multi-node case. - auto global_offset = node_rank * num_of_ranks_per_node; + auto global_offset = node_rank * buffer_config.num_of_ranks_per_node; // Open the dispatch handles for intra_node_write_completion_flags if (local_rank != 0) { @@ -322,7 +316,7 @@ void HybridEpBuffer::open_handles_from_other_ranks( } // Open the handles for export_output - for (int i = 0; i < num_of_ranks_per_node; i++) { + for (int i = 0; i < buffer_config.num_of_ranks_per_node; i++) { MemHandle expert_output_token_handle, expert_output_prob_handle, expert_output_scaling_factor_handle; @@ -341,7 +335,7 @@ void HybridEpBuffer::open_handles_from_other_ranks( &expert_output_token_handle); remote_allocator.open_handle((void**)(&dispatch_buffers.expert_output_prob_all_ranks[i]), &expert_output_prob_handle); - remote_allocator.open_handle((void**)(&dispatch_buffers.expert_output_scaling_factor_all_ranks[i]), + remote_allocator.open_handle((void**)(&dispatch_buffers.expert_output_scaling_factor_all_ranks[i]), &expert_output_scaling_factor_handle); } else { // For local rank, use direct pointer assignment (more efficient, no IPC overhead) @@ -356,9 +350,9 @@ void HybridEpBuffer::open_handles_from_other_ranks( // Malloc the pointer arrays used in the combine kernel. combine_buffers.expert_input_token_all_ranks = - (uint16_t **)malloc(config.num_of_ranks_per_node * sizeof(uint16_t *)); + (uint16_t **)malloc(buffer_config.num_of_ranks_per_node * sizeof(uint16_t *)); combine_buffers.expert_input_prob_all_ranks = - (float **)malloc(config.num_of_ranks_per_node * sizeof(float *)); + (float **)malloc(buffer_config.num_of_ranks_per_node * sizeof(float *)); // Open the combine handles for intra_node_write_completion_flags if (local_rank != 0) { MemHandle intra_node_write_completion_flags_handle; @@ -371,7 +365,7 @@ void HybridEpBuffer::open_handles_from_other_ranks( &intra_node_write_completion_flags_handle); } // Open the handles for expert_input - for (int i = 0; i < num_of_ranks_per_node; i++) { + for (int i = 0; i < buffer_config.num_of_ranks_per_node; i++) { MemHandle expert_input_token_handle, expert_input_prob_handle; auto base_ptr = combine_handles[i + global_offset].data_ptr(); // Extract the handles from the tensor. @@ -394,24 +388,51 @@ void HybridEpBuffer::open_handles_from_other_ranks( } } +bool HybridEPBuffer::update_buffer(HybridEpConfigInstance config) { + // If new config requires bigger buffer, we will release the old buffer and allocate a new one. + bool need_reallocate = false; + + need_reallocate |= grow_to(buffer_config.hidden_dim, config.hidden_dim); + need_reallocate |= grow_to(buffer_config.num_of_experts_per_rank,config.num_of_experts_per_rank); + need_reallocate |= grow_to(buffer_config.num_of_ranks_per_node, config.num_of_ranks_per_node); + need_reallocate |= grow_to(buffer_config.num_of_nodes, config.num_of_nodes); + need_reallocate |= grow_to(buffer_config.num_of_blocks_preprocessing_api, config.num_of_blocks_preprocessing_api); + need_reallocate |= grow_to(buffer_config.num_of_tokens_per_chunk_dispatch_api, config.num_of_tokens_per_chunk_dispatch_api); + need_reallocate |= grow_to(buffer_config.num_of_tokens_per_chunk_combine_api, config.num_of_tokens_per_chunk_combine_api); + + // Special case for token data type. + if(get_token_data_type_size(buffer_config.token_data_type) < get_token_data_type_size(config.token_data_type) + && !use_shared_buffer) { + need_reallocate = true; + buffer_config.token_data_type = config.token_data_type; + } + + if(need_reallocate) { + release_buffer(); + allocate_buffer(); + } + return need_reallocate; +} + std::tuple -HybridEpBuffer::metadata_preprocessing(torch::Tensor global_routing_map, int64_t num_of_tokens_per_rank) { +HybridEPBuffer::metadata_preprocessing(HybridEpConfigInstance config, torch::Tensor global_routing_map, int64_t num_of_tokens_per_rank) { // Basic checks assert(global_routing_map.device().is_cuda()); assert(global_routing_map.is_contiguous()); // Run the hybrid-ep metadata preprocessing kernel - return executor.metadata_preprocess_core(preprocessing_tmp, config, global_routing_map, num_of_tokens_per_rank); + return executor.metadata_preprocess_core(config, preprocessing_tmp, global_routing_map, num_of_tokens_per_rank); } std::tuple, c10::optional> -HybridEpBuffer::dispatch(torch::Tensor hidden, c10::optional probs, +HybridEPBuffer::dispatch(HybridEpConfigInstance config, + torch::Tensor hidden, c10::optional probs, c10::optional scaling_factor, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, - int64_t num_dispatched_tokens, + c10::optional num_dispatched_tokens, int64_t num_of_tokens_per_rank, bool with_probs) { // Check the input tensors @@ -438,12 +459,9 @@ HybridEpBuffer::dispatch(torch::Tensor hidden, c10::optional prob args.rdma_to_attn_map = rdma_to_attn_map; args.attn_to_rdma_map = attn_to_rdma_map; args.num_dispatched_tokens_tensor = num_dispatched_tokens_tensor; - if(num_dispatched_tokens < 0 ) { - // In this case, user will set the size of the output buffer by themselves. - args.num_dispatched_tokens = num_dispatched_tokens_tensor.value().item(); - }else { - args.num_dispatched_tokens = num_dispatched_tokens; - } + args.num_dispatched_tokens = (num_dispatched_tokens.has_value()) ? + num_dispatched_tokens.value() : + num_dispatched_tokens_tensor.value().item(); args.num_of_tokens_per_rank = num_of_tokens_per_rank; args.enable_permute = false; args.stream = at::cuda::getCurrentCUDAStream(); @@ -464,7 +482,8 @@ HybridEpBuffer::dispatch(torch::Tensor hidden, c10::optional prob } std::tuple -HybridEpBuffer::combine(torch::Tensor hidden, c10::optional probs, +HybridEPBuffer::combine(HybridEpConfigInstance config, + torch::Tensor hidden, c10::optional probs, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, int64_t num_of_tokens_per_rank, @@ -516,17 +535,18 @@ HybridEpBuffer::combine(torch::Tensor hidden, c10::optional probs std::tuple, c10::optional, torch::Tensor, torch::Tensor> -HybridEpBuffer::dispatch_with_permute(torch::Tensor hidden, c10::optional probs, +HybridEPBuffer::dispatch_with_permute(HybridEpConfigInstance config, + torch::Tensor hidden, c10::optional probs, c10::optional scaling_factor, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, c10::optional local_expert_routing_map, c10::optional row_id_map, - int64_t num_dispatched_tokens, - int64_t num_permuted_tokens, + c10::optional num_dispatched_tokens, + c10::optional num_permuted_tokens, int64_t num_of_tokens_per_rank, - int64_t pad_multiple, + c10::optional pad_multiple, bool use_host_meta, bool with_probs) { @@ -555,15 +575,12 @@ HybridEpBuffer::dispatch_with_permute(torch::Tensor hidden, c10::optional(); - }else { - args.num_dispatched_tokens = num_dispatched_tokens; - } + args.num_dispatched_tokens = (num_dispatched_tokens.has_value()) ? + num_dispatched_tokens.value() : + num_dispatched_tokens_tensor.value().item(); args.row_id_map = row_id_map; - args.num_permuted_tokens = num_permuted_tokens; - args.pad_multiple = pad_multiple; + args.num_permuted_tokens = (num_permuted_tokens.has_value()) ? num_permuted_tokens.value() : -1; + args.pad_multiple = (pad_multiple.has_value()) ? pad_multiple.value() : 0; args.use_host_meta = use_host_meta; args.num_of_tokens_per_rank = num_of_tokens_per_rank; args.enable_permute = true; @@ -584,13 +601,14 @@ HybridEpBuffer::dispatch_with_permute(torch::Tensor hidden, c10::optional -HybridEpBuffer::combine_with_unpermute(torch::Tensor hidden, c10::optional probs, +HybridEPBuffer::combine_with_unpermute(HybridEpConfigInstance config, + torch::Tensor hidden, c10::optional probs, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, c10::optional row_id_map, - int64_t num_dispatched_tokens, + c10::optional num_dispatched_tokens, int64_t num_of_tokens_per_rank, - int64_t pad_multiple, + c10::optional pad_multiple, bool with_probs) { // Check the input tensors @@ -624,9 +642,11 @@ HybridEpBuffer::combine_with_unpermute(torch::Tensor hidden, c10::optional(); args.row_id_map = row_id_map; - args.num_dispatched_tokens = num_dispatched_tokens; - args.pad_multiple = pad_multiple; + args.pad_multiple = (pad_multiple.has_value()) ? pad_multiple.value() : 0; args.num_of_tokens_per_rank = num_of_tokens_per_rank; args.enable_unpermute = true; args.stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/hybrid_ep/hybrid_ep.cuh b/csrc/hybrid_ep/hybrid_ep.cuh index 62578cec..bb83ab5a 100644 --- a/csrc/hybrid_ep/hybrid_ep.cuh +++ b/csrc/hybrid_ep/hybrid_ep.cuh @@ -12,80 +12,83 @@ #include #include -class HybridEpBuffer { +class HybridEPBuffer { public: - HybridEpBuffer(HybridEpConfigInstance config, int local_rank, int node_rank, int group_size, int num_of_ranks_per_node, bool use_fp8_dispatch); - ~HybridEpBuffer(); + HybridEPBuffer(BufferConfig config, int local_rank, int node_rank, int group_size); + ~HybridEPBuffer(); + bool update_buffer(HybridEpConfigInstance config); // True means the buffer is reallocated. // Exchange IPC addresses using C++ distributed communication void exchange_ipc_address(pybind11::object process_group); std::tuple - metadata_preprocessing(torch::Tensor global_routing_map, int64_t num_of_tokens_per_rank); + metadata_preprocessing(HybridEpConfigInstance config, torch::Tensor global_routing_map, int64_t num_of_tokens_per_rank); std::tuple, c10::optional> - dispatch(torch::Tensor hidden, c10::optional probs, + dispatch(HybridEpConfigInstance config, + torch::Tensor hidden, c10::optional probs, c10::optional scaling_factor, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, - int64_t num_dispatched_tokens, + c10::optional num_dispatched_tokens, int64_t num_of_tokens_per_rank, bool with_probs); std::tuple - combine(torch::Tensor hidden, c10::optional probs, + combine(HybridEpConfigInstance config, torch::Tensor hidden, c10::optional probs, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, int64_t num_of_tokens_per_rank, bool with_probs); std::tuple, c10::optional, torch::Tensor, torch::Tensor> - dispatch_with_permute(torch::Tensor hidden, c10::optional probs, + dispatch_with_permute( + HybridEpConfigInstance config, + torch::Tensor hidden, c10::optional probs, c10::optional scaling_factor, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, c10::optional local_expert_routing_map, c10::optional row_id_map, - int64_t num_dispatched_tokens, - int64_t num_permuted_tokens, + c10::optional num_dispatched_tokens, + c10::optional num_permuted_tokens, int64_t num_of_tokens_per_rank, - int64_t pad_multiple, + c10::optional pad_multiple, bool use_host_meta, bool with_probs); std::tuple - combine_with_unpermute(torch::Tensor hidden, c10::optional probs, + combine_with_unpermute( + HybridEpConfigInstance config, + torch::Tensor hidden, c10::optional probs, torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, c10::optional num_dispatched_tokens_tensor, c10::optional row_id_map, - int64_t num_dispatched_tokens, + c10::optional num_dispatched_tokens, int64_t num_of_tokens_per_rank, - int64_t pad_multiple, + c10::optional pad_multiple, bool with_probs); private: ExtendedMemoryAllocator remote_allocator; - HybridEpConfigInstance config; + BufferConfig buffer_config; Executor executor; void allocate_buffer(); void allocate_buffer_for_preprocessing(); void allocate_buffer_for_dispatch(); void allocate_buffer_for_combine(); + void release_buffer(); void open_handles_from_other_ranks(std::vector dispatch_handles, std::vector combine_handles); // Meta data of communication group. - int rank; int local_rank; int node_rank; - int num_of_ranks_per_node; int group_size; - int nvlink_domain_size; // Maximum number of tokens for experts. int64_t max_num_of_tokens_for_experts; - bool use_fp8_dispatch; // Only valid on intra-node communication. In this case, the dispatch/combine can share same buffers. bool use_shared_buffer; diff --git a/csrc/hybrid_ep/pybind_hybrid_ep.cu b/csrc/hybrid_ep/pybind_hybrid_ep.cu index 7bdb9b7b..2957312f 100644 --- a/csrc/hybrid_ep/pybind_hybrid_ep.cu +++ b/csrc/hybrid_ep/pybind_hybrid_ep.cu @@ -22,6 +22,29 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("__str__", [](const TOKEN_DATA_TYPE &type) { return type_to_string(type); }); + pybind11::class_(m, "BufferConfig") + .def(py::init<>()) + .def_readwrite("hidden_dim", &BufferConfig::hidden_dim) + .def_readwrite("max_num_of_tokens_per_rank", &BufferConfig::max_num_of_tokens_per_rank) + .def_readwrite("num_of_experts_per_rank", &BufferConfig::num_of_experts_per_rank) + .def_readwrite("num_of_ranks_per_node", &BufferConfig::num_of_ranks_per_node) + .def_readwrite("num_of_nodes", &BufferConfig::num_of_nodes) + .def_readwrite("token_data_type", &BufferConfig::token_data_type) + .def_readwrite("num_of_blocks_preprocessing_api", &BufferConfig::num_of_blocks_preprocessing_api) + .def_readwrite("num_of_tokens_per_chunk_dispatch_api", &BufferConfig::num_of_tokens_per_chunk_dispatch_api) + .def_readwrite("num_of_tokens_per_chunk_combine_api", &BufferConfig::num_of_tokens_per_chunk_combine_api) + .def("__repr__", [](const BufferConfig &config) { + return ""; + }); + pybind11::class_(m, "HybridEpConfigInstance") .def(py::init<>()) // Hybrid-ep Config @@ -77,24 +100,28 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ">"; }); - pybind11::class_(m, "HybridEpBuffer") - .def(py::init()) - .def("exchange_ipc_address", &HybridEpBuffer::exchange_ipc_address) - .def("metadata_preprocessing", &HybridEpBuffer::metadata_preprocessing, - py::kw_only(), py::arg("routing_map"), py::arg("num_of_tokens_per_rank")) - .def("dispatch", &HybridEpBuffer::dispatch, py::kw_only(), py::arg("hidden"), + pybind11::class_(m, "HybridEPBuffer") + .def(py::init()) + .def("update_buffer", &HybridEPBuffer::update_buffer, py::arg("config")) + .def("exchange_ipc_address", &HybridEPBuffer::exchange_ipc_address) + .def("metadata_preprocessing", &HybridEPBuffer::metadata_preprocessing, + py::kw_only(), py::arg("config"), py::arg("routing_map"), py::arg("num_of_tokens_per_rank")) + .def("dispatch", &HybridEPBuffer::dispatch, py::kw_only(), + py::arg("config"), py::arg("hidden"), py::arg("probs") = c10::nullopt, py::arg("scaling_factor") = c10::nullopt, py::arg("sparse_to_dense_map"), py::arg("rdma_to_attn_map"), py::arg("attn_to_rdma_map"), py::arg("num_dispatched_tokens_tensor"), py::arg("num_dispatched_tokens") = -1, py::arg("num_of_tokens_per_rank"), py::arg("with_probs")) - .def("combine", &HybridEpBuffer::combine, py::kw_only(), py::arg("hidden"), + .def("combine", &HybridEPBuffer::combine, py::kw_only(), + py::arg("config"), py::arg("hidden"), py::arg("probs") = c10::nullopt, py::arg("sparse_to_dense_map"), py::arg("rdma_to_attn_map"), py::arg("attn_to_rdma_map"), py::arg("num_of_tokens_per_rank"), py::arg("with_probs")) - .def("dispatch_with_permute", &HybridEpBuffer::dispatch_with_permute, py::kw_only(), py::arg("hidden"), + .def("dispatch_with_permute", &HybridEPBuffer::dispatch_with_permute, py::kw_only(), + py::arg("config"), py::arg("hidden"), py::arg("probs") = c10::nullopt, py::arg("scaling_factor") = c10::nullopt, py::arg("sparse_to_dense_map"), py::arg("rdma_to_attn_map"), @@ -103,7 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("num_permuted_tokens") = -1, py::arg("num_of_tokens_per_rank"), py::arg("pad_multiple") = 0, py::arg("use_host_meta") = false, py::arg("with_probs") = false) - .def("combine_with_unpermute", &HybridEpBuffer::combine_with_unpermute, py::kw_only(), py::arg("hidden"), + .def("combine_with_unpermute", &HybridEPBuffer::combine_with_unpermute, py::kw_only(), + py::arg("config"), py::arg("hidden"), py::arg("probs") = c10::nullopt, py::arg("sparse_to_dense_map"), py::arg("rdma_to_attn_map"), py::arg("attn_to_rdma_map"), py::arg("num_dispatched_tokens_tensor"), diff --git a/csrc/hybrid_ep/utils.cuh b/csrc/hybrid_ep/utils.cuh index 271183d4..1209a669 100644 --- a/csrc/hybrid_ep/utils.cuh +++ b/csrc/hybrid_ep/utils.cuh @@ -162,3 +162,41 @@ inline std::string convert_to_nvcc_arch_flags(std::string torch_arch_list) { return nvcc_arch_flags; } + +template +inline bool grow_to(T& dst, const T& src) { + if (dst < src) { dst = src; return true; } + return false; +} + +inline void print_ptr_info(void* p) { + cudaPointerAttributes attr{}; + cudaError_t err = cudaPointerGetAttributes(&attr, p); + if (err != cudaSuccess) { + printf("cudaPointerGetAttributes failed: %s\n", cudaGetErrorString(err)); + return; + } + cudaMemoryType memory_type; +#if CUDART_VERSION >= 10000 + memory_type = attr.type; +#else + memory_type = attr.memoryType; +#endif + std::string memory_type_str; + switch (memory_type) { + case cudaMemoryTypeHost: memory_type_str = "Host"; break; + case cudaMemoryTypeDevice: memory_type_str = "Device"; break; + case cudaMemoryTypeManaged: memory_type_str = "Managed"; break; + default: memory_type_str = "Unregistered/Unknown"; break; + } + printf("type=%s, device=%d\n", memory_type_str.c_str(), attr.device); + + // If this is a device/managed pointer, try to query its allocation range (base + size) + if (memory_type == cudaMemoryTypeDevice || memory_type == cudaMemoryTypeManaged) { + cuInit(0); + CUdeviceptr base = 0; + size_t size = 0; + CUresult r = cuMemGetAddressRange(&base, &size, reinterpret_cast(p)); + printf("alloc_base=%p, alloc_size=%zu bytes\n", reinterpret_cast(base), size); + } +} \ No newline at end of file diff --git a/deep_ep/__init__.py b/deep_ep/__init__.py index cc87716f..f09f54a0 100644 --- a/deep_ep/__init__.py +++ b/deep_ep/__init__.py @@ -3,7 +3,7 @@ from .utils import EventOverlap from .buffer import Buffer -from .hybrid_ep_buffer import HybridEpBuffer +from .hybrid_ep_buffer import HybridEPBuffer # noinspection PyUnresolvedReferences from deep_ep_cpp import Config diff --git a/deep_ep/hybrid_ep_buffer.py b/deep_ep/hybrid_ep_buffer.py index 5e8fa2b8..df64d704 100644 --- a/deep_ep/hybrid_ep_buffer.py +++ b/deep_ep/hybrid_ep_buffer.py @@ -5,31 +5,39 @@ import hybrid_ep_cpp -def indices_to_map(topk_idx: torch.Tensor, topk_weights: torch.Tensor, num_of_tokens: int, num_of_experts: int): +def indices_to_map( + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_of_tokens: int, + num_of_experts: int, +): """ Map the map to the indices. """ # Generate the routing map and the probs according to the topk_idx and topk_weights. assert topk_idx is not None - routing_map = torch.zeros(num_of_tokens, num_of_experts, device="cuda", dtype=torch.bool) + routing_map = torch.zeros( + num_of_tokens, num_of_experts, device="cuda", dtype=torch.bool + ) routing_map = routing_map.scatter(1, topk_idx.to(torch.int64), 1).bool() if topk_weights is not None: - probs = torch.zeros(num_of_tokens, num_of_experts, device="cuda", dtype=torch.float32) + probs = torch.zeros( + num_of_tokens, num_of_experts, device="cuda", dtype=torch.float32 + ) probs = probs.scatter(1, topk_idx.to(torch.int64), topk_weights) else: probs = None return routing_map, probs -class HybridEpBuffer: + +class HybridEPBuffer: def __init__( self, group: torch.distributed.ProcessGroup, - # Basic tensor setting + # Parameters for the hybrid-ep buffer allocation hidden_dim: int, max_num_of_tokens_per_rank: int, num_local_experts: int, - num_of_experts: int, - # Use fp8 in dispatch or not. use_fp8: bool = False, # Device-SM occupancy setting num_sms_dispatch_api: int = 32, @@ -68,11 +76,6 @@ def __init__( self.node_rank = self.rank // self.num_of_ranks_per_node # The number of nodes. self.num_of_nodes = self.group_size // self.num_of_ranks_per_node - - self.hidden_dim = hidden_dim - self.max_num_of_tokens_per_rank = max_num_of_tokens_per_rank - self.num_local_experts = num_local_experts - self.num_of_experts = num_of_experts self.use_fp8 = use_fp8 props = torch.cuda.get_device_properties(torch.cuda.current_device()) @@ -86,82 +89,93 @@ def __init__( self.num_sms_dispatch_api = num_sms_dispatch_api self.num_sms_combine_api = num_sms_combine_api - self.init_config() - self.init_buffer() + # Initialize the BufferConfig for the hybrid-ep buffer allocation. + self.config = hybrid_ep_cpp.BufferConfig() + self.config.hidden_dim = hidden_dim + self.config.max_num_of_tokens_per_rank = max(max_num_of_tokens_per_rank, 1024) + self.config.num_of_experts_per_rank = num_local_experts + self.config.num_of_ranks_per_node = self.num_of_ranks_per_node + self.config.num_of_nodes = self.num_of_nodes + # The SMs of preprocessing, chunk size of dispatch and combine will affact the size of intermediate buffers. + self.config.num_of_blocks_preprocessing_api = self.num_sms_preprocessing_api + # The fp8/bf16/fp16 data is communicated in the uint8/uint16 format. + self.config.token_data_type = ( + hybrid_ep_cpp.UINT8 if self.use_fp8 else hybrid_ep_cpp.UINT16 + ) + self.config.num_of_tokens_per_chunk_dispatch_api = int( + os.getenv("NUM_OF_TOKENS_PER_CHUNK_DISPATCH_API", "128") + ) + self.config.num_of_tokens_per_chunk_combine_api = int( + os.getenv("NUM_OF_TOKENS_PER_CHUNK_COMBINE_API", "128") + ) + + # Create C++ buffer - this will allocate all buffers during construction + self.runtime = hybrid_ep_cpp.HybridEPBuffer( + self.config, self.local_rank, self.node_rank, self.group_size + ) + # Exchange IPC addresses using C++ distributed communication + self.runtime.exchange_ipc_address(self.group) - def init_config( + def update_template_config( self, - # Metadata-preprocessing API Config - num_of_threads_per_block_preprocessing_api: int = None, - # Dispatch API Config - num_of_stages_dispatch_api: int = None, - num_of_tokens_per_chunk_dispatch_api: int = None, - device_side_sync_dispatch_api: bool = True, - # Combine API Config - num_of_stages_g2s_combine_api: int = None, - num_of_stages_s2g_combine_api: int = None, - num_of_tokens_per_chunk_combine_api: int = None, - num_of_tokens_per_group_combine_api: int = None, - num_of_additional_in_flight_s2g_combine_api: int = None, - device_side_sync_combine_api: bool = True, + hidden_dim: int = None, + max_num_of_tokens_per_rank: int = None, + num_local_experts: int = None, + use_fp8: bool = None, ): """ - Initialize the HybridEpConfigInstance for the hybrid-ep kernel. - We can contoal the detailed setting of the hybrid-ep kernel. + Initialize the HybridEpConfigInstance which used to control the detailed setting of the hybrid-ep kernel. In common case, no need to change the default setting. """ config = hybrid_ep_cpp.HybridEpConfigInstance() # Initialize the ConfigInstance # Hybrid-ep Config - config.hidden_dim = self.hidden_dim - config.max_num_of_tokens_per_rank = self.max_num_of_tokens_per_rank - config.num_of_experts_per_rank = self.num_local_experts + config.hidden_dim = ( + hidden_dim if hidden_dim is not None else self.config.hidden_dim + ) + config.max_num_of_tokens_per_rank = ( + max_num_of_tokens_per_rank + if max_num_of_tokens_per_rank is not None + else self.config.max_num_of_tokens_per_rank + ) + config.max_num_of_tokens_per_rank = max( + config.max_num_of_tokens_per_rank, self.config.max_num_of_tokens_per_rank + ) + config.num_of_experts_per_rank = ( + num_local_experts + if num_local_experts is not None + else self.config.num_of_experts_per_rank + ) config.num_of_ranks_per_node = self.num_of_ranks_per_node config.num_of_nodes = self.num_of_nodes # Metadata-preprocessing API Config config.num_of_blocks_preprocessing_api = self.num_sms_preprocessing_api - # 1. Try to get the value from the environment variable, Default value: 512 - # 2. If the value is provided, use the provided value. config.num_of_threads_per_block_preprocessing_api = int( os.getenv("NUM_OF_THREADS_PER_BLOCK_PREPROCESSING_API", "512") ) - if num_of_threads_per_block_preprocessing_api is not None: - config.num_of_threads_per_block_preprocessing_api = ( - num_of_threads_per_block_preprocessing_api - ) # Dispatch API Config - if self.use_fp8: - # The fp8 data is communicated in the uint8 format. - config.token_data_type = hybrid_ep_cpp.UINT8 - else: - # The bf16 data is communicated in the uint16 format. - config.token_data_type = hybrid_ep_cpp.UINT16 + if use_fp8 is None: + use_fp8 = self.use_fp8 + config.token_data_type = ( + hybrid_ep_cpp.UINT8 if use_fp8 else hybrid_ep_cpp.UINT16 + ) config.num_of_blocks_dispatch_api = self.num_sms_dispatch_api - config.device_side_sync_dispatch_api = device_side_sync_dispatch_api + config.device_side_sync_dispatch_api = True # Dispatch stages config: - # 1. Try to get the value from the environment variable config.num_of_stages_dispatch_api = int( os.getenv("NUM_OF_STAGES_DISPATCH_API", "10") ) config.num_of_tokens_per_chunk_dispatch_api = int( os.getenv("NUM_OF_TOKENS_PER_CHUNK_DISPATCH_API", "128") ) - # 2. If the value is provided, use the provided value. - if num_of_stages_dispatch_api is not None: - config.num_of_stages_dispatch_api = num_of_stages_dispatch_api - if num_of_tokens_per_chunk_dispatch_api is not None: - config.num_of_tokens_per_chunk_dispatch_api = ( - num_of_tokens_per_chunk_dispatch_api - ) # Combine API Config config.num_of_blocks_combine_api = self.num_sms_combine_api - config.device_side_sync_combine_api = device_side_sync_combine_api + config.device_side_sync_combine_api = True # Combine stages config: - # 1. Try to get the value from the environment variable config.num_of_stages_g2s_combine_api = int( os.getenv("NUM_OF_STAGES_G2S_COMBINE_API", "10") ) @@ -177,39 +191,12 @@ def init_config( config.num_of_additional_in_flight_s2g_combine_api = int( os.getenv("NUM_OF_ADDITIONAL_IN_FLIGHT_S2G_COMBINE_API", "2") ) - # 2. If the value is provided, use the provided value. - if num_of_stages_g2s_combine_api is not None: - config.num_of_stages_g2s_combine_api = num_of_stages_g2s_combine_api - if num_of_stages_s2g_combine_api is not None: - config.num_of_stages_s2g_combine_api = num_of_stages_s2g_combine_api - if num_of_tokens_per_chunk_combine_api is not None: - config.num_of_tokens_per_chunk_combine_api = ( - num_of_tokens_per_chunk_combine_api - ) - if num_of_tokens_per_group_combine_api is not None: - config.num_of_tokens_per_group_combine_api = ( - num_of_tokens_per_group_combine_api - ) - if num_of_additional_in_flight_s2g_combine_api is not None: - config.num_of_additional_in_flight_s2g_combine_api = ( - num_of_additional_in_flight_s2g_combine_api - ) - self.config = config - - def init_buffer(self): - """ - Initialize the buffer for the hybrid-ep kernel. - Creates the C++ buffer (which allocates buffers) and exchanges IPC addresses. - """ - assert self.config is not None, "Please initialize the config first." - # Create C++ buffer - this will allocate all buffers during construction - self.runtime = hybrid_ep_cpp.HybridEpBuffer( - self.config, self.local_rank, self.node_rank, self.group_size, self.num_of_ranks_per_node, self.use_fp8 - ) - - # Exchange IPC addresses using C++ distributed communication - self.runtime.exchange_ipc_address(self.group) + # Use the runtime kernel config to update the buffer. + reallocated = self.runtime.update_buffer(config) + if reallocated: + self.runtime.exchange_ipc_address(self.group) + return config def dispatch( self, @@ -217,10 +204,11 @@ def dispatch( scaling_factor: torch.Tensor = None, topk_idx: torch.Tensor = None, topk_weights: torch.Tensor = None, + num_of_experts: int = None, probs: torch.Tensor = None, routing_map: torch.Tensor = None, num_dispatched_tokens_tensor: torch.Tensor = None, - num_dispatched_tokens: int = -1, + num_dispatched_tokens: int = None, handle: tuple = None, ): """ @@ -232,29 +220,41 @@ def dispatch( Backward direction: combine_in_backward <- local_unpermute -> expert_mlp -> local_permute -> dispatch_in_backward """ - num_of_tokens = hidden.shape[0] - assert num_of_tokens <= self.max_num_of_tokens_per_rank, "The number of tokens should be less than or equal to the max number of tokens per rank." + num_of_tokens, hidden_dim = hidden.shape + if routing_map is not None: assert routing_map.dtype == torch.bool + num_of_experts = routing_map.size(-1) else: # Generate the routing map and the probs according to the topk_idx and topk_weights. + assert ( + num_of_experts is not None + ), "The number of experts should be provided on index-based routing." if topk_idx is not None: - routing_map, probs = indices_to_map(topk_idx, topk_weights, num_of_tokens, self.num_of_experts) + routing_map, probs = indices_to_map( + topk_idx, topk_weights, num_of_tokens, num_of_experts + ) assert ( handle is not None or routing_map is not None ), "The handle and routing_map should be both None" # If the handle is not provided, we need to generate the handle using the preprocessing kernel. if handle is None: + config = self.update_template_config( + hidden_dim=hidden_dim, + max_num_of_tokens_per_rank=num_of_tokens, + ) + # The hybrid-ep kernel requires the routing info from all ranks. global_routing_map = torch.empty( num_of_tokens * self.group_size, - self.num_of_experts, + num_of_experts, device="cuda", dtype=torch.bool, ) torch.distributed.all_gather_into_tensor( global_routing_map, routing_map, self.group ) + # Run the metadata preprocessing kernel. ( sparse_to_dense_map, rdma_to_attn_map, @@ -262,6 +262,7 @@ def dispatch( num_dispatched_tokens_tensor, local_expert_routing_map, ) = self.runtime.metadata_preprocessing( + config=config, routing_map=global_routing_map, num_of_tokens_per_rank=num_of_tokens, ) @@ -273,6 +274,7 @@ def dispatch( num_dispatched_tokens_tensor, local_expert_routing_map, num_of_tokens, + config, ) else: ( @@ -282,13 +284,15 @@ def dispatch( num_dispatched_tokens_tensor, local_expert_routing_map, num_of_tokens, + config, ) = handle - if num_dispatched_tokens < 0: + if num_dispatched_tokens is None: num_dispatched_tokens = num_dispatched_tokens_tensor.item() dispatched_token, dispatched_probs, dispatched_scaling_factor = ( self.runtime.dispatch( + config=config, hidden=hidden, probs=probs, scaling_factor=scaling_factor, @@ -317,8 +321,17 @@ def combine( Do not require preprocessing, but the handle is necessary. """ assert handle is not None, "The handle is necessary for combine." - sparse_to_dense_map, rdma_to_attn_map, attn_to_rdma_map, num_dispatched_tokens_tensor, local_expert_routing_map, num_of_tokens = handle + ( + sparse_to_dense_map, + rdma_to_attn_map, + attn_to_rdma_map, + num_dispatched_tokens_tensor, + local_expert_routing_map, + num_of_tokens, + config, + ) = handle combined_token, combined_probs = self.runtime.combine( + config=config, hidden=hidden, probs=probs, sparse_to_dense_map=sparse_to_dense_map, @@ -328,7 +341,7 @@ def combine( with_probs=probs is not None, ) return combined_token, combined_probs - + def dispatch_with_permute( self, *, @@ -336,14 +349,17 @@ def dispatch_with_permute( hidden: torch.Tensor, topk_idx: torch.Tensor = None, topk_weights: torch.Tensor = None, + num_of_experts_per_rank: int = None, + num_of_experts: int = None, + use_fp8: bool = None, routing_map: torch.Tensor = None, probs: torch.Tensor = None, scaling_factor: torch.Tensor = None, # Used in the sync-free permute - num_dispatched_tokens: int = -1, - num_permuted_tokens: int = -1, + num_dispatched_tokens: int = None, + num_permuted_tokens: int = None, # If we use permute kernel, the output tensor will be permuted. the result can be directly used in the gemm. - pad_multiple: int = 0, + pad_multiple: int = None, # The handle means the cached info from the first invocation of the dispatch kernel. # The handle includes: # # Output of Metadata Preprocessing @@ -354,6 +370,8 @@ def dispatch_with_permute( # 5. local_expert_routing_map # # Output of Permute Preprocessing # 6. row_id_map + # # Cache for template config + # 7. template_config: HybridEpConfigInstance handle: tuple = None, # If enable this, the produced num_dispatched_tokens will be put on the CPU pinned memory, and the tokens_per_expert will be put on the CPU, which may reduce the times of the sync use_host_meta: bool = True, @@ -362,28 +380,43 @@ def dispatch_with_permute( Dispatch the data to the experts with permute. """ with torch.cuda.nvtx.range("hybrid-ep dispatch with permute phase"): - num_of_tokens_per_rank = hidden.shape[0] - assert num_of_tokens_per_rank <= self.max_num_of_tokens_per_rank, "The number of tokens should be less than or equal to the max number of tokens per rank." + num_of_tokens_per_rank, hidden_dim = hidden.shape if routing_map is not None: assert routing_map.dtype == torch.bool + num_of_experts = routing_map.size(-1) else: # Generate the routing map and the probs according to the topk_idx and topk_weights. if topk_idx is not None: - routing_map, probs = indices_to_map(topk_idx, topk_weights, num_of_tokens_per_rank, self.num_of_experts) - + assert ( + num_of_experts is not None + ), "The number of experts should be provided on index-based routing." + routing_map, probs = indices_to_map( + topk_idx, topk_weights, num_of_tokens_per_rank, num_of_experts + ) + # If the handle is not provided, we need to generate the handle in the first invocation of the dispatch kernel. if handle is None: - assert hidden.size(0) == routing_map.size(0), "The hidden and the routing_map should have the same row number." + assert hidden.size(0) == routing_map.size( + 0 + ), "The hidden and the routing_map should have the same row number." + # Update the template config. + config = self.update_template_config( + hidden_dim=hidden_dim, + max_num_of_tokens_per_rank=num_of_tokens_per_rank, + num_local_experts=num_of_experts_per_rank, + use_fp8=use_fp8, + ) # Global routing map: the routing map for all tokens to all experts. global_routing_map = torch.empty( num_of_tokens_per_rank * self.group_size, - self.num_of_experts, + num_of_experts, device="cuda", dtype=torch.bool, ) torch.distributed.all_gather_into_tensor( global_routing_map, routing_map, self.group ) + # Run the metadata preprocessing kernel. row_id_map = None ( sparse_to_dense_map, @@ -392,6 +425,7 @@ def dispatch_with_permute( num_dispatched_tokens_tensor, local_expert_routing_map, ) = self.runtime.metadata_preprocessing( + config=config, routing_map=global_routing_map, num_of_tokens_per_rank=num_of_tokens_per_rank, ) @@ -409,8 +443,9 @@ def dispatch_with_permute( local_expert_routing_map, row_id_map, num_of_tokens_per_rank, + config, ) = handle - + # Dispatch phase ( dispatched_token, @@ -419,6 +454,7 @@ def dispatch_with_permute( row_id_map, tokens_per_expert, ) = self.runtime.dispatch_with_permute( + config=config, hidden=hidden, probs=probs, scaling_factor=scaling_factor, @@ -444,6 +480,7 @@ def dispatch_with_permute( local_expert_routing_map, row_id_map, num_of_tokens_per_rank, + config, ) return ( dispatched_token, @@ -459,9 +496,9 @@ def combine_with_unpermute( # Input tensors hidden: torch.Tensor, probs: torch.Tensor = None, - num_dispatched_tokens: int = -1, + num_dispatched_tokens: int = None, handle: tuple = None, - pad_multiple: int = 0, + pad_multiple: int = None, ): """ Combine the data from the experts with unpermute. @@ -479,19 +516,21 @@ def combine_with_unpermute( _, row_id_map, num_of_tokens_per_rank, + config, ) = handle combined_token, combined_probs = self.runtime.combine_with_unpermute( - hidden = hidden, - probs = probs, - sparse_to_dense_map = sparse_to_dense_map, - rdma_to_attn_map = rdma_to_attn_map, - attn_to_rdma_map = attn_to_rdma_map, - num_dispatched_tokens_tensor = num_dispatched_tokens_tensor, - row_id_map = row_id_map, - num_dispatched_tokens = num_dispatched_tokens, - num_of_tokens_per_rank = num_of_tokens_per_rank, - pad_multiple = pad_multiple, - with_probs = probs is not None, + config=config, + hidden=hidden, + probs=probs, + sparse_to_dense_map=sparse_to_dense_map, + rdma_to_attn_map=rdma_to_attn_map, + attn_to_rdma_map=attn_to_rdma_map, + num_dispatched_tokens_tensor=num_dispatched_tokens_tensor, + row_id_map=row_id_map, + num_dispatched_tokens=num_dispatched_tokens, + num_of_tokens_per_rank=num_of_tokens_per_rank, + pad_multiple=pad_multiple, + with_probs=probs is not None, ) - return combined_token, combined_probs \ No newline at end of file + return combined_token, combined_probs diff --git a/tests/test_hybrid_ep.py b/tests/test_hybrid_ep.py index 35d9988a..b17b28de 100644 --- a/tests/test_hybrid_ep.py +++ b/tests/test_hybrid_ep.py @@ -26,6 +26,13 @@ torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False +def bitwise_equal(a: torch.Tensor, b: torch.Tensor) -> bool: + if a.dtype != b.dtype or a.shape != b.shape or a.device != b.device: + return False + a_bytes = a.contiguous().view(torch.uint8) + b_bytes = b.contiguous().view(torch.uint8) + return torch.equal(a_bytes, b_bytes) + def init_tensor( hidden_dim: int, seq_len: int, @@ -66,7 +73,7 @@ def init_tensor( return hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights -def test_hybrid_ep_correctness(buffer: deep_ep.HybridEpBuffer, ref: TorchRef, use_fp8: bool): +def test_hybrid_ep_correctness(buffer: deep_ep.HybridEPBuffer, ref: TorchRef, use_fp8: bool): hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights = init_tensor( hidden_dim=HIDDEN_DIM, seq_len=NUM_TOKENS_PER_RANK, @@ -89,24 +96,24 @@ def test_hybrid_ep_correctness(buffer: deep_ep.HybridEpBuffer, ref: TorchRef, us dispatched_scaling_factor, handle, ) = buffer.dispatch( - hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights if with_probs else None + hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights if with_probs else None, num_of_experts=NUM_OF_EXPERTS, ) - assert torch.allclose(dispatched_hidden_ref, dispatched_hidden) + assert bitwise_equal(dispatched_hidden_ref, dispatched_hidden) if dispatched_probs is not None and dispatched_probs_ref is not None: start, end = ref._local_expert_range() masked_probs = torch.zeros_like(dispatched_probs) masked_probs[:, start:end] = dispatched_probs[:, start:end] - assert torch.allclose(dispatched_probs_ref, dispatched_probs[:, start:end]) + assert bitwise_equal(dispatched_probs_ref, dispatched_probs[:, start:end]) dispatched_probs = masked_probs if ( dispatched_scaling_factor is not None and dispatched_scaling_factor_ref is not None ): - assert torch.allclose( + assert bitwise_equal( dispatched_scaling_factor_ref, dispatched_scaling_factor ) - _, _, _, num_dispatched_tokens, local_expert_routing_map, _ = handle + _, _, _, num_dispatched_tokens, local_expert_routing_map, _, _ = handle num_dispatched_tokens = num_dispatched_tokens.cpu() local_expert_routing_map = local_expert_routing_map[ : num_dispatched_tokens.item() @@ -148,7 +155,7 @@ def test_hybrid_ep_correctness(buffer: deep_ep.HybridEpBuffer, ref: TorchRef, us scaling_factor=scaling_factor, pad_multiple=PAD_MULTIPLE, ) - _, _, _, num_dispatched_tokens_tensor, local_expert_routing_map, _, _ = ( + _, _, _, num_dispatched_tokens_tensor, local_expert_routing_map, _, _, _ = ( handle ) num_dispatched_tokens_tensor = num_dispatched_tokens_tensor.cpu() @@ -172,14 +179,14 @@ def test_hybrid_ep_correctness(buffer: deep_ep.HybridEpBuffer, ref: TorchRef, us enable_permute=True, ) - assert torch.allclose(dispatched_hidden_ref, dispatched_hidden) + assert bitwise_equal(dispatched_hidden_ref, dispatched_hidden) if dispatched_probs is not None and dispatched_probs_ref is not None: - assert torch.allclose(dispatched_probs_ref, dispatched_probs) + assert bitwise_equal(dispatched_probs_ref, dispatched_probs) if ( dispatched_scaling_factor is not None and dispatched_scaling_factor_ref is not None ): - assert torch.allclose( + assert bitwise_equal( dispatched_scaling_factor_ref, dispatched_scaling_factor ) @@ -211,7 +218,7 @@ def test_hybrid_ep_correctness(buffer: deep_ep.HybridEpBuffer, ref: TorchRef, us print(f'[rank {torch.distributed.get_rank()}] Correctness check passed ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})') -def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEpBuffer, group: dist.ProcessGroup, use_fp8: bool, nsys_profile: bool): +def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.ProcessGroup, use_fp8: bool, nsys_profile: bool): hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights = init_tensor( hidden_dim=HIDDEN_DIM, seq_len=NUM_TOKENS_PER_RANK, @@ -223,7 +230,7 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEpBuffer, group: dist.Process # warmup for _ in range(10): dispatched_hidden, dispatched_probs, _, handle = ( - buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights) + buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS) ) # The combine only support bf16 dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) @@ -235,14 +242,14 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEpBuffer, group: dist.Process dispatch_bf16_nvl_recv_bytes = dispatched_hidden.numel() * 2 combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes - dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights} + dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS} t = bench(lambda: buffer.dispatch(**dispatch_args))[0] nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes print(f'[rank {rank}] HybridEP dispatch torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): ' f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, nvl_recv_bytes: {nvl_recv_bytes / 1e6:.2f} MB', flush=True) dispatched_hidden, dispatched_probs, _, handle= ( - buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights) + buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS) ) dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) combine_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle} @@ -273,7 +280,7 @@ def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEpBuffer, group: dist.Process # noinspection PyShadowingNames def test_func(): dispatched_hidden, dispatched_probs, _, handle = ( - buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights) + buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS) ) # The combine only support bf16 dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) @@ -291,7 +298,7 @@ def test_func(): with torch.cuda.nvtx.range(f"hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"): if rank == 0: print(f"profile hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True) - dispatch_args = {'tensor': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights} + dispatch_args = {'tensor': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS} bench(lambda: buffer.dispatch(**dispatch_args)) with torch.cuda.nvtx.range("hybrid-ep combine"): if rank == 0: @@ -306,12 +313,11 @@ def test_main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) try: for use_fp8 in [True, False]: - buffer = deep_ep.HybridEpBuffer( + buffer = deep_ep.HybridEPBuffer( group=group, hidden_dim=HIDDEN_DIM, max_num_of_tokens_per_rank=MAX_NUM_OF_TOKENS_PER_RANK, num_local_experts=NUM_LOCAL_EXPERTS, - num_of_experts=NUM_OF_EXPERTS, use_fp8=use_fp8, )