Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,11 @@ def remote_openai_server(request):
"VLLM_SPYRE_USE_CB": "1",
"VLLM_SPYRE_DYNAMO_BACKEND": backend
}
server_args = [
server_args.extend([
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found a 🐛 !

"--max_num_seqs",
str(max_num_seqs), "--max-model-len",
str(max_model_len)
]

])
else:
warmup_shape = params['warmup_shape']
warmup_prompt_length = [t[0] for t in warmup_shape]
Expand Down
33 changes: 17 additions & 16 deletions tests/e2e/test_spyre_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
from spyre_util import get_spyre_backend_list, get_spyre_model_list


def _check_result(client, model, max_tokens=8, temperature=0.0, n=1) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🌶️ !

completion = client.completions.create(
model=model,
prompt="Hello World!",
max_tokens=max_tokens,
temperature=temperature,
n=n,
)
assert len(completion.choices) == n
assert len(completion.choices[0].text) > 0


@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize(
"tp_size",
Expand Down Expand Up @@ -40,20 +52,9 @@ def test_openai_serving(
"""Test online serving using the `vllm serve` CLI"""

client = remote_openai_server.get_client()
completion = client.completions.create(model=model,
prompt="Hello World!",
max_tokens=5,
temperature=0.0)
assert len(completion.choices) == 1
assert len(completion.choices[0].text) > 0

completion = client.completions.create(model=model,
prompt="Hello World!",
max_tokens=5,
temperature=1.0,
n=2)
assert len(completion.choices) == 2
assert len(completion.choices[0].text) > 0
_check_result(client, model, n=1)
_check_result(client, model, temperature=1.0, n=2)

# rest are SB tests
if cb:
Expand All @@ -73,8 +74,8 @@ def test_openai_serving(
# Short prompt under context length but requesting too many tokens for
# the warmup shape should return an empty result
try:
completion = client.completions.create(model=model,
prompt="Hello World!",
max_tokens=25)
client.completions.create(model=model,
prompt="Hello World!",
max_tokens=25)
except openai.BadRequestError as e:
assert "warmup" in str(e)
61 changes: 37 additions & 24 deletions tests/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from vllm_spyre.v1.core.scheduler import ContinuousBatchingSpyreScheduler

DISABLE_ASSERTS = False # used for debugging


def augment_checked_steps(
checked_steps: list[dict[str, Any]]) -> deque[dict[str, Any]]:
Expand Down Expand Up @@ -105,11 +107,13 @@ def check_scheduler_inference_steps(
generated_prompts.append(request.prompt_token_ids)

# Setup the engine
engine_args = EngineArgs(model=model,
tokenizer=model,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
num_gpu_blocks_override=available_blocks)
engine_args = EngineArgs(
model=model,
tokenizer=model,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
num_gpu_blocks_override=available_blocks,
)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
Expand Down Expand Up @@ -139,17 +143,18 @@ def check_scheduler_inference_steps(
r.request_id for r in request_outputs if r.finished
]

assert (scheduler.tkv == step_ref["tkv"]
), f"Step {step}, tkv: {scheduler.tkv}"
assert waiting == step_ref[
"waiting"], f"Step {step}, waiting: {waiting}"
assert running == step_ref[
"running"], f"Step {step}, running: {running}"
assert (out_reqs_ids == step_ref["request_outputs"]
), f"Step {step}, request outputs: {out_reqs_ids}"
assert DISABLE_ASSERTS or (scheduler.tkv == step_ref["tkv"]
), f"Step {step}, tkv: {scheduler.tkv}"
assert (DISABLE_ASSERTS or waiting
== step_ref["waiting"]), f"Step {step}, waiting: {waiting}"
assert (DISABLE_ASSERTS or running
== step_ref["running"]), f"Step {step}, running: {running}"
assert DISABLE_ASSERTS or (
out_reqs_ids == step_ref["request_outputs"]
), f"Step {step}, request outputs: {out_reqs_ids}"

ref_finished_reqs = step_ref.get("finished_requests", [])
assert (
assert DISABLE_ASSERTS or (
out_reqs_finished == ref_finished_reqs
), f"Step {step}, finished request output: {out_reqs_finished}"

Expand All @@ -166,27 +171,31 @@ def check_scheduler_inference_steps(
[len(blocks) for blocks in req_ids2blocks.values()])

if step > 0:
assert (
assert DISABLE_ASSERTS or (
n_reserved_blocks == step_ref["n_reserved_blocks"]
), f"Step {step}, n_reserved_blocks: {n_reserved_blocks}"
assert (n_used_blocks == step_ref["n_used_blocks"]
), f"Step {step}, n_used_blocks: {n_used_blocks}"
assert DISABLE_ASSERTS or (
n_used_blocks == step_ref["n_used_blocks"]
), f"Step {step}, n_used_blocks: {n_used_blocks}"

assert len(req_ids2blocks) == len(req_ids2reserved_blocks)
assert DISABLE_ASSERTS or len(req_ids2blocks) == len(
req_ids2reserved_blocks)
for req_id in req_ids2blocks:
# current number of used blocks should be less than reserved
assert len(
req_ids2blocks[req_id]) <= req_ids2reserved_blocks[req_id]
assert (DISABLE_ASSERTS or len(req_ids2blocks[req_id])
<= req_ids2reserved_blocks[req_id])
# update requested/reserved blocks to check in last step
# Note: overwrite and not max because of reduce_left_padding()
# Note: overwrite and not max
# because of reduce_left_padding()
requested_blocks[req_id] = len(req_ids2blocks[req_id])
reserved_blocks[req_id] = req_ids2reserved_blocks[req_id]

# last step: check that sequences used all their reserved blocks
# Note: no early stopping, all sequences produce max_num_tokens
if len(checked_steps) == 0:
for req_id in requested_blocks:
assert requested_blocks[req_id] == reserved_blocks[req_id]
assert (DISABLE_ASSERTS
or requested_blocks[req_id] == reserved_blocks[req_id])

# Perform next step
step_output = engine_core.step()
Expand All @@ -197,15 +206,17 @@ def check_scheduler_inference_steps(
for output in request_outputs:
new_token_ids = output.new_token_ids
new_logprobs = output.new_logprobs.logprobs
assert len(new_token_ids) == 1 and len(new_logprobs) == 1
assert DISABLE_ASSERTS or len(new_token_ids) == 1 and len(
new_logprobs) == 1

collected_outputs[output.request_id]["token_ids"].append(
new_token_ids[0])
collected_outputs[output.request_id]["logprobs"].append(
new_logprobs[0][0])

output_keys = sorted(int(k) for k in collected_outputs)
assert output_keys[0] == 0 and output_keys[-1] == len(output_keys) - 1
assert (DISABLE_ASSERTS
or output_keys[0] == 0 and output_keys[-1] == len(output_keys) - 1)

# convert dict of dicts to ordered list and make values immutable
collected_outputs_new = []
Expand All @@ -216,4 +227,6 @@ def check_scheduler_inference_steps(
output[k] = tuple(list_values)
collected_outputs_new.append(output)

# good practice?
engine_core.shutdown()
return collected_outputs_new, generated_prompts
Loading