Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tests/llm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,15 @@ def get_engine(
# then adjust these limits in the engine's scheduler for tests.

# Setup the engine
# Round max_num_seqs (batch size) to the next power of two for
# Spyre compilation. This seems more robust and helps that all tests in
# tests/e2e/test_spyre_cb_inference_steps.py pass on Spyre.
max_num_seqs_compiled = 1 << (max_num_seqs - 1).bit_length()
engine_args = EngineArgs(
model=model_name,
tokenizer=model_name,
max_model_len=max(max_model_len, 256),
max_num_seqs=max_num_seqs,
max_num_seqs=max_num_seqs_compiled,
num_gpu_blocks_override=None,
revision=revision,
)
Expand All @@ -202,7 +206,10 @@ def get_engine(
executor_class=executor_class,
log_stats=False)

# Set scheduler configs for max_model_len and max_num_seqs to the
# original values. They were changed for more robust compilation only.
engine_core.scheduler.scheduler_config.max_model_len = max_model_len
engine_core.scheduler.scheduler_config.max_num_seqs = max_num_seqs

if available_blocks is not None:
worker = engine_core.model_executor.driver_worker.worker
Expand Down