Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Add a gRPC server for completion request #2478

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c4ad1ac
add proto defenition for completion request
MrAta Dec 12, 2024
6987cd2
add grpc server
MrAta Dec 12, 2024
fdf3d12
add grpc server args
MrAta Dec 12, 2024
fa0881c
add tests
MrAta Dec 12, 2024
8c16262
add benchmark for grpc server
MrAta Dec 12, 2024
baba186
add proto generated files
MrAta Dec 12, 2024
1fe98c0
separate grpc and http server
MrAta Dec 12, 2024
ceae9d4
fix bench serving
MrAta Dec 12, 2024
d8d561e
remove grpc port in bench serving
MrAta Dec 12, 2024
d655ba1
add back server docstrings
MrAta Dec 12, 2024
a28182f
fix formats
MrAta Dec 12, 2024
9c33cf8
fix bench serving
MrAta Dec 13, 2024
3b2a806
fix erorr handling in server
MrAta Dec 13, 2024
0af6298
make client verbose
MrAta Dec 13, 2024
8f2e26c
update server docstring
MrAta Dec 13, 2024
7435d17
create final chunk explicitly
MrAta Dec 15, 2024
0d7362c
remove duplicate tests
MrAta Dec 15, 2024
eed1dca
add dependencies
MrAta Dec 15, 2024
81c4748
fix local import path
MrAta Dec 15, 2024
4ad057c
pin grpc versions
MrAta Dec 15, 2024
68007cb
remove debug prints
MrAta Dec 15, 2024
60e37af
trigger tests
MrAta Dec 15, 2024
31f8f8e
make fast api default server
MrAta Dec 17, 2024
cf0edac
remove tm depenency in grpc server
MrAta Dec 17, 2024
fb37ac0
launch grpc server inside tokenizer manager
MrAta Dec 17, 2024
9f88fd3
conver grpc launch into a func
MrAta Dec 17, 2024
b89ad96
create server inside the coroutine
MrAta Dec 17, 2024
8b29ed7
remove comments
MrAta Dec 17, 2024
9fbef84
move grpc server creation into loop handler
MrAta Dec 17, 2024
f072559
launch the grpc server in a separate thread
MrAta Dec 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.6"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6", "grpcio==1.68.1", "grpcio-tools==1.68.1"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is version restriction necessary?


# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
Expand Down
131 changes: 129 additions & 2 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union

import aiohttp
import grpc
import numpy as np
import requests
from tqdm.asyncio import tqdm
Expand All @@ -39,6 +40,8 @@
PreTrainedTokenizerFast,
)

from sglang.srt.proto import completion_pb2, completion_pb2_grpc

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

