-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding streaming client for AIbrix experiments (#676)
* add streaming client initial Signed-off-by: Le Xu <[email protected]> * move client to separate directory Signed-off-by: Le Xu <[email protected]> * update logging Signed-off-by: Le Xu <[email protected]> * update readme Signed-off-by: Le Xu <[email protected]> * update analyze script Signed-off-by: Le Xu <[email protected]> * add goodput Signed-off-by: Le Xu <[email protected]> * clean up Signed-off-by: Le Xu <[email protected]> * update with routing strategies Signed-off-by: Le Xu <[email protected]> * update user option Signed-off-by: Le Xu <[email protected]> --------- Signed-off-by: Le Xu <[email protected]> Co-authored-by: Le Xu <[email protected]>
- Loading branch information
1 parent
928f094
commit cb91c4e
Showing
7 changed files
with
415 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
## Test client locally | ||
|
||
Starting vllm server: | ||
|
||
|
||
```shell | ||
export API_KEY=${API_KEY} | ||
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 \ | ||
--port "8000" \ | ||
--model /root/models/deepseek-llm-7b-chat \ | ||
--trust-remote-code \ | ||
--max-model-len "4096" \ | ||
--api-key ${API_KEY} \ | ||
--enable-chunked-prefill | ||
``` | ||
|
||
Using a sample workload (generated by [the workload generator](../generator/README.md)) in a client. Turn on `--streaming` to collect fine grained metrics such as `TTFT` and `TPOT`: | ||
|
||
```shell | ||
export API_KEY=${API_KEY} | ||
python3 client.py \ | ||
--workload-path "../generator/output/constant.jsonl" \ | ||
--endpoint "http://localhost:8000" \ | ||
--model /root/models/deepseek-llm-7b-chat \ | ||
--api-key ${API_KEY} \ | ||
--streaming \ | ||
--output-file-path output.jsonl | ||
``` | ||
The output will be stored as a `.jsonl` file in `output.jsonl` | ||
|
||
Run analysis on metrics collected. For streaming client, we can specify a goodput target (e2e/tpot/ttft) like the following: | ||
|
||
```shell | ||
python analyze.py --trace output.jsonl --output output --goodput-target tpot:0.5 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import json | ||
import argparse | ||
import os | ||
import re | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
import numpy as np | ||
|
||
def parse_goodput_target(goodput_target): | ||
pattern = r'^(e2e|tpot|ttft):(-?\d+(\.\d+)?)$' | ||
match = re.match(pattern, goodput_target) | ||
|
||
if match: | ||
metric = match.group(1) | ||
threshold = float(match.group(2)) # Convert to float | ||
else: | ||
raise ValueError(f"Invalid goodput spec: {goodput_target}") | ||
return metric, threshold | ||
|
||
def main(args): | ||
input_file = args.trace | ||
output_path = args.output | ||
data = [] | ||
with open(input_file, "r") as f: | ||
for line in f: | ||
data.append(json.loads(line)) | ||
# Extract metrics | ||
timestamps = [item.get("start_time", f"Entry {i}") for i, item in enumerate(data)] | ||
prompt_tokens = [item["prompt_tokens"] for item in data] | ||
output_tokens = [item["output_tokens"] for item in data] | ||
total_tokens = [item["total_tokens"] for item in data] | ||
latencies = [item["latency"] for item in data] | ||
throughputs = [item["throughput"] for item in data] | ||
tokens_per_second = [item["total_tokens"] / item["latency"] for item in data] | ||
ttft = [item["ttft"] if "ttft" in item else 0.0 for item in data] # Time to First Token | ||
tpot = [item["tpot"] if "tpot" in item else 0.0 for item in data] # Time per Output Token | ||
|
||
goodput = None | ||
if args.goodput_target is not None: | ||
metric, threshold = parse_goodput_target(args.goodput_target) | ||
if metric == "e2e": | ||
goodput = len([item for item in latencies if item <= threshold]) / float(len(latencies)) | ||
elif metric == "ttft": | ||
goodput = len([item for item in ttft if item <= threshold]) / float(len(ttft)) | ||
elif metric == "tpot": | ||
goodput = len([item for item in tpot if item <= threshold]) / float(len(tpot)) | ||
else: | ||
raise ValueError(f"Invalid goodput target: {args.goodput_target}") | ||
|
||
# Sort data by start_time | ||
sorted_indices = np.argsort(timestamps) | ||
timestamps = [timestamps[i] for i in sorted_indices] | ||
prompt_tokens = [prompt_tokens[i] for i in sorted_indices] | ||
output_tokens = [output_tokens[i] for i in sorted_indices] | ||
total_tokens = [total_tokens[i] for i in sorted_indices] | ||
latencies = [latencies[i] for i in sorted_indices] | ||
throughputs = [throughputs[i] for i in sorted_indices] | ||
tokens_per_second = [tokens_per_second[i] for i in sorted_indices] | ||
ttft = [ttft[i] for i in sorted_indices] | ||
tpot = [tpot[i] for i in sorted_indices] | ||
|
||
# Convert timestamps to pandas datetime (if timestamps are actual time values) | ||
try: | ||
timestamps = pd.to_datetime(timestamps, unit='s') | ||
except Exception: | ||
timestamps = pd.Series(timestamps) | ||
|
||
# Helper function to calculate statistics | ||
def calculate_statistics(values): | ||
values = sorted(values) | ||
avg = sum(values) / len(values) | ||
median = np.median(values) | ||
percentile_99 = np.percentile(values, 99) | ||
return avg, median, percentile_99 | ||
|
||
# Calculate statistics for each metric | ||
stats = { | ||
"End-to-End Latency (s)": calculate_statistics(latencies), | ||
"Throughput": calculate_statistics(throughputs), | ||
"Tokens per Second": calculate_statistics(tokens_per_second), | ||
"Prompt Tokens": calculate_statistics(prompt_tokens), | ||
"Output Tokens": calculate_statistics(output_tokens), | ||
"Total Tokens": calculate_statistics(total_tokens), | ||
"Time to First Token (TTFT)": calculate_statistics(ttft), | ||
"Time per Output Token (TPOT)": calculate_statistics(tpot), | ||
} | ||
|
||
# Print statistics | ||
for metric, (avg, median, p99) in stats.items(): | ||
print(f"{metric} Statistics: Average = {avg:.4f}, Median = {median:.4f}, 99th Percentile = {p99:.4f}") | ||
if goodput != None: | ||
print(f"Goodput (reqs/s) {goodput:.4f}") | ||
|
||
# Create a DataFrame for plotting | ||
df = pd.DataFrame({ | ||
"Timestamp": timestamps, | ||
"Prompt Tokens": prompt_tokens, | ||
"Output Tokens": output_tokens, | ||
"Total Tokens": total_tokens, | ||
"End-to-End Latency (s)": latencies, | ||
"Throughput": throughputs, | ||
"Tokens per Second": tokens_per_second, | ||
"Time to First Token (TTFT)": ttft, | ||
"Time per Output Token (TPOT)": tpot, | ||
}).set_index("Timestamp") | ||
|
||
# Plot each metric in a separate subplot | ||
num_metrics = len(df.columns) | ||
fig, axes = plt.subplots(num_metrics, 1, figsize=(12, 4 * num_metrics), sharex=True) | ||
|
||
for ax, (column, values) in zip(axes, df.items()): | ||
ax.plot(df.index, values, marker='o', linestyle='-', label=column) | ||
ax.set_ylabel(column) | ||
ax.legend() | ||
ax.grid() | ||
|
||
axes[-1].set_xlabel("Time") # Only set x-axis label for the last subplot | ||
plt.suptitle("Time Series Analysis of LLM Performance Metrics") | ||
plt.xticks(rotation=45) | ||
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to fit the title | ||
os.makedirs(output_path, exist_ok=True) | ||
plt.savefig(f"{output_path}/performance_metrics_time_series.pdf") | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='extract and plot performance metrics from a JSONL file') | ||
parser.add_argument('--trace', type=str, required=True, help='Input trace containing collected metrics.') | ||
parser.add_argument('--output', type=str, required=True, default="output", help='Output path.') | ||
parser.add_argument('--goodput-target', type=str, required=False, default=None, help='Goodput target should be in the format of latency_metrics:threshold_in_seconds, choose latency metrics from one of the e2e, ttft, tpot.') | ||
|
||
args = parser.parse_args() | ||
main(args) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import argparse | ||
import logging | ||
import time | ||
import asyncio | ||
import openai | ||
import json | ||
import io | ||
import traceback | ||
|
||
|
||
from typing import List | ||
from utils import (load_workload, wrap_prompt_as_chat_message) | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
async def send_request_streaming(client: openai.AsyncOpenAI, | ||
model: str, | ||
endpoint: str, | ||
prompt: str, | ||
output_file: str, | ||
): | ||
start_time = asyncio.get_event_loop().time() | ||
first_response_time = None | ||
try: | ||
stream = await client.chat.completions.create( | ||
model=model, | ||
messages=prompt, | ||
temperature=0, | ||
max_tokens=2048, | ||
stream=True, | ||
stream_options={"include_usage": True}, | ||
) | ||
text_chunks = [] | ||
prompt_tokens = 0 | ||
output_tokens = 0 | ||
total_tokens = 0 | ||
|
||
async for chunk in stream: | ||
if chunk.choices: | ||
if chunk.choices[0].delta.content is not None: | ||
if not first_response_time: | ||
first_response_time = asyncio.get_event_loop().time() | ||
output_text = chunk.choices[0].delta.content | ||
text_chunks.append(output_text) | ||
prompt_tokens = chunk.usage.prompt_tokens | ||
output_tokens = chunk.usage.completion_tokens | ||
total_tokens = chunk.usage.total_tokens | ||
response = "".join(text_chunks) | ||
response_time = asyncio.get_event_loop().time() | ||
latency = response_time - start_time | ||
throughput = output_tokens / latency | ||
ttft = first_response_time - start_time | ||
tpot = (response_time - first_response_time) / output_tokens | ||
result = { | ||
"input": prompt, | ||
"output": response, | ||
"prompt_tokens": prompt_tokens, | ||
"output_tokens": output_tokens, | ||
"total_tokens": total_tokens, | ||
"latency": latency, | ||
"throughput": throughput, | ||
"start_time": start_time, | ||
"current_time": asyncio.get_event_loop().time(), | ||
"ttft": ttft, | ||
"tpot": tpot, | ||
} | ||
logging.info(result) | ||
# Write result to JSONL file | ||
output_file.write(json.dumps(result) + "\n") | ||
output_file.flush() # Ensure data is written immediately to the file | ||
return result | ||
except Exception as e: | ||
logging.error(f"Error sending request to at {endpoint}: {str(e)}") | ||
traceback.print_exc() | ||
return None | ||
|
||
async def benchmark_streaming(client: openai.AsyncOpenAI, | ||
endpoint: str, | ||
model: str, | ||
load_struct: List, | ||
output_file: io.TextIOWrapper): | ||
|
||
batch_tasks = [] | ||
base_time = time.time() | ||
num_requests = 0 | ||
for requests_dict in load_struct: | ||
ts = int(requests_dict["timestamp"]) | ||
requests = requests_dict["requests"] | ||
cur_time = time.time() | ||
target_time = base_time + ts / 1000.0 | ||
logging.warning(f"Prepare to launch {len(requests)} streaming tasks after {target_time - cur_time}") | ||
if target_time > cur_time: | ||
await asyncio.sleep(target_time - cur_time) | ||
formatted_prompts = [wrap_prompt_as_chat_message(request["prompt"]) for request in requests] | ||
for formatted_prompt in formatted_prompts: | ||
task = asyncio.create_task( | ||
send_request_streaming(client = client, | ||
model = model, | ||
endpoint = endpoint, | ||
prompt = formatted_prompt, | ||
output_file = output_file) | ||
) | ||
batch_tasks.append(task) | ||
num_requests += len(requests) | ||
await asyncio.gather(*batch_tasks) | ||
logging.warning(f"All {num_requests} requests completed for deployment.") | ||
|
||
# Asynchronous request handler | ||
async def send_request_batch(client, model, endpoint, prompt, output_file): | ||
start_time = asyncio.get_event_loop().time() | ||
try: | ||
response = await client.chat.completions.create( | ||
model=model, | ||
messages=prompt, | ||
temperature=0, | ||
max_tokens=2048 | ||
) | ||
|
||
latency = asyncio.get_event_loop().time() - start_time | ||
prompt_tokens = response.usage.prompt_tokens | ||
output_tokens = response.usage.completion_tokens | ||
total_tokens = response.usage.total_tokens | ||
throughput = output_tokens / latency | ||
output_text = response.choices[0].message.content | ||
|
||
result = { | ||
"input": prompt, | ||
"output": output_text, | ||
"prompt_tokens": prompt_tokens, | ||
"output_tokens": output_tokens, | ||
"total_tokens": total_tokens, | ||
"start_time": start_time, | ||
"current_time": asyncio.get_event_loop().time(), | ||
"latency": latency, | ||
"throughput": throughput | ||
} | ||
logging.info(result) | ||
# Write result to JSONL file | ||
output_file.write(json.dumps(result) + "\n") | ||
output_file.flush() # Ensure data is written immediately to the file | ||
|
||
return result | ||
except Exception as e: | ||
logging.error(f"Error sending request to at {endpoint}: {str(e)}") | ||
return None | ||
|
||
|
||
async def benchmark_batch(client: openai.AsyncOpenAI, | ||
endpoint: str, | ||
model: str, | ||
load_struct: List, | ||
output_file: io.TextIOWrapper): | ||
batch_tasks = [] | ||
base_time = time.time() | ||
num_requests = 0 | ||
for requests_dict in load_struct: | ||
ts = int(requests_dict["timestamp"]) | ||
requests = requests_dict["requests"] | ||
cur_time = time.time() | ||
target_time = base_time + ts / 1000.0 | ||
logging.warning(f"Prepare to launch {len(requests)} batched tasks after {target_time - cur_time}") | ||
if target_time > cur_time: | ||
await asyncio.sleep(target_time - cur_time) | ||
formatted_prompts = [wrap_prompt_as_chat_message(request["prompt"]) for request in requests] | ||
for formatted_prompt in formatted_prompts: | ||
task = asyncio.create_task( | ||
send_request_batch(client, model, endpoint, formatted_prompt, output_file) | ||
) | ||
batch_tasks.append(task) | ||
num_requests += len(requests) | ||
await asyncio.gather(*batch_tasks) | ||
logging.warning(f"All {num_requests} requests completed for deployment.") | ||
|
||
|
||
def main(args): | ||
logging.info(f"Starting benchmark on endpoint {args.endpoint}") | ||
with open(args.output_file_path, 'w', encoding='utf-8') as output_file: | ||
load_struct = load_workload(args.workload_path) | ||
client = openai.AsyncOpenAI( | ||
api_key=args.api_key, | ||
base_url=args.endpoint + "/v1", | ||
) | ||
if args.routing_strategy is not None: | ||
client.default_headers["routing-strategy"] = args.routing_strategy | ||
if not args.streaming: | ||
logging.info("Using batch client") | ||
start_time = time.time() | ||
asyncio.run(benchmark_batch( | ||
client = client, | ||
endpoint=args.endpoint, | ||
model=args.model, | ||
load_struct=load_struct, | ||
output_file=output_file, | ||
)) | ||
end_time = time.time() | ||
logging.info(f"Benchmark completed in {end_time - start_time:.2f} seconds") | ||
else: | ||
logging.info("Using streaming client") | ||
start_time = time.time() | ||
asyncio.run(benchmark_streaming( | ||
client = client, | ||
endpoint=args.endpoint, | ||
model=args.model, | ||
load_struct=load_struct, | ||
output_file=output_file, | ||
)) | ||
end_time = time.time() | ||
logging.info(f"Benchmark completed in {end_time - start_time:.2f} seconds") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Workload Generator') | ||
parser.add_argument("--workload-path", type=str, default=None, help="File path to the workload file.") | ||
parser.add_argument('--endpoint', type=str, required=True) | ||
parser.add_argument("--model", type=str, required=True, help="Name of the model.") | ||
parser.add_argument("--api-key", type=str, required=True, help="API key to the service. ") | ||
parser.add_argument('--output-file-path', type=str, default="output.jsonl") | ||
parser.add_argument("--streaming", action="store_true", help="Use streaming client.") | ||
parser.add_argument("--routing-strategy", type=str, required=False, default=None, help="Routing strategy to use.") | ||
|
||
args = parser.parse_args() | ||
main(args) |
Oops, something went wrong.