Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cache Offload] Improve radix cache offload benchmark #2534

Draft
wants to merge 15 commits into
base: xiezhq-hierarchical
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions benchmark/hicache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Run benchmark

### Benchmark SGLang with Radix Cache Offload
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 30000 --tensor-parallel-size 4 --enable-hierarchical-cache

```
python3 bench_sglang.py --num_groups 100 --group_size 100 --context_length 1000 --cache_rate 0.8
```

206 changes: 206 additions & 0 deletions benchmark/hicache/bench_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import argparse
import json
import random
import time

import requests
from lorem_text import lorem
from tqdm import tqdm

import sglang as sgl
from sglang import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)

# TODO: To avoid too much unintended tokenizer spliting,
# sample more reasonably text from datasets like alpaca or sample from nltk corpus?
try:
import nltk

# use brown corpus for generating text
try:
nltk.data.find("corpora/brown")
except LookupError:
nltk.download("brown")
sentences = nltk.corpus.brown.sents()
except ImportError:
sentences = None


def generate_text_nltk(num_tokens):
"""Generates a text with approximately num_tokens."""
num_words = int(num_tokens / 1.93)
text = []
while len(text) < num_words:
sentence = random.choice(sentences)
text.extend(sentence)
return " ".join(text[: int(num_words)])


def generate_text_lorem(num_tokens):
"""Generates a text with approximately num_tokens."""
num_words = int(num_tokens / 1.93) # Assuming average word length
return lorem.words(num_words)


def generate_text(num_tokens, method="nltk"):
if method == "nltk":
if sentences is None:
raise ImportError("Please install nltk to sample from its corpus.")
return generate_text_nltk(num_tokens)
elif method == "lorem":
return generate_text_lorem(num_tokens)
else:
raise ValueError(f"Invalid method: {method}")


def generate_prompts(
num_groups=100,
group_size=100,
context_length=1000,
cache_rate=0.8,
order="random",
max_tokens=1,
method="nltk",
):
"""
Generate prompts for the benchmark.

Args:
num_groups (int): Number of groups, each with shared context.
group_size (int): Number of requests in each group.
context_length (int): Length of the context.
cache_rate (float): Proportion of context cached across prompts within a group.
order (str): Order of prompts, one of 'random', 'sequential', or 'interleaved'.
max_tokens (int): Maximum tokens to generate.
method (str): Method to generate text, one of 'nltk' or 'lorem'.

Returns:
list: List of generated prompts.
"""
assert order in ["random", "sequential", "interleaved"], "Invalid prompt order"
prompts = []

for _ in tqdm(range(num_groups), desc="Generating prompts"):
shared_context = generate_text(context_length * cache_rate, method)
for _ in range(group_size):
prompt = shared_context + generate_text(
context_length * (1 - cache_rate), method
)
prompts.append({"prompt": prompt, "max_tokens": max_tokens})

if order == "random":
return random.sample(prompts, len(prompts))
elif order == "sequential":
return prompts
else: # interleaved
interleaved_prompts = [prompts[i::group_size] for i in range(group_size)]
return [item for sublist in interleaved_prompts for item in sublist]


@sgl.function
def test_sgl(s, prompt, max_tokens):
"""SGLang function for generating text based on a prompt."""
s += prompt
s += sgl.gen(max_tokens=max_tokens, ignore_eos=True)


def main(args):
# Initialize SGLang runtime
set_default_backend(select_sglang_backend(args))
result_jsonl = []

# Log current parameters
print(
f"Running with num_threads: {args.num_threads}, output_length: {args.output_length}"
)
print(f"Cache rate: {args.cache_rate}, Context length: {args.context_length}")
print(
f"Group size: {args.group_size}, Num groups: {args.num_groups}, Order: {args.order}"
)

# Generate prompts based on input arguments
prompts = generate_prompts(
num_groups=args.num_groups,
group_size=args.group_size,
context_length=args.context_length,
cache_rate=args.cache_rate,
order=args.order,
max_tokens=args.output_length,
method=args.prompt_pool,
)
print(f"Sample prompt: {prompts[0]['prompt'][:80]}...")

url = f"http://localhost:{args.port}/flush_cache"
requests.post(url)
# sgl.flush_cache()
time.sleep(1) # Wait for the cache to be flushed

# Measure the time taken for batch execution
tic = time.time()
test_sgl.run_batch(prompts, num_threads=args.num_threads, progress_bar=True)
toc = time.time()

# Record results
duration = toc - tic
result_jsonl.append(
{
"cache_rate": args.cache_rate,
"context_length": args.context_length,
"group_size": args.group_size,
"num_groups": args.num_groups,
"order": args.order,
"output_length": args.output_length,
"duration": duration,
}
)

# Display throughput information
throughput = len(prompts) / duration
print(f"Throughput: {throughput:.2f} requests per second")

# Write the results to a JSONL file
with open("result.jsonl", "a") as f:
for line in result_jsonl:
f.write(json.dumps(line) + "\n")


if __name__ == "__main__":
# Set up command-line arguments
parser = argparse.ArgumentParser(
description="Benchmark prompt generation and SGLang execution."
)
parser.add_argument(
"--order",
type=str,
default="random",
choices=["random", "sequential", "interleaved"],
help="Order of prompt execution",
)
parser.add_argument(
"--num_groups", "-n", type=int, default=100, help="Number of prompt groups"
)
parser.add_argument(
"--group_size", "-s", type=int, default=100, help="Size of each prompt group"
)
parser.add_argument(
"--context_length", type=int, default=1000, help="Length of the context"
)
parser.add_argument(
"--cache_rate", type=float, default=0.8, help="Cache rate for shared context"
)
parser.add_argument("--output_length", type=int, default=1, help="Output length")
parser.add_argument("--num_threads", type=int, default=64, help="Number of threads")
parser.add_argument(
"--prompt_pool",
type=str,
default="nltk",
help="Method to sample prompts",
choices=["nltk", "lorem"],
)

# args = parser.parse_args()
args = add_common_sglang_args_and_parse(parser)
main(args)
Loading