Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 214 additions & 15 deletions tests/test_low_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import numpy as np
from functools import partial
from typing import Optional

import pandas as pd
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back, get_global_token_indices


def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer,
imbalance_factor: float = 1.0, distribution: str = 'lognormal', print_res: bool = True,
use_logfmt: bool = False, seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
Expand All @@ -34,9 +35,21 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)

scatter_list = None
if rank == 0:
global_topk_idx = get_global_token_indices(
distribution, num_experts, num_tokens, num_ranks, num_topk, imbalance_factor, seed
)
scatter_list = [
chunk.contiguous() for chunk in torch.chunk(global_topk_idx, num_ranks, dim=0)
]
topk_idx = torch.empty(num_tokens, num_topk, dtype=torch.long, device='cuda')
dist.scatter(tensor=topk_idx, scatter_list=scatter_list, src=0, group=group)

scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
results = {}
results['topk_idx'] = topk_idx.cpu()

topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()

# Randomly mask some positions
Expand Down Expand Up @@ -142,11 +155,17 @@ def test_func(return_recv_hook: bool):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections
results['dispatch_comm_bytes'] = num_dispatch_comm_bytes
results['combine_comm_bytes'] = num_combine_comm_bytes

# Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
if print_res:
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
results['total_time_avg'] = avg_t * 1e6
results['total_time_min'] = min_t * 1e6
results['total_time_max'] = max_t * 1e6

# Separate profiling
for return_recv_hook in (False, True):
Expand All @@ -155,14 +174,161 @@ def test_func(return_recv_hook: bool):
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
if print_res:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True)
results['dispatch_time'] = dispatch_t * 1e6
results['combine_time'] = combine_t * 1e6
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
return hash_value
if print_res:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True)
results['dispatch_send_time'] = dispatch_t[0] * 1e6
results['dispatch_recv_time'] = dispatch_t[1] * 1e6
results['combine_send_time'] = combine_t[0] * 1e6
results['combine_recv_time'] = combine_t[1] * 1e6
return results, hash_value

def process_and_display_results(all_results, num_experts, num_ranks, imbalance_factor):
all_topk_idx_list = [result['topk_idx'] for result in all_results]
global_topk_idx = torch.cat(all_topk_idx_list, dim=0)

num_local_experts = num_experts // num_ranks
rank_counts = torch.zeros(num_ranks, dtype=torch.int64)

valid_indices = global_topk_idx[global_topk_idx >= 0]

for rank in range(num_ranks):
start_expert = rank * num_local_experts
end_expert = (rank + 1) * num_local_experts

mask = (valid_indices >= start_expert) & (valid_indices < end_expert)
rank_counts[rank] = mask.sum().item()

max_count = rank_counts.max().item()
min_count = rank_counts.min().item()
avg_count = rank_counts.float().mean().item()
median_count = rank_counts.float().median().item()
std_count = rank_counts.float().std().item()
actual_max_avg = max_count / avg_count if avg_count > 0 else 0

avg_dispatch_bytes = sum(result['dispatch_comm_bytes'] for result in all_results) / len(all_results)
avg_combine_bytes = sum(result['combine_comm_bytes'] for result in all_results) / len(all_results)

avg_total_time = sum(result['total_time_avg'] for result in all_results) / len(all_results)
avg_dispatch_time = sum(result['dispatch_time'] for result in all_results) / len(all_results)
avg_combine_time = sum(result['combine_time'] for result in all_results) / len(all_results)

total_bw = (avg_dispatch_bytes + avg_combine_bytes) / 1e9 / (avg_total_time / 1e6)
dispatch_bw = avg_dispatch_bytes / 1e9 / (avg_dispatch_time / 1e6)
combine_bw = avg_combine_bytes / 1e9 / (avg_combine_time / 1e6)

for result in all_results:
if 'topk_idx' in result:
del result['topk_idx']
for key in ['dispatch_comm_bytes', 'combine_comm_bytes']:
if key in result:
del result[key]

df = pd.DataFrame(all_results)
mean_series = df.mean()

mean_series['total_bw'] = total_bw
mean_series['dispatch_bw'] = dispatch_bw
mean_series['combine_bw'] = combine_bw
mean_series['imbalance_factor'] = imbalance_factor
mean_series['max_count'] = float(max_count)
mean_series['min_count'] = float(min_count)
mean_series['avg_count'] = float(avg_count)
mean_series['median_count'] = float(median_count)
mean_series['std_count'] = float(std_count)
mean_series['actual_max_avg'] = float(actual_max_avg)

return mean_series

def print_summary_tables(final_df):
print("\n" + "="*120)
print(" PERFORMANCE SUMMARY (Statistics across all ranks)")
print("="*120)

# Table 1: Token Distribution Statistics
print("\n--- Token Distribution per Rank ---")
imbalance_df = final_df[['max_count', 'min_count', 'avg_count', 'median_count', 'std_count']].copy()
imbalance_df.columns = ['Max', 'Min', 'Avg', 'Median', 'Std Dev']

