Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions tests/e2e/test_spyre_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,35 @@ def test_serving(remote_openai_server, model, warmup_shapes, backend):
score_url = remote_openai_server.url_for("/score")

query = "What is the capital of France?"
# Number of inputs larger than the warmup batch size of 4
# and with a non-uniform token length
docs = [
"The capital of France is Paris.", "The capital of Germany is Berlin."
"The capital of France is Paris.", "The capital of Germany is Berlin.",
"The capital of Brazil is Brasilia.",
"The capital of the country with the best beer is Berlin.",
"The capital of the USA is Washington.",
"The capital city of Spain is Madrid."
]
vllm_outputs = requests.post(url=score_url,
json={
"text_1": query,
"text_2": [docs[0], docs[1]]
"text_2": docs,
}).json()
vllm_scores = [o["score"] for o in vllm_outputs["data"]]

ce_model = CrossEncoder(model.name, revision=model.revision)
ce_scores = ce_model.predict([(query, docs[0]), (query, docs[1])])

vllm_scores = [o["score"] for o in vllm_outputs["data"]]
ce_scores = ce_model.predict([
(query, docs[0]),
(query, docs[1]),
(query, docs[2]),
(query, docs[3]),
(query, docs[4]),
(query, docs[5]),
])

assert ce_scores[0] == pytest.approx(vllm_scores[0], rel=0.02)
assert ce_scores[1] == pytest.approx(vllm_scores[1], rel=0.02)
assert ce_scores[2] == pytest.approx(vllm_scores[2], rel=0.02)
assert ce_scores[3] == pytest.approx(vllm_scores[3], rel=0.02)
assert ce_scores[4] == pytest.approx(vllm_scores[4], rel=0.02)
assert ce_scores[5] == pytest.approx(vllm_scores[5], rel=0.02)
72 changes: 44 additions & 28 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import deque
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast

import torch
from torch import nn
Expand Down Expand Up @@ -1322,31 +1322,18 @@ def forward(
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
# Because we're using transformers to load the pooler
# and classifier layers and the assumption there is that
# we have a right padded batch, we need to split
# and at the batch dimension.
if isinstance(hidden_states, torch.Tensor):
split_states = torch.split(hidden_states,
pooling_metadata.prompt_lens.tolist())
return [self.pooler(h.T) for h in split_states]
else:
return [self.pooler(h.T) for h in hidden_states]


class ClassifierAdapter(torch.nn.Module):

def __init__(self, classifier: torch.nn.Module):
super().__init__()
self.classifier = classifier

def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
if hidden_states.ndim == 2:
hidden_states = hidden_states.unsqueeze(dim=0)
return self.classifier(hidden_states)
hidden_states = torch.split(hidden_states,
pooling_metadata.prompt_lens.tolist())
return [self.pooler(h.unsqueeze(dim=0)) for h in hidden_states]


def _transpose(input: torch.Tensor) -> torch.Tensor:
return input.T
def _cls(input: torch.Tensor) -> torch.Tensor:
return input[:, 0]


class SpyrePoolingModelRunner(WarmupShapesMixin,
Expand Down Expand Up @@ -1397,9 +1384,11 @@ def load_model(self, prompt_lens: Iterable[int],
if hasattr(class_model, "bert"):
self.model = class_model.bert
self._pooler = PoolerAdapter(self.model.pooler)
self.model.pooler = None
elif hasattr(class_model, "roberta"):
self.model = class_model.roberta
self._pooler = PoolerAdapter(_transpose)
self._pooler = PoolerAdapter(_cls)
self.model.pooler = None
else:
raise ValueError(
f"Unsupported model {self.model_config.model}: Expected "
Expand Down Expand Up @@ -1447,7 +1436,7 @@ def load_model(self, prompt_lens: Iterable[int],
with set_current_vllm_config(self.vllm_config):
self.pooler = ClassifierPooler(
pooling=self._pooler,
classifier=ClassifierAdapter(self.classifier),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
self.model_config),
)
Expand Down Expand Up @@ -1479,10 +1468,37 @@ def update_states(self, scheduler_output: SchedulerOutput):
self.input_batch.remove_request(req_id)
self.requests.pop(req_id, None)

def _uncompress_token_types(self) -> list[list[int]]:

pooling_metadata = self.input_batch.make_pooling_metadata()
pooling_params = pooling_metadata.pooling_params

token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params):
if param.extra_kwargs is not None and \
(token_types := param.extra_kwargs.get(
"compressed_token_type_ids")) is not None:
token_type_id_requests[i] = token_types

if len(token_type_id_requests) == 0:
return []

seq_lens = pooling_metadata.prompt_lens
token_type_ids = []

for i, seq_len in enumerate(seq_lens):
pos = token_type_id_requests.get(i, seq_len)
ids = (torch.arange(seq_lens[i]) >= pos).int()
token_type_ids.append(ids)

return token_type_ids

def _token_types(self, input_ids):
from vllm.model_executor.models import bert
if hasattr(bert, "_decode_token_type_ids"):
return bert._decode_token_type_ids(input_ids)
if (token_type_ids_lst := self._uncompress_token_types()):
token_type_ids = torch.zeros_like(input_ids)
for i, token_types in enumerate(token_type_ids_lst):
token_type_ids[i, -len(token_types):] = token_types
return token_type_ids
else:
locs = torch.where(input_ids == self.sep_token_id, 1, 0)
return locs.cumsum(dim=1) - locs
Expand Down