Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
db53053
more
fzyzcjy Jun 20, 2025
2278722
more
fzyzcjy Jun 20, 2025
661e188
more
fzyzcjy Jun 20, 2025
5567637
more
fzyzcjy Jun 20, 2025
20855ee
more
fzyzcjy Jun 20, 2025
ad11318
more
fzyzcjy Jun 20, 2025
a960335
more
fzyzcjy Jun 21, 2025
9a8d98a
Merge branch 'feat/test_detailed_time' into feat/dev_20250621
fzyzcjy Jun 21, 2025
681bdc5
cherry pick
fzyzcjy Jun 20, 2025
7421672
Merge branch 'feat/num_processes' into feat/dev_20250621
fzyzcjy Jun 21, 2025
56758db
more
fzyzcjy Jun 21, 2025
0a8848a
more
fzyzcjy Jun 21, 2025
cd4af65
more
fzyzcjy Jun 21, 2025
d559938
add profile
fzyzcjy Jun 23, 2025
c30eca6
hack
fzyzcjy Jun 23, 2025
eefd72c
more
fzyzcjy Jun 23, 2025
932a5b8
more
fzyzcjy Jun 23, 2025
626efdc
more
fzyzcjy Jun 23, 2025
d0ba512
more
fzyzcjy Jun 23, 2025
2437fad
more
fzyzcjy Jun 23, 2025
35cd4ee
more
fzyzcjy Jun 23, 2025
5f28f95
more
fzyzcjy Jun 23, 2025
f4aa019
more
fzyzcjy Jun 23, 2025
981ae58
more
fzyzcjy Jun 23, 2025
38d03cf
more
fzyzcjy Jun 23, 2025
4c71e4c
more
fzyzcjy Jun 23, 2025
eb8ddd9
more
fzyzcjy Jun 23, 2025
2b22134
more
fzyzcjy Jun 23, 2025
48162a4
topk_idx i32
fzyzcjy Jun 24, 2025
786760c
temp hack
fzyzcjy Jun 24, 2025
a18e5d0
Revert "temp hack"
fzyzcjy Jun 24, 2025
db75cd0
more
fzyzcjy Jun 24, 2025
67504fd
more
fzyzcjy Jun 24, 2025
d30fe73
temp hack
fzyzcjy Jun 24, 2025
723a001
Revert "temp hack"
fzyzcjy Jun 24, 2025
9351972
more
fzyzcjy Jun 24, 2025
9c8413b
more
fzyzcjy Jun 24, 2025
58150c7
more
fzyzcjy Jun 24, 2025
18c55e7
more
fzyzcjy Jun 24, 2025
11d81ea
more
fzyzcjy Jun 24, 2025
423cae0
more
fzyzcjy Jun 24, 2025
0e3c6dc
more
fzyzcjy Jun 24, 2025
e79e61a
more
fzyzcjy Jun 24, 2025
c258f11
more
fzyzcjy Jun 24, 2025
6ec596a
more
fzyzcjy Jun 24, 2025
f867a3e
more
fzyzcjy Jun 24, 2025
45025b1
more
fzyzcjy Jun 24, 2025
7a2f46d
more
fzyzcjy Jun 24, 2025
611c1c9
more
fzyzcjy Jun 24, 2025
fffc43a
more
fzyzcjy Jun 24, 2025
005dcc9
more
fzyzcjy Jun 24, 2025
6d9a1d0
more
fzyzcjy Jun 24, 2025
399945f
more
fzyzcjy Jun 24, 2025
c375e51
hack
fzyzcjy Jun 24, 2025
8d7c31a
more
fzyzcjy Jun 24, 2025
4ca6c94
morew
fzyzcjy Jun 24, 2025
15917dc
more
fzyzcjy Jun 24, 2025
104cec9
more
fzyzcjy Jun 24, 2025
07a638d
more
fzyzcjy Jun 24, 2025
7f75b4e
more
fzyzcjy Jun 24, 2025
6693f1d
more
fzyzcjy Jun 24, 2025
a51ad7e
more
fzyzcjy Jun 24, 2025
62540c1
temp
fzyzcjy Jun 24, 2025
2185522
more
fzyzcjy Jun 24, 2025
e70ca4b
more
fzyzcjy Jun 24, 2025
9a2cacb
more
fzyzcjy Jun 24, 2025
d2dc94f
fix
fzyzcjy Jun 24, 2025
ce7f0e1
more
fzyzcjy Jun 24, 2025
428c2fb
revert
fzyzcjy Jun 24, 2025
3bd07e3
apply
fzyzcjy Jun 24, 2025
941c1e0
revert
fzyzcjy Jun 24, 2025
a708906
loosen threshold for 768 tokens
fzyzcjy Jun 24, 2025
68241f1
more
fzyzcjy Jul 5, 2025
7f68252
clean
fzyzcjy Jul 5, 2025
9f2e142
more
fzyzcjy Jul 5, 2025
b723ba2
more
fzyzcjy Jul 5, 2025
4cc8f5f
more
fzyzcjy Jul 5, 2025
4d1669b
more
fzyzcjy Jul 5, 2025
687b023
more
fzyzcjy Jul 5, 2025
bf29c08
more
fzyzcjy Jul 5, 2025
5baeead
more
fzyzcjy Jul 5, 2025
87825c5
temp
fzyzcjy Jul 5, 2025
5446a26
more
fzyzcjy Jul 5, 2025
0789e5b
more
fzyzcjy Jul 5, 2025
baa60f6
Revert "temp"
fzyzcjy Jul 5, 2025
a506776
more
fzyzcjy Jul 5, 2025
5d6c983
more
fzyzcjy Jul 5, 2025
e16d5c0
Revert "more"
fzyzcjy Jul 5, 2025
4a13551
Revert "more"
fzyzcjy Jul 5, 2025
2fce45b
temp
fzyzcjy Jul 5, 2025
1755d7f
more
fzyzcjy Jul 5, 2025
a1a61ac
more
fzyzcjy Jul 5, 2025
5c00dfd
temp
fzyzcjy Jul 5, 2025
777bb20
temp
fzyzcjy Jul 5, 2025
f7ea7a5
temp
fzyzcjy Jul 5, 2025
16cbbea
revert temp
fzyzcjy Jul 5, 2025
e26dbe5
more
fzyzcjy Jul 5, 2025
7fe0a4f
more
fzyzcjy Jul 5, 2025
6d49b1c
more
fzyzcjy Jul 5, 2025
a0cdae1
more
fzyzcjy Jul 5, 2025
7250dd5
hack
fzyzcjy Jul 5, 2025
6892f3f
revert temp
fzyzcjy Jul 5, 2025
d559fd7
hack
fzyzcjy Jul 5, 2025
910b899
Revert "hack"
fzyzcjy Jul 5, 2025
37d2d2b
hack
fzyzcjy Jul 5, 2025
f5b4d76
Revert "hack"
fzyzcjy Jul 5, 2025
f4186f7
hack
fzyzcjy Jul 5, 2025
e8d50a3
Revert "hack"
fzyzcjy Jul 5, 2025
3c21dc3
hack
fzyzcjy Jul 5, 2025
6339525
Revert "hack"
fzyzcjy Jul 5, 2025
d755602
hack
fzyzcjy Jul 5, 2025
3291302
Revert "hack"
fzyzcjy Jul 5, 2025
977d341
more
fzyzcjy Jul 6, 2025
e389304
more
fzyzcjy Jul 6, 2025
6c7d6a7
more
fzyzcjy Jul 6, 2025
739b7ff
more
fzyzcjy Jul 6, 2025
2f8b264
hack
fzyzcjy Jul 6, 2025
a50dadc
Revert "hack"
fzyzcjy Jul 6, 2025
ffb0052
more
fzyzcjy Jul 6, 2025
0b894c4
more
fzyzcjy Jul 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,11 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0);
EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous());
EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1));
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);

