Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def test_abort(
model=model,
tokenizer=model,
max_model_len=128,
max_num_seqs=8,
max_num_seqs=2,
block_size=2048,
))
has_unfinished_requests = \
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_full_batch_scheduling(model: str, backend: str, monkeypatch):
engine_args = EngineArgs(model=model,
tokenizer=model,
max_num_batched_tokens=max_batched_tokens,
max_num_seqs=4)
max_num_seqs=batch_size)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@pytest.mark.cb
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize("max_num_seqs", [2, 4],
@pytest.mark.parametrize("max_num_seqs", [1, 2, 4],
ids=lambda val: f"max_num_seqs({val})")
def test_cb_output(
model: str,
Expand Down
5 changes: 0 additions & 5 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,6 @@ def __init__(
super().__init__(vllm_config=vllm_config,
is_driver_worker=is_driver_worker)

# TODO: remove this limitation once we update the warm-up logic to
# support batch_size=1
assert vllm_config.scheduler_config.max_num_seqs >= 2, "Currently, " \
"continuous batching needs config to set batch_size >= 2"

self.block_size = 64

# TODO: move to a KV cache manager
Expand Down
26 changes: 26 additions & 0 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import vllm_spyre.perf_metrics as perf_metrics
from vllm_spyre.model_executor.model_loader import spyre_setup
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre.v1.worker.spyre_input_batch import InputBatch
from vllm_spyre.v1.worker.spyre_model_runner import (
ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner)

Expand Down Expand Up @@ -317,6 +318,18 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
prompt_len = 42
num_decode_tokens = 2

# Fix for batch size 1: set input batch to fit 2 requests for warmup
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
model_runner.input_batch = InputBatch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, could the InputBatch construct itself with:

self.max_num_reqs = min(max_num_reqs, 2)

since we know that it'll always need at least 2, and then we avoid reconstructing it in the worker here? That way we have a much smaller diff to back out once we can lift this bs>=2 restriction

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if I follow here. it has to be >=2 for the warmup. with the min(1,2) we would still fail?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that work if you directly set model_runner.input_batch.max_num_reqs = 2, instead of instantiating a new InputBatch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because InputBatch initialization gets model_runner.input_batch.max_num_reqs..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_num_seqs occurs 17 times in the init of the InputBatch. It is not a single attribute, but used to construct several attributes. So re-initializing is simpler...

max_num_reqs=2,
max_model_len=model_runner.vllm_config.model_config.
max_model_len,
device=model_runner.device,
pin_memory=model_runner.pin_memory,
vocab_size=model_runner.vllm_config.model_config.
get_vocab_size(),
)

# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
Expand Down Expand Up @@ -412,6 +425,19 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=[add_dummy_request])

# Fix for batch size 1: reset input batch to fit max_num_seqs requests
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
model_runner.input_batch = InputBatch(
max_num_reqs=model_runner.vllm_config.scheduler_config.
max_num_seqs,
max_model_len=model_runner.vllm_config.model_config.
max_model_len,
device=model_runner.device,
pin_memory=model_runner.pin_memory,
vocab_size=model_runner.vllm_config.model_config.
get_vocab_size(),
)

# get the number or pages from the actual Spyre card after the warmup
# and set it accordingly in the model runner and the kv cache size
n_blocks_avail = self._get_num_blocks_available()
Expand Down
Loading