global args
Expand Down Expand Up @@ -386,6 +389,118 @@ async def async_request_sglang_generate(
return output


async def async_request_sglang_grpc(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.startswith("grpc://"), "gRPC URL must start with grpc://"

output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

try:
# Create gRPC request with same parameters as FastAPI
request = completion_pb2.CompletionRequest(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to match (e.g., return logprobs or **request_func_input.extra_request_body,)

prompt=request_func_input.prompt,
temperature=0.0001,
top_p=1.0,
top_k=-1,
max_tokens=request_func_input.output_len,
ignore_eos=not args.disable_ignore_eos,
stream=not args.disable_stream,
)

st = time.perf_counter()
most_recent_timestamp = st
ttft = 0.0
generated_text = ""

# Create channel and stub
server_addr = api_url[7:] # Remove grpc:// prefix
channel = grpc.aio.insecure_channel(
server_addr,
options=[
("grpc.max_send_message_length", 100 * 1024 * 1024),
("grpc.max_receive_message_length", 100 * 1024 * 1024),
],
)

try:
stub = completion_pb2_grpc.CompletionServiceStub(channel)
response_stream = stub.Complete(request)

async for response in response_stream:
timestamp = time.perf_counter()

# Handle streaming response similar to FastAPI
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
else:
output.itl.append(timestamp - most_recent_timestamp)

most_recent_timestamp = timestamp

# Accumulate text from each response
if response.text:
generated_text = (
response.text
) # Use latest text as it contains full response

# Check if this is the final response
if response.finished:
output.generated_text = generated_text
output.success = True
output.latency = time.perf_counter() - st
output.output_len = request_func_input.output_len
break

# Ensure we have final output values set
if output.success:
output.generated_text = generated_text
output.success = True
output.latency = time.perf_counter() - st
output.output_len = request_func_input.output_len

except grpc.aio.AioRpcError as rpc_error:
# Get the trailing metadata
try:
metadata = await rpc_error.trailing_metadata()
metadata_dict = {k: v for k, v in metadata}
except:
metadata_dict = {}

# Build comprehensive error message
error_msg = [
f"gRPC Error:",
f"Status code: {rpc_error.code().name} ({rpc_error.code().value})",
f"Details: {rpc_error.details()}",
f"Debug error string: {rpc_error.debug_error_string()}",
]
if metadata_dict:
error_msg.append(f"Metadata: {metadata_dict}")

output.error = "\n".join(error_msg)
output.success = False

except Exception as e:
output.error = f"Stream error: {str(e)}\n{traceback.format_exc()}"
output.success = False

finally:
await channel.close()

except Exception as e:
output.error = f"Request creation error: {str(e)}\n{traceback.format_exc()}"
output.success = False

if pbar:
pbar.update(1)

return output


async def async_request_gserver(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
Expand Down Expand Up @@ -478,6 +593,7 @@ def get_dataset(args, tokenizer):

ASYNC_REQUEST_FUNCS = {
"sglang": async_request_sglang_generate,
"sglang-grpc": async_request_sglang_grpc,
"sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions,
"vllm": async_request_openai_completions,
Expand Down Expand Up @@ -917,11 +1033,14 @@ async def limited_request_func(request_func_input, pbar):
lora_name=lora_name,
extra_request_body=extra_request_body,
)

test_output = await request_func(request_func_input=test_input)

if not test_output.success:
error_msg = test_output.error or "Unknown error occurred"
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
f"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified.\nError details: {error_msg}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
Expand Down Expand Up @@ -1169,6 +1288,9 @@ def run_benchmark(args_: argparse.Namespace):
else f"http://{args.host}:{args.port}/v1/models"
)

if args.backend == "sglang-grpc":
# use grpc address
api_url = f"grpc://{args.host}:{args.grpc_port}"
if args.backend in ["sglang", "sglang-native"]:
api_url = (
f"{args.base_url}/generate"
Expand Down Expand Up @@ -1316,6 +1438,11 @@ def set_ulimit(target_soft_limit=65535):
type=int,
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
)
parser.add_argument(
"--grpc-port",
type=int,
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's comment this is only for sglang backend and used only when grpc backend is used with --backend=="sglang-grpc"?

)
parser.add_argument(
"--dataset-name",
type=str,
Expand Down
74 changes: 74 additions & 0 deletions python/sglang/srt/grpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import traceback
from typing import Any, AsyncGenerator, Callable, Dict

import grpc

from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.proto import completion_pb2, completion_pb2_grpc


class CompletionServicer(completion_pb2_grpc.CompletionServiceServicer):
def __init__(
self,
generate_request: Callable[
[GenerateReqInput], AsyncGenerator[Dict[str, Any], None]
],
):
self.generate_request = generate_request

async def Complete(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
async def Complete(
async def complete(

self,
request: completion_pb2.CompletionRequest,
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[completion_pb2.CompletionResponse, None]:
try:
# Convert gRPC request to internal format
adapted_request = GenerateReqInput(
text=request.prompt,
sampling_params={
"max_new_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is dangerous. grpc protobuf has a default value 0 for unset values, and it means this top_p, top_k could be different from the default (1 and -1 respectively)

I think we need some sort of mechanism to set the default value? one potential solution is to use optional field in protobuf and set the correct default value when they are not provided

"top_k": request.top_k,
"min_p": request.min_p,
"frequency_penalty": request.frequency_penalty,
"presence_penalty": request.presence_penalty,
"stop": list(request.stop),
"ignore_eos": request.ignore_eos,
},
stream=request.stream,
)

# Process request through tokenizer manager
async for content in self.generate_request(adapted_request):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to hanadle the client closing the connection and abort the request (I think we can do it using context object). We can also do it as a follow up, and in that case, can you create a separate isuse?

# Create response for each token/chunk
response = completion_pb2.CompletionResponse(
text=content["text"], # Send full text so far
finished=False, # Not finished until last message
usage=completion_pb2.Usage(
prompt_tokens=content["meta_info"]["prompt_tokens"],
completion_tokens=content["meta_info"]["completion_tokens"],
# TODO: fix this
# total_tokens=content["meta_info"]["total_tokens"],
),
)
yield response

# Send final response with finished flag
final_response = completion_pb2.CompletionResponse(
text=content["text"], # Final complete text
finished=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: how does regular http completion request figure out if it is finished?

usage=completion_pb2.Usage(
prompt_tokens=content["meta_info"]["prompt_tokens"],
completion_tokens=content["meta_info"]["completion_tokens"],
# TODO: fix this
# total_tokens=content["meta_info"]["total_tokens"],
),
)
yield final_response

except Exception as e:
# Handle errors consistently
error_msg = f"Error in gRPC Complete: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
await context.abort(grpc.StatusCode.INTERNAL, error_msg)
56 changes: 53 additions & 3 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
import os
import signal
import sys
import threading
import time
import uuid
from concurrent import futures
from typing import Any, Dict, List, Optional, Union

import fastapi
import grpc
import uvloop
import zmq
import zmq.asyncio
from fastapi import BackgroundTasks

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.grpc_server import CompletionServicer
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
Expand Down Expand Up @@ -60,6 +64,7 @@
UpdateWeightsFromDistributedReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.proto import completion_pb2_grpc
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree
Expand Down Expand Up @@ -171,6 +176,17 @@ def __init__(
},
)

if server_args.grpc_port:
t = threading.Thread(
target=self._launch_grpc_server_in_loop,
name="gRPCServerThread",
daemon=True,
)
t.start()
logger.info(
f"TokenizerManager: launched gRPC server thread on port {server_args.grpc_port}"
)

async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
Expand Down Expand Up @@ -566,8 +582,13 @@ def create_handle_loop(self):
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())

signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
if threading.current_thread() is threading.main_thread():
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
else:
# the thread that is used by grpc server doesn't need to handle the signal
logger.warning("Skipping add_signal_handler because not in main thread.")

loop.create_task(self.sigterm_watchdog())

async def sigterm_watchdog(self):
Expand All @@ -590,7 +611,6 @@ async def sigterm_watchdog(self):

async def handle_loop(self):
"""The event loop that handles requests"""

while True:
recv_obj: Union[
BatchStrOut,
Expand Down Expand Up @@ -789,6 +809,36 @@ def detokenize_top_logprobs_tokens(
ret.append(None)
return ret

def _launch_grpc_server_in_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = loop.run_until_complete(self._create_grpc_server())
# Start the server before run_forever()
loop.run_until_complete(server.start())

logger.info(
f"gRPC server started, listening on {self.server_args.host}:{self.server_args.grpc_port}"
)
# Keep this loop alive so the server remains accessible
loop.run_forever()

async def _create_grpc_server(self):
# Create the server
server = grpc.aio.server(
options=[
("grpc.max_send_message_length", 100 * 1024 * 1024),
("grpc.max_receive_message_length", 100 * 1024 * 1024),
]
)
completion_pb2_grpc.add_CompletionServiceServicer_to_server(
CompletionServicer(self.generate_request), server
)

server.add_insecure_port(
f"{self.server_args.host}:{self.server_args.grpc_port}"
)
return server


class SignalHandler:
def __init__(self, tokenizer_manager):
Expand Down
Loading
Loading