Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,33 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if scheduler_config.is_multi_step:
raise NotImplementedError

# Near future TODO: vLLM will have an api to check whether v0 or v1 is
# used that isn't just checking the environment variable
is_decoder = model_config.task == "generate"
is_embedding = model_config.task == "embed"

# v0 is only supported for embedding models, and embedding models must
# be run on v0
if is_embedding and envs.VLLM_USE_V1:
raise ValueError("Embedding models are only supported on v0")
elif is_decoder and not envs.VLLM_USE_V1:
raise ValueError("Decoder models are only supported on v1")
elif not is_decoder and not is_embedding:
raise ValueError("Only the 'generate' and 'embed' tasks are "
"supported")

if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = (
f'vllm_spyre{".v1" if envs.VLLM_USE_V1 else ""}'\
'.worker.spyre_worker.SpyreWorker')

if not envs_spyre.VLLM_SPYRE_USE_CB: # no CB
if envs_spyre.VLLM_SPYRE_USE_CB and is_decoder:
scheduler_config.scheduler_cls = "vllm_spyre.v1.core."\
"scheduler.ContinuousBatchingSpyreScheduler"
else:
# Static batching or embedding model.
# Override --max-num-seqs to the biggest warmup batch size
# And override --max-model-len to the biggest warmup sequence
cls._warmup_shapes = None
max_model_len = model_config.max_model_len \
if model_config is not None else sys.maxsize
max_model_len = model_config.max_model_len
spyre_warmup_shapes = cls.get_warmup_shapes(
scheduler_config, max_model_len)
max_batch_size = 0
Expand All @@ -87,30 +100,17 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
max_batch_size = max(max_batch_size, shape["batch_size"])
max_seq_len = max(max_seq_len,
shape["prompt_length"] + shape["new_tokens"])
if model_config is not None:
model_config.max_model_len = max_seq_len

if envs.VLLM_USE_V1: # No CB with V1
model_config.max_model_len = max_seq_len
scheduler_config.max_num_seqs = max_batch_size

if is_decoder:
scheduler_config.scheduler_cls = (
"vllm_spyre.v1.core.scheduler."\
"StaticBatchingSpyreScheduler")
scheduler_config.max_num_seqs = max_batch_size
else: # No CB with V0
elif is_embedding:
scheduler_config.scheduler_cls = (
"vllm_spyre.core.scheduler.SpyreScheduler")
else: # CB related checks
if not envs.VLLM_USE_V1: # CB with V0
raise NotImplementedError(
"Continuous batching is only implemented for vLLM V1")
else: # CB with V1
# As of 0.7.3 the scheduler for V1 isn't actually pluggable like
# this yet
scheduler_config.scheduler_cls = "vllm_spyre.v1.core."\
"scheduler.ContinuousBatchingSpyreScheduler"

# Cache and model config aren't set in the individual worker procs
# These are set in the main engine process

# To disable any paged attention ops in the base scheduler, we:
# - Set the block size (in tokens) to the maximum sequence length
Expand All @@ -137,12 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
logger.info(
"Overriding configurations based on warmup shapes. "
"max_model_len=%d, max_num_seqs=%d, block_size=%d, "
"num_gpu_blocks_override=%d",
model_config.max_model_len,
scheduler_config.max_num_seqs,
cache_config.block_size,
cache_config.num_gpu_blocks_override,
)
"num_gpu_blocks_override=%d, max_num_batched_tokens=%d",
model_config.max_model_len, scheduler_config.max_num_seqs,
cache_config.block_size, cache_config.num_gpu_blocks_override,
scheduler_config.max_num_batched_tokens)

# set env vars for torch_sendnn to consume
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
Expand Down Expand Up @@ -227,7 +225,8 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return True
# We don't have an embedding runner for v1 yet
return model_config.task != "embed"

@classmethod
def validate_request(
Expand Down