Skip to content

Commit 7705f53

Browse files
committed
Refactor testing arguments
1 parent 6b17f4f commit 7705f53

File tree

3 files changed

+40
-61
lines changed

3 files changed

+40
-61
lines changed

tests/test_internode.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import os
23
import time
34
import torch
@@ -11,13 +12,13 @@
1112
import test_low_latency
1213

1314

14-
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
15+
# noinspection PyShadowingNames
16+
def test_main(args: argparse.Namespace, num_sms: int,
17+
local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int,
18+
buffer: deep_ep.Buffer, group: dist.ProcessGroup):
1519
# Settings
16-
num_tokens = args.num_tokens
17-
hidden = args.hidden
18-
num_topk_groups = args.num_topk_groups
19-
num_topk = args.num_topk
20-
num_experts = args.num_experts
20+
num_tokens, hidden = args.num_tokens, args.hidden
21+
num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts
2122

2223
assert num_experts % num_ranks == 0 and num_local_ranks == 8
2324
if local_rank == 0:
@@ -223,29 +224,28 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
223224
print('', flush=True)
224225

225226

226-
# noinspection PyUnboundLocalVariable
227-
def test_loop(local_rank: int, num_local_ranks: int, args):
227+
# noinspection PyUnboundLocalVariable,PyShadowingNames
228+
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
228229
num_nodes = int(os.getenv('WORLD_SIZE', 1))
229230
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
230-
test_ll_compatibility = os.getenv('EP_TEST_LL_COMPATIBILITY', False)
231-
if test_ll_compatibility:
231+
if args.test_ll_compatibility:
232232
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
233233

234234
num_sms = 24
235-
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if test_ll_compatibility else 0)
235+
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
236236

237-
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
237+
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
238238
num_qps_per_rank=num_qps_per_rank)
239239
assert num_local_ranks == 8 and num_ranks > 8
240240
torch.manual_seed(rank)
241241

242242
for i in (num_sms, ):
243-
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, args)
243+
test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
244244
if local_rank == 0:
245245
print('', flush=True)
246246

247247
# Test compatibility with low latency functions
248-
if test_ll_compatibility:
248+
if args.test_ll_compatibility:
249249
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
250250
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
251251

@@ -255,30 +255,27 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
255255

256256

257257
if __name__ == '__main__':
258-
import argparse
259-
parser = argparse.ArgumentParser(description='Test internode expert parallel')
258+
parser = argparse.ArgumentParser(description='Test internode EP kernels')
260259
parser.add_argument('--num-processes', type=int, default=8,
261260
help='Number of processes to spawn (default: 8)')
262261
parser.add_argument('--num-tokens', type=int, default=4096,
263262
help='Number of tokens (default: 4096)')
264263
parser.add_argument('--hidden', type=int, default=7168,
265264
help='Hidden dimension size (default: 7168)')
266265
parser.add_argument('--num-topk-groups', type=int, default=None,
267-
help='Number of top-k groups (default: min(num_nodes, 4))')
266+
help='Number of top-k groups (default: `min(num_nodes, 4)`)')
268267
parser.add_argument('--num-topk', type=int, default=8,
269268
help='Number of top-k experts (default: 8)')
270-
parser.add_argument('--num-experts', type=int, default=None,
271-
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
269+
parser.add_argument('--num-experts', type=int, default=256,
270+
help='Number of experts (default: 256')
271+
parser.add_argument('--test-ll-compatibility', action='store_true',
272+
help='whether to test compatibility with low-latency kernels')
272273
args = parser.parse_args()
273274

274-
# Set default num_topk_groups if not provided
275+
# Set default `num_topk_groups` if not provided
275276
if args.num_topk_groups is None:
276277
num_nodes = int(os.getenv('WORLD_SIZE', 1))
277278
args.num_topk_groups = min(num_nodes, 4)
278279

279-
# Set default num_experts if not provided
280-
if args.num_experts is None:
281-
args.num_experts = (256 // args.num_processes) * args.num_processes
282-
283280
num_processes = args.num_processes
284281
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)

tests/test_intranode.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
import argparse
22
import time
33
import torch
44
import torch.distributed as dist
@@ -11,12 +11,12 @@
1111
import test_low_latency
1212

1313

14-
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup, args):
14+
# noinspection PyShadowingNames
15+
def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int,
16+
buffer: deep_ep.Buffer, group: dist.ProcessGroup):
1517
# Settings
16-
num_tokens = args.num_tokens
17-
hidden = args.hidden
18-
num_topk = args.num_topk
19-
num_experts = args.num_experts
18+
num_tokens, hidden = args.num_tokens, args.hidden
19+
num_topk, num_experts = args.num_topk, args.num_experts
2020

