Skip to content

Commit 349fa10

Browse files
committed
[V1] Support DP with Ray
Signed-off-by: Rui Qiao <[email protected]>
1 parent ed522eb commit 349fa10

File tree

8 files changed

+490
-113
lines changed

8 files changed

+490
-113
lines changed

tests/v1/test_async_llm_dp.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
5959

6060

6161
@pytest.mark.parametrize(
62-
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
62+
"output_kind",
63+
[
64+
RequestOutputKind.DELTA,
65+
RequestOutputKind.FINAL_ONLY,
66+
],
67+
)
68+
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
6369
@pytest.mark.asyncio
64-
async def test_load(output_kind: RequestOutputKind):
70+
async def test_load(output_kind: RequestOutputKind,
71+
data_parallel_backend: str):
6572

6673
with ExitStack() as after:
6774

6875
prompt = "This is a test of data parallel"
6976

77+
engine_args.data_parallel_backend = data_parallel_backend
7078
engine = AsyncLLM.from_engine_args(engine_args)
7179
after.callback(engine.shutdown)
7280

@@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
8290
asyncio.create_task(
8391
generate(engine, request_id, prompt, output_kind,
8492
NUM_EXPECTED_TOKENS)))
85-
8693
# Confirm that we got all the EXPECTED tokens from the requests.
8794
done, pending = await asyncio.wait(tasks,
8895
return_when=asyncio.FIRST_EXCEPTION)

vllm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,8 @@ class ParallelConfig:
16931693
"""Port for data parallel messaging."""
16941694
data_parallel_master_port: int = 29500
16951695
"""Port of the data parallel master."""
1696+
data_parallel_backend: str = "mp"
1697+
"""Backend to use for data parallel, either "mp" or "ray"."""
16961698
enable_expert_parallel: bool = False
16971699
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
16981700
max_parallel_loading_workers: Optional[int] = None
@@ -1856,6 +1858,10 @@ def __post_init__(self) -> None:
18561858
"please install Ray with `pip install "
18571859
"ray`.") from ray_utils.ray_import_err
18581860
backend = "ray"
1861+
elif self.data_parallel_backend == "ray":
1862+
logger.info("Using ray distributed inference because "
1863+
"data_parallel_backend is ray")
1864+
backend = "ray"
18591865
elif ray_found:
18601866
if self.placement_group:
18611867
backend = "ray"

vllm/engine/arg_utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from vllm.transformers_utils.utils import check_gguf_file
3939
from vllm.usage.usage_lib import UsageContext
4040
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
41-
GiB_bytes, is_in_doc_build, is_in_ray_actor)
41+
GiB_bytes, get_ip, is_in_doc_build, is_in_ray_actor)
4242

4343
# yapf: enable
4444

