Skip to content

Commit c0c7f77

Browse files
zhiyi Huzhiyi Hu
authored andcommitted
minor modifications
1 parent 90b8a4b commit c0c7f77

File tree

2 files changed

+33
-34
lines changed

2 files changed

+33
-34
lines changed

csrc/kernels/utils.cuh

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ __forceinline__ __device__ std::pair<WarpRole_Dispatch, int> get_warp_role_dispa
577577
}
578578
}
579579

580-
else if (return_recv_hook) { // hook mode
580+
if (return_recv_hook) { // hook mode
581581
EP_DEVICE_ASSERT(phases != 0);
582582
if ((phases & NORMAL_DECOUPLED_RECV_PHASE) == 0) { // send phase
583583
if (warp_id < kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS) {
@@ -599,22 +599,20 @@ __forceinline__ __device__ std::pair<WarpRole_Dispatch, int> get_warp_role_dispa
599599
}
600600
}
601601

602-
else { // decoupled mode, but no hook
603-
if (not is_forwarder) { // send warps
604-
if (warp_id < kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS) {
605-
return {WarpRole_Dispatch::kRDMASender, -1};
606-
} else {
607-
return {WarpRole_Dispatch::kRDMASenderCoordinator, -1};
608-
}
602+
// decoupled mode, but no hook
603+
if (not is_forwarder) { // send warps
604+
if (warp_id < kNumDispatchRDMASenderWarps + NUM_MAX_NVL_PEERS) {
605+
return {WarpRole_Dispatch::kRDMASender, -1};
606+
} else {
607+
return {WarpRole_Dispatch::kRDMASenderCoordinator, -1};
609608
}
610-
611-
else { // recv warps
612-
if (warp_id < NUM_MAX_NVL_PEERS) {
613-
return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
614-
}
615-
else {
616-
return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};
617-
}
609+
}
610+
else { // recv warps
611+
if (warp_id < NUM_MAX_NVL_PEERS) {
612+
return {WarpRole_Dispatch::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
613+
}
614+
else {
615+
return {WarpRole_Dispatch::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS};
618616
}
619617
}
620618
}
@@ -649,7 +647,8 @@ __forceinline__ __device__ std::pair<WarpRole_Combine, int> get_warp_role_combin
649647
}
650648
}
651649
}
652-
else if (return_recv_hook) { // hook mode
650+
651+
if (return_recv_hook) { // hook mode
653652
EP_DEVICE_ASSERT(phases != 0);
654653
if ((phases & NORMAL_DECOUPLED_RECV_PHASE) == 0) { // send phase
655654
if (warp_id < NUM_MAX_NVL_PEERS) {
@@ -671,22 +670,22 @@ __forceinline__ __device__ std::pair<WarpRole_Combine, int> get_warp_role_combin
671670
return {WarpRole_Combine::kInvalidWarpRole, -1};
672671
}
673672
}
674-
else { // decoupled mode, but no hook
675-
if (is_forwarder_sm) { // send warps
676-
if (warp_id < NUM_MAX_NVL_PEERS) {
677-
auto shuffled_warp_id = warp_id;
678-
shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS;
679-
return {WarpRole_Combine::kNVLSender, shuffled_warp_id};
680-
} else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
681-
auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS;
682-
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
683-
return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id};
684-
} else {
685-
return {WarpRole_Combine::kCoordinator, 0};
686-
}
687-
} else { // recv warps
688-
return {WarpRole_Combine::kRDMAReceiver, warp_id};
673+
674+
// decoupled mode, but no hook
675+
if (is_forwarder_sm) { // send warps
676+
if (warp_id < NUM_MAX_NVL_PEERS) {
677+
auto shuffled_warp_id = warp_id;
678+
shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS;
679+
return {WarpRole_Combine::kNVLSender, shuffled_warp_id};
680+
} else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) {
681+
auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS;
682+
shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders;
683+
return {WarpRole_Combine::kNVLAndRDMAForwarder, shuffled_warp_id};
684+
} else {
685+
return {WarpRole_Combine::kCoordinator, 0};
689686
}
687+
} else { // recv warps
688+
return {WarpRole_Combine::kRDMAReceiver, warp_id};
690689
}
691690
}
692691

deep_ep/buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
357357
handle: the returned communication handle.
358358
event: the event after executing the kernel (valid only if `async_finish` is set).
359359
"""
360-
decoupled_mode = return_recv_hook
360+
decoupled_mode = return_recv_hook # This mode (decoupled_mode=True, return_recv_hook=False) is implemented to support large buffers without hooks, but offers no practical performance benefit and is not exposed to user for use.
361361
# Default config
362362
config = self.get_dispatch_config(self.group_size) if config is None else config
363363

@@ -420,7 +420,7 @@ def combine(self, x: torch.Tensor, handle: Tuple,
420420
recv_topk_weights: the reduced top-k weights from its dispatch ranks.
421421
event: the event after executing the kernel (valid only if `async_finish` is set).
422422
"""
423-
decoupled_mode = return_recv_hook
423+
decoupled_mode = return_recv_hook # This mode (decoupled_mode=True, return_recv_hook=False) is implemented to support large buffers without hooks, but offers no practical performance benefit and is not exposed to user for use.
424424

425425
# Default config
426426
config = self.get_combine_config(self.group_size) if config is None else config

0 commit comments

Comments
 (0)