Skip to content

Commit cccfaa6

Browse files
committed
[CI] LLM tests integration
ghstack-source-id: 6362a74 Pull-Request: #3216
1 parent 03ffc82 commit cccfaa6

File tree

9 files changed

+76
-65
lines changed

9 files changed

+76
-65
lines changed

.github/unittest/linux_libs/scripts_llm/install.sh renamed to .github/unittest/llm/scripts_llm/install.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ git submodule sync && git submodule update --init --recursive
3030
#printf "Installing PyTorch with cu128"
3131
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
3232
# if [ "${CU_VERSION:-}" == cpu ] ; then
33-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
3434
# else
35-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
3636
# fi
3737
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
3838
# if [ "${CU_VERSION:-}" == cpu ] ; then
39-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
39+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
4040
# else
41-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
41+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
4242
# fi
4343
#else
4444
# printf "Failed to install pytorch"
@@ -47,9 +47,10 @@ git submodule sync && git submodule update --init --recursive
4747

4848
# install tensordict
4949
if [[ "$RELEASE" == 0 ]]; then
50-
pip3 install git+https://github.com/pytorch/tensordict.git
50+
pip install "pybind11[global]" ninja
51+
pip install git+https://github.com/pytorch/tensordict.git
5152
else
52-
pip3 install tensordict
53+
pip install tensordict
5354
fi
5455

5556
# smoke test

.github/unittest/linux_libs/scripts_llm/run_test.sh renamed to .github/unittest/llm/scripts_llm/run_test.sh

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,4 @@ lib_dir="${env_dir}/lib"
2323

2424
conda deactivate && conda activate ./env
2525

26-
python -c "import transformers, datasets"
27-
28-
pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips
29-
30-
python examples/rlhf/train_rlhf.py \
31-
sys.device=cuda:0 sys.ref_device=cuda:0 \
32-
model.name_or_path=gpt2 train.max_epochs=2 \
33-
data.batch_size=2 train.ppo.ppo_batch_size=2 \
34-
train.ppo.ppo_num_epochs=1 reward_model.name_or_path= \
35-
train.ppo.episode_length=8 train.ppo.num_rollouts_per_epoch=4 \
36-
data.block_size=110 io.logger=csv
26+
pytest test/llm -vvv --instafail --durations 600 --capture no --error-for-skips

.github/unittest/linux_libs/scripts_llm/setup_env.sh renamed to .github/unittest/llm/scripts_llm/setup_env.sh

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,19 @@
66
# Do not install PyTorch and torchvision here, otherwise they also get cached.
77

88
set -e
9-
apt-get update && apt-get upgrade -y && apt-get install -y git cmake
9+
export DEBIAN_FRONTEND=noninteractive
10+
export TZ=UTC
11+
apt-get update
12+
apt-get install -yq --no-install-recommends git wget unzip curl patchelf
1013
# Avoid error: "fatal: unsafe repository"
1114
git config --global --add safe.directory '*'
12-
apt-get install -y wget \
13-
gcc \
14-
g++ \
15-
unzip \
16-
curl \
17-
patchelf \
18-
libosmesa6-dev \
19-
libgl1-mesa-glx \
20-
libglfw3 \
21-
swig3.0 \
22-
libglew-dev \
23-
libglvnd0 \
24-
libgl1 \
25-
libglx0 \
26-
libegl1 \
27-
libgles2
15+
# The base PyTorch devel image provides compilers, CMake >= 3.22, and most build deps.
16+
# Install only minimal utilities not guaranteed to be present.
2817

