Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/offline_inference/spyre_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
sampling_params = SamplingParams(max_tokens=args.max_tokens,
temperature=0.0,
ignore_eos=True)
model = "/models/llama-7b-chat"
# Create an LLM.
llm = LLM(model=args.model,
tokenizer=args.model,
Expand Down
5 changes: 5 additions & 0 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED: int = 0
VLLM_SPYRE_PERF_METRIC_LOGGING_DIR: str = "/tmp"
VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER: bool = False
VLLM_SPYRE_VLLM_MODEL: bool = False
# Prompt logprobs are behind a flag because they're only supported for
# static batching and require passing back the hidden states for the full
# prefill on every request. This could incur a heavy performance penalty in
Expand Down Expand Up @@ -99,6 +100,10 @@ def _backend_backwards_compat() -> str:
# By default, prompt_logprobs aren't supported
"VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS":
lambda: bool(int(os.getenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", "0"))),

# If set, uses the VLLM model instead of fms
"VLLM_SPYRE_VLLM_MODEL":
lambda: bool(int(os.getenv("VLLM_SPYRE_VLLM_MODEL", "0"))),
}
# --8<-- [end:env-vars-definition]

Expand Down
71 changes: 61 additions & 10 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import torch.nn as nn
from fms.models import get_model
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf)
from vllm.model_executor.model_loader import get_model as vllm_get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata

import vllm_spyre.envs as envs_spyre
Expand Down Expand Up @@ -44,6 +45,7 @@ class SpyreCausalLM(nn.Module):

def __init__(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
Expand All @@ -69,6 +71,12 @@ def __init__(
self.model = ContinuousBatchingFmsModel(model_config,
parallel_config,
scheduler_config)
elif envs_spyre.VLLM_SPYRE_VLLM_MODEL:
self.model= StaticBatchingVllmModel(
vllm_config,
max_prompt_length,
max_decode_length,
)
else:
self.model = StaticBatchingFmsModel(
model_config,
Expand Down Expand Up @@ -122,6 +130,9 @@ def compute_logits(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if envs_spyre.VLLM_SPYRE_VLLM_MODEL:
logits = self.model.compute_logits(hidden_states, sampling_metadata)
return logits
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits

Expand All @@ -138,6 +149,7 @@ class FmsModelBase(nn.Module):

def __init__(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
max_prompt_length: int,
Expand All @@ -146,6 +158,7 @@ def __init__(
) -> None:
super().__init__()

self.vllm_config = vllm_config
self.config: PretrainedConfig = model_config.hf_config
self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \
'sendnn' else torch.float32
Expand Down Expand Up @@ -212,15 +225,18 @@ def load_weights(
# we can use fused weights unless running on Spyre
fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn"

self.model = get_model(architecture="hf_configured",
variant=model_config.model,
model_path=model_path,
source=model_source,
data_type=self.dtype,
distributed_strategy=distributed_strategy,
group=dist.group.WORLD,
fused_weights=fused_weights,
linear_config=linear_config)
if envs_spyre.VLLM_SPYRE_VLLM_MODEL:
self.model = vllm_get_model(vllm_config=self.vllm_config)
else:
self.model = get_model(architecture="hf_configured",
variant=model_config.model,
model_path=model_path,
source=model_source,
data_type=self.dtype,
distributed_strategy=distributed_strategy,
group=dist.group.WORLD,
fused_weights=fused_weights,
linear_config=linear_config)

self.model.eval()
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -439,3 +455,38 @@ def forward(
logits, self.past_key_value_states = output

return logits

class StaticBatchingVllmModel(FmsModelBase):

def __init__(
self,
vllm_config: VllmConfig,
max_prompt_length: int,
max_decode_length: int,
) -> None:
super().__init__(vllm_config,
vllm_config.model_config,
vllm_config.parallel_config,
max_prompt_length,
max_decode_length,
sendnn_dynamic=False)

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
**extra_kwargs,
) -> torch.Tensor:

output = self.model(
input_ids,
positions=position_ids,
)

return output

def compute_logits(self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states,
sampling_metadata)
14 changes: 12 additions & 2 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ModelConfig = None
VllmConfig = None
import vllm.envs as envs
from vllm.platforms import Platform, PlatformEnum
from vllm.platforms import Platform, PlatformEnum, _Backend

import vllm_spyre.envs as envs_spyre

Expand Down Expand Up @@ -56,13 +56,23 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
"""
return False

def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool) -> str:
logger.info("Using Torch SDPA backend.")
return ("vllm_spyre.v1.attention.backends.spyre.SpyreSDPABackend")

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cls._config = vllm_config
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config

from vllm.config import CompilationLevel # noqa: E402
compilation_config.pass_config.enable_fusion = False
compilation_config.pass_config.enable_attn_fusion = False
compilation_config.level = CompilationLevel.NO_COMPILATION

if scheduler_config.is_multi_step:
raise NotImplementedError
Expand Down Expand Up @@ -329,4 +339,4 @@ def get_max_output_tokens(self, prompt_len: int) -> int:
if prompt_len <= shape['prompt_length']:
max_new_tokens = max(max_new_tokens, shape['new_tokens'])

return max_new_tokens
return max_new_tokens
Loading