diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index 14711fa3..d2fe0bdb 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -12,7 +12,6 @@ from vllm.config import ModelConfig, VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf) from vllm.v1.outputs import SamplerOutput @@ -58,9 +57,6 @@ def __init__( rank: int, ) -> None: super().__init__() - self.logits_processor = LogitsProcessor( - vllm_config.model_config.hf_config.vocab_size, - logits_as_input=True) try: ## Temporary backwards compatibility for 0.10.2 @@ -132,14 +128,6 @@ def forward( return logits - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - def sample( self, logits: torch.Tensor, diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 150ce912..40e57b43 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -525,13 +525,10 @@ def execute_model( # Execute the model attn_metadata = self.build_attn_metadata(model_input) with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model(input_ids=model_input.input_tokens, - positions=model_input.input_positions, - masks=model_input.input_masks, - is_prompt=model_input.is_prompt) - - # Compute the logits. - logits = self.model.compute_logits(hidden_states, None) + logits = self.model(input_ids=model_input.input_tokens, + positions=model_input.input_positions, + masks=model_input.input_masks, + is_prompt=model_input.is_prompt) is_prefill = cast(bool, model_input.is_prompt)