Skip to content

[Core] Rework dtype resolution #18751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs(

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
Expand All @@ -69,7 +68,6 @@ def test_models(
hf_runner,
model: str,
backend: str,
dtype: str,
max_tokens: int,
enforce_eager: bool,
enable_prompt_embeds: bool,
Expand Down Expand Up @@ -97,7 +95,7 @@ def test_models(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]

with hf_runner(model, dtype=dtype) as hf_model:
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
Expand All @@ -106,7 +104,6 @@ def test_models(

with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model:
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,12 @@ def __init__(
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)
self.dtype = torch_dtype = _get_and_verify_dtype(
self.model_name,
self.config,
dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder,
)

model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)
Expand Down
11 changes: 4 additions & 7 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner,
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
model_dtype = getattr(
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)

with set_default_torch_dtype(model_dtype) and hf_runner(
with set_default_torch_dtype(vllm_dtype) and hf_runner(
model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
dtype=vllm_dtype) as hf_model:

if hf_model_callback is not None:
hf_model_callback(hf_model)

st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)

print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformer:", model_dtype, st_main_score)
print("VLLM:", vllm_main_score)
print("SentenceTransformers:", st_main_score)
print("Difference:", st_main_score - vllm_main_score)

assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
2 changes: 1 addition & 1 deletion tests/models/language/pooling/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ def test_models(

# the tolerance value of 1e-2 is selected based on the
# half datatype tests in
# tests/models/embedding/language/test_embedding.py
# tests/models/language/pooling/test_embedding.py
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)
6 changes: 1 addition & 5 deletions tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model,
dtype: str,
monkeypatch,
) -> None:

Expand All @@ -58,13 +56,11 @@ def test_models(
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]

with hf_runner(model, dtype=dtype,
is_sentence_transformer=True) as hf_model:
with hf_runner(model, is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)

with vllm_runner(model,
task="embed",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/generation/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def run_test(

with vllm_runner(
model,
dtype="half",
max_model_len=448,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _test_processing_correctness(
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="float16",
dtype="auto",
revision=None,
hf_overrides=model_info.hf_overrides,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/samplers/test_no_bad_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def setup_method(self, method):
add_special_tokens=False)[0]

def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
with vllm_runner(self.MODEL, dtype="half") as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1, self.target_token_id2
Expand Down
102 changes: 80 additions & 22 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, is_lossless_cast,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
Expand Down Expand Up @@ -567,12 +568,65 @@ def test_lru_cache():
assert 6 in cache


# yapf: disable
@pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"),
[
# Different precision_levels
(torch.bool, torch.int8, True),
(torch.bool, torch.float16, True),
(torch.bool, torch.complex32, True),
(torch.int64, torch.bool, False),
(torch.int64, torch.float16, True),
(torch.int64, torch.complex32, True),
(torch.float64, torch.bool, False),
(torch.float64, torch.int8, False),
(torch.float64, torch.complex32, True),
(torch.complex128, torch.bool, False),
(torch.complex128, torch.int8, False),
(torch.complex128, torch.float16, False),
# precision_level=0
(torch.bool, torch.bool, True),
# precision_level=1
(torch.int8, torch.int16, True),
(torch.int16, torch.int8, False),
(torch.uint8, torch.int8, False),
(torch.int8, torch.uint8, False),
# precision_level=2
(torch.float16, torch.float32, True),
(torch.float32, torch.float16, False),
(torch.bfloat16, torch.float32, True),
(torch.float32, torch.bfloat16, False),
# precision_level=3
(torch.complex32, torch.complex64, True),
(torch.complex64, torch.complex32, False),
],
)
# yapf: enable
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result


# yapf: disable
@pytest.mark.parametrize(
("dtypes", "expected_result"),
[
([torch.bool], torch.bool),
([torch.bool, torch.int8], torch.int8),
([torch.bool, torch.int8, torch.float16], torch.float16),
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
],
)
# yapf: enable
def test_common_broadcastable_dtype(dtypes, expected_result):
assert common_broadcastable_dtype(dtypes) == expected_result


def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")

def build_ctx():
return pytest.raises(ModuleNotFoundError,
match="No module named")
return pytest.raises(ModuleNotFoundError, match="No module named")

with build_ctx():
int(placeholder)
Expand Down Expand Up @@ -608,6 +662,7 @@ def build_ctx():
_ = placeholder_attr.module


# yapf: disable
@pytest.mark.parametrize(
"obj,key1,key2",
[
Expand All @@ -618,6 +673,7 @@ def build_ctx():
# Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4),
])
# yapf: enable
def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy()
swap_dict_values(obj, key1, key2)
Expand All @@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2):
assert key1 not in obj


def test_model_specification(parser_with_config,
cli_config_file,
def test_model_specification(parser_with_config, cli_config_file,
cli_config_file_with_model):
# Test model in CLI takes precedence over config
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model
])
args = parser_with_config.parse_args(
['serve', 'cli-model', '--config', cli_config_file_with_model])
assert args.model_tag == 'cli-model'
assert args.served_model_name == 'mymodel'

# Test model from config file works
args = parser_with_config.parse_args([
'serve', '--config', cli_config_file_with_model,
'serve',
'--config',
cli_config_file_with_model,
])
assert args.model == 'config-model'
assert args.served_model_name == 'mymodel'
Expand All @@ -654,17 +710,19 @@ def test_model_specification(parser_with_config,

# Test using --model option raises error
with pytest.raises(
ValueError,
match=(
"With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."
),
ValueError,
match=
("With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."),
):
parser_with_config.parse_args(['serve', '--model', 'my-model'])

# Test other config values are preserved
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model,
'serve',
'cli-model',
'--config',
cli_config_file_with_model,
])
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
Expand All @@ -673,7 +731,7 @@ def test_model_specification(parser_with_config,


@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])])
(None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2])
def test_sha256(input: tuple, output: int):
hash = sha256(input)
Expand All @@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int):
assert hash != 0

bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
byteorder="big")

# hashing again, returns the same value
assert hash == sha256(input)
Expand All @@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int):
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")),
]
)
])
def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected

Expand All @@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected):
"tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host
]
)
])
def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError):
split_zmq_path(invalid_path)
Expand All @@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6():
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)

# Verify that the IPV6 option is set
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
assert zsock.getsockopt(
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"

# Clean up
zsock.close()
Expand Down
Loading