diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 13071f581375..d9458bed7ed3 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import asyncio import itertools import os import uuid @@ -26,7 +27,7 @@ async def lifespan(app: FastAPI): # Create prefill clients for i, (host, port) in enumerate(global_args.prefiller_instances): - prefiller_base_url = f'http://{host}:{port}/v1' + prefiller_base_url = f'http://{host}:{port}/' app.state.prefill_clients.append({ 'client': httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), @@ -40,7 +41,7 @@ async def lifespan(app: FastAPI): # Create decode clients for i, (host, port) in enumerate(global_args.decoder_instances): - decoder_base_url = f'http://{host}:{port}/v1' + decoder_base_url = f'http://{host}:{port}/' app.state.decode_clients.append({ 'client': httpx.AsyncClient(timeout=None, base_url=decoder_base_url), @@ -206,7 +207,7 @@ async def handle_completions(request: Request): # Send request to prefill service response = await send_request_to_service(prefill_client_info, - "/completions", req_data, + "v1/completions", req_data, request_id) # Extract the needed fields @@ -223,7 +224,7 @@ async def handle_completions(request: Request): # Stream response from decode service async def generate_stream(): async for chunk in stream_service_response(decode_client_info, - "/completions", + "v1/completions", req_data, request_id=request_id): yield chunk @@ -252,6 +253,60 @@ async def healthcheck(): } +async def send_profile_cmd(request: Request, req_data, profiler_cmd): + assert profiler_cmd in ["start", "stop"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + # Send to all prefiller and decoder, leaving iterator in same state. + tasks = [] + for _ in range(len(app.state.prefill_clients)): + for client in ['prefill', 'decode']: + client_info = get_next_client(request.app, client) + + tasks.append(client_info['client'].post(f"/{profiler_cmd}_profile", + json=req_data, + headers=headers)) + + responses = await asyncio.gather(*tasks) + for r in responses: + r.raise_for_status() + + return responses[0].json() + + +@app.post("/start_profile") +async def start_profile(request: Request): + try: + req_data = await request.json() + return await send_profile_cmd(request, req_data, "start") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - start_profile endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +@app.post("/stop_profile") +async def stop_profile(request: Request): + try: + req_data = await request.json() + return await send_profile_cmd(request, req_data, "stop") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - stop_profile endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + if __name__ == '__main__': global global_args global_args = parse_args()