diff --git a/vllm_spyre/v1/core/scheduler.py b/vllm_spyre/v1/core/scheduler.py index 13bc8f481..638227c7b 100644 --- a/vllm_spyre/v1/core/scheduler.py +++ b/vllm_spyre/v1/core/scheduler.py @@ -153,8 +153,9 @@ def __init__(self, *args, **kwargs) -> None: self.block_size = SpyrePlatform.get_block_size() self.max_batch_tkv_limit = os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", default='-1') - assert self.max_batch_tkv_limit != '-1', "Expecting the env var" - "VLLM_DT_MAX_BATCH_TKV_LIMIT to be set in platform.py" + assert self.max_batch_tkv_limit != '-1', ( + "Expecting the env var VLLM_DT_MAX_BATCH_TKV_LIMIT to be set in " + "platform.py") def update_from_output( self, @@ -162,10 +163,9 @@ def update_from_output( model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: # Need an instance of CBSpyreModelRunnerOutput which holds the tkv value - assert isinstance( - model_runner_output, CBSpyreModelRunnerOutput - ), "Expecting an instance of CBSpyreModelRunnerOutput" - "when doing continuous batching." + assert isinstance(model_runner_output, CBSpyreModelRunnerOutput), ( + "Expecting an instance of CBSpyreModelRunnerOutput when doing " + "continuous batching.") self.tkv = model_runner_output.tkv self.n_free_blocks = model_runner_output.n_free_blocks return super(SpyreScheduler,