Skip to content

[Bugfix] Fix the failing gte embedding test #18720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 29, 2025

Conversation

Isotr0py
Copy link
Collaborator

@Isotr0py Isotr0py commented May 26, 2025

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

DarkLight1337 commented May 26, 2025

Hmm, perhaps we can try setting dtype="float32" for this model?

@Isotr0py
Copy link
Collaborator Author

Hmmm, we can set fp32 for gte-Qwen2-1.5B, but for "ssmits/Qwen2-7B-Instruct-embed-base", seems that the CI machine won't have enough VRAM to run it with fp32.

Signed-off-by: Isotr0py <[email protected]>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, see if this works then

@Isotr0py
Copy link
Collaborator Author

Hmmm, not sure why the test failed on CI, while it passed locally. Let me try to reproduce the failure on another machine:

$ pytest -s -v tests/models/language/pooling/test_embedding.py -k half-ssmits/Qwen2-7B-Instruct-embed-base
INFO 05-27 03:00:20 [__init__.py:243] Automatically detected platform cuda.
==================================================================== test session starts ====================================================================
platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.6.0 -- /kaggle/working/vllm/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /kaggle/working/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0
collecting ... WARNING 05-27 03:00:22 [interface.py:470] Current platform cuda does not have '_pytestfixturefunction' attribute.
WARNING 05-27 03:00:22 [interface.py:470] Current platform cuda does not have '__test__' attribute.
WARNING 05-27 03:00:22 [interface.py:470] Current platform cuda does not have '__bases__' attribute.
WARNING 05-27 03:00:22 [interface.py:470] Current platform cuda does not have '__test__' attribute.
collected 8 items / 7 deselected / 1 selected                                                                                                               

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 11.36it/s]
INFO 05-27 03:00:31 [__init__.py:31] Available plugins for group vllm.general_plugins:
INFO 05-27 03:00:31 [__init__.py:33] - lora_filesystem_resolver -> vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
INFO 05-27 03:00:31 [__init__.py:36] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
WARNING 05-27 03:00:41 [arg_utils.py:1591] Compute Capability < 8.0 is not supported by the V1 Engine. Falling back to V0. 
INFO 05-27 03:00:41 [llm_engine.py:230] Initializing a V0 LLM engine (v0.8.5.dev937+g27bebcd89) with config: model='ssmits/Qwen2-7B-Instruct-embed-base', speculative_config=None, tokenizer='ssmits/Qwen2-7B-Instruct-embed-base', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=ssmits/Qwen2-7B-Instruct-embed-base, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=False, pooler_config=PoolerConfig(pooling_type='MEAN', normalize=None, softmax=None, step_tag_id=None, returned_token_ids=None), compilation_config={"compile_sizes": [], "inductor_compile_config": {"enable_auto_functionalized_v2": false}, "cudagraph_capture_sizes": [256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 4, 2, 1], "max_capture_size": 256}, use_cached_outputs=False, 
INFO 05-27 03:00:43 [cuda.py:224] Using XFormers backend.
...
INFO 05-27 03:01:09 [default_loader.py:280] Loading weights took 4.65 seconds
INFO 05-27 03:01:09 [model_runner.py:1202] Model loading took 13.2335 GiB and 4.900560 seconds
Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2445.30it/s]
Processed prompts: 100%|███████████████████████████████████████████████████| 8/8 [00:00<00:00, 32.72it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
PASSED

============================================================= 1 passed, 7 deselected in 48.12s ==============================================================

Isotr0py added 5 commits May 27, 2025 15:41
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
@Isotr0py
Copy link
Collaborator Author

Hmmm, seems that the problematic outputs are from hf_runner, here is the passing test0 outputs locally:

Test0:
Cosine similarity:      0.9999
hf:     array([-0.5522 , -0.01881, -0.438  ,  0.87   , -0.04807, -2.197  ,
       -1.328  , -0.688  , -0.32   , -7.28   ,  1.792  ,  0.3933 ,
        0.1716 , -0.689  ,  0.3757 , -0.3765 ], dtype=float16)
vllm:   [-0.5517578125, -0.0159149169921875, -0.436279296875, 0.86962890625, -0.04962158203125, -2.197265625, -1.3291015625, -0.6865234375, -0.32080078125, -7.28125, 1.7939453125, 0.39306640625, 0.1727294921875, -0.68896484375, 0.37255859375, -0.376220703125]

And the failing test0 outputs from https://buildkite.com/vllm/fastcheck/builds/25254#01971122-90b4-40a0-9e77-86261080bb5b/212-1864:

[2025-05-27T11:00:15Z] E           AssertionError: Test0:
--
  | [2025-05-27T11:00:15Z] E           Cosine similarity: 	0.8102
  | [2025-05-27T11:00:15Z] E           hf:	array([-0.3745 , -0.807  , -1.024  ,  1.553  , -0.02649, -1.482  ,
  | [2025-05-27T11:00:15Z] E                  -1.159  , -1.292  ,  1.325  , -5.93   ,  0.796  , -1.122  ,
  | [2025-05-27T11:00:15Z] E                   1.014  , -0.7275 ,  2.037  ,  0.3257 ], dtype=float16)
  | [2025-05-27T11:00:15Z] E           vllm:	[-0.5537109375, -0.016204833984375, -0.43798828125, 0.87255859375, -0.049407958984375, -2.197265625, -1.3271484375, -0.68798828125, -0.321044921875, -7.28125, 1.7919921875, 0.393310546875, 0.171142578125, -0.6884765625, 0.36865234375, -0.37548828125]

Perhaps there is something causing the hf_runner not initialized properly on CI?

@noooop
Copy link
Contributor

noooop commented May 28, 2025

Sorry I don't have enough GPU memory to run the float32 test, but float16 only scores 0.44361902720248764 on mteb\sts12,

which is as same as random to me.

100M models can achieve 0.70+.

Should we really need test this model in CI?

@DarkLight1337
Copy link
Member

You need to use mean pooling for this model.

@noooop
Copy link
Contributor

noooop commented May 28, 2025

You need to use mean pooling for this model.

https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_mteb.py

python benchmarks/test_mteb/test_mteb.py ssmits/Qwen2-7B-Instruct-embed-base

The result of SentenceTransformer is 0.44361902720248764

Wait a moment, is it a bug with SentenceTransformer or Transformers? I will try downgrading .

@DarkLight1337
Copy link
Member

DarkLight1337 commented May 28, 2025

Based on the note in https://docs.vllm.ai/en/latest/models/supported_models.html#text-embedding, it looks like the sentence transformers config is incorrect as well. You should set mean pooling explicitly for both vLLM and sentence transformers

@Isotr0py
Copy link
Collaborator Author

Hmmm, but I remember Sentence-Transformers initialize mean pooling by default for models missing pooler config: https://github.com/UKPLab/sentence-transformers/blob/dd76c033ac2161a2958fee2e18fd0227a81ee921/sentence_transformers/SentenceTransformer.py#L1508-L1530

@noooop
Copy link
Contributor

noooop commented May 28, 2025

Hmmm, but I remember Sentence-Transformers initialize mean pooling by default for models missing pooler config: https://github.com/UKPLab/sentence-transformers/blob/dd76c033ac2161a2958fee2e18fd0227a81ee921/sentence_transformers/SentenceTransformer.py#L1508-L1530

+1

There is the following output

No sentence-transformers model found with name ssmits/Qwen2-7B-Instruct-embed-base. Creating a new one with mean pooling.

@noooop
Copy link
Contributor

noooop commented May 28, 2025

but "pooling_mode_lasttoken": true, in 1_Pooling/1_Pooling_config.json

https://huggingface.co/ssmits/Qwen2-7B-Instruct-embed-base/blob/main/1_Pooling/1_Pooling_config.json

@noooop
Copy link
Contributor

noooop commented May 28, 2025

mteb\sts12 dataset

Sentence-Transformers 4.1.0 float16 0.44361902720248764 <- I think it uses MEAN pooling.

Default float16 ssmits/Qwen2-7B-Instruct-embed-base 0.44367362158651924 4.433834209841097e-05
MEAN float16 ssmits/Qwen2-7B-Instruct-embed-base 0.44367362158651924 4.433834209841097e-05
LAST float16 ssmits/Qwen2-7B-Instruct-embed-base 0.3198316580165386 8.225774123289438e-05

I think the vllm code defaults to using MEAN pooling.

Try adding ssmits/Qwen2-7B-Instruct-embed-base to tests/models/language/pooling/test_gte.py, see if it passes the tests.

@Isotr0py
Copy link
Collaborator Author

Isotr0py commented May 28, 2025

IIRC, ssmits/Qwen2-7B-Instruct-embed-base is just converted from Qwen2-7B-Instruct by removing lm_head and overwriting the config.json. Perhaps we can convert a smaller ones to allow testing with fp32?

@noooop
Copy link
Contributor

noooop commented May 28, 2025

Perhaps there is something causing the hf_runner not initialized properly on CI? + 1

using vllm_runner will result in nan, but from vllm import LLM will not

# SPDX-License-Identifier: Apache-2.0
import gc
from collections.abc import Sequence

import mteb
import numpy as np
import pytest
import torch

from .mteb_utils import (MTEB_EMBED_TASKS, MTEB_EMBED_TOL, VllmMtebEncoder,
                         run_mteb_embed_task)

model = "ssmits/Qwen2-7B-Instruct-embed-base"


@pytest.mark.parametrize("dtype", ["float16"])
def test_embed_models_mteb(hf_runner, vllm_runner, dtype: str) -> None:

    with vllm_runner(model, task="embed", max_model_len=None,
                     dtype=dtype) as vllm_model:

        vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
                                              MTEB_EMBED_TASKS)
        # ValueError: Input contains NaN.
    """
    X = array(
        [[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan],
         ..., [-0.00183868, 0.00259018, -0.00291634, ..., -0.00311279, 0.00650024, -0.01535797],
         [0.00305748, 0.00748825, -0.01096344, ..., 0.00295448, 0.00697327, -0.00764465],
         [nan, nan, nan, ..., nan, nan, nan]], shape=(3108, 3584))
    """

    with hf_runner(model, is_sentence_transformer=True,
                   dtype=dtype) as hf_model:

        st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)

    print("VLLM:", dtype, vllm_main_score)
    print("SentenceTransformer:", dtype, st_main_score)
    print("Difference:", st_main_score - vllm_main_score)

    assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)


