Skip to content

Commit 04583f8

Browse files
committed
refact: removed unnecessary logits processor
Signed-off-by: Wallas Santos <[email protected]>
1 parent 76daf7a commit 04583f8

File tree

2 files changed

+4
-19
lines changed

2 files changed

+4
-19
lines changed

vllm_spyre/model_executor/model_loader/spyre.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.config import ModelConfig, VllmConfig
1313
from vllm.forward_context import get_forward_context
1414
from vllm.logger import init_logger
15-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1615
from vllm.model_executor.model_loader.weight_utils import (
1716
download_weights_from_hf)
1817
from vllm.v1.outputs import SamplerOutput
@@ -58,9 +57,6 @@ def __init__(
5857
rank: int,
5958
) -> None:
6059
super().__init__()
61-
self.logits_processor = LogitsProcessor(
62-
vllm_config.model_config.hf_config.vocab_size,
63-
logits_as_input=True)
6460

6561
try:
6662
## Temporary backwards compatibility for 0.10.2
@@ -132,14 +128,6 @@ def forward(
132128

133129
return logits
134130

135-
def compute_logits(
136-
self,
137-
hidden_states: torch.Tensor,
138-
sampling_metadata: SamplingMetadata,
139-
) -> torch.Tensor:
140-
logits = self.logits_processor(None, hidden_states, sampling_metadata)
141-
return logits
142-
143131
def sample(
144132
self,
145133
logits: torch.Tensor,

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -525,13 +525,10 @@ def execute_model(
525525
# Execute the model
526526
attn_metadata = self.build_attn_metadata(model_input)
527527
with set_forward_context(attn_metadata, self.vllm_config):
528-
hidden_states = self.model(input_ids=model_input.input_tokens,
529-
positions=model_input.input_positions,
530-
masks=model_input.input_masks,
531-
is_prompt=model_input.is_prompt)
532-
533-
# Compute the logits.
534-
logits = self.model.compute_logits(hidden_states, None)
528+
logits = self.model(input_ids=model_input.input_tokens,
529+
positions=model_input.input_positions,
530+
masks=model_input.input_masks,
531+
is_prompt=model_input.is_prompt)
535532

536533
is_prefill = cast(bool, model_input.is_prompt)
537534

0 commit comments

Comments
 (0)