Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support TransformerEngine to enable communication overlap #2627

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ba9077f
profile of M sizes for Torch native and TE
Oct 27, 2024
a7d7995
change it to bf16
Oct 27, 2024
9812cd1
update comments
Nov 3, 2024
c3ab480
install scripts for hyperbolic
Nov 4, 2024
cc61d98
update install.sh
Nov 4, 2024
0f49870
update install
Nov 4, 2024
04e06d4
update install
Nov 4, 2024
b122fb5
Merge branch 'sgl-project:main' into comm_overlap
Zhuohao-Li Nov 18, 2024
3fce4d9
update enable_te support in llamaforcasuallm
Zhuohao-Li Nov 24, 2024
be9b2f8
update enable_te support in args
Zhuohao-Li Nov 24, 2024
4df3035
update enable_te support in args with running model
Zhuohao-Li Nov 24, 2024
64690d6
update enable_te support in Llama
Zhuohao-Li Nov 26, 2024
1f38017
update enable_te support in Llama
Zhuohao-Li Nov 26, 2024
6504538
TE example
Dec 12, 2024
8ef1aa8
TE example with torch
Dec 12, 2024
eff8ba7
integration doc
Zhuohao-Li Dec 27, 2024
9b0763e
benchmark result with Llama/TP
Zhuohao-Li Dec 27, 2024
dc85834
variable max_new_token in benchmark result with Llama/TP
Zhuohao-Li Dec 27, 2024
a7b1465
minor changee on TEllama
Zhuohao-Li Dec 27, 2024
43294ad
minor change on benchmark script
Zhuohao-Li Dec 27, 2024
b805a81
benchmark result from sglang.server_latency
Zhuohao-Li Dec 27, 2024
f7a2d62
merge updates from upstream
Zhuohao-Li Dec 28, 2024
7f584ae
benchmark update using sglang.one_batch
Zhuohao-Li Dec 28, 2024
96f6731
update scripts to benchmark using sglang.one_batch
Zhuohao-Li Dec 28, 2024
05786ad
update readme
Zhuohao-Li Dec 28, 2024
f1738dd
minor change and reorg of tellama
Zhuohao-Li Dec 28, 2024
1a23360
fix readme error
Zhuohao-Li Dec 28, 2024
91868af
upload example results
Zhuohao-Li Dec 28, 2024
43de8bd
minor change on visualization
Zhuohao-Li Dec 28, 2024
243c708
update readme (fix launching instruction)
Zhuohao-Li Dec 28, 2024
09b12e3
update readme (fix figure bug)
Zhuohao-Li Dec 28, 2024
d7edaf4
rm figrues
Zhuohao-Li Dec 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions comm_overlap/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
### TE integration in sglang (v0.1.0)

#### How to run the code

It is recommended to use the pre-built docker image. The docker image is built with TE (v1.14.0.dev0+994f19d) and torch (2.5.1+cu124). You can use the docker image via `docker pull zhuohaol/sglang-te:latest`. Please do not modify the torch version in the docker image.

To launch the sglang server with TE enabled, run:

```bash
docker run -it --shm-size 32g --gpus all -p 30001:30001 --ipc=host --rm zhuohaol/sglang-te:latest

git clone https://github.com/Zhuohao-Li/sglang/tree/zhuohaol-comm-overlap

cd sglang/python

python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --enable-te
```

Methods to request the server remains the same.

Key changes:

- In `sglang.srt.models`, we add `llama_te_sgl.py` to support llama models with TE. (`llama_te.py` is a another support version, but it is not recommended to use it now.)
- In `sglang.model_executor.model_runner.py` and `server_args.py`, we add corresponding arguments to enable TE when launching the engine.
- In `comm_overlap.profile`, we provide some scripts to profile performance with torch/te
- In `comm_overlap.benchmark`, we provide the benchmark instructions for the latency of the TE model.
- In `comm_overlap.TE`, it is an example of using TE features (comm_overlap and fp8) to do inference.

#### Benchmark

To run the benchmark, it is recommended to use the script `sglang.benchmark_one_batch`:

native pytorch:
```bash
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 16 64 128 --input-len 256 512 --output-len 32 256 --run-name test_run --tp 4
```

TE:
```bash
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 16 64 128 --input-len 256 512 --output-len 32 256 --run-name test_run --tp 4 --enable-te
```

`llama_latency_figure_visual.py` is a script to visualize the latency of the prefill/decoding latency of both torch/TE model. Please replace the data with the `result.jsonl` generated by the benchmark script. Other `.jsonl` files are generated by older benchmark scripts with `sglang.benchmark_server_latency`.

#### TODO

- [ ] Evaluate TE performance increase with different configs and show the performance gain.
- [ ] Add more models support
- [ ] Add fp8 support

