Skip to content

Commit 6b17f4f

Browse files
authored
Use CLI args instead of envs (#273)
* use cli arg for num_processes Signed-off-by: youkaichao <[email protected]> * update low-latency Signed-off-by: youkaichao <[email protected]> * update intranode Signed-off-by: youkaichao <[email protected]> * update internode Signed-off-by: youkaichao <[email protected]> --------- Signed-off-by: youkaichao <[email protected]>
1 parent 341bb96 commit 6b17f4f

File tree

3 files changed

+83
-26
lines changed

3 files changed

+83
-26
lines changed

tests/test_internode.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
import test_low_latency
1212

1313

14-
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
14+
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
1515
# Settings
16-
num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '4096'))
17-
hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168'))
18-
num_topk_groups = int(os.environ.get('EP_TEST_NUM_TOPK_GROUPS', str(min(num_nodes, 4))))
19-
num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8'))
20-
num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', str((256 // num_ranks) * num_ranks)))
16+
num_tokens = args.num_tokens
17+
hidden = args.hidden
18+
num_topk_groups = args.num_topk_groups
19+
num_topk = args.num_topk
20+
num_experts = args.num_experts
2121

2222
assert num_experts % num_ranks == 0 and num_local_ranks == 8
2323
if local_rank == 0:
@@ -224,7 +224,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
224224

225225

226226
# noinspection PyUnboundLocalVariable
227-
def test_loop(local_rank: int, num_local_ranks: int):
227+
def test_loop(local_rank: int, num_local_ranks: int, args):
228228
num_nodes = int(os.getenv('WORLD_SIZE', 1))
229229
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
230230
test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False)
@@ -240,7 +240,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
240240
torch.manual_seed(rank)
241241

242242
for i in (num_sms, ):
243-
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
243+
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, args)
244244
if local_rank == 0:
245245
print('', flush=True)
246246

@@ -255,5 +255,30 @@ def test_loop(local_rank: int, num_local_ranks: int):
255255

256256

257257
if __name__ == '__main__':
258-
num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8'))
259-
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
258+
import argparse
259+
parser = argparse.ArgumentParser(description='Test internode expert parallel')
260+
parser.add_argument('--num-processes', type=int, default=8,
261+
help='Number of processes to spawn (default: 8)')
262+
parser.add_argument('--num-tokens', type=int, default=4096,
263+
help='Number of tokens (default: 4096)')
264+
parser.add_argument('--hidden', type=int, default=7168,
265+
help='Hidden dimension size (default: 7168)')
266+
parser.add_argument('--num-topk-groups', type=int, default=None,
267+
help='Number of top-k groups (default: min(num_nodes, 4))')
268+
parser.add_argument('--num-topk', type=int, default=8,
269+
help='Number of top-k experts (default: 8)')
270+
parser.add_argument('--num-experts', type=int, default=None,
271+
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
272+
args = parser.parse_args()
273+
274+
# Set default num_topk_groups if not provided
275+
if args.num_topk_groups is None:
276+
num_nodes = int(os.getenv('WORLD_SIZE', 1))
277+
args.num_topk_groups = min(num_nodes, 4)
278+
279+
# Set default num_experts if not provided
280+
if args.num_experts is None:
281+
args.num_experts = (256 // args.num_processes) * args.num_processes
282+
283+
num_processes = args.num_processes
284+
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)

tests/test_intranode.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
import test_low_latency
1212

1313

14-
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
14+
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
1515
# Settings
16-
num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '4096'))
17-
hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168'))
18-
num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8'))
19-
num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', str((256 // num_ranks) * num_ranks)))
16+
num_tokens = args.num_tokens
17+
hidden = args.hidden
18+
num_topk = args.num_topk
19+
num_experts = args.num_experts
2020

2121
assert num_experts % num_ranks == 0
2222
if local_rank == 0:
@@ -230,7 +230,7 @@ def check_data(check_x, rank_prefix_matrix):
230230

231231

232232
# noinspection PyUnboundLocalVariable
233-
def test_loop(local_rank: int, num_local_ranks: int):
233+
def test_loop(local_rank: int, num_local_ranks: int, args):
234234
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
235235
test_ll_compatibility, num_rdma_bytes = False, 0
236236
if test_ll_compatibility:
@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
242242
torch.manual_seed(rank)
243243

244244
for i in (24, ):
245-
test_main(i, local_rank, num_ranks, rank, buffer, group)
245+
test_main(i, local_rank, num_ranks, rank, buffer, group, args)
246246
if local_rank == 0:
247247
print('', flush=True)
248248

@@ -257,5 +257,23 @@ def test_loop(local_rank: int, num_local_ranks: int):
257257

258258

259259
if __name__ == '__main__':
260-
num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8'))
261-
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
260+
import argparse
261+
parser = argparse.ArgumentParser(description='Test intranode expert parallel')
262+
parser.add_argument('--num-processes', type=int, default=8,
263+
help='Number of processes to spawn (default: 8)')
264+
parser.add_argument('--num-tokens', type=int, default=4096,
265+
help='Number of tokens (default: 4096)')
266+
parser.add_argument('--hidden', type=int, default=7168,
267+
help='Hidden dimension size (default: 7168)')
268+
parser.add_argument('--num-topk', type=int, default=8,
269+
help='Number of top-k experts (default: 8)')
270+
parser.add_argument('--num-experts', type=int, default=None,
271+
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
272+
args = parser.parse_args()
273+
274+
# Set default num_experts if not provided
275+
if args.num_experts is None:
276+
args.num_experts = (256 // args.num_processes) * args.num_processes
277+
278+
num_processes = args.num_processes
279+
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)

tests/test_low_latency.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,12 @@ def test_func(zero_copy: bool, return_recv_hook: bool):
157157

158158

159159
# noinspection PyUnboundLocalVariable
160-
def test_loop(local_rank: int, num_local_ranks: int):
160+
def test_loop(local_rank: int, num_local_ranks: int, args):
161161
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
162-
num_tokens = int(os.environ.get('EP_TEST_NUM_TOKENS', '128'))
163-
hidden = int(os.environ.get('EP_TEST_HIDDEN', '7168'))
164-
num_topk = int(os.environ.get('EP_TEST_NUM_TOPK', '8'))
165-
num_experts = int(os.environ.get('EP_TEST_NUM_EXPERTS', '288'))
162+
num_tokens = args.num_tokens
163+
hidden = args.hidden
164+
num_topk = args.num_topk
165+
num_experts = args.num_experts
166166

167167
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
168168
if local_rank == 0:
@@ -186,5 +186,19 @@ def test_loop(local_rank: int, num_local_ranks: int):
186186

187187
if __name__ == '__main__':
188188
# TODO: you may modify NUMA binding for less CPU overhead
189-
num_processes = int(os.getenv('EP_TEST_NUM_PROCESSES', '8'))
190-
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
189+
import argparse
190+
parser = argparse.ArgumentParser(description='Test low latency expert parallel')
191+
parser.add_argument('--num-processes', type=int, default=8,
192+
help='Number of processes to spawn (default: 8)')
193+
parser.add_argument('--num-tokens', type=int, default=128,
194+
help='Number of tokens (default: 128)')
195+
parser.add_argument('--hidden', type=int, default=7168,
196+
help='Hidden dimension size (default: 7168)')
197+
parser.add_argument('--num-topk', type=int, default=8,
198+
help='Number of top-k experts (default: 8)')
199+
parser.add_argument('--num-experts', type=int, default=288,
200+
help='Number of experts (default: 288)')
201+
args = parser.parse_args()
202+
203+
num_processes = args.num_processes
204+
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)

0 commit comments

Comments
 (0)