Skip to content

Commit

Permalink
Add optional nltk sampling; add bench README
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Dec 27, 2024
1 parent b6c401d commit 0ca1128
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 45 deletions.
11 changes: 11 additions & 0 deletions benchmark/hicache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Run benchmark

### Benchmark SGLang with Radix Cache Offload
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 30000 --tensor-parallel-size 4 --enable-hierarchical-cache
```
python3 bench_sglang.py --num_groups 100 --group_size 100 --context_length 1000 --cache_rate 0.8
```
115 changes: 70 additions & 45 deletions benchmark/hicache/bench_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,41 @@

import requests
from lorem_text import lorem

from tqdm import tqdm
import sglang as sgl
from sglang import RuntimeEndpoint, set_default_backend


def generate_text(num_tokens):
"""Generates a text with approximately num_tokens."""
num_words = int(num_tokens / 1.93) # Assuming average word length
return lorem.words(num_words)


from sglang import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)

# TODO: To avoid too much unintended tokenizer spliting,
# sample more reasonably text from datasets like alpaca or sample from nltk corpus?
try:
import nltk
# use brown corpus for generating text
try:
nltk.data.find('corpora/brown')
except LookupError:
nltk.download('brown')
sentences = nltk.corpus.brown.sents()

def generate_text(num_tokens):
"""Generates a text with approximately num_tokens."""
num_words = int(num_tokens / 1.93)
text = []
while len(text) < num_words:
sentence = random.choice(sentences)
text.extend(sentence)
return " ".join(text[:int(num_words)])

except ImportError:
def generate_text(num_tokens):
"""Generates a text with approximately num_tokens."""
num_words = int(num_tokens / 1.93) # Assuming average word length
return lorem.words(num_words)


def generate_prompts(
num_groups=100,
group_size=100,
Expand All @@ -41,12 +65,12 @@ def generate_prompts(
assert order in ["random", "sequential", "interleaved"], "Invalid prompt order"
prompts = []

for _ in range(num_groups):
for _ in tqdm(range(num_groups), desc="Generating prompts"):
shared_context = generate_text(context_length * cache_rate)
for _ in range(group_size):
prompt = shared_context + generate_text(context_length * (1 - cache_rate))
prompts.append({"prompt": prompt, "max_tokens": max_tokens})

if order == "random":
return random.sample(prompts, len(prompts))
elif order == "sequential":
Expand All @@ -63,37 +87,9 @@ def test_sgl(s, prompt, max_tokens):
s += sgl.gen(max_tokens=max_tokens, ignore_eos=True)


def main():
# Set up command-line arguments
parser = argparse.ArgumentParser(
description="Benchmark prompt generation and SGLang execution."
)
parser.add_argument(
"--order",
type=str,
default="random",
choices=["random", "sequential", "interleaved"],
help="Order of prompt execution",
)
parser.add_argument(
"--num_groups", type=int, default=100, help="Number of prompt groups"
)
parser.add_argument(
"--group_size", type=int, default=100, help="Size of each prompt group"
)
parser.add_argument(
"--context_length", type=int, default=1000, help="Length of the context"
)
parser.add_argument(
"--cache_rate", type=float, default=0.8, help="Cache rate for shared context"
)
parser.add_argument("--output_length", type=int, default=1, help="Output length")
parser.add_argument("--num_threads", type=int, default=64, help="Number of threads")

args = parser.parse_args()

def main(args):
# Initialize SGLang runtime
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
set_default_backend(select_sglang_backend(args))
result_jsonl = []

# Log current parameters
Expand All @@ -114,8 +110,9 @@ def main():
order=args.order,
max_tokens=args.output_length,
)

url = "http://localhost:30000/flush_cache"
print(f"Sample prompt: {prompts[0]['prompt'][:80]}...")

url = f"http://localhost:{args.port}/flush_cache"
requests.post(url)
# sgl.flush_cache()
time.sleep(1) # Wait for the cache to be flushed
Expand Down Expand Up @@ -150,4 +147,32 @@ def main():


if __name__ == "__main__":
main()
# Set up command-line arguments
parser = argparse.ArgumentParser(
description="Benchmark prompt generation and SGLang execution."
)
parser.add_argument(
"--order",
type=str,
default="random",
choices=["random", "sequential", "interleaved"],
help="Order of prompt execution",
)
parser.add_argument(
"--num_groups", "-n", type=int, default=100, help="Number of prompt groups"
)
parser.add_argument(
"--group_size", "-s", type=int, default=100, help="Size of each prompt group"
)
parser.add_argument(
"--context_length", type=int, default=1000, help="Length of the context"
)
parser.add_argument(
"--cache_rate", type=float, default=0.8, help="Cache rate for shared context"
)
parser.add_argument("--output_length", type=int, default=1, help="Output length")
parser.add_argument("--num_threads", type=int, default=64, help="Number of threads")

# args = parser.parse_args()
args = add_common_sglang_args_and_parse(parser)
main(args)

0 comments on commit 0ca1128

Please sign in to comment.