2020from concurrent .futures import ThreadPoolExecutor , wait
2121from typing import Any , Literal , TYPE_CHECKING
2222
23- import ray
24-
2523import torch
2624
27- from ray .util .placement_group import placement_group , remove_placement_group
28- from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
2925from torchrl ._utils import logger as torchrl_logger
3026
3127# Import RLvLLMEngine and shared utilities
5147 _has_vllm = False
5248
5349
50+ def _get_ray ():
51+ """Import Ray on demand to avoid global import side-effects.
52+
53+ Returns:
54+ ModuleType: The imported Ray module.
55+
56+ Raises:
57+ ImportError: If Ray is not installed.
58+ """
59+ try :
60+ import ray # type: ignore
61+
62+ return ray
63+ except Exception as e : # pragma: no cover - surfaced to callers
64+ raise ImportError (
65+ "ray is not installed. Please install it with `pip install ray`."
66+ ) from e
67+
68+
5469class _AsyncvLLMWorker :
5570 """Async vLLM worker extension for Ray with weight update capabilities."""
5671
@@ -264,7 +279,7 @@ async def generate(
264279 "vllm is not installed. Please install it with `pip install vllm`."
265280 )
266281
267- from vllm import RequestOutput , SamplingParams , TokensPrompt
282+ from vllm import SamplingParams , TokensPrompt
268283
269284 # Track whether input was originally a single prompt
270285 single_prompt_input = False
@@ -471,11 +486,7 @@ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
471486 )
472487
473488
474- # Create Ray remote versions
475- if ray is not None and _has_vllm :
476- _AsyncLLMEngineActor = ray .remote (num_cpus = 0 , num_gpus = 0 )(_AsyncLLMEngine )
477- else :
478- _AsyncLLMEngineActor = None
489+ # Ray actor wrapper is created lazily in __init__ to avoid global Ray import.
479490
480491
481492class AsyncVLLM (RLvLLMEngine ):
@@ -580,17 +591,18 @@ def __init__(
580591 raise ImportError (
581592 "vllm is not installed. Please install it with `pip install vllm`."
582593 )
583- if ray is None :
584- raise ImportError (
585- "ray is not installed. Please install it with `pip install ray`."
586- )
594+ # Lazily import ray only when constructing the actor class to avoid global import
587595
588596 # Enable prefix caching by default for better performance
589597 engine_args .enable_prefix_caching = enable_prefix_caching
590598
591599 self .engine_args = engine_args
592600 self .num_replicas = num_replicas
593- self .actor_class = actor_class or _AsyncLLMEngineActor
601+ if actor_class is None :
602+ ray = _get_ray ()
603+ self .actor_class = ray .remote (num_cpus = 0 , num_gpus = 0 )(_AsyncLLMEngine )
604+ else :
605+ self .actor_class = actor_class
594606 self .actors : list = []
595607 self ._launched = False
596608 self ._service_id = uuid .uuid4 ().hex [
@@ -605,6 +617,11 @@ def _launch(self):
605617 torchrl_logger .warning ("AsyncVLLMEngineService already launched" )
606618 return
607619
620+ # Local imports to avoid global Ray dependency
621+ ray = _get_ray ()
622+ from ray .util .placement_group import placement_group
623+ from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
624+
608625 torchrl_logger .info (
609626 f"Launching { self .num_replicas } async vLLM engine actors..."
610627 )
@@ -944,6 +961,7 @@ def generate(
944961 Returns:
945962 RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
946963 """
964+ ray = _get_ray ()
947965 # Check if this is a batch request
948966 if self ._is_batch (prompts , prompt_token_ids ):
949967 # Handle batched input by unbinding and sending individual requests
@@ -1068,6 +1086,9 @@ def shutdown(self):
10681086 f"Shutting down { len (self .actors )} async vLLM engine actors..."
10691087 )
10701088
1089+ ray = _get_ray ()
1090+ from ray .util .placement_group import remove_placement_group
1091+
10711092 # Kill all actors
10721093 for i , actor in enumerate (self .actors ):
10731094 try :
@@ -1260,6 +1281,7 @@ def _update_weights_with_nccl_broadcast_simple(
12601281 )
12611282
12621283 updated_weights = 0
1284+ ray = _get_ray ()
12631285 with torch .cuda .device (0 ): # Ensure we're on the correct CUDA device
12641286 for name , weight in gpu_weights .items ():
12651287 # Convert dtype to string name (like periodic-mono)
@@ -1336,6 +1358,7 @@ def get_num_unfinished_requests(
13361358 "AsyncVLLM service must be launched before getting request counts"
13371359 )
13381360
1361+ ray = _get_ray ()
13391362 if actor_index is not None :
13401363 if not (0 <= actor_index < len (self .actors )):
13411364 raise IndexError (
@@ -1366,6 +1389,7 @@ def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]
13661389 "AsyncVLLM service must be launched before getting cache usage"
13671390 )
13681391
1392+ ray = _get_ray ()
13691393 if actor_index is not None :
13701394 if not (0 <= actor_index < len (self .actors )):
13711395 raise IndexError (
@@ -1678,6 +1702,7 @@ def _select_by_requests(self) -> int:
16781702 futures = [
16791703 actor .get_num_unfinished_requests .remote () for actor in self .actors
16801704 ]
1705+ ray = _get_ray ()
16811706 request_counts = ray .get (futures )
16821707
16831708 # Find the actor with minimum pending requests
@@ -1705,6 +1730,7 @@ def _select_by_cache_usage(self) -> int:
17051730 else :
17061731 # Query actors directly
17071732 futures = [actor .get_cache_usage .remote () for actor in self .actors ]
1733+ ray = _get_ray ()
17081734 cache_usages = ray .get (futures )
17091735
17101736 # Find the actor with minimum cache usage
@@ -1844,7 +1870,8 @@ def _is_actor_overloaded(self, actor_index: int) -> bool:
18441870 futures = [
18451871 actor .get_num_unfinished_requests .remote () for actor in self .actors
18461872 ]
1847- request_counts = ray .get (futures )
1873+ ray = _get_ray ()
1874+ request_counts = ray .get (futures )
18481875
18491876 if not request_counts :
18501877 return False
@@ -1893,8 +1920,9 @@ def get_stats(self) -> dict[str, Any]:
18931920 cache_futures = [
18941921 actor .get_cache_usage .remote () for actor in self .actors
18951922 ]
1896- request_counts = ray .get (request_futures )
1897- cache_usages = ray .get (cache_futures )
1923+ ray = _get_ray ()
1924+ request_counts = ray .get (request_futures )
1925+ cache_usages = ray .get (cache_futures )
18981926
18991927 for i , (requests , cache_usage ) in enumerate (
19001928 zip (request_counts , cache_usages )
0 commit comments