// HACK
// EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt32);

EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous());
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
Expand Down Expand Up @@ -1214,7 +1218,10 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
internode_ll::combine(combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
x.data_ptr(),
// topk_idx.data_ptr<int64_t>(),
topk_idx.data_ptr<int32_t>(),
topk_weights.data_ptr<float>(),
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
Expand Down
2 changes: 1 addition & 1 deletion csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,

void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const void* x, const int32_t* topk_idx_i32, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
Expand Down
106 changes: 82 additions & 24 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,18 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
#undef DISPATCH_LAUNCH_CASE
}

// TODO generalize
constexpr int kMaxNumTokensPerSm = 6;
constexpr int kIdxOrWeightDim = 2;
constexpr int kNumActualTopkDivFour = 2;
constexpr int kNumActualTopk = kNumActualTopkDivFour * 4;
constexpr int kWarpSize = 32;

template <int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(1024, 1) void
combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const void* x, const int32_t* topk_idx_i32, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int* atomic_clean_flag,
Expand Down Expand Up @@ -489,42 +496,92 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;

int self_num_iteration = (sm_id >= num_combined_tokens) ? 0 : (1 + (num_combined_tokens - sm_id - 1) / num_sms);

alignas(16) __shared__ int4 shared_topk_info[kMaxNumTokensPerSm * kIdxOrWeightDim * kNumActualTopkDivFour];
const auto compute_shared_topk_info_addr = [=](int idx_iteration, int idx_iow, int idx_topkdivfour) {
return shared_topk_info
+ idx_iteration * (kIdxOrWeightDim * kNumActualTopkDivFour)
+ idx_iow * kNumActualTopkDivFour
+ idx_topkdivfour;
};