#### Reference

- We use NVIDIA [TE](https://github.com/NVIDIA/TransformerEngine/tree/main)(v1.14.0.dev0+994f19d) APIs to build the model.
69 changes: 69 additions & 0 deletions comm_overlap/TE/example_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import statistics
import time

import torch
import torch.nn as nn
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
num_iterations = 100

# Initialize TE model and inputs.
te_model = te.Linear(in_features, out_features, bias=True).cuda()
# Initialize Torch model and inputs
torch_model = nn.Linear(in_features, out_features, bias=True).cuda()

with torch.no_grad():
torch_model.weight.copy_(te_model.weight)
torch_model.bias.copy_(te_model.bias)

inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

# warmup GPUs
print("Warm up...")
for _ in range(10):
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
_ = te_model(inp)
_ = torch_model(inp)

# test torch performance
torch_times = []
torch.cuda.synchronize()
print("testing torch performance")
for _ in range(num_iterations):
start = time.perf_counter()
out = torch_model(inp)
torch.cuda.synchronize()
torch_times.append(time.perf_counter() - start)

# test te performace
te_times = []
torch.cuda.synchronize()
print("tesing te performance")
for _ in range(num_iterations):
start = time.perf_counter()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = te_model(inp)
torch.cuda.synchronize()
te_times.append(time.perf_counter() - start)

te_mean = statistics.mean(te_times) * 1000 # convert to ms
te_std = statistics.stdev(te_times) * 1000
torch_mean = statistics.mean(torch_times) * 1000
torch_std = statistics.stdev(torch_times) * 1000


print("\nResults (in milliseconds):")
print(f"TransformerEngine: {te_mean:.3f} ± {te_std:.3f} ms")
print(f"PyTorch: {torch_mean:.3f} ± {torch_std:.3f} ms")
print(f"Speedup: {torch_mean/te_mean:.2f}x")

# loss = out.sum()
# loss.backward()
83 changes: 83 additions & 0 deletions comm_overlap/benchmark/llama_latency_figure_visual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import matplotlib.pyplot as plt
import numpy as np

batch_sizes = [1, 16, 64, 128]
configs = [(256, 32), (256, 256), (512, 32), (512, 256)]

torch_prefill = [
[0.02312, 0.04599, 0.46298, 0.29263], # 256/32
[0.02305, 0.04196, 0.14980, 0.29213], # 256/256
[0.02562, 0.07940, 0.28881, 0.59434], # 512/32
[0.02357, 0.07765, 0.28795, 0.57709], # 512/256
]

torch_decode = [
[0.00382, 0.00428, 0.00527, 0.00589], # 256/32
[0.00381, 0.00427, 0.00529, 0.00595], # 256/256
[0.00382, 0.00432, 0.00546, 0.00617], # 512/32
[0.00381, 0.00432, 0.00541, 0.00631], # 512/256
]

te_prefill = [
[0.02251, 0.04727, 0.14864, 0.28941], # 256/32
[0.02318, 0.04203, 0.14796, 0.29031], # 256/256
[0.19420, 0.27644, 0.28941, 0.94313], # 512/32
[0.02361, 0.07754, 0.28755, 0.57659], # 512/256
]

te_decode = [
[0.00381, 0.00427, 0.00528, 0.00589], # 256/32
[0.00380, 0.00427, 0.00529, 0.00593], # 256/256
[0.00382, 0.00432, 0.00546, 0.00618], # 512/32
[0.00381, 0.00433, 0.00542, 0.00629], # 512/256
]

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 15))

x = np.arange(len(batch_sizes))
width = 0.1

for i, (input_len, output_len) in enumerate(configs):
ax1.bar(
x + i * width * 2,
torch_prefill[i],
width,
label=f"Torch-{input_len}/{output_len}",
alpha=0.8,
)
ax1.bar(
x + i * width * 2 + width,
te_prefill[i],
width,
label=f"TE-{input_len}/{output_len}",
alpha=0.8,
)

for i, (input_len, output_len) in enumerate(configs):
ax2.bar(
x + i * width * 2,
torch_decode[i],
width,
label=f"Torch-{input_len}/{output_len}",
alpha=0.8,
)
ax2.bar(
x + i * width * 2 + width,
te_decode[i],
width,
label=f"TE-{input_len}/{output_len}",
alpha=0.8,
)

for ax, title in [(ax1, "Prefill Latency"), (ax2, "Median Decode Latency")]:
ax.set_xlabel("Batch Size")
ax.set_ylabel("Latency (s)")
ax.set_title(f"Torch vs TE {title} Comparison")
ax.set_xticks(x + width * 3.5)
ax.set_xticklabels(batch_sizes)
ax.grid(True, linestyle="--", alpha=0.7)
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