class VllmEncoder(mteb.Encoder):

    def __init__(self, model, dtype, trust_remote_code: bool = True, **kwargs):
        super().__init__()
        from vllm import LLM

        self.model = LLM(model=model,
                         task="embed",
                         dtype=dtype,
                         trust_remote_code=trust_remote_code,
                         max_num_seqs=4,
                         **kwargs)

    def encode(self, sentences: Sequence[str], **kwargs) -> np.ndarray:
        outputs = self.model.embed(sentences, use_tqdm=False)
        embeds = np.array([o.outputs.embedding for o in outputs])
        return embeds


@pytest.mark.parametrize("dtype", ["float16"])
def test_embed_models_mteb2(hf_runner, vllm_runner, dtype: str) -> None:
    vllm_main_score = run_mteb_embed_task(VllmEncoder(model, dtype=dtype),
                                          MTEB_EMBED_TASKS)

    gc.collect()
    torch.cuda.empty_cache()

    with hf_runner(model, is_sentence_transformer=True,
                   dtype=dtype) as hf_model:

        st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)

    print("VLLM:", dtype, vllm_main_score)
    print("SentenceTransformer:", dtype, st_main_score)
    print("Difference:", st_main_score - vllm_main_score)

    assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

@Isotr0py
Copy link
Collaborator Author

