Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
236 commits
Select commit Hold shift + click to select a range
db53053
more
fzyzcjy Jun 20, 2025
2278722
more
fzyzcjy Jun 20, 2025
661e188
more
fzyzcjy Jun 20, 2025
5567637
more
fzyzcjy Jun 20, 2025
20855ee
more
fzyzcjy Jun 20, 2025
ad11318
more
fzyzcjy Jun 20, 2025
a960335
more
fzyzcjy Jun 21, 2025
9a8d98a
Merge branch 'feat/test_detailed_time' into feat/dev_20250621
fzyzcjy Jun 21, 2025
681bdc5
cherry pick
fzyzcjy Jun 20, 2025
7421672
Merge branch 'feat/num_processes' into feat/dev_20250621
fzyzcjy Jun 21, 2025
56758db
more
fzyzcjy Jun 21, 2025
0a8848a
more
fzyzcjy Jun 21, 2025
cd4af65
more
fzyzcjy Jun 21, 2025
c0aa0dc
more
fzyzcjy Jun 21, 2025
edf309e
more
fzyzcjy Jun 21, 2025
6529a71
more
fzyzcjy Jun 21, 2025
85c4056
more
fzyzcjy Jun 21, 2025
4955fe7
more
fzyzcjy Jun 21, 2025
53d4c7c
more
fzyzcjy Jun 21, 2025
556d111
more
fzyzcjy Jun 21, 2025
ed42906
more
fzyzcjy Jun 21, 2025
a7f68e5
more
fzyzcjy Jun 21, 2025
c5c8c1b
more
fzyzcjy Jun 21, 2025
fcbde21
more
fzyzcjy Jun 21, 2025
2070562
more
fzyzcjy Jun 21, 2025
7a21473
more
fzyzcjy Jun 21, 2025
9e5f1aa
more
fzyzcjy Jun 21, 2025
6e3f4d0
more
fzyzcjy Jun 21, 2025
748dd12
more
fzyzcjy Jun 21, 2025
f2caa1f
more
fzyzcjy Jun 21, 2025
18de794
more
fzyzcjy Jun 21, 2025
646e596
more
fzyzcjy Jun 21, 2025
0e9ca87
more
fzyzcjy Jun 21, 2025
784254d
more
fzyzcjy Jun 21, 2025
342225d
more
fzyzcjy Jun 21, 2025
174f61c
more
fzyzcjy Jun 21, 2025
7a55a5d
more
fzyzcjy Jun 21, 2025
02b71ec
more
fzyzcjy Jun 21, 2025
28a5870
more
fzyzcjy Jun 21, 2025
4a76681
more
fzyzcjy Jun 21, 2025
8d0ab1a
more
fzyzcjy Jun 21, 2025
404d014
more
fzyzcjy Jun 21, 2025
87cc21a
more
fzyzcjy Jun 21, 2025
8f48b0e
more
fzyzcjy Jun 21, 2025
d5efc31
more
fzyzcjy Jun 21, 2025
bc42177
more
fzyzcjy Jun 21, 2025
cd2d0db
more
fzyzcjy Jun 21, 2025
16a659e
more
fzyzcjy Jun 21, 2025
b86681c
more
fzyzcjy Jun 21, 2025
7af4e91
more
fzyzcjy Jun 21, 2025
04337b0
more
fzyzcjy Jun 21, 2025
1610620
hook+async
fzyzcjy Jun 21, 2025
3c09a75
more
fzyzcjy Jun 21, 2025
be1e0f9
more
fzyzcjy Jun 21, 2025
6b60049
more
fzyzcjy Jun 21, 2025
0988729
more
fzyzcjy Jun 21, 2025
119f5cb
more
fzyzcjy Jun 21, 2025
3b2d393
more
fzyzcjy Jun 21, 2025
4c486b9
more
fzyzcjy Jun 21, 2025
964c6b4
more
fzyzcjy Jun 21, 2025
72339c6
more
fzyzcjy Jun 21, 2025
0f4eeaa
more
fzyzcjy Jun 21, 2025
1a7daef
more
fzyzcjy Jun 21, 2025
a7b892b
more
fzyzcjy Jun 21, 2025
851ce0a
more
fzyzcjy Jun 21, 2025
b7bab7c
more
fzyzcjy Jun 21, 2025
83b5f99
more
fzyzcjy Jun 21, 2025
705b72f
moerw
fzyzcjy Jun 21, 2025
17ad13d
more
fzyzcjy Jun 21, 2025
0075ed1
more
fzyzcjy Jun 21, 2025
2c35ad1
more
fzyzcjy Jun 21, 2025
08c470b
more
fzyzcjy Jun 21, 2025
6e5432b
more
fzyzcjy Jun 21, 2025
352d5f7
more
fzyzcjy Jun 21, 2025
c16df7a
more
fzyzcjy Jun 21, 2025
f6b1948
more
fzyzcjy Jun 21, 2025
a4c0455
more
fzyzcjy Jun 21, 2025
7d658ca
temp revert mv-finishing-flag and reorder-loops
fzyzcjy Jun 21, 2025
34d3dcb
Revert "temp revert mv-finishing-flag and reorder-loops"
fzyzcjy Jun 21, 2025
fbb2470
revert only reorder-loops
fzyzcjy Jun 21, 2025
18e6a66
Revert "revert only reorder-loops"
fzyzcjy Jun 21, 2025
0829e23
prefetch layout
fzyzcjy Jun 21, 2025
d6d25bf
minor
fzyzcjy Jun 21, 2025
81c554f
Revert "minor"
fzyzcjy Jun 21, 2025
c367a11
Revert "prefetch layout"
fzyzcjy Jun 21, 2025
b0e041c
more
fzyzcjy Jun 21, 2025
e52bf81
more
fzyzcjy Jun 21, 2025
7fbd0f0
more
fzyzcjy Jun 21, 2025
a4dde9c
more
fzyzcjy Jun 21, 2025
de972a1
more
fzyzcjy Jun 21, 2025
1ff63d1
more
fzyzcjy Jun 21, 2025
4ad9d2f
more
fzyzcjy Jun 21, 2025
26cac87
more
fzyzcjy Jun 21, 2025
fd0d6af
again "revert only reorder-loops"
fzyzcjy Jun 21, 2025
a1ba116
Revert "again "revert only reorder-loops""
fzyzcjy Jun 21, 2025
5c77624
more
fzyzcjy Jun 22, 2025
cfbd166
more
fzyzcjy Jun 22, 2025
852005e
more
fzyzcjy Jun 22, 2025
5afb292
more
fzyzcjy Jun 22, 2025
f9ee732
more
fzyzcjy Jun 22, 2025
ef164be
Revert "more"
fzyzcjy Jun 22, 2025
f823c50
more
fzyzcjy Jun 22, 2025
33601b3
more
fzyzcjy Jun 22, 2025
f94cad2
more
fzyzcjy Jun 22, 2025
5b409de
more
fzyzcjy Jun 22, 2025
6f65d7b
more
fzyzcjy Jun 22, 2025
457b11d
more
fzyzcjy Jun 22, 2025
5ce7d03
more
fzyzcjy Jun 22, 2025
c877556
more
fzyzcjy Jun 22, 2025
543eddb
more
fzyzcjy Jun 22, 2025
34de11d
more
fzyzcjy Jun 22, 2025
df38520
more
fzyzcjy Jun 22, 2025
62c693e
more
fzyzcjy Jun 22, 2025
a7f55fb
more
fzyzcjy Jun 22, 2025
29903c4
more
fzyzcjy Jun 22, 2025
cd490c7
more
fzyzcjy Jun 22, 2025
53af784
more
fzyzcjy Jun 22, 2025
7996d02
more
fzyzcjy Jun 22, 2025
596435f
more
fzyzcjy Jun 22, 2025
35f0871
more
fzyzcjy Jun 22, 2025
572aed8
more
fzyzcjy Jun 22, 2025
51c09b8
more
fzyzcjy Jun 22, 2025
74a3261
more
fzyzcjy Jun 22, 2025
5828362
more
fzyzcjy Jun 22, 2025
20fe506
more
fzyzcjy Jun 22, 2025
7d24078
more
fzyzcjy Jun 22, 2025
f04cf80
more
fzyzcjy Jun 22, 2025
9e9f311
more
fzyzcjy Jun 22, 2025
76f39c5
more
fzyzcjy Jun 22, 2025
bdb3c48
more
fzyzcjy Jun 22, 2025
07e1207
more
fzyzcjy Jun 22, 2025
11c07da
more
fzyzcjy Jun 22, 2025
8e576c3
more
fzyzcjy Jun 22, 2025
fcfc987
more
fzyzcjy Jun 22, 2025
3324adc
more
fzyzcjy Jun 22, 2025
e137f2d
more
fzyzcjy Jun 22, 2025
158b90f
more
fzyzcjy Jun 22, 2025
c02a5c9
more
fzyzcjy Jun 22, 2025
888fa70
more
fzyzcjy Jun 22, 2025
0a517f6
more
fzyzcjy Jun 22, 2025
d8843a4
more
fzyzcjy Jun 22, 2025
e3a8258
more
fzyzcjy Jun 22, 2025
38e579c
more
fzyzcjy Jun 22, 2025
8c98b7c
more
fzyzcjy Jun 22, 2025
02fc99b
rm support of hook+async
fzyzcjy Jun 22, 2025
9111dd9
more
fzyzcjy Jun 22, 2025
46e59c9
more
fzyzcjy Jun 22, 2025
18c61c0
more
fzyzcjy Jun 22, 2025
646d617
more
fzyzcjy Jun 22, 2025
1da7c15
more
fzyzcjy Jun 22, 2025
febcd71
more
fzyzcjy Jun 22, 2025
0221e1f
more
fzyzcjy Jun 22, 2025
136e47a
more
fzyzcjy Jun 22, 2025
85c9b4e
more
fzyzcjy Jun 22, 2025
25b1315
more
fzyzcjy Jun 22, 2025
990de10
more
fzyzcjy Jun 22, 2025
707f451
more
fzyzcjy Jun 22, 2025
9547f59
more
fzyzcjy Jun 22, 2025
22cfab2
mor
fzyzcjy Jun 22, 2025
03ec760
more
fzyzcjy Jun 22, 2025
02eb690
more
fzyzcjy Jun 22, 2025
61e4ff4
more
fzyzcjy Jun 22, 2025
b4725ce
more
fzyzcjy Jun 22, 2025
c924323
more
fzyzcjy Jun 22, 2025
b31ced5
more
fzyzcjy Jun 22, 2025
643b0b6
more
fzyzcjy Jun 22, 2025
b6f1bae
more
fzyzcjy Jun 22, 2025
d0c1243
more
fzyzcjy Jun 22, 2025
6583693
more
fzyzcjy Jun 22, 2025
33e5036
more
fzyzcjy Jun 22, 2025
95f24bd
more
fzyzcjy Jun 22, 2025
ed6d363
more
fzyzcjy Jun 22, 2025
575d257
more
fzyzcjy Jun 22, 2025
ce7a7c5
more
fzyzcjy Jun 22, 2025
572d14c
more
fzyzcjy Jun 22, 2025
1170321
more
fzyzcjy Jun 22, 2025
bdb15db
more
fzyzcjy Jun 22, 2025
7c57b8d
more
fzyzcjy Jun 22, 2025
fee3ae8
more
fzyzcjy Jun 22, 2025
1d8aa74
dispatch 12 group
fzyzcjy Jun 25, 2025
5117c59
more
fzyzcjy Jun 25, 2025
a60e6d9
more
fzyzcjy Jun 25, 2025
87b4766
more
fzyzcjy Jun 25, 2025
06ccd55
more
fzyzcjy Jun 25, 2025
f14400d
more
fzyzcjy Jun 25, 2025
d594914
more
fzyzcjy Jun 25, 2025
a794c8f
more
fzyzcjy Jun 25, 2025
0564152
more
fzyzcjy Jun 25, 2025
a93e950
more
fzyzcjy Jun 25, 2025
23137b9
more
fzyzcjy Jun 25, 2025
9c9626d
more
fzyzcjy Jun 25, 2025
1776490
more
fzyzcjy Jun 25, 2025
fba0c5d
more
fzyzcjy Jun 25, 2025
ee87342
more
fzyzcjy Jun 25, 2025
ad0af15
more
fzyzcjy Jun 25, 2025
1c7c999
more
fzyzcjy Jun 25, 2025
922ba2d
more
fzyzcjy Jun 25, 2025
0155732
more
fzyzcjy Jun 25, 2025
9117794
more
fzyzcjy Jun 25, 2025
aa029e0
more
fzyzcjy Jun 25, 2025
c901256
more
fzyzcjy Jun 25, 2025
ce88634
more
fzyzcjy Jun 25, 2025
ce4f866
more
fzyzcjy Jun 25, 2025
d563f16
more
fzyzcjy Jun 25, 2025
7d517f0
more
fzyzcjy Jun 25, 2025
7c6f1fe
more
fzyzcjy Jun 25, 2025
a4f5fd9
more
fzyzcjy Jun 25, 2025
6417566
more
fzyzcjy Jun 25, 2025
1214715
moew
fzyzcjy Jun 25, 2025
b2a9d43
more
fzyzcjy Jun 25, 2025
18b8c80
more
fzyzcjy Jun 25, 2025
4ef1b9b
more
fzyzcjy Jun 25, 2025
dff2a88
more
fzyzcjy Jun 25, 2025
0c45b6c
more
fzyzcjy Jun 25, 2025
b8f9871
more
fzyzcjy Jun 25, 2025
e9bf90d
more
fzyzcjy Jun 26, 2025
a1fea9a
more
fzyzcjy Jun 26, 2025
abe6700
more
fzyzcjy Jun 26, 2025
67f8892
more
fzyzcjy Jun 26, 2025
130fb50
more
fzyzcjy Jun 26, 2025
4db70a8
more
fzyzcjy Jun 26, 2025
c7e6cab
more
fzyzcjy Jun 26, 2025
899f038
more
fzyzcjy Jun 26, 2025
6c3e569
more
fzyzcjy Jun 26, 2025
7a904fe
more
fzyzcjy Jun 26, 2025
6465df0
more
fzyzcjy Jun 26, 2025
3b37454
more
fzyzcjy Jun 26, 2025
ee41d63
more
fzyzcjy Jun 26, 2025
e9d8bb0
more
fzyzcjy Jun 26, 2025
bff1a00
extract
fzyzcjy Jun 26, 2025
72ea8a2
more
fzyzcjy Jun 26, 2025
8464286
more
fzyzcjy Jun 26, 2025
99fd0b4
more
fzyzcjy Jun 26, 2025
69470a7
more
fzyzcjy Jun 26, 2025
a3db15a
more
fzyzcjy Jun 26, 2025
e6784b2
more
fzyzcjy Jun 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
const std::optional<torch::Tensor>& out,
const std::optional<torch::Tensor>& src_signals, uint32_t src_signal_expect_value) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);