29-
# Upgrade specific package
30-
apt-get upgrade -y libstdc++6
18+
# CMake available in the PyTorch devel image (Ubuntu 22.04) is sufficient.
19+
20+
# Cleanup APT cache
21+
apt-get clean && rm -rf /var/lib/apt/lists/*
3122

3223
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
3324
root_dir="$(git rev-parse --show-toplevel)"

.github/workflows/test-linux-llm.yml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@ permissions:
2121

2222
jobs:
2323
unittests:
24+
if: ${{ github.event_name == 'push' || (github.event_name == 'pull_request' && contains(join(github.event.pull_request.labels.*.name, ', '), 'llm/')) }}
2425
strategy:
2526
matrix:
26-
python_version: ["3.9"]
27-
cuda_arch_version: ["12.8"]
27+
python_version: ["3.12"]
28+
cuda_arch_version: ["12.9"]
2829
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2930
with:
3031
repository: pytorch/rl
31-
runner: "linux.g5.4xlarge.nvidia.gpu"
32+
runner: "linux.g6.4xlarge.experimental.nvidia.gpu"
3233
# gpu-arch-type: cuda
3334
# gpu-arch-version: "11.7"
34-
docker-image: "nvidia/cudagl:11.4.0-base"
35+
docker-image: "pytorch/pytorch:2.8.0-cuda12.9-cudnn9-devel"
3536
timeout: 120
3637
script: |
3738
if [[ "${{ github.ref }}" =~ release/* ]]; then
@@ -43,14 +44,14 @@ jobs:
4344
fi
4445
4546
set -euo pipefail
46-
export PYTHON_VERSION="3.9"
47-
export CU_VERSION="cu117"
47+
export PYTHON_VERSION="3.12"
48+
export CU_VERSION="cu129"
4849
export TAR_OPTIONS="--no-same-owner"
4950
export UPLOAD_CHANNEL="nightly"
5051
export TF_CPP_MIN_LOG_LEVEL=0
5152
export TD_GET_DEFAULTS_TO_NONE=1
5253
53-
bash .github/unittest/linux_libs/scripts_llm/setup_env.sh
54-
bash .github/unittest/linux_libs/scripts_llm/install.sh
55-
bash .github/unittest/linux_libs/scripts_llm/run_test.sh
56-
bash .github/unittest/linux_libs/scripts_llm/post_process.sh
54+
bash .github/unittest/llm/scripts_llm/setup_env.sh
55+
bash .github/unittest/llm/scripts_llm/install.sh
56+
bash .github/unittest/llm/scripts_llm/run_test.sh
57+
bash .github/unittest/llm/scripts_llm/post_process.sh

test/llm/test_updaters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
5+
from __future__ import annotations
66

77
import argparse
88
import gc

torchrl/modules/llm/backends/vllm/vllm_async.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@
2020
from concurrent.futures import ThreadPoolExecutor, wait
2121
from typing import Any, Literal, TYPE_CHECKING
2222

23-
import ray
24-
2523
import torch
2624

27-
from ray.util.placement_group import placement_group, remove_placement_group
28-
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
2925
from torchrl._utils import logger as torchrl_logger
3026

3127
# Import RLvLLMEngine and shared utilities
@@ -51,6 +47,25 @@
5147
_has_vllm = False
5248

5349

50+
def _get_ray():
51+
"""Import Ray on demand to avoid global import side-effects.
52+
53+
Returns:
54+
ModuleType: The imported Ray module.
55+
56+
Raises:
57+
ImportError: If Ray is not installed.
58+
"""
59+
try:
60+
import ray # type: ignore
61+
62+
return ray
63+
except Exception as e: # pragma: no cover - surfaced to callers
64+
raise ImportError(
65+
"ray is not installed. Please install it with `pip install ray`."
66+
) from e
67+
68+
5469
class _AsyncvLLMWorker:
5570
"""Async vLLM worker extension for Ray with weight update capabilities."""
5671

@@ -264,7 +279,7 @@ async def generate(
264279
"vllm is not installed. Please install it with `pip install vllm`."
265280
)
266281

267-
from vllm import RequestOutput, SamplingParams, TokensPrompt
282+
from vllm import SamplingParams, TokensPrompt
268283

269284
# Track whether input was originally a single prompt
270285
single_prompt_input = False
@@ -471,11 +486,7 @@ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
471486
)
472487

473488

474-
# Create Ray remote versions
475-
if ray is not None and _has_vllm:
476-
_AsyncLLMEngineActor = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
477-
else:
478-
_AsyncLLMEngineActor = None
489+
# Ray actor wrapper is created lazily in __init__ to avoid global Ray import.
479490

480491

481492
class AsyncVLLM(RLvLLMEngine):
@@ -580,17 +591,18 @@ def __init__(
580591
raise ImportError(
581592
"vllm is not installed. Please install it with `pip install vllm`."
582593
)
583-
if ray is None:
584-
raise ImportError(
585-
"ray is not installed. Please install it with `pip install ray`."
586-
)
594+
# Lazily import ray only when constructing the actor class to avoid global import
587595

588596
# Enable prefix caching by default for better performance
589597
engine_args.enable_prefix_caching = enable_prefix_caching
590598

591599
self.engine_args = engine_args
592600
self.num_replicas = num_replicas
593-
self.actor_class = actor_class or _AsyncLLMEngineActor
601+
if actor_class is None:
602+
ray = _get_ray()
603+
self.actor_class = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
604+
else:
605+
self.actor_class = actor_class
594606
self.actors: list = []
595607
self._launched = False
596608
self._service_id = uuid.uuid4().hex[
@@ -605,6 +617,11 @@ def _launch(self):
605617
torchrl_logger.warning("AsyncVLLMEngineService already launched")
606618
return
607619

620+
# Local imports to avoid global Ray dependency
621+
ray = _get_ray()
622+
from ray.util.placement_group import placement_group
623+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
624+
608625
torchrl_logger.info(
609626
f"Launching {self.num_replicas} async vLLM engine actors..."
610627
)
@@ -944,6 +961,7 @@ def generate(
944961
Returns:
945962
RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
946963
"""
964+
ray = _get_ray()
947965
# Check if this is a batch request
948966
if self._is_batch(prompts, prompt_token_ids):
949967
# Handle batched input by unbinding and sending individual requests
@@ -1068,6 +1086,9 @@ def shutdown(self):
10681086
f"Shutting down {len(self.actors)} async vLLM engine actors..."
10691087
)
10701088

1089+
ray = _get_ray()
1090+
from ray.util.placement_group import remove_placement_group
1091+
10711092
# Kill all actors
10721093
for i, actor in enumerate(self.actors):
10731094
try:
@@ -1260,6 +1281,7 @@ def _update_weights_with_nccl_broadcast_simple(
12601281
)
12611282

12621283
updated_weights = 0
1284+
ray = _get_ray()
12631285
with torch.cuda.device(0): # Ensure we're on the correct CUDA device
12641286
for name, weight in gpu_weights.items():
12651287
# Convert dtype to string name (like periodic-mono)
@@ -1336,6 +1358,7 @@ def get_num_unfinished_requests(
13361358
"AsyncVLLM service must be launched before getting request counts"
13371359
)
13381360

1361+
ray = _get_ray()
13391362
if actor_index is not None:
13401363
if not (0 <= actor_index < len(self.actors)):
13411364
raise IndexError(
@@ -1366,6 +1389,7 @@ def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]
13661389
"AsyncVLLM service must be launched before getting cache usage"
13671390
)
13681391

1392+
ray = _get_ray()
13691393
if actor_index is not None:
13701394
if not (0 <= actor_index < len(self.actors)):
13711395
raise IndexError(
@@ -1678,6 +1702,7 @@ def _select_by_requests(self) -> int:
16781702
futures = [
16791703
actor.get_num_unfinished_requests.remote() for actor in self.actors
16801704
]
1705+
ray = _get_ray()
16811706
request_counts = ray.get(futures)
16821707

16831708
# Find the actor with minimum pending requests
@@ -1705,6 +1730,7 @@ def _select_by_cache_usage(self) -> int:
17051730
else:
17061731
# Query actors directly
17071732
futures = [actor.get_cache_usage.remote() for actor in self.actors]
1733+
ray = _get_ray()
17081734
cache_usages = ray.get(futures)
17091735

17101736
# Find the actor with minimum cache usage
@@ -1844,7 +1870,8 @@ def _is_actor_overloaded(self, actor_index: int) -> bool:
18441870
futures = [
18451871
actor.get_num_unfinished_requests.remote() for actor in self.actors
18461872
]
1847-
request_counts = ray.get(futures)
1873+
ray = _get_ray()
1874+
request_counts = ray.get(futures)
18481875

18491876
if not request_counts:
18501877
return False
@@ -1893,8 +1920,9 @@ def get_stats(self) -> dict[str, Any]:
18931920
cache_futures = [
18941921
actor.get_cache_usage.remote() for actor in self.actors
18951922
]
1896-
request_counts = ray.get(request_futures)
1897-
cache_usages = ray.get(cache_futures)
1923+
ray = _get_ray()
1924+
request_counts = ray.get(request_futures)
1925+
cache_usages = ray.get(cache_futures)
18981926

18991927
for i, (requests, cache_usage) in enumerate(
19001928
zip(request_counts, cache_usages)

0 commit comments

Comments
 (0)