int4 temp_buf;
int prepare_topk_idx_iteration, prepare_topk_idx_iow, prepare_topk_idx_topkdivfour;
static_assert(sizeof(shared_topk_info) / sizeof(shared_topk_info[0]) <= kWarpSize);
if (warp_id == 0) {
int index = thread_id;
prepare_topk_idx_topkdivfour = index % kNumActualTopkDivFour;
index /= kNumActualTopkDivFour;
prepare_topk_idx_iow = index % kIdxOrWeightDim;
index /= kIdxOrWeightDim;
prepare_topk_idx_iteration = index;
}
bool enable_prepare_topk = (warp_id == 0) and (prepare_topk_idx_iteration < self_num_iteration);
if (enable_prepare_topk) {
const int prepare_topk_token_idx = sm_id + prepare_topk_idx_iteration * num_sms;
const int4* src_addr = (
((prepare_topk_idx_iow == 0)
? reinterpret_cast<const int4*>(topk_idx_i32)
: reinterpret_cast<const int4*>(topk_weights))
+ prepare_topk_token_idx * kNumActualTopkDivFour
+ prepare_topk_idx_topkdivfour
);
temp_buf = ld_nc_global(src_addr);
}

// Wait all ranks to arrive
if (responsible_expert_idx < num_experts) {
EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) {
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
}
const int recv_flag_responsible_expert_idx = thread_id;
if (recv_flag_responsible_expert_idx < num_experts) {
while (ld_acquire_sys_global(rdma_recv_flag + recv_flag_responsible_expert_idx) == 0);
}

if (enable_prepare_topk) {
int4* smem_addr = compute_shared_topk_info_addr(prepare_topk_idx_iteration, prepare_topk_idx_iow, prepare_topk_idx_topkdivfour);
*smem_addr = temp_buf;
}
cg::this_grid().sync();

__syncthreads();

// Reduce tokens with FP8 cast
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
for (int idx_iteration = 0; idx_iteration < self_num_iteration; ++ idx_iteration) {
const int token_idx = sm_id + idx_iteration * num_sms;

// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
alignas(16) int reg_topk_idx[kNumMaxTopk];
alignas(16) float reg_topk_weights[kNumMaxTopk];
auto reg_topk_idx_vec = reinterpret_cast<int4*>(reg_topk_idx);
auto reg_topk_weights_vec = reinterpret_cast<float4*>(reg_topk_weights);
#pragma unroll
for (int i = 0; i < kNumActualTopkDivFour; ++i) {
reg_topk_idx_vec[i] = *compute_shared_topk_info_addr(idx_iteration, 0, i);
}
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
for (int i = 0; i < kNumActualTopkDivFour; ++i) {
reg_topk_weights_vec[i] = *reinterpret_cast<float4*>(compute_shared_topk_info_addr(idx_iteration, 1, i));
}

float combined_values[kNumElemsPerInt4] = {0.0f};
// Read from sources, Reduce
int4 zero4 = {0,0,0,0};
int4 x_vec[kNumActualTopk];
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
for (int i = 0; i < kNumActualTopk; ++i) {
auto rdma_buffer_type = reinterpret_cast<const int*>(static_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
x_vec[i] = (reg_topk_idx[i] >= 0) ? ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id) : zero4;
}

// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
float combined_values[kNumElemsPerInt4] = {0.0f};
#pragma unroll
for (int i = 0; i < kNumActualTopk; ++i) {
if (reg_topk_idx[i] >= 0) {
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec[i]);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j) combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
}
}