Hmmm, I tried to use Qwen/Qwen2.5-0.5B-Instruct as embedding models in this mteb tests with fp32, and it can pass the test after using mean pooling on vllm runner:

# SPDX-License-Identifier: Apache-2.0
import gc
from collections.abc import Sequence

import mteb
import numpy as np
import pytest
import torch

from vllm.config import PoolerConfig

from .mteb_utils import (MTEB_EMBED_TASKS, MTEB_EMBED_TOL, VllmMtebEncoder,
                         run_mteb_embed_task)

model = "Qwen/Qwen2.5-0.5B-Instruct"


@pytest.mark.parametrize("dtype", ["float32"])
def test_embed_models_mteb(hf_runner, vllm_runner, dtype: str) -> None:

    vllm_extra_kwargs = {}
    vllm_extra_kwargs["override_pooler_config"] = \
            PoolerConfig(pooling_type="MEAN", normalize=False)

    with vllm_runner(model, task="embed", max_model_len=None,
                     dtype=dtype, **vllm_extra_kwargs) as vllm_model:

        vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
                                              MTEB_EMBED_TASKS)
        # ValueError: Input contains NaN.
    """
    X = array(
        [[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan],
         ..., [-0.00183868, 0.00259018, -0.00291634, ..., -0.00311279, 0.00650024, -0.01535797],
         [0.00305748, 0.00748825, -0.01096344, ..., 0.00295448, 0.00697327, -0.00764465],
         [nan, nan, nan, ..., nan, nan, nan]], shape=(3108, 3584))
    """

    with hf_runner(model, is_sentence_transformer=True,
                   dtype=dtype) as hf_model:

        st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)

    print("VLLM:", dtype, vllm_main_score)
    print("SentenceTransformer:", dtype, st_main_score)
    print("Difference:", st_main_score - vllm_main_score)

    assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)


