|
29 | 29 | import vllm_spyre.perf_metrics as perf_metrics |
30 | 30 | from vllm_spyre.model_executor.model_loader import spyre_setup |
31 | 31 | from vllm_spyre.platform import SpyrePlatform |
32 | | -from vllm_spyre.v1.worker.spyre_input_batch import InputBatch |
| 32 | +from vllm_spyre.v1.worker.spyre_input_batch import SamplingInputBatch |
33 | 33 | from vllm_spyre.v1.worker.spyre_model_runner import ( |
34 | 34 | ContinuousBatchingSpyreModelRunner, SpyrePoolingModelRunner, |
35 | 35 | StaticBatchingSpyreModelRunner, SupportedTask) |
@@ -329,7 +329,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): |
329 | 329 |
|
330 | 330 | # Fix for batch size 1: set input batch to fit 2 requests for warmup |
331 | 331 | if model_runner.vllm_config.scheduler_config.max_num_seqs == 1: |
332 | | - model_runner.input_batch = InputBatch( |
| 332 | + model_runner.input_batch = SamplingInputBatch( |
333 | 333 | max_num_reqs=2, |
334 | 334 | max_model_len=model_runner.vllm_config.model_config. |
335 | 335 | max_model_len, |
@@ -388,7 +388,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): |
388 | 388 |
|
389 | 389 | # Fix for batch size 1: reset input batch to fit max_num_seqs requests |
390 | 390 | if model_runner.vllm_config.scheduler_config.max_num_seqs == 1: |
391 | | - model_runner.input_batch = InputBatch( |
| 391 | + model_runner.input_batch = SamplingInputBatch( |
392 | 392 | max_num_reqs=model_runner.vllm_config.scheduler_config. |
393 | 393 | max_num_seqs, |
394 | 394 | max_model_len=model_runner.vllm_config.model_config. |
|
0 commit comments