Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Hybrid-EP_Intra-node_Implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ Refer to `tests/test_hybrid_ep.py` for comprehensive usage examples including:
- Performance benchmarking setups

### Important Configuration Note
Here are important parameter settings in `csrc/hybrid_ep/config.cuh`. You can modify these parameters via `HybridEpBuffer.init_config()` or by setting proper environment variables (see `deep_ep/hybrid_ep_buffer.py`) to achieve better performance/usability:
Here are important parameter settings in `csrc/hybrid_ep/config.cuh`. You can modify these parameters via `HybridEPBuffer.init_config()` or by setting proper environment variables (see `deep_ep/hybrid_ep_buffer.py`) to achieve better performance/usability:

- HIDDEN_DIM
Hidden size (must match model hidden dimension).
Expand Down
25 changes: 25 additions & 0 deletions csrc/hybrid_ep/config.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
// This will be used to initialize the template param_t for communication kernel.
#define MAX_NUM_OF_RANKS_PER_NODE 72

// Config used for buffer allocation.
struct BufferConfig {
int hidden_dim;
int max_num_of_tokens_per_rank;
int num_of_experts_per_rank;
int num_of_ranks_per_node;
int num_of_nodes;
TOKEN_DATA_TYPE token_data_type;
int num_of_blocks_preprocessing_api;
int num_of_tokens_per_chunk_dispatch_api;
int num_of_tokens_per_chunk_combine_api;

/*
* Validation check
*/
bool is_valid(){
bool valid = true;
valid &= (hidden_dim % 512 == 0);
valid &= ((num_of_experts_per_rank * num_of_ranks_per_node) % 4 == 0);
valid &= (num_of_ranks_per_node % 2 == 0);
return valid;
}
};

// Config used for hybrid-ep kernel.
struct HybridEpConfigInstance {
/*
* Hybrid-ep Config
Expand Down
25 changes: 16 additions & 9 deletions csrc/hybrid_ep/executor/executor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ Executor::Executor(int local_rank, int node_rank) : local_rank(local_rank), node

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Executor::metadata_preprocess_core(
hybrid_ep::tmp_state_t *preprocessing_tmp,
HybridEpConfigInstance config,
hybrid_ep::tmp_state_t *preprocessing_tmp,
torch::Tensor global_routing_map,
int num_of_tokens_per_rank
) {
nvtxRangePushA("metadata_preprocess_core in hybrid-ep");
// padding for the routing map
const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16;

auto num_of_expert = global_routing_map.size(-1);
assert(num_of_expert == config.num_of_experts_per_rank * config.num_of_ranks_per_node * config.num_of_nodes);

// Construt the output tensor of the metadata preprocessing kernel.
auto sparse_to_dense_map =
torch::empty({num_of_tokens_per_rank * config.num_of_nodes,
Expand Down Expand Up @@ -112,7 +115,16 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d
torch::Tensor row_id_map, tokens_per_expert;

if(args.num_dispatched_tokens == 0 ) {
// Fast return if there are no tokens to dispatch
// Fast return empty tensors if there are no tokens to dispatch
dispatched_tokens = torch::empty({0, config.hidden_dim}, torch::dtype(args.hidden.dtype()).device(torch::kCUDA));
if(config.forward_dispatch_api) {
dispatched_probs = torch::empty({0}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
}
if(config.token_data_type == TOKEN_DATA_TYPE::UINT8) {
dispatched_scaling_factor = torch::empty({0, config.hidden_dim / 128}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
}
row_id_map = torch::empty({0, config.num_of_experts_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
tokens_per_expert = torch::full({config.num_of_experts_per_rank}, 0, torch::dtype(torch::kInt32).device(torch::kCUDA));
return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor, row_id_map, tokens_per_expert);
}

Expand All @@ -123,12 +135,7 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d
int num_dispatched_tokens = args.num_dispatched_tokens;
int num_permuted_tokens = args.num_permuted_tokens;
torch::Tensor num_dispatched_tokens_tensor = args.num_dispatched_tokens_tensor.value();
// If args.num_dispatched_tokens >= 0, which means that the sync-free model is used.
// Otherwise, we will use the values in args.num_dispatched_tokens_tensor.
if (num_dispatched_tokens < 0) {
num_dispatched_tokens = num_dispatched_tokens_tensor.item<int32_t>();
}


if (args.row_id_map.has_value()) {
// The row_id_map is valid, which means that the cached model is used.
// Then we will use the values in args directly.
Expand Down Expand Up @@ -179,7 +186,7 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d
size_t sizeof_token_data_type = get_token_data_type_size(dispatch_buffers.data_type);
dispatched_tokens = torch::empty({args.num_dispatched_tokens, config.hidden_dim}, torch::dtype(args.hidden.dtype()).device(torch::kCUDA));
auto res_sz = args.num_dispatched_tokens * config.hidden_dim * sizeof_token_data_type;
CUDA_CHECK(cudaMemcpyAsync(dispatched_tokens.data_ptr(), dispatch_buffers.expert_output_token,res_sz, cudaMemcpyDeviceToDevice, args.stream));
CUDA_CHECK(cudaMemcpyAsync(dispatched_tokens.data_ptr(), dispatch_buffers.expert_output_token, res_sz, cudaMemcpyDeviceToDevice, args.stream));

if(config.forward_dispatch_api) {
dispatched_probs = torch::empty({args.num_dispatched_tokens,
Expand Down
2 changes: 1 addition & 1 deletion csrc/hybrid_ep/executor/executor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ public:

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
metadata_preprocess_core(
hybrid_ep::tmp_state_t *preprocessing_tmp,
HybridEpConfigInstance config,
hybrid_ep::tmp_state_t *preprocessing_tmp,
torch::Tensor global_routing_map,
int num_of_tokens_per_rank
);
Expand Down
Loading