Skip to content

[P/D][Misc] Enable profiling in disagg setup #18827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import itertools
import os
import uuid
Expand All @@ -26,7 +27,7 @@ async def lifespan(app: FastAPI):

# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f'http://{host}:{port}/v1'
prefiller_base_url = f'http://{host}:{port}/'
app.state.prefill_clients.append({
'client':
httpx.AsyncClient(timeout=None, base_url=prefiller_base_url),
Expand All @@ -40,7 +41,7 @@ async def lifespan(app: FastAPI):

# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f'http://{host}:{port}/v1'
decoder_base_url = f'http://{host}:{port}/'
app.state.decode_clients.append({
'client':
httpx.AsyncClient(timeout=None, base_url=decoder_base_url),
Expand Down Expand Up @@ -206,7 +207,7 @@ async def handle_completions(request: Request):

# Send request to prefill service
response = await send_request_to_service(prefill_client_info,
"/completions", req_data,
"v1/completions", req_data,
request_id)

# Extract the needed fields
Expand All @@ -223,7 +224,7 @@ async def handle_completions(request: Request):
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(decode_client_info,
"/completions",
"v1/completions",
req_data,
request_id=request_id):
yield chunk
Expand Down Expand Up @@ -252,6 +253,60 @@ async def healthcheck():
}


async def send_profile_cmd(request: Request, req_data, profiler_cmd):
assert profiler_cmd in ["start", "stop"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
# Send to all prefiller and decoder, leaving iterator in same state.
tasks = []
for _ in range(len(app.state.prefill_clients)):
for client in ['prefill', 'decode']:
client_info = get_next_client(request.app, client)

tasks.append(client_info['client'].post(f"/{profiler_cmd}_profile",
json=req_data,
headers=headers))

responses = await asyncio.gather(*tasks)
for r in responses:
r.raise_for_status()

return responses[0].json()


@app.post("/start_profile")
async def start_profile(request: Request):
try:
req_data = await request.json()
return await send_profile_cmd(request, req_data, "start")

except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - start_profile endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))


@app.post("/stop_profile")
async def stop_profile(request: Request):
try:
req_data = await request.json()
return await send_profile_cmd(request, req_data, "stop")

except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - stop_profile endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))


if __name__ == '__main__':
global global_args
global_args = parse_args()
Expand Down