Skip to content

Commit 4817c71

Browse files
yannicks1nikolaospapandreoutdoublepjoerundesducouedic
authored
Paged attention/ new fms API (#82)
* [Continuous batching] FMS model wrapper (#18) * fms wrapper dummy for continuous batching implementation, gating via env var VLLM_SPYRE_USE_CB Signed-off-by: Yannick Schnider <[email protected]> * implementing fms wrapper with correct KV cache managment Signed-off-by: Yannick Schnider <[email protected]> * disable prints by default Signed-off-by: Yannick Schnider <[email protected]> * code refactoring fms wrapper Signed-off-by: Yannick Schnider <[email protected]> * fix default path not using CB/ fms wrapper Signed-off-by: Yannick Schnider <[email protected]> * correct print when TESTING_CB Signed-off-by: Yannick Schnider <[email protected]> * remove self.past_key_value_states when KV cache is managed by FMS wrapper Signed-off-by: Yannick Schnider <[email protected]> * read-out only active pages of KV cache (covers when curr batch size < max batch size) Signed-off-by: Yannick Schnider <[email protected]> * uniquely distinguishing prefills and decodes Signed-off-by: Yannick Schnider <[email protected]> * reading kv cache dimension from model config Signed-off-by: Yannick Schnider <[email protected]> * cosmetics and comments Signed-off-by: Yannick Schnider <[email protected]> * support for gpt big code models Signed-off-by: Yannick Schnider <[email protected]> * bugfix hard coded test mask Signed-off-by: Yannick Schnider <[email protected]> * change KV cache type for prefill Signed-off-by: Yannick Schnider <[email protected]> * update tkv in fms wrapper Signed-off-by: Yannick Schnider <[email protected]> * moving fms wrapper to own class Signed-off-by: Yannick Schnider <[email protected]> * reset tkv for new prompt Signed-off-by: Yannick Schnider <[email protected]> * ignoring test_spyre_tensor_parallel.py, since FMS wrapper does not support it Signed-off-by: Yannick Schnider <[email protected]> * removing VLLM_SPYRE_USE_CB, since FMS wrapper is now used by default Signed-off-by: Yannick Schnider <[email protected]> * typing fms wrapper class Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Yannick Schnider <[email protected]> * moving model loading into FMS wrapper (#35) Signed-off-by: Yannick Schnider <[email protected]> * bugfix idx kv cache update (#40) Signed-off-by: Yannick Schnider <[email protected]> * FMS Wrapper for static batching (#39) * introducing pseudo fms wrapper for static batching Signed-off-by: Yannick Schnider <[email protected]> * small bug fix Signed-off-by: Yannick Schnider <[email protected]> * bugfix idx kv cache update Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> * [Continuous Batching] Introducing new env variables (#67) * introducing env variables for AIU Spyre KV cache dimensions Signed-off-by: Yannick Schnider <[email protected]> * removing prints Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Yannick Schnider <[email protected]> * [Continuous batching] Initial cb test (#52) * initial cb test Signed-off-by: Nikolaos Papandreou <[email protected]> * make tkv, active_pages optional in SpyreCausalLM class for the V0 tests Signed-off-by: Nikolaos Papandreou <[email protected]> * format Signed-off-by: Nikolaos Papandreou <[email protected]> * remove manual testing and fix formatting Signed-off-by: Yannick Schnider <[email protected]> * remove tkv2fms Signed-off-by: Yannick Schnider <[email protected]> * remove unnecessary class variables Signed-off-by: Yannick Schnider <[email protected]> * tidy up class variables Signed-off-by: Yannick Schnider <[email protected]> * simplify code: req_ids2idx and active_pages will be reset in prepare input anyway... Signed-off-by: Yannick Schnider <[email protected]> * renaming variable Signed-off-by: Yannick Schnider <[email protected]> * removing batch padding in prefil stage Signed-off-by: Yannick Schnider <[email protected]> * indices always list of Trues since no padding or removed sequences... Signed-off-by: Yannick Schnider <[email protected]> * fix active/free page handling Signed-off-by: Yannick Schnider <[email protected]> * avoiding unnecessary tensor construction Signed-off-by: Yannick Schnider <[email protected]> * fix sorting indifference token/position_ids vs masks Signed-off-by: Yannick Schnider <[email protected]> * refactoring not requiring req_ids2idx Signed-off-by: Yannick Schnider <[email protected]> * removing unsused class variables, simplifying code Signed-off-by: Yannick Schnider <[email protected]> * use VLLM_SPYRE_MAX_BATCH_SIZE to control (decoding) batch size on AIU Spyre Signed-off-by: Yannick Schnider <[email protected]> * removing unnecessary helper functions for schedule and add_request Signed-off-by: Yannick Schnider <[email protected]> * removing unused argument Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Nikolaos Papandreou <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> Co-authored-by: Yannick Schnider <[email protected]> * re-enabling TP tests Signed-off-by: Yannick Schnider <[email protected]> * addressing feedback: renaming and removing unused stuff Signed-off-by: Yannick Schnider <[email protected]> * removing unnecessary getter function and other feedback Signed-off-by: Yannick Schnider <[email protected]> * integrating new FMS API on branch 'paged_attn_mock' Signed-off-by: Yannick Schnider <[email protected]> * torch dynamo: mark dynamic/static shapes Signed-off-by: Yannick Schnider <[email protected]> * bugfix key_value_states name Signed-off-by: Nikolaos Papandreou <[email protected]> * making block_table and slot_mapping args, not class vars Signed-off-by: Yannick Schnider <[email protected]> * formatting after browser merge... Signed-off-by: Yannick Schnider <[email protected]> * nicer handling of arguments continuous vs static batching Signed-off-by: Yannick Schnider <[email protected]> * Implement warmup for continuous batching (#83) * Implement warmup for continuous batching Signed-off-by: Thomas Parnell <[email protected]> * fmt Signed-off-by: Thomas Parnell <[email protected]> * freeing block directly and small things Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> Co-authored-by: Yannick Schnider <[email protected]> * initialize tkv Signed-off-by: Nikolaos Papandreou <[email protected]> * Return empty ModelRunnerOuptut if no work Signed-off-by: Nikolaos Papandreou <[email protected]> * update mask for decode Signed-off-by: Nikolaos Papandreou <[email protected]> * Fix copy/paste error Signed-off-by: Thomas Parnell <[email protected]> * adaptive loging (thx joerunde) Co-authored-by: Joe Runde <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> * remove warmup shapes for continuous batching Signed-off-by: Yannick Schnider <[email protected]> * assuring prefil lengths are multiples of block size 64 in example script Signed-off-by: Yannick Schnider <[email protected]> * revert change to warmup shape Signed-off-by: Thomas Parnell <[email protected]> * 🎨 fmt Signed-off-by: Joe Runde <[email protected]> * Added call to update_lazyhandle Signed-off-by: Thomas Parnell <[email protected]> * Right padding of prompts (#95) * right padding initial implementation Signed-off-by: Yannick Schnider <[email protected]> * fix right padding: remove the right padded logits before sampling Signed-off-by: Yannick Schnider <[email protected]> * fix typing Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Yannick Schnider <[email protected]> * [CB] Fix Tensor Parallelism Error (#103) * divide tensor third dimension by number of TP Signed-off-by: Sophie du Couédic <[email protected]> * Use existing method from vllm to get 'num_kv_heads' (works also for TP>1) Signed-off-by: Sophie du Couédic <[email protected]> --------- Signed-off-by: Sophie du Couédic <[email protected]> * support granite-3.2-8b-instruct (#106) Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> * comments Signed-off-by: Yannick Schnider <[email protected]> * adapt to change of arguments in fms Signed-off-by: Yannick Schnider <[email protected]> * fix mypy issue Signed-off-by: Yannick Schnider <[email protected]> * revising continuous batching scheduler Signed-off-by: Yannick Schnider <[email protected]> * [V1] Decoupling static and continuous batching (#116) * decoupling static and continuous batching scheduler Signed-off-by: Yannick Schnider <[email protected]> * fix dynamo cache for continuous batching Signed-off-by: Yannick Schnider <[email protected]> * removing warmup shape dependency for continuous batching! Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Yannick Schnider <[email protected]> * addressing review cosmetics Signed-off-by: Yannick Schnider <[email protected]> * fix/refactor: remove last_running and total_running (#112) Signed-off-by: Travis Johnson <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> Co-authored-by: Yannick Schnider <[email protected]> * fix comment kv cache tensor initialization Signed-off-by: Yannick Schnider <[email protected]> --------- Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Nikolaos Papandreou <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Joe Runde <[email protected]> Signed-off-by: Sophie du Couédic <[email protected]> Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: Nikolaos Papandreou <[email protected]> Co-authored-by: Thomas Parnell <[email protected]> Co-authored-by: Joe Runde <[email protected]> Co-authored-by: Sophie du Couédic <[email protected]> Co-authored-by: Travis Johnson <[email protected]>
1 parent b440505 commit 4817c71

File tree

6 files changed

+572
-269
lines changed

6 files changed

+572
-269
lines changed

examples/offline_inference_spyre_cb_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33

44
from vllm import LLM, SamplingParams
55

6-
max_tokens1 = 10
7-
max_tokens2 = 5
6+
# RUN with fms branch: https://github.com/foundation-model-stack/
7+
# foundation-model-stack/tree/paged_attn_mock
8+
9+
max_tokens1 = 65
10+
max_tokens2 = 67
811
max_tokens3 = 7
9-
max_tokens = max([max_tokens1, max_tokens2, max_tokens3])
1012
max_num_seqs = 2 # defines max batch size
1113

12-
os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64'
13-
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens)
14-
1514
# defining here to be able to run/debug directly from VSC (not via terminal)
1615
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager'
1716
os.environ['VLLM_SPYRE_USE_CB'] = '1'

vllm_spyre/model_executor/model_loader/spyre.py

Lines changed: 98 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Utilities for selecting and loading Spyre models."""
22
import os
3-
from typing import Optional
3+
from typing import Any, Optional
44

55
import torch
66
import torch._inductor.config
@@ -54,49 +54,68 @@ def __init__(
5454
# False for finished or padded sequences
5555
self.indices = None
5656

57+
# number of right pads (relevant for continuous batching only)
58+
self.n_pads_right = 0
59+
5760
# FMS Model
58-
fms_model = ContinuousBatchingFmsModel if envs_spyre.VLLM_SPYRE_USE_CB\
59-
else StaticBatchingFmsModel
60-
self.model = fms_model(
61-
model_config,
62-
parallel_config,
63-
max_prompt_length,
64-
max_decode_length,
65-
)
61+
if envs_spyre.VLLM_SPYRE_USE_CB:
62+
self.model = ContinuousBatchingFmsModel(model_config,
63+
parallel_config)
64+
else:
65+
self.model = StaticBatchingFmsModel(
66+
model_config,
67+
parallel_config,
68+
max_prompt_length,
69+
max_decode_length,
70+
)
6671

6772
def forward(
6873
self,
6974
input_ids: torch.Tensor,
7075
positions: torch.Tensor,
7176
masks: torch.Tensor,
7277
is_prompt: bool,
73-
tkv: Optional[int] = None,
74-
active_pages: Optional[list[int]] = None,
78+
current_tkv_mask: Optional[torch.Tensor] = None,
79+
left_padded_prompt_mask: Optional[torch.Tensor] = None,
80+
block_table: Optional[torch.Tensor] = None,
81+
slot_mapping: Optional[torch.Tensor] = None,
7582
) -> torch.Tensor:
7683

7784
if is_prompt and not envs_spyre.VLLM_SPYRE_USE_CB:
78-
self.model.past_key_value_states = None
85+
self.model.past_key_value_states = None # type: ignore
7986

80-
extra_kwargs = {}
87+
extra_kwargs: dict[str, Any] = {}
8188
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder":
8289
# Bug in 2.3.1 fixed in 2.4.1 for SDPA flash
8390
# cpu impl when padding too much
8491
extra_kwargs["attn_algorithm"] = "math"
8592

86-
# normal prefil or decoding step
93+
if envs_spyre.VLLM_SPYRE_USE_CB:
94+
extra_kwargs["current_tkv_mask"] = current_tkv_mask
95+
extra_kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask
96+
extra_kwargs["block_table"] = block_table
97+
extra_kwargs["slot_mapping"] = slot_mapping
98+
99+
# normal prefill or decoding step
87100
logits = self.model(
88101
input_ids,
89102
position_ids=positions,
90103
mask=masks,
91104
use_cache=True,
92-
only_last_token=True,
93-
tkv=tkv,
94-
active_pages=active_pages,
105+
only_last_token=not envs_spyre.VLLM_SPYRE_USE_CB,
95106
**extra_kwargs,
96107
)
97108

98-
# removing finished or padded sequences
99-
logits = logits[self.indices]
109+
if envs_spyre.VLLM_SPYRE_USE_CB:
110+
if is_prompt and self.n_pads_right > 0:
111+
# get last token before the right padding
112+
logits = logits[self.indices, -self.n_pads_right - 1, :]
113+
else:
114+
# just take last token if no right padding
115+
logits = logits[self.indices, -1, :]
116+
else:
117+
# removing finished or padded sequences
118+
logits = logits[self.indices]
100119

101120
return logits
102121

@@ -151,11 +170,6 @@ def load_weights(
151170
**kwargs,
152171
) -> None:
153172

154-
if self.dtype is not model_config.dtype:
155-
logger.info(
156-
"Ignoring user-provided dtype=%s and using dtype=%s instead.",
157-
model_config.dtype, self.dtype)
158-
159173
if model_config.quantization == "gptq":
160174
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
161175
from fms_mo.aiu_addons.gptq import ( # noqa: F401
@@ -173,13 +187,17 @@ def load_weights(
173187
"group_size": quant_cfg['group_size'],
174188
"desc_act": quant_cfg['desc_act'],
175189
}
176-
data_type = None
190+
self.dtype = None
177191
model_source = "hf_gptq_aiu"
178192
else:
179193
linear_config = {"linear_type": "torch_linear"}
180-
data_type = self.dtype
181194
model_source = "hf"
182195

196+
if self.dtype is not model_config.dtype:
197+
logger.info(
198+
"Ignoring user-provided dtype=%s and using dtype=%s instead.",
199+
model_config.dtype, self.dtype)
200+
183201
is_local = os.path.isdir(model_config.model)
184202
model_path = model_config.model
185203
# Get location of model from HF cache.
@@ -197,7 +215,7 @@ def load_weights(
197215
variant=model_config.model,
198216
model_path=model_path,
199217
source=model_source,
200-
data_type=data_type,
218+
data_type=self.dtype,
201219
distributed_strategy=distributed_strategy,
202220
group=dist.group.WORLD,
203221
fused_weights=fused_weights,
@@ -245,39 +263,51 @@ def load_weights(
245263

246264
class ContinuousBatchingFmsModel(FmsModelBase):
247265

248-
def __init__(
249-
self,
250-
model_config: ModelConfig,
251-
parallel_config: ParallelConfig,
252-
max_prompt_length: int,
253-
max_decode_length: int,
254-
) -> None:
255-
super().__init__(model_config, parallel_config, max_prompt_length,
256-
max_decode_length)
266+
def __init__(self, model_config: ModelConfig,
267+
parallel_config: ParallelConfig) -> None:
257268

258-
# physical KV cache on AIU Spyre
269+
BLOCK_SIZE = 64
259270
max_batch = envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE
260271
max_model_len = envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH
261272

262-
if self.config.model_type == 'llama':
273+
# edge case: prompt fills model length: can produce 1 token with prefill
274+
max_prompt_length = max_model_len
275+
# edge case: prompt will be padded to first block:
276+
# can produce 1 token with prefill plus rest of model length
277+
max_decode_length = max_model_len - BLOCK_SIZE + 1
278+
super().__init__(model_config, parallel_config, max_prompt_length,
279+
max_decode_length)
280+
281+
# physical KV cache on AIU Spyre: will eventually not live in this class
282+
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
283+
284+
if self.config.model_type in {'llama', 'granite'}:
263285
num_layers = self.config.num_hidden_layers
264-
num_kv_heads = self.config.num_key_value_heads
265286
head_dim = self.config.hidden_size // \
266287
self.config.num_attention_heads
267288
elif self.config.model_type == 'gpt_bigcode':
268289
num_layers = self.config.n_layer
269-
num_kv_heads = 1 if self.config.multi_query else self.config.n_head
270290
head_dim = self.config.n_embd // self.config.n_head
271291
else:
272-
print(f"[SpyreCausalLM] model type {self.config.model_type} "
273-
f"not supported in ContinuousBatchingFmsModel")
274-
275-
# (layers)x(k,v)x[max_batch, num_kv_heads, max_model_len, head_dim]
276-
self.fms_kv_cache: list[tuple[torch.Tensor, torch.Tensor]] = [
277-
(torch.empty((max_batch, num_kv_heads, max_model_len, head_dim)),
278-
torch.empty((max_batch, num_kv_heads, max_model_len, head_dim)))
279-
for i in range(num_layers)
280-
]
292+
raise NotImplementedError(
293+
f"[SpyreCausalLM] model type {self.config.model_type} "
294+
f"not supported in ContinuousBatchingFmsModel")
295+
296+
num_blocks = max_batch * max_model_len // BLOCK_SIZE # 64
297+
298+
# List[layers] of Tuple[k,v] of
299+
# Tensor[num_blocks, BLOCK_SIZE, num_kv_heads, head_dim]
300+
self.past_key_value_states = [(torch.zeros(num_blocks,
301+
BLOCK_SIZE,
302+
num_kv_heads,
303+
head_dim,
304+
dtype=self.dtype),
305+
torch.zeros(num_blocks,
306+
BLOCK_SIZE,
307+
num_kv_heads,
308+
head_dim,
309+
dtype=self.dtype))
310+
for _ in range(num_layers)]
281311

282312
def forward(
283313
self,
@@ -286,50 +316,36 @@ def forward(
286316
mask: torch.Tensor,
287317
use_cache: bool,
288318
only_last_token: bool,
289-
tkv: int,
290-
active_pages: list[int],
319+
current_tkv_mask: torch.Tensor,
320+
left_padded_prompt_mask: torch.Tensor,
321+
block_table: torch.Tensor,
322+
slot_mapping: torch.Tensor,
291323
**extra_kwargs,
292324
) -> torch.Tensor:
293325

294-
# read-out (dynamic) kv_cache for decoding steps only,
295-
# for prefills kv_cache = None
296-
if tkv == 0: # prefil
297-
kv_cache = None
298-
tkv = input_ids.shape[1]
299-
else: # decode
300-
kv_cache = []
301-
active_pages_mask = torch.zeros(self.fms_kv_cache[0][0].shape[0],
302-
dtype=torch.bool)
303-
active_pages_mask[active_pages] = True
304-
for layer in range(len(self.fms_kv_cache)):
305-
kv_cache.append(
306-
(self.fms_kv_cache[layer][0][active_pages_mask, :, :tkv -
307-
1, :],
308-
self.fms_kv_cache[layer][1][active_pages_mask, :, :tkv -
309-
1, :]))
326+
# mark dynamic: Not sure if that's correct/needed here,
327+
# copied from fms branch paged_atten_mock
328+
if self.past_key_value_states is not None:
329+
for layer in self.past_key_value_states:
330+
if isinstance(layer, tuple):
331+
for tensor in layer:
332+
torch._dynamo.mark_dynamic(tensor, 2)
310333

311334
output = self.model(
312335
input_ids,
313336
position_ids=position_ids,
314337
mask=mask,
315-
past_key_value_states=kv_cache,
338+
past_key_value_states=self.past_key_value_states,
316339
use_cache=use_cache,
317340
only_last_token=only_last_token,
341+
current_tkv_mask=current_tkv_mask,
342+
left_padded_prompt_mask=left_padded_prompt_mask,
343+
block_table=block_table,
344+
slot_mapping=slot_mapping,
318345
**extra_kwargs,
319346
)
320-
logits, key_value_states = output
321-
322-
# updating (physical) KV cache: self.fms_kv_cache
323-
for idx, page in enumerate(sorted(active_pages)):
324-
for layer in range(len(self.fms_kv_cache)):
325-
# inserting partial KV cache at correct location
326-
# (page, tkv) in the KV cache of the whole batch
327-
self.fms_kv_cache[layer][0][
328-
page, :, :tkv, :] = key_value_states[layer][0][
329-
idx, :, :, :] # [1, 8, L, 128]
330-
self.fms_kv_cache[layer][1][
331-
page, :, :tkv, :] = key_value_states[layer][1][
332-
idx, :, :, :] # [1, 8, L, 128]
347+
348+
logits, self.past_key_value_states = output
333349

334350
return logits
335351

@@ -356,8 +372,6 @@ def forward(
356372
mask: torch.Tensor,
357373
use_cache: bool,
358374
only_last_token: bool,
359-
tkv: int,
360-
active_pages: list[int],
361375
**extra_kwargs,
362376
) -> torch.Tensor:
363377

@@ -371,7 +385,6 @@ def forward(
371385
**extra_kwargs,
372386
)
373387

374-
logits, past_key_value_states = output
375-
self.past_key_value_states = past_key_value_states
388+
logits, self.past_key_value_states = output
376389

377390
return logits

vllm_spyre/platform.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
5050
if scheduler_config.is_multi_step:
5151
raise NotImplementedError
5252

53+
# continuous batching related checks
54+
if envs_spyre.VLLM_SPYRE_USE_CB and not envs.VLLM_USE_V1:
55+
raise NotImplementedError(
56+
"Continuous batching is only implemented for vLLM V1")
57+
5358
# Near future TODO: vLLM will have an api to check whether v0 or v1 is
5459
# used that isn't just checking the environment variable
5560

@@ -69,21 +74,22 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6974
"vllm_spyre.v1.core.scheduler.ContinuousBatchingSpyreScheduler"
7075
else:
7176
scheduler_config.scheduler_cls = \
72-
"vllm_spyre.v1.core.scheduler.SpyreScheduler"
77+
"vllm_spyre.v1.core.scheduler.StaticBatchingSpyreScheduler"
7378
else:
7479
scheduler_config.scheduler_cls = \
7580
"vllm_spyre.core.scheduler.SpyreScheduler"
7681

77-
# Override --max-num-seqs to the biggest warmup batch size
78-
# And override --max-model-len to the biggest warmup sequence
79-
cls._warmup_shapes = None
80-
spyre_warmup_shapes = cls.get_warmup_shapes(scheduler_config)
81-
max_batch_size = 0
82-
max_seq_len = 0
83-
for shape in spyre_warmup_shapes:
84-
max_batch_size = max(max_batch_size, shape['batch_size'])
85-
max_seq_len = max(max_seq_len,
86-
shape['prompt_length'] + shape['new_tokens'])
82+
if not envs_spyre.VLLM_SPYRE_USE_CB:
83+
# Override --max-num-seqs to the biggest warmup batch size
84+
# And override --max-model-len to the biggest warmup sequence
85+
cls._warmup_shapes = None
86+
spyre_warmup_shapes = cls.get_warmup_shapes(scheduler_config)
87+
max_batch_size = 0
88+
max_seq_len = 0
89+
for shape in spyre_warmup_shapes:
90+
max_batch_size = max(max_batch_size, shape['batch_size'])
91+
max_seq_len = max(max_seq_len,
92+
shape['prompt_length'] + shape['new_tokens'])
8793

8894
if envs.VLLM_USE_V1:
8995
if envs_spyre.VLLM_SPYRE_USE_CB:
@@ -98,11 +104,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
98104
# The v0 scheduler will run out of blocks if this is overridden
99105
scheduler_config.max_num_seqs = max_batch_size
100106

101-
# continuous batching related checks
102-
if envs_spyre.VLLM_SPYRE_USE_CB and not envs.VLLM_USE_V1:
103-
raise NotImplementedError(
104-
"Continuous batching is only implemented for vLLM V1")
105-
106107
cache_config = vllm_config.cache_config
107108

108109
if cache_config and model_config:
@@ -115,7 +116,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
115116
# one single block.
116117
# - Set the number of blocks to the maximum number of sequences, so
117118
# the scheduler always thinks there's a block available
118-
model_config.max_model_len = max_seq_len
119+
if not envs_spyre.VLLM_SPYRE_USE_CB:
120+
model_config.max_model_len = max_seq_len
119121
cache_config.block_size = model_config.max_model_len
120122

121123
if envs.VLLM_USE_V1:
@@ -166,9 +168,10 @@ def get_warmup_shapes(cls, scheduler_config) -> tuple[dict[str, int], ...]:
166168
"The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and "
167169
"VLLM_SPYRE_WARMUP_NEW_TOKENS must have equal length")
168170

169-
logger.info("VLLM_SPYRE_WARMUP_PROMPT_LENS = %s", wup_prompt_lens)
170-
logger.info("VLLM_SPYRE_WARMUP_NEW_TOKENS = %s", wup_new_tokens)
171-
logger.info("VLLM_SPYRE_WARMUP_BATCH_SIZES = %s", wup_batch_sizes)
171+
if not envs_spyre.VLLM_SPYRE_USE_CB:
172+
logger.info("VLLM_SPYRE_WARMUP_PROMPT_LENS = %s", wup_prompt_lens)
173+
logger.info("VLLM_SPYRE_WARMUP_NEW_TOKENS = %s", wup_new_tokens)
174+
logger.info("VLLM_SPYRE_WARMUP_BATCH_SIZES = %s", wup_batch_sizes)
172175

173176
cls._warmup_shapes = tuple(
174177
sorted([{

0 commit comments

Comments
 (0)