Skip to content

Commit 97ed5b0

Browse files
authored
disable transformers pooler (#481)
# Description The pooler is always run in in the transformers code even when the outputs aren't used. And in our case, we instantiate the pooler outside of the transformers code to use the vLLM code. In a small test that I'm using the total time goes from 1.8s to 1.2s,. Signed-off-by: Max de Bayser <[email protected]>
1 parent 748cc38 commit 97ed5b0

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,11 +1384,9 @@ def load_model(self, prompt_lens: Iterable[int],
13841384
if hasattr(class_model, "bert"):
13851385
self.model = class_model.bert
13861386
self._pooler = PoolerAdapter(self.model.pooler)
1387-
self.model.pooler = None
13881387
elif hasattr(class_model, "roberta"):
13891388
self.model = class_model.roberta
13901389
self._pooler = PoolerAdapter(_cls)
1391-
self.model.pooler = None
13921390
else:
13931391
raise ValueError(
13941392
f"Unsupported model {self.model_config.model}: Expected "
@@ -1397,6 +1395,11 @@ def load_model(self, prompt_lens: Iterable[int],
13971395
else:
13981396
raise ValueError(f"Unsupported task {task}")
13991397

1398+
# Disable pooler because in transformers it's
1399+
# always run even tough we don't use the outputs
1400+
# directly.
1401+
self.model.pooler = None
1402+
14001403
model_class_name = type(self.model).__name__
14011404
self.is_roberta = "roberta" in model_class_name.lower()
14021405

0 commit comments

Comments
 (0)