Expand Down Expand Up @@ -1220,7 +1221,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, num_device_sms,
launch_stream, phases, zero_copy);
launch_stream, phases, zero_copy,
src_signals.has_value() ? src_signals->data_ptr<uint32_t>() : nullptr, src_signal_expect_value);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));

Expand All @@ -1247,6 +1249,12 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
#endif
}

void Buffer::notify_src_signals(const torch::Tensor& src_signals, int index) {
const uint32_t* addr = src_signals.data_ptr<uint32_t>() + index;
// TODO comm stream or current stream or whatever stream?
CU_CHECK(cuStreamWriteValue32(at::cuda::getCurrentCUDAStream(), (CUdeviceptr) addr, 1, 0));
}

torch::Tensor
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const {
#ifndef DISABLE_NVSHMEM
Expand Down Expand Up @@ -1312,6 +1320,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer)
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch)
.def("low_latency_combine", &deep_ep::Buffer::low_latency_combine)
.def("notify_src_signals", &deep_ep::Buffer::notify_src_signals)
.def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer);

m.def("is_sm90_compiled", deep_ep::is_sm90_compiled);
Expand Down
5 changes: 4 additions & 1 deletion csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ struct Buffer {
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
const std::optional<torch::Tensor>& out = std::nullopt,
const std::optional<torch::Tensor>& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0);

