1+ import  argparse 
12import  os 
23import  time 
34import  torch 
1112import  test_low_latency 
1213
1314
14- def  test_main (num_sms : int , local_rank : int , num_local_ranks : int , num_ranks : int , num_nodes : int , rank : int , buffer : deep_ep .Buffer , group : dist .ProcessGroup , args ):
15+ # noinspection PyShadowingNames 
16+ def  test_main (args : argparse .Namespace , num_sms : int ,
17+               local_rank : int , num_local_ranks : int , num_ranks : int , num_nodes : int , rank : int ,
18+               buffer : deep_ep .Buffer , group : dist .ProcessGroup ):
1519    # Settings 
16-     num_tokens  =  args .num_tokens 
17-     hidden  =  args .hidden 
18-     num_topk_groups  =  args .num_topk_groups 
19-     num_topk  =  args .num_topk 
20-     num_experts  =  args .num_experts 
20+     num_tokens , hidden  =  args .num_tokens , args .hidden 
21+     num_topk_groups , num_topk , num_experts  =  args .num_topk_groups , args .num_topk , args .num_experts 
2122
2223    assert  num_experts  %  num_ranks  ==  0  and  num_local_ranks  ==  8 
2324    if  local_rank  ==  0 :
@@ -223,29 +224,28 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
223224        print ('' , flush = True )
224225
225226
226- # noinspection PyUnboundLocalVariable 
227- def  test_loop (local_rank : int , num_local_ranks : int , args ):
227+ # noinspection PyUnboundLocalVariable,PyShadowingNames  
228+ def  test_loop (local_rank : int , num_local_ranks : int , args :  argparse . Namespace ):
228229    num_nodes  =  int (os .getenv ('WORLD_SIZE' , 1 ))
229230    rank , num_ranks , group  =  init_dist (local_rank , num_local_ranks )
230-     test_ll_compatibility  =  os .getenv ('EP_TEST_LL_COMPATIBILITY' , False )
231-     if  test_ll_compatibility :
231+     if  args .test_ll_compatibility :
232232        ll_num_tokens , ll_hidden , ll_num_experts , ll_num_topk  =  16 , 5120 , 256 , 9 
233233
234234    num_sms  =  24 
235-     num_qps_per_rank  =  max (num_sms , ll_num_experts  //  num_ranks  if  test_ll_compatibility  else  0 )
235+     num_qps_per_rank  =  max (num_sms , ll_num_experts  //  num_ranks  if  args . test_ll_compatibility  else  0 )
236236
237-     buffer  =  deep_ep .Buffer (group , int (1e9 ), int (1e9 ), low_latency_mode = test_ll_compatibility ,
237+     buffer  =  deep_ep .Buffer (group , int (1e9 ), int (1e9 ), low_latency_mode = args . test_ll_compatibility ,
238238                            num_qps_per_rank = num_qps_per_rank )
239239    assert  num_local_ranks  ==  8  and  num_ranks  >  8 
240240    torch .manual_seed (rank )
241241
242242    for  i  in  (num_sms , ):
243-         test_main (i , local_rank , num_local_ranks , num_ranks , num_nodes , rank , buffer , group ,  args )
243+         test_main (args ,  i , local_rank , num_local_ranks , num_ranks , num_nodes , rank , buffer , group )
244244        if  local_rank  ==  0 :
245245            print ('' , flush = True )
246246
247247    # Test compatibility with low latency functions 
248-     if  test_ll_compatibility :
248+     if  args . test_ll_compatibility :
249249        buffer .clean_low_latency_buffer (ll_num_tokens , ll_hidden , ll_num_experts )
250250        test_low_latency .test_main (ll_num_tokens , ll_hidden , ll_num_experts , ll_num_topk , rank , num_ranks , group , buffer , seed = 1 )
251251
@@ -255,30 +255,27 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
255255
256256
257257if  __name__  ==  '__main__' :
258-     import  argparse 
259-     parser  =  argparse .ArgumentParser (description = 'Test internode expert parallel' )
258+     parser  =  argparse .ArgumentParser (description = 'Test internode EP kernels' )
260259    parser .add_argument ('--num-processes' , type = int , default = 8 ,
261260                       help = 'Number of processes to spawn (default: 8)' )
262261    parser .add_argument ('--num-tokens' , type = int , default = 4096 ,
263262                       help = 'Number of tokens (default: 4096)' )
264263    parser .add_argument ('--hidden' , type = int , default = 7168 ,
265264                       help = 'Hidden dimension size (default: 7168)' )
266265    parser .add_argument ('--num-topk-groups' , type = int , default = None ,
267-                        help = 'Number of top-k groups (default: min(num_nodes, 4))' )
266+                        help = 'Number of top-k groups (default: ` min(num_nodes, 4)` )' )
268267    parser .add_argument ('--num-topk' , type = int , default = 8 ,
269268                       help = 'Number of top-k experts (default: 8)' )
270-     parser .add_argument ('--num-experts' , type = int , default = None ,
271-                        help = 'Number of experts (default: calculated as (256 // num_ranks) * num_ranks)' )
269+     parser .add_argument ('--num-experts' , type = int , default = 256 ,
270+                        help = 'Number of experts (default: 256' )
271+     parser .add_argument ('--test-ll-compatibility' , action = 'store_true' ,
272+                         help = 'whether to test compatibility with low-latency kernels' )
272273    args  =  parser .parse_args ()
273274
274-     # Set default num_topk_groups if not provided 
275+     # Set default ` num_topk_groups`  if not provided 
275276    if  args .num_topk_groups  is  None :
276277        num_nodes  =  int (os .getenv ('WORLD_SIZE' , 1 ))
277278        args .num_topk_groups  =  min (num_nodes , 4 )
278279
279-     # Set default num_experts if not provided 
280-     if  args .num_experts  is  None :
281-         args .num_experts  =  (256  //  args .num_processes ) *  args .num_processes 
282- 
283280    num_processes  =  args .num_processes 
284281    torch .multiprocessing .spawn (test_loop , args = (num_processes , args ), nprocs = num_processes )
0 commit comments