class VllmEncoder(mteb.Encoder):

    def __init__(self, model, dtype, trust_remote_code: bool = True, **kwargs):
        super().__init__()
        from vllm import LLM

        self.model = LLM(model=model,
                         task="embed",
                         dtype=dtype,
                         trust_remote_code=trust_remote_code,
                         max_num_seqs=4,
                         **kwargs)

    def encode(self, sentences: Sequence[str], **kwargs) -> np.ndarray:
        outputs = self.model.embed(sentences, use_tqdm=False)
        embeds = np.array([o.outputs.embedding for o in outputs])
        return embeds


@pytest.mark.parametrize("dtype", ["float32"])
def test_embed_models_mteb2(hf_runner, vllm_runner, dtype: str) -> None:
    vllm_extra_kwargs = {}
    vllm_extra_kwargs["override_pooler_config"] = \
            PoolerConfig(pooling_type="MEAN", normalize=False)

    vllm_main_score = run_mteb_embed_task(VllmEncoder(model, dtype=dtype, **vllm_extra_kwargs),
                                          MTEB_EMBED_TASKS)

    gc.collect()
    torch.cuda.empty_cache()

    with hf_runner(model, is_sentence_transformer=True,
                   dtype=dtype) as hf_model:

        st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)

    print("VLLM:", dtype, vllm_main_score)
    print("SentenceTransformer:", dtype, st_main_score)
    print("Difference:", st_main_score - vllm_main_score)

    assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

However, the test changed to Qwen/Qwen2.5-0.5B-Instruct and FP32 can't pass on CI at all...(https://buildkite.com/vllm/fastcheck/builds/25436#01971632-0c7e-4a0c-81ec-6d275c543971/213-1928)

Signed-off-by: Isotr0py <[email protected]>
@mergify mergify bot added the ci/build label May 28, 2025
@Isotr0py
Copy link
Collaborator Author

Isotr0py commented May 28, 2025