void notify_src_signals(const torch::Tensor& src_signals, int index);

torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
Expand Down
3 changes: 2 additions & 1 deletion csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ void combine(void* combined_x,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy);
cudaStream_t stream, int phases, bool zero_copy,
uint32_t* src_signals, uint32_t src_signal_expect_value);

} // namespace internode_ll

Expand Down
12 changes: 12 additions & 0 deletions csrc/kernels/exception.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ do { \
} while (0)
#endif

#ifndef CU_CHECK
#define CU_CHECK(cmd) \
do { \
CUresult e = (cmd); \
if (e != CUDA_SUCCESS) { \
const char *error_str = NULL; \
cuGetErrorString(e, &error_str); \
throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \
} \
} while (0)
#endif

#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
Expand Down
678 changes: 413 additions & 265 deletions csrc/kernels/internode_ll.cu

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
src_signals: Optional[torch.Tensor] = None, src_signal_expect_value: int = 0) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
Expand Down Expand Up @@ -573,7 +574,7 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts,
zero_copy, async_finish, return_recv_hook, out)
zero_copy, async_finish, return_recv_hook, out, src_signals, src_signal_expect_value)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook

Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
include_dirs = ['csrc/']
library_dirs = []
nvcc_dlink = []
extra_link_args = []
# NOTE MODIFIED
extra_link_args = ['-lcuda']

