@@ -306,7 +306,7 @@ def large_gemm_with_hook(hook):
306306 # noinspection PyShadowingNames
307307 def test_dispatch_hook (x , config , handle , return_recv_hook ):
308308 _ , _ , _ , _ , _ , _ , hook = \
309- buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook , num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank )
309+ buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook )
310310 large_gemm_with_hook (hook ) if return_recv_hook else None
311311 torch .cuda .synchronize ()
312312
@@ -318,7 +318,7 @@ def test_combine_hook(x, config, handle, return_recv_hook):
318318
319319 def test_dispatch_combine_hook (x , config , handle , return_recv_hook ):
320320 recv_x , _ , _ , _ , _ , _ , hook = \
321- buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook , num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank )
321+ buffer .dispatch (x = x , config = config , handle = handle , async_finish = False , return_recv_hook = return_recv_hook )
322322 large_gemm_with_hook (hook ) if return_recv_hook else None
323323
324324 recv_x = per_token_cast_back (* recv_x ) if isinstance (recv_x , tuple ) else recv_x
@@ -471,7 +471,7 @@ def test_func_native(x, config, handle):
471471 # Tune combine performance
472472 best_time , best_results = 1e10 , None
473473 for nvl_chunk_size in range (1 , 13 , 1 ):
474- for rdma_chunk_size in range (8 , 33 , 4 ):
474+ for rdma_chunk_size in range (12 , 33 , 4 ):
475475 config = deep_ep .Config (num_sms , nvl_chunk_size , nvl_buffer_size , rdma_chunk_size , rdma_buffer_size )
476476 tune_args = {'x' : recv_x , 'handle' : handle_native , 'config' : config }
477477 avg_t = bench (lambda : buffer .combine (** tune_args ))[0 ]
0 commit comments