Oh, I finally found out what's wrong... Don't know why, but seems hf_model on CI is using bidirectional attention is_casual=False (probably because these Qwen2 converted models don't have is_casual in config, and it reused cached custom model code from Alibaba-NLP/gte-Qwen2-1.5B-instruct which set is_casual=False by default)...

I can reproduce the reversed CI failure (https://buildkite.com/vllm/fastcheck/builds/25436#01971632-0c7e-4a0c-81ec-6d275c543971/213-1928) locally by adding hf_overrides = {"is_causal": False} now:

Test0:
Cosine similarity:      0.8341
hf:     array([ 0.37173074,  1.1417271 ,  2.046475  ,  1.9221017 ,  0.26261473,
       -0.80174524, -1.0346788 , -5.9414687 , -3.2826447 ,  1.7353998 ,
        2.083949  , -1.9396118 , -3.4990685 , -3.533126  , -1.8913524 ,
        0.84274614], dtype=float32)
vllm:   [1.5029491186141968, 0.2734091281890869, 1.623220443725586, 3.9636447429656982, -0.7554096579551697, -1.0056668519973755, -2.1818337440490723, -6.111933708190918, -3.677251100540161, 4.014033317565918, -0.005377042107284069, -1.6859935522079468, 0.05065629631280899, 0.6660666465759277, 2.703446865081787, -4.743410110473633]

@noooop
Copy link
Contributor

noooop commented May 29, 2025

it reused cached custom model code from Alibaba-NLP/gte-Qwen2-1.5B-instruct which set is_casual=False by default)...

I feel that the errors in ssmits/Qwen2-7B-Instruct-embed-base and Alibaba-NLP/gte-Qwen2-1.5B-instruct occur simultaneously and are possibly related.

LOL

Comment on lines 13 to 22
# Be careful of the order of models, decoder-only models should be
# placed before encoder-only models, otherwise `Qwen2.5-0.5B-Instruct`
# case won't pass because gte-Qwen2-1.5B-instruct will cache custom
# model code with bidirectional attention.
# [Decoder-only]
pytest.param("BAAI/bge-multilingual-gemma2",
marks=[pytest.mark.core_model]),
pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("Qwen/Qwen2.5-0.5B-Instruct"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that the test order is the case, if we test Alibaba-NLP/gte-Qwen2-1.5B-instruct before ssmits/Qwen2-7B-Instruct-embed-base, the test pipeline will fail at ssmits/Qwen2-7B-Instruct-embed-base.

And if we test ssmits/Qwen2-7B-Instruct-embed-base before Alibaba-NLP/gte-Qwen2-1.5B-instruct, the test pipeline can pass now.

I only ran the single test before, so that I can't reproduce it because the custom code won't be cached. 😅

Copy link
Contributor

@noooop noooop May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it is caused by _cached_get_attn_backend

@cache
def _cached_get_attn_backend(

I have also encountered test error caused by cache before, and using cache_clear can temporarily fix it.

#18755 (comment)

c746de9

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the cause, we need to clear the attn_backend cache in vllm_runner to solve it entirely .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems it's more likely a bug in sentence-transformers or transformers instead of vLLM:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("Qwen/Qwen2.5-0.5B-Instruct")
embed_ref = model.encode(["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:, :16]
print(embed_ref)

# array([[ 0.37174752,  1.1417291 ,  2.0464673 ,  1.9220997 ,  0.2626055 ,
#         -0.8017454 , -1.0346682 , -5.9414687 , -3.2826447 ,  1.7354006 ,
#          2.0839477 , -1.9396178 , -3.4990966 , -3.533139  , -1.8913338 ,
#          0.84275854]], dtype=float32)

model1 = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
model1.encode(["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:, :16]

model2 = SentenceTransformer("Qwen/Qwen2.5-0.5B-Instruct")
embed_bug = model2.encode(["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:, :16]
print(embed_bug)

# array([[ 1.5029449e+00,  2.7341652e-01,  1.6232374e+00,  3.9635556e+00,
#         -7.5535011e-01, -1.0056579e+00, -2.1817753e+00, -6.1119280e+00,
#         -3.6772854e+00,  4.0140610e+00, -5.4029943e-03, -1.6860790e+00,
#          5.0638936e-02,  6.6610152e-01,  2.7034342e+00, -4.7433991e+00]],
#       dtype=float32)

Seems that trust_remote_code=False is not respected in SentenceTransformer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alibaba-NLP/gte-Qwen2-1.5B-instruct set trust_remote_code=True

modeling_qwen module in hf was used instead of the built-in version from transformers.

https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct/blob/main/modeling_qwen.py

There is no reload when inference ends.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using clear_import_cache cannot fix it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, clear_import_cache has no effect, outputs are still different after loading Alibaba-NLP/gte-Qwen2-1.5B-instruct:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=False)
embed_ref = model.encode(["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:, :16]
print(embed_ref)

# array([[ 0.37174752,  1.1417291 ,  2.0464673 ,  1.9220997 ,  0.2626055 ,
#         -0.8017454 , -1.0346682 , -5.9414687 , -3.2826447 ,  1.7354006 ,
#          2.0839477 , -1.9396178 , -3.4990966 , -3.533139  , -1.8913338 ,
#          0.84275854]], dtype=float32)

from transformers.utils.import_utils import clear_import_cache

SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
clear_import_cache()

from sentence_transformers import SentenceTransformer
from transformers import AutoModel

model2 = SentenceTransformer("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=False)
embed_bug = model2.encode(["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:, :16]
print(embed_bug)

# array([[ 1.5029449e+00,  2.7341652e-01,  1.6232374e+00,  3.9635556e+00,
#         -7.5535011e-01, -1.0056579e+00, -2.1817753e+00, -6.1119280e+00,
#         -3.6772854e+00,  4.0140610e+00, -5.4029943e-03, -1.6860790e+00,
#          5.0638936e-02,  6.6610152e-01,  2.7034342e+00, -4.7433991e+00]],
#       dtype=float32)

Anyway, since this is possibly an issue at sentence-transformers side, I think we should create an issue in their repo to get more insights, which is more effective :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because sentence_transformers is still importing transformers, there is no clean way to reload transformers.

Copy link
Contributor

@noooop noooop May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be fixed using very hacky way.

import sys
import importlib


def t1():
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("Qwen/Qwen2.5-0.5B-Instruct")
    embed_ref = model.encode(["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:,
                :16]
    print(embed_ref)

    model1 = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
    model1.encode(["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:, :16]


def t2():
    from sentence_transformers import SentenceTransformer
    model2 = SentenceTransformer("Qwen/Qwen2.5-0.5B-Instruct")
    embed_bug = model2.encode(
        ["vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."])[:, :16]
    print(embed_bug)

def clean():
    # Deleting any line of code here cannot fix it.
    sentence_transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("sentence_transformers.")]
    for mod_name in sentence_transformers_modules:
        del sys.modules[mod_name]

    transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")]
    for mod_name in transformers_modules:
        del sys.modules[mod_name]

    import transformers
    import sentence_transformers

    importlib.reload(transformers)
    importlib.reload(sentence_transformers)

t1()
clean()
t2()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there is a bug at transformers side after these investigations, because if we load Qwen2 with trust_remote_code=False, it shouldn't load from any custom code even if that code has been trusted and cached before.

Anyway, since the extended pooling CI is green now. Let's merge this now and leave the cached module to be fixed at transformers side.

@Isotr0py Isotr0py enabled auto-merge (squash) May 29, 2025 13:25
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 29, 2025
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating this! Can you link the transformers bug report to this PR later?

@vllm-bot vllm-bot merged commit c9479b2 into vllm-project:main May 29, 2025
52 of 57 checks passed
@Isotr0py Isotr0py deleted the fix-gte-test branch May 29, 2025 14:46
@Isotr0py
Copy link
Collaborator Author

I'm taking a look into this issue in transformers to make a quick fix PR for it directly. Will create a bug report if it's too complicated to fix.

amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants