Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ids=lambda val: f"TP({val})",
)
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize("max_num_seqs", [4],
@pytest.mark.parametrize("max_num_seqs", [1, 4],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was just convenient for testing. think I will remove to not make the test suite any longer. @joerunde what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add one extra test with it instead of parameterizing it here (for unnecessary combinations with other params) so that it gets tested at least once every iteration. But then it's repeated code so there's that 🤷

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, honestly I don't know whats better here 😄 @joerunde any opinions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, the tests right now take a ton of CI time both here and on jenkins, and since --max-num-seqs=1 isn't really a setting that anybody would use in practice, I'd be fine not testing it directly. All the existing CB tests should cover the warmup change, which is the important thing to not break

ids=lambda val: f"max_num_seqs({val})")
def test_output(
model: str,
Expand Down
4 changes: 3 additions & 1 deletion vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# set env vars for torch_sendnn to consume
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
vllm_config.model_config.max_model_len)
# min decode batch size is 2 due to symbolic shape constraint in torch
# min value 2 needed for VLLM_DT_MAX_BATCH_SIZE (compiler constraint)
# Note that we can still have decodes of batch size 1 as the env var
# only concerns the max batch size.
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(
max(vllm_config.scheduler_config.max_num_seqs, 2))

Expand Down
26 changes: 1 addition & 25 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,8 @@ def load_model(self, prompt_lens: Iterable[int],
)

def build_input_batch(self) -> SamplingInputBatch:
# Fix for batch size 1: set input batch to fit 2 requests for warmup,
# and reset input batch to fit max_num_seqs requests after warmup
min_seqs_required = 2 if self.warmup_mode else 1

return SamplingInputBatch(
max_num_reqs=max(min_seqs_required,
self.scheduler_config.max_num_seqs),
max_num_reqs=self.scheduler_config.max_num_seqs,
max_model_len=self.model_config.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -809,8 +804,6 @@ def __init__(

def complete_warmup(self) -> None:
super().complete_warmup()
# Fix for batch size 1: need to update the input_batch after the warmup
self.input_batch = self.build_input_batch()
# get the number or pages from the actual Spyre card after the warmup
# and set it accordingly in the model runner and the kv cache size
n_blocks_avail = self._get_num_blocks_available()
Expand Down Expand Up @@ -1098,23 +1091,6 @@ def _prepare_decode(
# mask not needed during decode
mask = None

# add pads for min decode batch size of 2 (Spyre compiler constraint)
if len(cached_request_data.req_ids) == 1:
padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu")
self.model.indices = torch.cat(
(self.model.indices, padd_seq_indices), -1)
assert self.model.indices.size(dim=0) == 2

input_tokens = torch.cat(2 * [input_tokens])
position_ids = torch.cat(2 * [position_ids])
current_tkv_mask = torch.cat(2 * [current_tkv_mask])
left_padded_prompt_mask = torch.cat(2 * [left_padded_prompt_mask])
block_table = torch.cat(2 * [block_table])
slot_mapping = torch.cat(2 * [slot_mapping])

# assert min batch size 2 for decodes (Spyre compiler constraint)
assert len(input_tokens) >= 2

model_inputs = SamplingForwardInputs(
input_tokens=input_tokens,
input_positions=position_ids,
Expand Down
87 changes: 37 additions & 50 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def load_model(self):
logger.info("load model took %.3fs", load_model_total_t)

def _warmup_spyre_dynamic_size(self, special_token_ids):
# this setting is required to mark a dimension of size 1 as dynamic
# for pytorch >= 2.7.1 (needed to support batch size 1 for decodes)
from torch.fx.experimental import _config as config
config.backed_size_oblivious = True

warmup_start_t = time.time()

Expand All @@ -347,15 +351,14 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
valid_token_ids_tensor = torch.tensor(valid_token_ids,
dtype=torch.long,
device=torch.device("cpu"))
batch_size = 2
prompt_len = 42
num_decode_tokens = 2

# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
0, len(valid_token_ids_tensor), (2, prompt_len))]

dummy_requests: list[NewRequestData] = [
warmup_req, deploy_req = (
NewRequestData(
req_id="warmup-%d" % (i),
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
Expand All @@ -367,24 +370,21 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
block_ids=[0], # not actually used
num_computed_tokens=0,
lora_request=None,
) for i in range(batch_size + 1)
]
add_dummy_request = dummy_requests.pop(-1)
) for i in range(2))

with _maybe_warmup_context():
self._dynamic_warmup(dummy_requests=dummy_requests,
self._dynamic_warmup(request=warmup_req,
prompt_len=prompt_len,
batch_size=batch_size,
valid_token_ids_tensor=valid_token_ids_tensor)

# warmup_mode completes the graph compilation, but we need to do
# one additional prefill to deploy the compiled program to the device,
# the necessary operations are included in the graph and will be removed
# after this execution
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[add_dummy_request],
scheduled_new_reqs=[deploy_req],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={add_dummy_request.req_id: prompt_len},
num_scheduled_tokens={deploy_req.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
Expand All @@ -396,7 +396,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
)
logger.info("[WARMUP] Deploying to device...")
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=[add_dummy_request])
self._cleanup_model_runner(request=[deploy_req])

model_runner.complete_warmup()

Expand Down Expand Up @@ -550,61 +550,48 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,

def _dynamic_warmup(
self,
dummy_requests: list[NewRequestData],
request: NewRequestData,
prompt_len: int,
batch_size: int,
valid_token_ids_tensor: torch.Tensor,
) -> None:

assert (
_inside_warmup_mode
), "it looks like you are outside the warmup context for warmup"

for i, req in enumerate(dummy_requests):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[req],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[request],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={request.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
logger.info("[WARMUP] Prefill...")

self.execute_model(scheduler_output)
self.execute_model(scheduler_output)

# one decode iteration across all sequences
req_ids = []
new_token_ids = []
new_block_ids = []
num_computed_tokens = []
for req in dummy_requests:
req_ids.append(req.req_id)
new_token_ids.append([
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
]) # placeholder token
new_block_ids.append([req.block_ids])
num_computed_tokens.append(prompt_len)
cached_request_data = CachedRequestData(
req_ids=req_ids,
req_ids=[request.req_id],
resumed_from_preemption=False,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
new_token_ids=[[
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
]],
new_block_ids=[request.block_ids],
num_computed_tokens=[prompt_len],
)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_request_data,
num_scheduled_tokens={f"warmup-{i}": 1
for i in range(batch_size)},
total_num_scheduled_tokens=batch_size,
num_scheduled_tokens={request.req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
Expand All @@ -615,7 +602,7 @@ def _dynamic_warmup(
)
logger.info("[WARMUP] Decode...")
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=dummy_requests)
self._cleanup_model_runner(request=[request])

def _warmup_model_forward_pass(
self,
Expand Down
Loading