@@ -290,6 +290,7 @@ class EngineArgs:
290290
data_parallel_size_local: Optional[int] = None
291291
data_parallel_address: Optional[str] = None
292292
data_parallel_rpc_port: Optional[int] = None
293+
data_parallel_backend: str = ParallelConfig.data_parallel_backend
293294
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
294295
max_parallel_loading_workers: Optional[
295296
int] = ParallelConfig.max_parallel_loading_workers
@@ -618,6 +619,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
618619
type=int,
619620
help='Port for data parallel RPC '
620621
'communication.')
622+
parallel_group.add_argument('--data-parallel-backend',
623+
'-dpb',
624+
type=str,
625+
help='Backend for data parallel, either '
626+
'"mp" or "ray".')
621627
parallel_group.add_argument(
622628
"--enable-expert-parallel",
623629
**parallel_kwargs["enable_expert_parallel"])
@@ -1050,23 +1056,37 @@ def create_engine_config(
10501056

10511057
# DP address, used in multi-node case for torch distributed group
10521058
# and ZMQ sockets.
1053-
data_parallel_address = self.data_parallel_address if (
1054-
self.data_parallel_address
1055-
is not None) else ParallelConfig.data_parallel_master_ip
1059+
if self.data_parallel_address is None:
1060+
if self.data_parallel_backend == "ray":
1061+
host_ip = get_ip()
1062+
logger.info(
1063+
"Using host IP %s as ray-based data parallel address",
1064+
host_ip)
1065+
data_parallel_address = host_ip
1066+
else:
1067+
assert self.data_parallel_backend == "mp", (
1068+
"data_parallel_backend can only be ray or mp, got %s",
1069+
self.data_parallel_backend)
1070+
data_parallel_address = ParallelConfig.data_parallel_master_ip
1071+
else:
1072+
data_parallel_address = self.data_parallel_address
10561073

10571074
# This port is only used when there are remote data parallel engines,
10581075
# otherwise the local IPC transport is used.
10591076
data_parallel_rpc_port = self.data_parallel_rpc_port if (
10601077
self.data_parallel_rpc_port
10611078
is not None) else ParallelConfig.data_parallel_rpc_port
10621079

1080+
data_parallel_backend = self.data_parallel_backend
1081+
10631082
parallel_config = ParallelConfig(
10641083
pipeline_parallel_size=self.pipeline_parallel_size,
10651084
tensor_parallel_size=self.tensor_parallel_size,
10661085
data_parallel_size=self.data_parallel_size,
10671086
data_parallel_size_local=data_parallel_size_local,
10681087
data_parallel_master_ip=data_parallel_address,
10691088
data_parallel_rpc_port=data_parallel_rpc_port,
1089+
data_parallel_backend=data_parallel_backend,
10701090
enable_expert_parallel=self.enable_expert_parallel,
10711091
max_parallel_loading_workers=self.max_parallel_loading_workers,
10721092
disable_custom_all_reduce=self.disable_custom_all_reduce,

vllm/entrypoints/cli/serve.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.v1.executor.abstract import Executor
2828
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
2929
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
30-
EngineZmqAddresses, get_engine_client_zmq_addr,
30+
CoreEngineActorManager, EngineZmqAddresses,
31+
get_engine_client_zmq_addr,
3132
wait_for_completion_or_failure,
3233
wait_for_engine_startup)
3334

@@ -229,6 +230,32 @@ def run_multi_api_server(args: argparse.Namespace):
229230
logger.info("Started DP Coordinator process (PID: %d)",
230231
coordinator.proc.pid)
231232

233+
if parallel_config.data_parallel_backend == "ray":
234+
logger.info("Starting ray-based data parallel backend")
235+
236+
engine_actor_manager = CoreEngineActorManager(
237+
local_engine_count=local_engine_count,
238+
vllm_config=vllm_config,
239+
addresses=addresses,
240+
executor_class=Executor.get_class(vllm_config),
241+
log_stats=not engine_args.disable_log_stats,
242+
)
243+
# Start API servers using the manager
244+
api_server_manager = APIServerProcessManager(
245+
target_server_fn=run_api_server_worker_proc,
246+
listen_address=listen_address,
247+
sock=sock,
248+
args=args,
249+
num_servers=num_api_servers,
250+
input_addresses=input_addresses,
251+
output_addresses=output_addresses,
252+
stats_update_address=stats_update_address)
253+
254+
wait_for_completion_or_failure(api_server_manager=api_server_manager,
255+
engine_manager=engine_actor_manager,
256+
coordinator=coordinator)
257+
return
258+
232259
handshake_address = get_engine_client_zmq_addr(
233260
local_only, host, parallel_config.data_parallel_rpc_port)
234261

@@ -277,10 +304,9 @@ def run_multi_api_server(args: argparse.Namespace):
277304
)
278305

279306
# Wait for API servers
280-
wait_for_completion_or_failure(
281-
api_server_manager=api_server_manager,
282-
local_engine_manager=local_engine_manager,
283-
coordinator=coordinator)
307+
wait_for_completion_or_failure(api_server_manager=api_server_manager,
308+
engine_manager=local_engine_manager,
309+
coordinator=coordinator)
284310

285311

286312
def run_api_server_worker_proc(listen_address,

vllm/v1/engine/async_llm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.usage.usage_lib import UsageContext
2828
from vllm.utils import Device, cdiv
2929
from vllm.v1.engine import EngineCoreRequest
30-
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
30+
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
31+
RayDPClient)
3132
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
3233
from vllm.v1.engine.output_processor import (OutputProcessor,
3334
RequestOutputCollector)
@@ -119,9 +120,15 @@ def __init__(
119120
log_stats=self.log_stats)
120121

121122
# EngineCore (starts the engine in background process).
122-
core_client_class = AsyncMPClient if (
123-
vllm_config.parallel_config.data_parallel_size
124-
== 1) else DPAsyncMPClient
123+
core_client_class: Union[type[RayDPClient], type[DPAsyncMPClient],
124+
type[AsyncMPClient]]
125+
if vllm_config.parallel_config.data_parallel_size > 1:
126+
if vllm_config.parallel_config.data_parallel_backend == "ray":
127+
core_client_class = RayDPClient
128+
else:
129+
core_client_class = DPAsyncMPClient
130+
else:
131+
core_client_class = AsyncMPClient
125132

126133
self.engine_core = core_client_class(
127134
vllm_config=vllm_config,

0 commit comments

Comments
 (0)