Skip to content

Commit 46c6ac4

Browse files
🐛 fix a bug in tests, add DISABLE_ASSERTS (#375)
# Description - Added `DISABLE_ASSERTS ` to the scheduling tests - It's very useful if you want to debug without actually asserting the values - Fixed a bug where `server_args.extend(` should have be called in tests --------- Signed-off-by: Prashant Gupta <[email protected]>
1 parent 7d06a62 commit 46c6ac4

File tree

3 files changed

+56
-43
lines changed

3 files changed

+56
-43
lines changed

tests/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,11 @@ def remote_openai_server(request):
152152
"VLLM_SPYRE_USE_CB": "1",
153153
"VLLM_SPYRE_DYNAMO_BACKEND": backend
154154
}
155-
server_args = [
155+
server_args.extend([
156156
"--max_num_seqs",
157157
str(max_num_seqs), "--max-model-len",
158158
str(max_model_len)
159-
]
160-
159+
])
161160
else:
162161
warmup_shape = params['warmup_shape']
163162
warmup_prompt_length = [t[0] for t in warmup_shape]

tests/e2e/test_spyre_online.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33
from spyre_util import get_spyre_backend_list, get_spyre_model_list
44

55

6+
def _check_result(client, model, max_tokens=8, temperature=0.0, n=1) -> None:
7+
completion = client.completions.create(
8+
model=model,
9+
prompt="Hello World!",
10+
max_tokens=max_tokens,
11+
temperature=temperature,
12+
n=n,
13+
)
14+
assert len(completion.choices) == n
15+
assert len(completion.choices[0].text) > 0
16+
17+
618
@pytest.mark.parametrize("model", get_spyre_model_list())
719
@pytest.mark.parametrize(
820
"tp_size",
@@ -40,20 +52,9 @@ def test_openai_serving(
4052
"""Test online serving using the `vllm serve` CLI"""
4153

4254
client = remote_openai_server.get_client()
43-
completion = client.completions.create(model=model,
44-
prompt="Hello World!",
45-
max_tokens=5,
46-
temperature=0.0)
47-
assert len(completion.choices) == 1
48-
assert len(completion.choices[0].text) > 0
4955

50-
completion = client.completions.create(model=model,
51-
prompt="Hello World!",
52-
max_tokens=5,
53-
temperature=1.0,
54-
n=2)
55-
assert len(completion.choices) == 2
56-
assert len(completion.choices[0].text) > 0
56+
_check_result(client, model, n=1)
57+
_check_result(client, model, temperature=1.0, n=2)
5758

5859
# rest are SB tests
5960
if cb:
@@ -73,8 +74,8 @@ def test_openai_serving(
7374
# Short prompt under context length but requesting too many tokens for
7475
# the warmup shape should return an empty result
7576
try:
76-
completion = client.completions.create(model=model,
77-
prompt="Hello World!",
78-
max_tokens=25)
77+
client.completions.create(model=model,
78+
prompt="Hello World!",
79+
max_tokens=25)
7980
except openai.BadRequestError as e:
8081
assert "warmup" in str(e)

tests/scheduling_utils.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from vllm_spyre.v1.core.scheduler import ContinuousBatchingSpyreScheduler
1313

14+
DISABLE_ASSERTS = False # used for debugging
15+
1416

1517
def augment_checked_steps(
1618
checked_steps: list[dict[str, Any]]) -> deque[dict[str, Any]]:
@@ -105,11 +107,13 @@ def check_scheduler_inference_steps(
105107
generated_prompts.append(request.prompt_token_ids)
106108

107109
# Setup the engine
108-
engine_args = EngineArgs(model=model,
109-
tokenizer=model,
110-
max_model_len=max_model_len,
111-
max_num_seqs=max_num_seqs,
112-
num_gpu_blocks_override=available_blocks)
110+
engine_args = EngineArgs(
111+
model=model,
112+
tokenizer=model,
113+
max_model_len=max_model_len,
114+
max_num_seqs=max_num_seqs,
115+
num_gpu_blocks_override=available_blocks,
116+
)
113117
vllm_config = engine_args.create_engine_config()
114118
executor_class = Executor.get_class(vllm_config)
115119
engine_core = EngineCore(vllm_config=vllm_config,
@@ -139,17 +143,18 @@ def check_scheduler_inference_steps(
139143
r.request_id for r in request_outputs if r.finished
140144
]
141145

142-
assert (scheduler.tkv == step_ref["tkv"]
143-
), f"Step {step}, tkv: {scheduler.tkv}"
144-
assert waiting == step_ref[
145-
"waiting"], f"Step {step}, waiting: {waiting}"
146-
assert running == step_ref[
147-
"running"], f"Step {step}, running: {running}"
148-
assert (out_reqs_ids == step_ref["request_outputs"]
149-
), f"Step {step}, request outputs: {out_reqs_ids}"
146+
assert DISABLE_ASSERTS or (scheduler.tkv == step_ref["tkv"]
147+
), f"Step {step}, tkv: {scheduler.tkv}"
148+
assert (DISABLE_ASSERTS or waiting
149+
== step_ref["waiting"]), f"Step {step}, waiting: {waiting}"
150+
assert (DISABLE_ASSERTS or running
151+
== step_ref["running"]), f"Step {step}, running: {running}"
152+
assert DISABLE_ASSERTS or (
153+
out_reqs_ids == step_ref["request_outputs"]
154+
), f"Step {step}, request outputs: {out_reqs_ids}"
150155

151156
ref_finished_reqs = step_ref.get("finished_requests", [])
152-
assert (
157+
assert DISABLE_ASSERTS or (
153158
out_reqs_finished == ref_finished_reqs
154159
), f"Step {step}, finished request output: {out_reqs_finished}"
155160

@@ -166,27 +171,31 @@ def check_scheduler_inference_steps(
166171
[len(blocks) for blocks in req_ids2blocks.values()])
167172

168173
if step > 0:
169-
assert (
174+
assert DISABLE_ASSERTS or (
170175
n_reserved_blocks == step_ref["n_reserved_blocks"]
171176
), f"Step {step}, n_reserved_blocks: {n_reserved_blocks}"
172-
assert (n_used_blocks == step_ref["n_used_blocks"]
173-
), f"Step {step}, n_used_blocks: {n_used_blocks}"
177+
assert DISABLE_ASSERTS or (
178+
n_used_blocks == step_ref["n_used_blocks"]
179+
), f"Step {step}, n_used_blocks: {n_used_blocks}"
174180

175-
assert len(req_ids2blocks) == len(req_ids2reserved_blocks)
181+
assert DISABLE_ASSERTS or len(req_ids2blocks) == len(
182+
req_ids2reserved_blocks)
176183
for req_id in req_ids2blocks:
177184
# current number of used blocks should be less than reserved
178-
assert len(
179-
req_ids2blocks[req_id]) <= req_ids2reserved_blocks[req_id]
185+
assert (DISABLE_ASSERTS or len(req_ids2blocks[req_id])
186+
<= req_ids2reserved_blocks[req_id])
180187
# update requested/reserved blocks to check in last step
181-
# Note: overwrite and not max because of reduce_left_padding()
188+
# Note: overwrite and not max
189+
# because of reduce_left_padding()
182190
requested_blocks[req_id] = len(req_ids2blocks[req_id])
183191
reserved_blocks[req_id] = req_ids2reserved_blocks[req_id]
184192

185193
# last step: check that sequences used all their reserved blocks
186194
# Note: no early stopping, all sequences produce max_num_tokens
187195
if len(checked_steps) == 0:
188196
for req_id in requested_blocks:
189-
assert requested_blocks[req_id] == reserved_blocks[req_id]
197+
assert (DISABLE_ASSERTS
198+
or requested_blocks[req_id] == reserved_blocks[req_id])
190199

191200
# Perform next step
192201
step_output = engine_core.step()
@@ -197,15 +206,17 @@ def check_scheduler_inference_steps(
197206
for output in request_outputs:
198207
new_token_ids = output.new_token_ids
199208
new_logprobs = output.new_logprobs.logprobs
200-
assert len(new_token_ids) == 1 and len(new_logprobs) == 1
209+
assert DISABLE_ASSERTS or len(new_token_ids) == 1 and len(
210+
new_logprobs) == 1
201211

202212
collected_outputs[output.request_id]["token_ids"].append(
203213
new_token_ids[0])
204214
collected_outputs[output.request_id]["logprobs"].append(
205215
new_logprobs[0][0])
206216

207217
output_keys = sorted(int(k) for k in collected_outputs)
208-
assert output_keys[0] == 0 and output_keys[-1] == len(output_keys) - 1
218+
assert (DISABLE_ASSERTS
219+
or output_keys[0] == 0 and output_keys[-1] == len(output_keys) - 1)
209220

210221
# convert dict of dicts to ordered list and make values immutable
211222
collected_outputs_new = []
@@ -216,4 +227,6 @@ def check_scheduler_inference_steps(
216227
output[k] = tuple(list_values)
217228
collected_outputs_new.append(output)
218229

230+
# good practice?
231+
engine_core.shutdown()
219232
return collected_outputs_new, generated_prompts

0 commit comments

Comments
 (0)