Skip to content

Commit e02e4d2

Browse files
authored
Support EP24 for internode kernels (#432)
* Support EP24 for internode kernels. * Skip check for round_scale test
1 parent d981409 commit e02e4d2

File tree

4 files changed

+5
-3
lines changed

4 files changed

+5
-3
lines changed

csrc/kernels/launch.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ cfg.dynamicSmemBytes = smem_size;
6262
#define SWITCH_RDMA_RANKS(case_macro) \
6363
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
6464
case 2: case_macro(2); \
65+
case 3: case_macro(3); \
6566
case 4: case_macro(4); \
6667
case 6: case_macro(6); \
6768
case 8: case_macro(8); \

deep_ep/buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def get_dispatch_config(num_ranks: int) -> Config:
238238
4: Config(Buffer.num_sms, 6, 256, 6, 128),
239239
8: Config(Buffer.num_sms, 6, 256, 6, 128),
240240
16: Config(Buffer.num_sms, 36, 288, 20, 128),
241-
24: Config(Buffer.num_sms, 8, 288, 32, 128),
241+
24: Config(Buffer.num_sms, 32, 288, 8, 128),
242242
32: Config(Buffer.num_sms, 32, 288, 8, 128),
243243
48: Config(Buffer.num_sms, 32, 288, 8, 128),
244244
64: Config(Buffer.num_sms, 32, 288, 8, 128),

tests/test_internode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_main(args: argparse.Namespace, num_sms: int,
9292
time.sleep(1)
9393

9494
# Config
95-
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (48, 96, 144, 160) else 512)
95+
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (24, 48, 96, 144, 160) else 512)
9696
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
9797

9898
# Test dispatch

tests/test_low_latency.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
158158
topk_idx[failed_topk_idx] = -1
159159
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
160160
assert torch.isnan(combined_x).sum().item() == 0
161-
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}'
161+
if not round_scale:
162+
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}'
162163
hash_value ^= hash_tensor(combined_x)
163164

164165
# Clean buffer API

0 commit comments

Comments
 (0)