Skip to content

Commit f3df4f1

Browse files
author
huzhiyi.hzy
committed
add arg return_recv_hook for get_dispatch_layout, so the kernel will be on compute stream in hook mode
1 parent 4ce931d commit f3df4f1

File tree

3 files changed

+20
-13
lines changed

3 files changed

+20
-13
lines changed

csrc/deep_ep.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,24 +241,31 @@ void Buffer::sync(const std::vector<int> &device_ids,
241241

242242
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
243243
Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
244-
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
244+
std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream, bool return_recv_hook) {
245+
if (return_recv_hook) {
246+
EP_HOST_ASSERT(not async);
247+
}
248+
245249
EP_HOST_ASSERT(topk_idx.dim() == 2);
246250
EP_HOST_ASSERT(topk_idx.is_contiguous());
247251
EP_HOST_ASSERT(num_experts > 0);
248252

249253
// Allocate all tensors on comm stream if set
250254
// NOTES: do not allocate tensors upfront!
251255
auto compute_stream = at::cuda::getCurrentCUDAStream();
256+
auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
252257
if (allocate_on_comm_stream) {
253258
EP_HOST_ASSERT(previous_event.has_value() and async);
254259
at::cuda::setCurrentCUDAStream(comm_stream);
255260
}
256261

257262
// Wait previous tasks to be finished
258-
if (previous_event.has_value()) {
259-
stream_wait(comm_stream, previous_event.value());
260-
} else {
261-
stream_wait(comm_stream, compute_stream);
263+
if(not return_recv_hook) {
264+
if (previous_event.has_value()) {
265+
stream_wait(launch_stream, previous_event.value());
266+
} else {
267+
stream_wait(launch_stream, compute_stream);
268+
}
262269
}
263270

264271
auto num_tokens = static_cast<int>(topk_idx.size(0)), num_topk = static_cast<int>(topk_idx.size(1));
@@ -275,14 +282,14 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
275282
num_tokens_per_expert.data_ptr<int>(),
276283
is_token_in_rank.data_ptr<bool>(),
277284
num_tokens, num_topk, num_ranks, num_experts,
278-
comm_stream);
285+
launch_stream);
279286

280287
// Wait streams
281288
std::optional<EventHandle> event;
282289
if (async) {
283-
event = EventHandle(comm_stream);
290+
event = EventHandle(launch_stream);
284291
for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {
285-
t.record_stream(comm_stream);
292+
t.record_stream(launch_stream);
286293
if (allocate_on_comm_stream)
287294
t.record_stream(compute_stream);
288295
}
@@ -291,8 +298,8 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
291298
if (allocate_on_comm_stream)
292299
to.has_value() ? to->record_stream(compute_stream) : void();
293300
}
294-
} else {
295-
stream_wait(compute_stream, comm_stream);
301+
} else if (not return_recv_hook) {
302+
stream_wait(compute_stream, launch_stream);
296303
}
297304

298305
// Switch back compute stream

csrc/deep_ep.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct Buffer {
107107

108108
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
109109
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
110-
bool async, bool allocate_on_comm_stream);
110+
bool async, bool allocate_on_comm_stream, bool return_recv_hook);
111111

112112
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
113113
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,

deep_ep/buffer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def get_combine_config(num_ranks: int) -> Config:
291291
# noinspection PyTypeChecker
292292
def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
293293
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
294-
allocate_on_comm_stream: bool = False) -> \
294+
allocate_on_comm_stream: bool = False, return_recv_hook: bool = False) -> \
295295
Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap]:
296296
"""
297297
Calculate the layout required for later communication.
@@ -314,7 +314,7 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
314314
"""
315315
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
316316
self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
317-
async_finish, allocate_on_comm_stream)
317+
async_finish, allocate_on_comm_stream, return_recv_hook)
318318
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
319319

320320
# noinspection PyTypeChecker

0 commit comments

Comments
 (0)