1111import test_low_latency
1212
1313
14- def test_main (num_sms : int , local_rank : int , num_local_ranks : int , num_ranks : int , num_nodes : int , rank : int , buffer : deep_ep .Buffer , group : dist .ProcessGroup ):
14+ def test_main (num_sms : int , local_rank : int , num_local_ranks : int , num_ranks : int , num_nodes : int , rank : int , buffer : deep_ep .Buffer , group : dist .ProcessGroup , args ):
1515 # Settings
16- num_tokens = int ( os . environ . get ( 'EP_TEST_NUM_TOKENS' , '4096' ))
17- hidden = int ( os . environ . get ( 'EP_TEST_HIDDEN' , '7168' ))
18- num_topk_groups = int ( os . environ . get ( 'EP_TEST_NUM_TOPK_GROUPS' , str ( min ( num_nodes , 4 ))))
19- num_topk = int ( os . environ . get ( 'EP_TEST_NUM_TOPK' , '8' ))
20- num_experts = int ( os . environ . get ( 'EP_TEST_NUM_EXPERTS' , str (( 256 // num_ranks ) * num_ranks )))
16+ num_tokens = args . num_tokens
17+ hidden = args . hidden
18+ num_topk_groups = args . num_topk_groups
19+ num_topk = args . num_topk
20+ num_experts = args . num_experts
2121
2222 assert num_experts % num_ranks == 0 and num_local_ranks == 8
2323 if local_rank == 0 :
@@ -224,7 +224,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
224224
225225
226226# noinspection PyUnboundLocalVariable
227- def test_loop (local_rank : int , num_local_ranks : int ):
227+ def test_loop (local_rank : int , num_local_ranks : int , args ):
228228 num_nodes = int (os .getenv ('WORLD_SIZE' , 1 ))
229229 rank , num_ranks , group = init_dist (local_rank , num_local_ranks )
230230 test_ll_compatibility = os .getenv ('EP_TEST_LL_COMPATIBILITY' , False )
@@ -240,7 +240,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
240240 torch .manual_seed (rank )
241241
242242 for i in (num_sms , ):
243- test_main (i , local_rank , num_local_ranks , num_ranks , num_nodes , rank , buffer , group )
243+ test_main (i , local_rank , num_local_ranks , num_ranks , num_nodes , rank , buffer , group , args )
244244 if local_rank == 0 :
245245 print ('' , flush = True )
246246
@@ -255,5 +255,30 @@ def test_loop(local_rank: int, num_local_ranks: int):
255255
256256
257257if __name__ == '__main__' :
258- num_processes = int (os .getenv ('EP_TEST_NUM_PROCESSES' , '8' ))
259- torch .multiprocessing .spawn (test_loop , args = (num_processes , ), nprocs = num_processes )
258+ import argparse
259+ parser = argparse .ArgumentParser (description = 'Test internode expert parallel' )
260+ parser .add_argument ('--num-processes' , type = int , default = 8 ,
261+ help = 'Number of processes to spawn (default: 8)' )
262+ parser .add_argument ('--num-tokens' , type = int , default = 4096 ,
263+ help = 'Number of tokens (default: 4096)' )
264+ parser .add_argument ('--hidden' , type = int , default = 7168 ,
265+ help = 'Hidden dimension size (default: 7168)' )
266+ parser .add_argument ('--num-topk-groups' , type = int , default = None ,
267+ help = 'Number of top-k groups (default: min(num_nodes, 4))' )
268+ parser .add_argument ('--num-topk' , type = int , default = 8 ,
269+ help = 'Number of top-k experts (default: 8)' )
270+ parser .add_argument ('--num-experts' , type = int , default = None ,
271+ help = 'Number of experts (default: calculated as (256 // num_ranks) * num_ranks)' )
272+ args = parser .parse_args ()
273+
274+ # Set default num_topk_groups if not provided
275+ if args .num_topk_groups is None :
276+ num_nodes = int (os .getenv ('WORLD_SIZE' , 1 ))
277+ args .num_topk_groups = min (num_nodes , 4 )
278+
279+ # Set default num_experts if not provided
280+ if args .num_experts is None :
281+ args .num_experts = (256 // args .num_processes ) * args .num_processes
282+
283+ num_processes = args .num_processes
284+ torch .multiprocessing .spawn (test_loop , args = (num_processes , args ), nprocs = num_processes )
0 commit comments