-
Notifications
You must be signed in to change notification settings - Fork 717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Experimental] Add a gRPC server for completion request #2478
base: main
Are you sure you want to change the base?
Changes from all commits
c4ad1ac
6987cd2
fdf3d12
fa0881c
8c16262
baba186
1fe98c0
ceae9d4
d8d561e
d655ba1
a28182f
9c33cf8
3b2a806
0af6298
8f2e26c
7435d17
0d7362c
eed1dca
81c4748
4ad057c
68007cb
60e37af
31f8f8e
cf0edac
fb37ac0
9f88fd3
b89ad96
8b29ed7
9fbef84
f072559
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -29,6 +29,7 @@ | |||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union | ||||
|
||||
import aiohttp | ||||
import grpc | ||||
import numpy as np | ||||
import requests | ||||
from tqdm.asyncio import tqdm | ||||
|
@@ -39,6 +40,8 @@ | |||
PreTrainedTokenizerFast, | ||||
) | ||||
|
||||
from sglang.srt.proto import completion_pb2, completion_pb2_grpc | ||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) | ||||
|
||||
global args | ||||
|
@@ -386,6 +389,118 @@ async def async_request_sglang_generate( | |||
return output | ||||
|
||||
|
||||
async def async_request_sglang_grpc( | ||||
request_func_input: RequestFuncInput, | ||||
pbar: Optional[tqdm] = None, | ||||
) -> RequestFuncOutput: | ||||
api_url = request_func_input.api_url | ||||
assert api_url.startswith("grpc://"), "gRPC URL must start with grpc://" | ||||
|
||||
output = RequestFuncOutput() | ||||
output.prompt_len = request_func_input.prompt_len | ||||
|
||||
try: | ||||
# Create gRPC request with same parameters as FastAPI | ||||
request = completion_pb2.CompletionRequest( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sglang/python/sglang/bench_serving.py Line 315 in 2125898
This doesn't seem to match (e.g., return logprobs or **request_func_input.extra_request_body,) |
||||
prompt=request_func_input.prompt, | ||||
temperature=0.0001, | ||||
top_p=1.0, | ||||
top_k=-1, | ||||
max_tokens=request_func_input.output_len, | ||||
ignore_eos=not args.disable_ignore_eos, | ||||
stream=not args.disable_stream, | ||||
) | ||||
|
||||
st = time.perf_counter() | ||||
most_recent_timestamp = st | ||||
ttft = 0.0 | ||||
generated_text = "" | ||||
|
||||
# Create channel and stub | ||||
server_addr = api_url[7:] # Remove grpc:// prefix | ||||
channel = grpc.aio.insecure_channel( | ||||
server_addr, | ||||
options=[ | ||||
("grpc.max_send_message_length", 100 * 1024 * 1024), | ||||
("grpc.max_receive_message_length", 100 * 1024 * 1024), | ||||
], | ||||
) | ||||
|
||||
try: | ||||
stub = completion_pb2_grpc.CompletionServiceStub(channel) | ||||
response_stream = stub.Complete(request) | ||||
|
||||
async for response in response_stream: | ||||
timestamp = time.perf_counter() | ||||
|
||||
# Handle streaming response similar to FastAPI | ||||
if ttft == 0.0: | ||||
ttft = timestamp - st | ||||
output.ttft = ttft | ||||
else: | ||||
output.itl.append(timestamp - most_recent_timestamp) | ||||
|
||||
most_recent_timestamp = timestamp | ||||
|
||||
# Accumulate text from each response | ||||
if response.text: | ||||
generated_text = ( | ||||
response.text | ||||
) # Use latest text as it contains full response | ||||
|
||||
# Check if this is the final response | ||||
if response.finished: | ||||
output.generated_text = generated_text | ||||
output.success = True | ||||
output.latency = time.perf_counter() - st | ||||
output.output_len = request_func_input.output_len | ||||
break | ||||
|
||||
# Ensure we have final output values set | ||||
if output.success: | ||||
output.generated_text = generated_text | ||||
output.success = True | ||||
output.latency = time.perf_counter() - st | ||||
output.output_len = request_func_input.output_len | ||||
|
||||
except grpc.aio.AioRpcError as rpc_error: | ||||
# Get the trailing metadata | ||||
try: | ||||
metadata = await rpc_error.trailing_metadata() | ||||
metadata_dict = {k: v for k, v in metadata} | ||||
except: | ||||
metadata_dict = {} | ||||
|
||||
# Build comprehensive error message | ||||
error_msg = [ | ||||
f"gRPC Error:", | ||||
f"Status code: {rpc_error.code().name} ({rpc_error.code().value})", | ||||
f"Details: {rpc_error.details()}", | ||||
f"Debug error string: {rpc_error.debug_error_string()}", | ||||
] | ||||
if metadata_dict: | ||||
error_msg.append(f"Metadata: {metadata_dict}") | ||||
|
||||
output.error = "\n".join(error_msg) | ||||
output.success = False | ||||
|
||||
except Exception as e: | ||||
output.error = f"Stream error: {str(e)}\n{traceback.format_exc()}" | ||||
output.success = False | ||||
|
||||
finally: | ||||
await channel.close() | ||||
|
||||
except Exception as e: | ||||
output.error = f"Request creation error: {str(e)}\n{traceback.format_exc()}" | ||||
output.success = False | ||||
|
||||
if pbar: | ||||
pbar.update(1) | ||||
|
||||
return output | ||||
|
||||
|
||||
async def async_request_gserver( | ||||
request_func_input: RequestFuncInput, | ||||
pbar: Optional[tqdm] = None, | ||||
|
@@ -478,6 +593,7 @@ def get_dataset(args, tokenizer): | |||
|
||||
ASYNC_REQUEST_FUNCS = { | ||||
"sglang": async_request_sglang_generate, | ||||
"sglang-grpc": async_request_sglang_grpc, | ||||
"sglang-native": async_request_sglang_generate, | ||||
"sglang-oai": async_request_openai_completions, | ||||
"vllm": async_request_openai_completions, | ||||
|
@@ -917,11 +1033,14 @@ async def limited_request_func(request_func_input, pbar): | |||
lora_name=lora_name, | ||||
extra_request_body=extra_request_body, | ||||
) | ||||
|
||||
test_output = await request_func(request_func_input=test_input) | ||||
|
||||
if not test_output.success: | ||||
error_msg = test_output.error or "Unknown error occurred" | ||||
raise ValueError( | ||||
"Initial test run failed - Please make sure benchmark arguments " | ||||
f"are correctly specified. Error: {test_output.error}" | ||||
f"Initial test run failed - Please make sure benchmark arguments " | ||||
f"are correctly specified.\nError details: {error_msg}" | ||||
) | ||||
else: | ||||
print("Initial test run completed. Starting main benchmark run...") | ||||
|
@@ -1169,6 +1288,9 @@ def run_benchmark(args_: argparse.Namespace): | |||
else f"http://{args.host}:{args.port}/v1/models" | ||||
) | ||||
|
||||
if args.backend == "sglang-grpc": | ||||
# use grpc address | ||||
api_url = f"grpc://{args.host}:{args.grpc_port}" | ||||
if args.backend in ["sglang", "sglang-native"]: | ||||
api_url = ( | ||||
f"{args.base_url}/generate" | ||||
|
@@ -1316,6 +1438,11 @@ def set_ulimit(target_soft_limit=65535): | |||
type=int, | ||||
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", | ||||
) | ||||
parser.add_argument( | ||||
"--grpc-port", | ||||
type=int, | ||||
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's comment this is only for sglang backend and used only when grpc backend is used with --backend=="sglang-grpc"? |
||||
) | ||||
parser.add_argument( | ||||
"--dataset-name", | ||||
type=str, | ||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,74 @@ | ||||||
import traceback | ||||||
from typing import Any, AsyncGenerator, Callable, Dict | ||||||
|
||||||
import grpc | ||||||
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput | ||||||
from sglang.srt.proto import completion_pb2, completion_pb2_grpc | ||||||
|
||||||
|
||||||
class CompletionServicer(completion_pb2_grpc.CompletionServiceServicer): | ||||||
def __init__( | ||||||
self, | ||||||
generate_request: Callable[ | ||||||
[GenerateReqInput], AsyncGenerator[Dict[str, Any], None] | ||||||
], | ||||||
): | ||||||
self.generate_request = generate_request | ||||||
|
||||||
async def Complete( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
self, | ||||||
request: completion_pb2.CompletionRequest, | ||||||
context: grpc.aio.ServicerContext, | ||||||
) -> AsyncGenerator[completion_pb2.CompletionResponse, None]: | ||||||
try: | ||||||
# Convert gRPC request to internal format | ||||||
adapted_request = GenerateReqInput( | ||||||
text=request.prompt, | ||||||
sampling_params={ | ||||||
"max_new_tokens": request.max_tokens, | ||||||
"temperature": request.temperature, | ||||||
"top_p": request.top_p, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this part is dangerous. grpc protobuf has a default value 0 for unset values, and it means this top_p, top_k could be different from the default (1 and -1 respectively) I think we need some sort of mechanism to set the default value? one potential solution is to use optional field in protobuf and set the correct default value when they are not provided |
||||||
"top_k": request.top_k, | ||||||
"min_p": request.min_p, | ||||||
"frequency_penalty": request.frequency_penalty, | ||||||
"presence_penalty": request.presence_penalty, | ||||||
"stop": list(request.stop), | ||||||
"ignore_eos": request.ignore_eos, | ||||||
}, | ||||||
stream=request.stream, | ||||||
) | ||||||
|
||||||
# Process request through tokenizer manager | ||||||
async for content in self.generate_request(adapted_request): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to hanadle the client closing the connection and abort the request (I think we can do it using context object). We can also do it as a follow up, and in that case, can you create a separate isuse? |
||||||
# Create response for each token/chunk | ||||||
response = completion_pb2.CompletionResponse( | ||||||
text=content["text"], # Send full text so far | ||||||
finished=False, # Not finished until last message | ||||||
usage=completion_pb2.Usage( | ||||||
prompt_tokens=content["meta_info"]["prompt_tokens"], | ||||||
completion_tokens=content["meta_info"]["completion_tokens"], | ||||||
# TODO: fix this | ||||||
# total_tokens=content["meta_info"]["total_tokens"], | ||||||
), | ||||||
) | ||||||
yield response | ||||||
|
||||||
# Send final response with finished flag | ||||||
final_response = completion_pb2.CompletionResponse( | ||||||
text=content["text"], # Final complete text | ||||||
finished=True, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: how does regular http completion request figure out if it is finished? |
||||||
usage=completion_pb2.Usage( | ||||||
prompt_tokens=content["meta_info"]["prompt_tokens"], | ||||||
completion_tokens=content["meta_info"]["completion_tokens"], | ||||||
# TODO: fix this | ||||||
# total_tokens=content["meta_info"]["total_tokens"], | ||||||
), | ||||||
) | ||||||
yield final_response | ||||||
|
||||||
except Exception as e: | ||||||
# Handle errors consistently | ||||||
error_msg = f"Error in gRPC Complete: {str(e)}\n{traceback.format_exc()}" | ||||||
print(error_msg) | ||||||
await context.abort(grpc.StatusCode.INTERNAL, error_msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is version restriction necessary?