# NOTE MODIFIED
nvcc_flags += ['-lineinfo']

# NVSHMEM flags
if disable_nvshmem:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,5 @@ def test_loop(local_rank: int, num_local_ranks: int):


if __name__ == '__main__':
num_processes = 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
22 changes: 15 additions & 7 deletions tests/test_intranode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
import torch
import torch.distributed as dist
Expand All @@ -12,7 +13,12 @@

def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
# num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "4096"))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168"))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8"))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", str((256 // num_ranks) * num_ranks)))

assert num_experts % num_ranks == 0
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
Expand Down Expand Up @@ -184,9 +190,9 @@ def check_data(check_x, rank_prefix_matrix):
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) t={t * 1e3}ms', flush=True)
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL) t={best_time * 1e3}ms', flush=True)
print('', flush=True)

# Gather the best config from rank 0 and the first test setting
Expand Down Expand Up @@ -215,12 +221,12 @@ def check_data(check_x, rank_prefix_matrix):
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) t={t * 1e3}ms', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)

if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL) t={best_time * 1e3}ms', flush=True)
print('', flush=True)


Expand All @@ -236,7 +242,9 @@ def test_loop(local_rank: int, num_local_ranks: int):
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank)

for i in (24, ):
num_sms = int(os.environ.get("DEEPEP_TEST_NUM_SMS", "24"))

for i in (num_sms, ):
test_main(i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
Expand All @@ -252,5 +260,5 @@ def test_loop(local_rank: int, num_local_ranks: int):


if __name__ == '__main__':
num_processes = 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
Loading