Skip to content

Commit d467810

Browse files
authored
feat: add remote offline batch inference with vllm example (#848)
* no-jira: add remote offline batch inference with vllm example * no-jira: remove notebook output * feat: add temp workdir for gcs workaround * feat: add config details to batch inference demo * chore: ignore jupyter notebooks from codecov
1 parent 326434d commit d467810

File tree

4 files changed

+283
-0
lines changed

4 files changed

+283
-0
lines changed

codecov.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
ignore:
2+
- "**/*.ipynb"
3+
- "demo-notebooks/**"
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Remote Offline Batch Inference with Ray Data & vLLM Example\n",
8+
"\n",
9+
"This notebook presumes:\n",
10+
"- You have a Ray Cluster URL given to you to run workloads on\n"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 4,
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"from codeflare_sdk import RayJobClient\n",
20+
"\n",
21+
"# Setup Authentication Configuration\n",
22+
"auth_token = \"XXXX\"\n",
23+
"header = {\"Authorization\": f\"Bearer {auth_token}\"}"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": 6,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"# Gather the dashboard URL (provided by the creator of the RayCluster)\n",
33+
"ray_dashboard = \"XXXX\" # Replace with the Ray dashboard URL\n",
34+
"\n",
35+
"# Initialize the RayJobClient\n",
36+
"client = RayJobClient(address=ray_dashboard, headers=header, verify=True)"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"### Simple Example Explanation\n",
44+
"\n",
45+
"With the RayJobClient instantiated, lets run some batch inference. The following code is stored in `simple_batch_inf.py`, and is used as the entrypoint for the RayJob.\n",
46+
"\n",
47+
"What this processor configuration does:\n",
48+
"- Set up a vLLM engine with your model\n",
49+
"- Configure some settings for GPU processing\n",
50+
"- Defines batch processing parameters (8 requests per batch, 2 GPU workers)\n",
51+
"\n",
52+
"#### Model Source Configuration\n",
53+
"\n",
54+
"The `model_source` parameter supports several loading methods:\n",
55+
"\n",
56+
"* **Hugging Face Hub** (default): Use repository ID `model_source=\"meta-llama/Llama-2-7b-chat-hf\"`\n",
57+
"* **Local Directory**: Use file path `model_source=\"/path/to/my/local/model\"`\n",
58+
"* **Other Sources**: ModelScope via environment variables `VLLM_MODELSCOPE_DOWNLOADS_DIR`\n",
59+
"\n",
60+
"For complete model support and options, see the [official vLLM documentation](https://docs.vllm.ai/en/latest/models/supported_models.html).\n",
61+
"\n",
62+
"```python\n",
63+
"import ray\n",
64+
"from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig\n",
65+
"\n",
66+
"processor_config = vLLMEngineProcessorConfig(\n",
67+
" model_source=\"replace-me\",\n",
68+
" engine_kwargs=dict(\n",
69+
" enable_lora=False,\n",
70+
" dtype=\"half\",\n",
71+
" max_model_len=1024,\n",
72+
" ),\n",
73+
" # Batch size: Larger batches increase throughput but reduce fault tolerance\n",
74+
" # - Small batches (4-8): Better for fault tolerance and memory constraints\n",
75+
" # - Large batches (16-32): Higher throughput, better GPU utilization\n",
76+
" # - Choose based on your Ray Cluster size and memory availability\n",
77+
" batch_size=8,\n",
78+
" # Concurrency: Number of vLLM engine workers to spawn \n",
79+
" # - Set to match your total GPU count for maximum utilization\n",
80+
" # - Each worker gets assigned to a GPU automatically by Ray scheduler\n",
81+
" # - Can use all GPUs across head and worker nodes\n",
82+
" concurrency=2,\n",
83+
")\n",
84+
"```"
85+
]
86+
},
87+
{
88+
"cell_type": "markdown",
89+
"metadata": {},
90+
"source": [
91+
"With the config defined, we can instantiate the processor. This enables batch inference by processing multiple requests through the vLLM engine, with two key steps:\n",
92+
"- **Preprocess**: Converts each row into a structured chat format with system instructions and user queries, preparing the input for the LLM\n",
93+
"- **Postprocess**: Extracts only the generated text from the model response, cleaning up the output\n",
94+
"\n",
95+
"The processor defines the pipeline that will be applied to each row in the dataset, enabling efficient batch processing through Ray Data's distributed execution framework.\n",
96+
"\n",
97+
"```python\n",
98+
"processor = build_llm_processor(\n",
99+
" processor_config,\n",
100+
" preprocess=lambda row: dict(\n",
101+
" messages=[\n",
102+
" {\n",
103+
" \"role\": \"system\",\n",
104+
" \"content\": \"You are a calculator. Please only output the answer \"\n",
105+
" \"of the given equation.\",\n",
106+
" },\n",
107+
" {\"role\": \"user\", \"content\": f\"{row['id']} ** 3 = ?\"},\n",
108+
" ],\n",
109+
" sampling_params=dict(\n",
110+
" temperature=0.3,\n",
111+
" max_tokens=20,\n",
112+
" detokenize=False,\n",
113+
" ),\n",
114+
" ),\n",
115+
" postprocess=lambda row: {\n",
116+
" \"resp\": row[\"generated_text\"],\n",
117+
" },\n",
118+
")\n",
119+
"```"
120+
]
121+
},
122+
{
123+
"cell_type": "markdown",
124+
"metadata": {},
125+
"source": [
126+
"#### Running the Pipeline\n",
127+
"Now we can run the batch inference pipeline on our data, it will:\n",
128+
"- In the background, the processor will download the model into memory where vLLM serves it locally (on Ray Cluster) for use in inference\n",
129+
"- Generate a sample Ray Dataset with 32 rows (0-31) to process\n",
130+
"- Run the LLM processor on the dataset, triggering the preprocessing, inference, and postprocessing steps\n",
131+
"- Execute the lazy pipeline and loads results into memory\n",
132+
"- Iterate through all outputs and print each response \n",
133+
"\n",
134+
"```python\n",
135+
"ds = ray.data.range(30)\n",
136+
"ds = processor(ds)\n",
137+
"ds = ds.materialize()\n",
138+
"\n",
139+
"for out in ds.take_all():\n",
140+
" print(out)\n",
141+
" print(\"==========\")\n",
142+
"```\n",
143+
"\n",
144+
"### Job Submission\n",
145+
"\n",
146+
"Now we can submit this job against the Ray Cluster using the `RayJobClient` from earlier "
147+
]
148+
},
149+
{
150+
"cell_type": "code",
151+
"execution_count": null,
152+
"metadata": {},
153+
"outputs": [],
154+
"source": [
155+
"import tempfile\n",
156+
"import shutil\n",
157+
"\n",
158+
"# Create a clean directory with ONLY your script\n",
159+
"temp_dir = tempfile.mkdtemp()\n",
160+
"shutil.copy(\"simple_batch_inf.py\", temp_dir)\n",
161+
"\n",
162+
"entrypoint_command = \"python simple_batch_inf.py\"\n",
163+
"\n",
164+
"submission_id = client.submit_job(\n",
165+
" entrypoint=entrypoint_command,\n",
166+
" runtime_env={\"working_dir\": temp_dir, \"pip\": \"requirements.txt\"},\n",
167+
")\n",
168+
"\n",
169+
"print(submission_id + \" successfully submitted\")"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"# Get the job's status\n",
179+
"client.get_job_status(submission_id)"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": null,
185+
"metadata": {},
186+
"outputs": [],
187+
"source": [
188+
"# Get the job's logs\n",
189+
"client.get_job_logs(submission_id)"
190+
]
191+
}
192+
],
193+
"metadata": {
194+
"kernelspec": {
195+
"display_name": ".venv",
196+
"language": "python",
197+
"name": "python3"
198+
},
199+
"language_info": {
200+
"codemirror_mode": {
201+
"name": "ipython",
202+
"version": 3
203+
},
204+
"file_extension": ".py",
205+
"mimetype": "text/x-python",
206+
"name": "python",
207+
"nbconvert_exporter": "python",
208+
"pygments_lexer": "ipython3",
209+
"version": "3.11.12"
210+
}
211+
},
212+
"nbformat": 4,
213+
"nbformat_minor": 2
214+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
vllm
2+
transformers
3+
triton>=2.0.0
4+
torch>=2.0.0
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import ray
2+
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
3+
4+
5+
# 1. Construct a vLLM processor config.
6+
processor_config = vLLMEngineProcessorConfig(
7+
# The base model.
8+
model_source="unsloth/Llama-3.2-1B-Instruct",
9+
# vLLM engine config.
10+
engine_kwargs=dict(
11+
enable_lora=False,
12+
# # Older GPUs (e.g. T4) don't support bfloat16. You should remove
13+
# # this line if you're using later GPUs.
14+
dtype="half",
15+
# Reduce the model length to fit small GPUs. You should remove
16+
# this line if you're using large GPUs.
17+
max_model_len=1024,
18+
),
19+
# The batch size used in Ray Data.
20+
batch_size=8,
21+
# Use one GPU in this example.
22+
concurrency=1,
23+
# If you save the LoRA adapter in S3, you can set the following path.
24+
# dynamic_lora_loading_path="s3://your-lora-bucket/",
25+
)
26+
27+
# 2. Construct a processor using the processor config.
28+
processor = build_llm_processor(
29+
processor_config,
30+
preprocess=lambda row: dict(
31+
# Remove the LoRA model specification
32+
messages=[
33+
{
34+
"role": "system",
35+
"content": "You are a calculator. Please only output the answer "
36+
"of the given equation.",
37+
},
38+
{"role": "user", "content": f"{row['id']} ** 3 = ?"},
39+
],
40+
sampling_params=dict(
41+
temperature=0.3,
42+
max_tokens=20,
43+
detokenize=False,
44+
),
45+
),
46+
postprocess=lambda row: {
47+
"resp": row["generated_text"],
48+
},
49+
)
50+
51+
# 3. Synthesize a dataset with 32 rows.
52+
ds = ray.data.range(32)
53+
# 4. Apply the processor to the dataset. Note that this line won't kick off
54+
# anything because processor is execution lazily.
55+
ds = processor(ds)
56+
# Materialization kicks off the pipeline execution.
57+
ds = ds.materialize()
58+
59+
# 5. Print all outputs.
60+
for out in ds.take_all():
61+
print(out)
62+
print("==========")

0 commit comments

Comments
 (0)