2121
assert num_experts % num_ranks == 0
2222
if local_rank == 0:
@@ -229,8 +229,8 @@ def check_data(check_x, rank_prefix_matrix):
229229
print('', flush=True)
230230

231231

232-
# noinspection PyUnboundLocalVariable
233-
def test_loop(local_rank: int, num_local_ranks: int, args):
232+
# noinspection PyUnboundLocalVariable,PyShadowingNames
233+
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
234234
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
235235
test_ll_compatibility, num_rdma_bytes = False, 0
236236
if test_ll_compatibility:
@@ -242,7 +242,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
242242
torch.manual_seed(rank)
243243

244244
for i in (24, ):
245-
test_main(i, local_rank, num_ranks, rank, buffer, group, args)
245+
test_main(args, i, local_rank, num_ranks, rank, buffer, group)
246246
if local_rank == 0:
247247
print('', flush=True)
248248

@@ -257,8 +257,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
257257

258258

259259
if __name__ == '__main__':
260-
import argparse
261-
parser = argparse.ArgumentParser(description='Test intranode expert parallel')
260+
parser = argparse.ArgumentParser(description='Test intranode EP kernels')
262261
parser.add_argument('--num-processes', type=int, default=8,
263262
help='Number of processes to spawn (default: 8)')
264263
parser.add_argument('--num-tokens', type=int, default=4096,
@@ -267,13 +266,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
267266
help='Hidden dimension size (default: 7168)')
268267
parser.add_argument('--num-topk', type=int, default=8,
269268
help='Number of top-k experts (default: 8)')
270-
parser.add_argument('--num-experts', type=int, default=None,
271-
help='Number of experts (default: calculated as (256 // num_ranks) * num_ranks)')
269+
parser.add_argument('--num-experts', type=int, default=256,
270+
help='Number of experts (default: 256)')
272271
args = parser.parse_args()
273272

274-
# Set default num_experts if not provided
275-
if args.num_experts is None:
276-
args.num_experts = (256 // args.num_processes) * args.num_processes
277-
278273
num_processes = args.num_processes
279274
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)

tests/test_low_latency.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
import argparse
22
import random
33
import torch
44
import torch.distributed as dist
@@ -16,7 +16,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
1616
assert num_experts % num_ranks == 0
1717
num_local_experts = num_experts // num_ranks
1818

19-
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
19+
# NOTES: the integers greater than 256 exceed the BF16 precision limit
2020
rank_offset = 128
2121
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
2222

@@ -98,16 +98,6 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
9898
assert diff < (7e-4 if round_scale else 1e-5), f'Error: {diff=}, {zero_copy=}'
9999
hash_value ^= hash_tensor(combined_x)
100100

101-
def create_test_cast_with_outliers(num_outliers):
102-
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
103-
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
104-
assert tmp.abs().amax().item() <= 1
105-
106-
# Create some amax outliers
107-
for i in range(num_outliers):
108-
tmp[random.randint(0, num_tokens - 1)] *= 1e3
109-
return tmp
110-
111101
# noinspection PyShadowingNames
112102
def large_gemm_with_hook(hook):
113103
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
@@ -156,13 +146,11 @@ def test_func(zero_copy: bool, return_recv_hook: bool):
156146
return hash_value
157147

158148

159-
# noinspection PyUnboundLocalVariable
160-
def test_loop(local_rank: int, num_local_ranks: int, args):
149+
# noinspection PyUnboundLocalVariable,PyShadowingNames
150+
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
161151
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
162-
num_tokens = args.num_tokens
163-
hidden = args.hidden
164-
num_topk = args.num_topk
165-
num_experts = args.num_experts
152+
num_tokens, hidden = args.num_tokens, args.hidden
153+
num_topk, num_experts = args.num_topk, args.num_experts
166154

167155
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
168156
if local_rank == 0:
@@ -186,8 +174,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args):
186174

187175
if __name__ == '__main__':
188176
# TODO: you may modify NUMA binding for less CPU overhead
189-
import argparse
190-
parser = argparse.ArgumentParser(description='Test low latency expert parallel')
177+
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
191178
parser.add_argument('--num-processes', type=int, default=8,
192179
help='Number of processes to spawn (default: 8)')
193180
parser.add_argument('--num-tokens', type=int, default=128,

0 commit comments

Comments
 (0)