diff --git a/tests/e2e/test_spyre_scoring.py b/tests/e2e/test_spyre_scoring.py index ffa62fb4..613538f4 100644 --- a/tests/e2e/test_spyre_scoring.py +++ b/tests/e2e/test_spyre_scoring.py @@ -18,19 +18,35 @@ def test_serving(remote_openai_server, model, warmup_shapes, backend): score_url = remote_openai_server.url_for("/score") query = "What is the capital of France?" + # Number of inputs larger than the warmup batch size of 4 + # and with a non-uniform token length docs = [ - "The capital of France is Paris.", "The capital of Germany is Berlin." + "The capital of France is Paris.", "The capital of Germany is Berlin.", + "The capital of Brazil is Brasilia.", + "The capital of the country with the best beer is Berlin.", + "The capital of the USA is Washington.", + "The capital city of Spain is Madrid." ] vllm_outputs = requests.post(url=score_url, json={ "text_1": query, - "text_2": [docs[0], docs[1]] + "text_2": docs, }).json() + vllm_scores = [o["score"] for o in vllm_outputs["data"]] ce_model = CrossEncoder(model.name, revision=model.revision) - ce_scores = ce_model.predict([(query, docs[0]), (query, docs[1])]) - - vllm_scores = [o["score"] for o in vllm_outputs["data"]] + ce_scores = ce_model.predict([ + (query, docs[0]), + (query, docs[1]), + (query, docs[2]), + (query, docs[3]), + (query, docs[4]), + (query, docs[5]), + ]) assert ce_scores[0] == pytest.approx(vllm_scores[0], rel=0.02) assert ce_scores[1] == pytest.approx(vllm_scores[1], rel=0.02) + assert ce_scores[2] == pytest.approx(vllm_scores[2], rel=0.02) + assert ce_scores[3] == pytest.approx(vllm_scores[3], rel=0.02) + assert ce_scores[4] == pytest.approx(vllm_scores[4], rel=0.02) + assert ce_scores[5] == pytest.approx(vllm_scores[5], rel=0.02) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 99f099ca..51cc73b9 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -5,7 +5,7 @@ from collections import deque from collections.abc import Iterable from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast import torch from torch import nn @@ -1322,31 +1322,18 @@ def forward( hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[torch.Tensor, list[torch.Tensor]]: + # Because we're using transformers to load the pooler + # and classifier layers and the assumption there is that + # we have a right padded batch, we need to split + # and at the batch dimension. if isinstance(hidden_states, torch.Tensor): - split_states = torch.split(hidden_states, - pooling_metadata.prompt_lens.tolist()) - return [self.pooler(h.T) for h in split_states] - else: - return [self.pooler(h.T) for h in hidden_states] - - -class ClassifierAdapter(torch.nn.Module): - - def __init__(self, classifier: torch.nn.Module): - super().__init__() - self.classifier = classifier - - def forward( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - if hidden_states.ndim == 2: - hidden_states = hidden_states.unsqueeze(dim=0) - return self.classifier(hidden_states) + hidden_states = torch.split(hidden_states, + pooling_metadata.prompt_lens.tolist()) + return [self.pooler(h.unsqueeze(dim=0)) for h in hidden_states] -def _transpose(input: torch.Tensor) -> torch.Tensor: - return input.T +def _cls(input: torch.Tensor) -> torch.Tensor: + return input[:, 0] class SpyrePoolingModelRunner(WarmupShapesMixin, @@ -1397,9 +1384,11 @@ def load_model(self, prompt_lens: Iterable[int], if hasattr(class_model, "bert"): self.model = class_model.bert self._pooler = PoolerAdapter(self.model.pooler) + self.model.pooler = None elif hasattr(class_model, "roberta"): self.model = class_model.roberta - self._pooler = PoolerAdapter(_transpose) + self._pooler = PoolerAdapter(_cls) + self.model.pooler = None else: raise ValueError( f"Unsupported model {self.model_config.model}: Expected " @@ -1447,7 +1436,7 @@ def load_model(self, prompt_lens: Iterable[int], with set_current_vllm_config(self.vllm_config): self.pooler = ClassifierPooler( pooling=self._pooler, - classifier=ClassifierAdapter(self.classifier), + classifier=self.classifier, act_fn=ClassifierPooler.act_fn_for_cross_encoder( self.model_config), ) @@ -1479,10 +1468,37 @@ def update_states(self, scheduler_output: SchedulerOutput): self.input_batch.remove_request(req_id) self.requests.pop(req_id, None) + def _uncompress_token_types(self) -> list[list[int]]: + + pooling_metadata = self.input_batch.make_pooling_metadata() + pooling_params = pooling_metadata.pooling_params + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return [] + + seq_lens = pooling_metadata.prompt_lens + token_type_ids = [] + + for i, seq_len in enumerate(seq_lens): + pos = token_type_id_requests.get(i, seq_len) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + return token_type_ids + def _token_types(self, input_ids): - from vllm.model_executor.models import bert - if hasattr(bert, "_decode_token_type_ids"): - return bert._decode_token_type_ids(input_ids) + if (token_type_ids_lst := self._uncompress_token_types()): + token_type_ids = torch.zeros_like(input_ids) + for i, token_types in enumerate(token_type_ids_lst): + token_type_ids[i, -len(token_types):] = token_types + return token_type_ids else: locs = torch.where(input_ids == self.sep_token_id, 1, 0) return locs.cumsum(dim=1) - locs