@@ -241,24 +241,31 @@ void Buffer::sync(const std::vector<int> &device_ids,
241241
242242std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
243243Buffer::get_dispatch_layout (const torch::Tensor& topk_idx, int num_experts,
244- std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream) {
244+ std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream, bool return_recv_hook) {
245+ if (return_recv_hook) {
246+ EP_HOST_ASSERT (not async);
247+ }
248+
245249 EP_HOST_ASSERT (topk_idx.dim () == 2 );
246250 EP_HOST_ASSERT (topk_idx.is_contiguous ());
247251 EP_HOST_ASSERT (num_experts > 0 );
248252
249253 // Allocate all tensors on comm stream if set
250254 // NOTES: do not allocate tensors upfront!
251255 auto compute_stream = at::cuda::getCurrentCUDAStream ();
256+ auto launch_stream = return_recv_hook ? compute_stream : comm_stream;
252257 if (allocate_on_comm_stream) {
253258 EP_HOST_ASSERT (previous_event.has_value () and async);
254259 at::cuda::setCurrentCUDAStream (comm_stream);
255260 }
256261
257262 // Wait previous tasks to be finished
258- if (previous_event.has_value ()) {
259- stream_wait (comm_stream, previous_event.value ());
260- } else {
261- stream_wait (comm_stream, compute_stream);
263+ if (not return_recv_hook) {
264+ if (previous_event.has_value ()) {
265+ stream_wait (launch_stream, previous_event.value ());
266+ } else {
267+ stream_wait (launch_stream, compute_stream);
268+ }
262269 }
263270
264271 auto num_tokens = static_cast <int >(topk_idx.size (0 )), num_topk = static_cast <int >(topk_idx.size (1 ));
@@ -275,14 +282,14 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
275282 num_tokens_per_expert.data_ptr <int >(),
276283 is_token_in_rank.data_ptr <bool >(),
277284 num_tokens, num_topk, num_ranks, num_experts,
278- comm_stream );
285+ launch_stream );
279286
280287 // Wait streams
281288 std::optional<EventHandle> event;
282289 if (async) {
283- event = EventHandle (comm_stream );
290+ event = EventHandle (launch_stream );
284291 for (auto & t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {
285- t.record_stream (comm_stream );
292+ t.record_stream (launch_stream );
286293 if (allocate_on_comm_stream)
287294 t.record_stream (compute_stream);
288295 }
@@ -291,8 +298,8 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts,
291298 if (allocate_on_comm_stream)
292299 to.has_value () ? to->record_stream (compute_stream) : void ();
293300 }
294- } else {
295- stream_wait (compute_stream, comm_stream );
301+ } else if ( not return_recv_hook) {
302+ stream_wait (compute_stream, launch_stream );
296303 }
297304
298305 // Switch back compute stream
0 commit comments