Skip to content

Commit

Permalink
update benchmarking script and add request status code
Browse files Browse the repository at this point in the history
  • Loading branch information
nwangfw committed Feb 24, 2025
1 parent 9179510 commit 51be6fd
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,27 @@ generate_workload() {
local output_len=$2
local api_key=$3
local num_prompts=$4

local model=$5

echo " input_len: $input_len"
echo " output_len: $output_len"
echo " api_key: $api_key"
echo " num_prompts: $num_prompts"
echo " model: $model"
echo "Generating workload for input=$input_len, output=$output_len, API_KEY=$api_key, num_prompts=$num_prompts, model=$model"

echo "python $PATH_PREFIX/gen_benchmark_prompt.py \
$workload \
--input-tokens \"$input_len\" \
--min-output-tokens \"$output_len\" \
--tolerance \"0.2\" \
--qps \"2.0\" \
--host \"localhost\" \
--port \"8010\" \
--api-key \"$api_key\" \
--total-prompts \"$num_prompts\" \
--model \"$model\""

echo "Generating workload for input=$input_len, output=$output_len, API_KEY=$api_key, num_prompts=$num_prompts"
python $PATH_PREFIX/gen_benchmark_prompt.py \
$workload \
--input-tokens "$input_len" \
Expand All @@ -57,7 +71,8 @@ generate_workload() {
--host "localhost" \
--port "8010" \
--api-key "$api_key" \
--total-prompts "$num_prompts"
--total-prompts "$num_prompts" \
--model "$model"
}

while [[ $# -gt 0 ]]; do
Expand Down Expand Up @@ -108,6 +123,7 @@ while [[ $# -gt 0 ]]; do
;;
--api-key)
LLM_API_KEY=$2
debug_print "Set LLM_API_KEY to: $LLM_API_KEY"
shift 2
;;
--workload)
Expand Down Expand Up @@ -160,7 +176,7 @@ while [[ $input_len -le $input_limit ]]; do
output_len=$output_start
while [[ $output_len -le $output_limit ]]; do
# Make sure all arguments are passed in the correct order
generate_workload "$input_len" "$output_len" "$LLM_API_KEY" "$TOTAL"
generate_workload "$input_len" "$output_len" "$LLM_API_KEY" "$TOTAL" "$MODEL"

# Convert rate_start to integer (multiply by 100 and remove decimals)
req_rate=$(echo "$rate_start * 100" | bc | cut -d. -f1)
Expand All @@ -170,7 +186,7 @@ while [[ $input_len -le $input_limit ]]; do

WORKLOAD_FILE="$PROMPT_DIR/prompt_in${input_len}_out${output_len}.json"
if [[ -f "$WORKLOAD_FILE" ]]; then
python $PATH_PREFIX/gpu_benchmark.py --backend=vllm --port 8010 --model=$MODEL --request-rate=$actual_rate --num-prompts=$TOTAL --input-len $input_len --output-len $output_len --api-key "$LLM_API_KEY" --stream --workload_dataset_file $WORKLOAD_FILE >> "$OUTPUT_FILE"
python $PATH_PREFIX/gpu_benchmark.py --backend=vllm --port 8010 --model=$MODEL --request-rate=$actual_rate --num-prompts=$TOTAL --input-len $input_len --output-len $output_len --api-key "$LLM_API_KEY" --workload_dataset_file $WORKLOAD_FILE >> "$OUTPUT_FILE"
fi
req_rate=$((req_rate * 2))
done
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def wait(self):
class PromptSelector:
def __init__(self, trace_file: str,
model_endpoint: str = "http://localhost:8888/v1/chat/completions",
model: str = "deepseek-coder-7b",
qps: float = 2.0,
temperature: float = 0.0,
api_key: str = "any_key",
total_prompts: int = 1):
self.trace_file = trace_file
self.model_endpoint = model_endpoint
self.model = model
self.tokenizer = get_tokenizer("", False)
self.rate_limiter = RateLimiter(qps)
self.temperature = temperature
Expand All @@ -54,7 +56,7 @@ def count_tokens(self, text: str) -> int:
"""Estimate token count using VLLM's tokenizer."""
return len(self.tokenizer.encode(text))

def get_completion_tokens(self, prompt: str, model: str = "deepseek-coder-33b-instruct") -> Tuple[Optional[int], Dict]:
def get_completion_tokens(self, prompt: str) -> Tuple[Optional[int], Dict]:
"""Get actual completion tokens by querying the model with rate limiting."""
self.rate_limiter.wait()

Expand All @@ -64,7 +66,7 @@ def get_completion_tokens(self, prompt: str, model: str = "deepseek-coder-33b-in
}

data = {
"model": model,
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": self.temperature
}
Expand Down Expand Up @@ -204,8 +206,9 @@ def parse_args():
help='API key for model access (default: any_key)')
parser.add_argument('--total-prompts', type=int, default=1,
help='Number of prompts to generate (default: 1)')
return parser.parse_args()

parser.add_argument('--model', type=str, default='deepseek-coder-7b',
help='Model name to use for completion')
return parser.parse_args()
def main():
args = parse_args()
start_time = time.time()
Expand All @@ -227,6 +230,7 @@ def main():
trace_file=args.workload_dataset_file,
model_endpoint=model_endpoint,
qps=args.qps,
model=args.model,
temperature=args.temperature,
api_key=args.api_key,
total_prompts=args.total_prompts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async def send_request(
}
if api_key is not None or api_key != "":
headers["Authorization"] = f"Bearer {api_key}"

streaming = stream
if backend == "vllm":
pload = {
Expand All @@ -193,71 +194,96 @@ async def send_request(
request_start_time = time.perf_counter()
ts = datetime.now(timezone.utc)
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
# print(f"Sending request: {api_url}:{pload}")
async with session.post(api_url, headers=headers, json=pload) as response:
chunks = []
token_latencies = []
previous_token_time = time.perf_counter()
first = True
status_code = None
error_msg = None
token_latencies = [] # Initialize here
time_to_first = 0.0 # Initialize here

try:
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
status_code = response.status
response_status = "success" if status_code == 200 else "failed"

# Capture error response for non-200 status codes
if status_code != 200:
error_response = await response.text()
try:
error_json = json.loads(error_response)
error_msg = error_json.get('error', error_response)
except:
error_msg = error_response
print_err(f"Request {idx} failed with status {status_code}: {error_msg}")
break

chunks = []
previous_token_time = time.perf_counter()
first = True

try:
if streaming:
async for chunk, _ in response.content.iter_chunks():
chunks = [chunk]
now_time = time.perf_counter()
if first:
time_to_first = now_time - previous_token_time
first = False
else:
token_latencies.append(now_time - previous_token_time)
previous_token_time = now_time
# Stream off: Chunks are full response.
# chunks.append(chunk)

output = b"".join(chunks).decode("utf-8")
santicized = output.rstrip("\n\t ")
else:
time_to_first = time.perf_counter() - previous_token_time
output = await response.text()
santicized = output
except Exception as e:
error_msg = f"Failed to read response: {str(e)}"
print_err(f"Failed to read response for request {idx}: {e}")
break

try:
if streaming:
async for chunk, _ in response.content.iter_chunks():
# Stream on: Each chunk in the response is the full response so far
chunks = [chunk]

now_time = time.perf_counter()
if first:
time_to_first = now_time - previous_token_time
first = False
else:
token_latencies.append(now_time - previous_token_time)
previous_token_time = now_time

# Stream off: Chunks are full response.
# chunks.append(chunk)

output = b"".join(chunks).decode("utf-8")
santicized = output.rstrip(
"\n\t "
) # Remove trailing whitespace characters including EOF, and "[DONE]"
else:
time_to_first = time.perf_counter() - previous_token_time
output = await response.text()
santicized = output
ret = load_response(santicized)
if "error" not in ret:
break
error_msg = f"API error: {ret.get('error', 'Unknown error')}"
except Exception as e:
print_err(f"Failed to read response for request {idx}: {e}")
break
try:
ret = load_response(santicized)

# Re-send the request if it failed.
if "error" not in ret:
error_msg = f"Failed to parse response: {str(e)}"
print_err(f"Invalid response for request {idx}: {santicized}: {e}")
break
except Exception as e:
# It's ok to parse failure, santicized output could be jsonl, other format, or internal error.
print_err(f"Invalid response for request {idx}: {santicized}: {e}")
break
except Exception as e:
# It's ok to parse failure, santicized output could be jsonl, other format, or internal error.
print_err(f"Invalid response for request {idx}: {santicized}: {e}")
return

request_end_time = time.perf_counter()
request_latency = request_end_time - request_start_time

if trace:
request_trace = {
"request_id": idx,
"input_tokens": prompt_len,
"output_tokens": output_len
if len(token_latencies) == 0
else len(token_latencies) + 1,
"timestamp": ts.strftime("%Y-%m-%d %H:%M:%S %Z%z"),
"E2E": request_latency,
"status_code": status_code,
"success": status_code == 200 if status_code else False,
# "request_payload": pload
}
if error_msg:
request_trace["error"] = error_msg
if len(token_latencies) > 0:
request_trace["TTFT"] = time_to_first
request_trace["TPOT_mean"] = np.mean(token_latencies) # type: ignore
request_trace["TPOT_P50"] = np.percentile(token_latencies, 50) # type: ignore
request_trace["TPOT_P90"] = np.percentile(token_latencies, 90) # type: ignore
request_trace["TPOT_P99"] = np.percentile(token_latencies, 99) # type: ignore
request_trace["TPOT_mean"] = np.mean(token_latencies)
request_trace["TPOT_P50"] = np.percentile(token_latencies, 50)
request_trace["TPOT_P90"] = np.percentile(token_latencies, 90)
request_trace["TPOT_P99"] = np.percentile(token_latencies, 99)
print(json.dumps(request_trace))
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
if len(token_latencies) > 0:
Expand Down

0 comments on commit 51be6fd

Please sign in to comment.