// Write results
Expand All @@ -540,7 +597,7 @@ combine(void* combined_x,

void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const void* x, const int32_t* topk_idx_i32, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
int* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
Expand All @@ -559,13 +616,14 @@ void combine(void* combined_x,
auto atomic_clean_flag = static_cast<int*>(workspace);
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
EP_HOST_ASSERT(num_combined_tokens <= kMaxNumTokensPerSm * num_sms);

#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<hidden, kNumMaxTopk>; \
LAUNCH_KERNEL(&cfg, combine_func, \
combined_x, \
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
x, topk_idx_i32, topk_weights, src_info, layout_range, \
next_clean, num_next_clean_int, \
atomic_clean_flag, \
num_combined_tokens, hidden, num_topk, \
Expand Down
8 changes: 8 additions & 0 deletions csrc/kernels/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ __device__ __forceinline__ int4 ld_nc_global(const int4 *ptr) {
return ret;
}

template <>
__device__ __forceinline__ float4 ld_nc_global(const float4 *ptr) {
float4 ret;
asm volatile(LD_NC_FUNC ".v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr));
return ret;
}

__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
asm volatile("st.relaxed.gpu.global.L1::no_allocate.b8 [%0], %1;" : : "l"(ptr), "h"(static_cast<uint16_t>(val)));
}
Expand Down
4 changes: 4 additions & 0 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
event: the event after executing the kernel (valid only if `async_finish` is set).
hook: the receiving hook function (valid only if `return_recv_hook` is set).
"""

# hack
topk_idx = topk_idx.to(torch.int32)

src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts,
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
nvcc_dlink = []
extra_link_args = []

# TODO merge another PR first
nvcc_flags += ['-lineinfo']

# NVSHMEM flags
if disable_nvshmem:
cxx_flags.append('-DDISABLE_NVSHMEM')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,5 @@ def test_loop(local_rank: int, num_local_ranks: int):


if __name__ == '__main__':
num_processes = 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
22 changes: 15 additions & 7 deletions tests/test_intranode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
import torch
import torch.distributed as dist
Expand All @@ -12,7 +13,12 @@

def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
# num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "4096"))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168"))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8"))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", str((256 // num_ranks) * num_ranks)))

assert num_experts % num_ranks == 0
if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
Expand Down Expand Up @@ -184,9 +190,9 @@ def check_data(check_x, rank_prefix_matrix):
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) t={t * 1e3}ms', flush=True)
if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL) t={best_time * 1e3}ms', flush=True)
print('', flush=True)

# Gather the best config from rank 0 and the first test setting
Expand Down Expand Up @@ -215,12 +221,12 @@ def check_data(check_x, rank_prefix_matrix):
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) t={t * 1e3}ms', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size)

if local_rank == 0:
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL) t={best_time * 1e3}ms', flush=True)
print('', flush=True)


Expand All @@ -236,7 +242,9 @@ def test_loop(local_rank: int, num_local_ranks: int):
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank)

for i in (24, ):
num_sms = int(os.environ.get("DEEPEP_TEST_NUM_SMS", "24"))

for i in (num_sms, ):
test_main(i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print('', flush=True)
Expand All @@ -252,5 +260,5 @@ def test_loop(local_rank: int, num_local_ranks: int):


if __name__ == '__main__':
num_processes = 8
num_processes = int(os.getenv("DEEPEP_TEST_NUM_PROCESSES", "8"))
torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes)
Loading