plt.tight_layout()

plt.savefig("latency_comparison_pd.png", bbox_inches="tight")
94 changes: 94 additions & 0 deletions comm_overlap/benchmark/llama_te.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import atexit
import json
import os
import signal
import statistics
import subprocess
import time

import requests


def start_server():
print("start server...")
server_cmd = (
"python -m sglang.launch_server "
"--model-path meta-llama/Meta-Llama-3.1-8B-Instruct "
"--port 30001 --host 0.0.0.0 "
"--enable-te --tp 4"
)

original_dir = os.getcwd()
server_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "python"
)
os.chdir(server_dir)

server_process = subprocess.Popen(server_cmd.split())

os.chdir(original_dir)

atexit.register(lambda: os.kill(server_process.pid, signal.SIGTERM))

print("waiting for server to start...")
time.sleep(30)
return server_process


def test_latency(num_requests=50):
url = "http://localhost:30001/generate"
headers = {"Content-Type": "application/json"}

max_new_tokens = 100

payload = {
"text": "The president of USA,",
"sampling_params": {"max_new_tokens": max_new_tokens, "temperature": 0},
}

latencies = []

print(f"start to test latency for {num_requests} requests...")

for i in range(num_requests):
start_time = time.time()
try:
response = requests.post(url, headers=headers, json=payload)
if response.status_code == 200:
end_time = time.time()
latency = (end_time - start_time) * 1000 # 转换为毫秒
latencies.append(latency)
print(f"request {i+1}: {latency:.2f}ms")
print(f"response: {response.json()}")
else:
print(f"request {i+1} failed: HTTP {response.status_code}")
print(f"error: {response.text}")
except Exception as e:
print(f"request {i+1} failed: {str(e)}")

time.sleep(0.5)

if latencies:
results = {
"min latency": f"{min(latencies):.2f}ms",
"max latency": f"{max(latencies):.2f}ms",
"mean latency": f"{statistics.mean(latencies):.2f}ms",
"median latency": f"{statistics.median(latencies):.2f}ms",
}

with open("latency_results_te_{max_new_tokens}.json", "w") as f:
json.dump(results, f, ensure_ascii=False, indent=2)

print("\ntest results:")
for key, value in results.items():
print(f"{key}: {value}")


if __name__ == "__main__":
server_process = start_server()
try:
test_latency()
finally:
print("closing server...")
server_process.terminate()
server_process.wait()
92 changes: 92 additions & 0 deletions comm_overlap/benchmark/llama_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import atexit
import json
import os
import signal
import statistics
import subprocess
import time

import requests


def start_server():
print("starting server...")
server_cmd = (
"python -m sglang.launch_server "
"--model-path meta-llama/Meta-Llama-3.1-8B-Instruct "
"--port 30001 --host 0.0.0.0 "
"--tp 4"
)

original_dir = os.getcwd()
server_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "python"
)
os.chdir(server_dir)

server_process = subprocess.Popen(server_cmd.split())

os.chdir(original_dir)

atexit.register(lambda: os.kill(server_process.pid, signal.SIGTERM))

print("waiting for server to start...")
time.sleep(30)
return server_process


def test_latency(num_requests=50):
url = "http://localhost:30001/generate"
headers = {"Content-Type": "application/json"}

payload = {
"text": "The president of USA,",
"sampling_params": {"max_new_tokens": 100, "temperature": 0},
}

latencies = []

print(f"start to test latency for {num_requests} requests...")

for i in range(num_requests):
start_time = time.time()
try:
response = requests.post(url, headers=headers, json=payload)
if response.status_code == 200:
end_time = time.time()
latency = (end_time - start_time) * 1000 # 转换为毫秒
latencies.append(latency)
print(f"request {i+1}: {latency:.2f}ms")
print(f"response: {response.json()}")
else:
print(f"request {i+1} failed: HTTP {response.status_code}")
print(f"error: {response.text}")
except Exception as e:
print(f"request {i+1} failed: {str(e)}")

time.sleep(0.5)

if latencies:
results = {
"min latency": f"{min(latencies):.2f}ms",
"max latency": f"{max(latencies):.2f}ms",
"mean latency": f"{statistics.mean(latencies):.2f}ms",
"median latency": f"{statistics.median(latencies):.2f}ms",
}

with open("latency_results_torch_{max_new_tokens}.json", "w") as f:
json.dump(results, f, ensure_ascii=False, indent=2)

print("\ntest results:")
for key, value in results.items():
print(f"{key}: {value}")


if __name__ == "__main__":
server_process = start_server()
try:
test_latency()
finally:
print("closing server...")
server_process.terminate()
server_process.wait()
Loading
Loading