formatters = {
'Max': lambda x: f"{x:4.0f}",
'Min': lambda x: f"{x:3.0f}",
'Avg': lambda x: f"{x:5.1f}",
'Median': lambda x: f"{x:6.1f}",
'Std Dev': lambda x: f"{x:6.1f}"
}

for col, formatter in formatters.items():
imbalance_df[col] = imbalance_df[col].apply(formatter)

print(imbalance_df.to_string())

# Table 2: Total Performance
print("\n--- Total Performance (Dispatch + Combine) ---")
total_perf_df = final_df[['total_bw', 'total_time_avg', 'total_time_min', 'total_time_max']].copy()
total_perf_df.columns = ['Total BW', 'Avg Time', 'Min Time', 'Max Time']

formatters = {
'Total BW': lambda x: f"{x:.2f} GB/s",
'Avg Time': lambda x: f"{x:.2f} us",
'Min Time': lambda x: f"{x:.2f} us",
'Max Time': lambda x: f"{x:.2f} us"
}

for col, formatter in formatters.items():
total_perf_df[col] = total_perf_df[col].apply(formatter)

print(total_perf_df.to_string())

# Table 3: Separate Performance
print("\n--- Separate Dispatch & Combine Performance ---")
separate_perf_df = final_df[['dispatch_bw', 'dispatch_time', 'combine_bw', 'combine_time']].copy()
separate_perf_df.columns = ['Dispatch BW', 'Dispatch Time', 'Combine BW', 'Combine Time']

formatters = {
'Dispatch BW': lambda x: f"{x:.2f} GB/s",
'Dispatch Time': lambda x: f"{x:.2f} us",
'Combine BW': lambda x: f"{x:.2f} GB/s",
'Combine Time': lambda x: f"{x:.2f} us"
}

for col, formatter in formatters.items():
separate_perf_df[col] = separate_perf_df[col].apply(formatter)

print(separate_perf_df.to_string())

# Table 4: Hook Performance
print("\n--- Send/Recv Timings (Hook=True) ---")

hook_data = []
for idx in final_df.index:
row = final_df.loc[idx]
hook_data.append([
f"{row['dispatch_send_time']:>6.2f} us",
f"{row['dispatch_recv_time']:>6.2f} us",
f"{row['combine_send_time']:>6.2f} us",
f"{row['combine_recv_time']:>6.2f} us"
])

columns = pd.MultiIndex.from_tuples([
(' ', 'Send'),
('Dispatch', 'Recv'),
(' ', 'Send'),
('Combine', 'Recv')
])

hook_df = pd.DataFrame(hook_data, index=final_df.index, columns=columns)

print(hook_df.to_string())

print("\n" + "="*120)

# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
Expand All @@ -178,16 +344,44 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1)

dist.barrier()
if rank == 0:
all_imbalance_summaries = []
for imbalance_factor in args.imbalance_factors:
if rank == 0:
print(f"\n--> Running test for target imbalance factor: {imbalance_factor}", flush=True)
results_dict, _ = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
imbalance_factor=imbalance_factor, distribution = args.distribution,
print_res=False, seed=1)
if rank == 0:
all_results = [None] * num_ranks
dist.gather_object(results_dict, all_results, dst=0, group=group)

mean_series = process_and_display_results(all_results, num_experts, num_ranks, imbalance_factor)
all_imbalance_summaries.append(mean_series)
else:
dist.gather_object(results_dict, None, dst=0, group=group)

if rank == 0 and all_imbalance_summaries:
df = pd.DataFrame(all_imbalance_summaries)

df['display_index'] = df.apply(
lambda row: f"{row['imbalance_factor']:.1f} (Actual: {row['actual_max_avg']:.2f})",
axis=1
)
df.set_index('display_index', inplace=True)
df.index.name = "Max/Avg Ratio"
print_summary_tables(df)
do_pressure_test = args.pressure_test
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
_, ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed)
for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
_, current_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=seed)
assert current_hash == ref_hash, f'Error: seed={seed}'

# Destroy the buffer runtime and communication group
buffer.destroy()
Expand Down Expand Up @@ -217,6 +411,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
help='Whether to test LogFMT combine')
parser.add_argument("--pressure-test", action='store_true',
help='Whether to do pressure test')
parser.add_argument('--imbalance-factors', type=float, nargs='+', default=[1.0, 2.0, 3.0],
help='A list of target max/avg ratios for per-rank expert load (tokens per expert). '
'Higher values create more load imbalance (e.g., 1.0, 2.0, 3.0).'
'Note: actual ratios may be lower than targets due to token count constraints.')
parser.add_argument('--distribution', type=str, default='lognormal', choices=['lognormal','powerlaw', 'gamma'])
args = parser.parse_args()

num_processes = args.num_processes
Expand Down
Loading