diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 86b5e1e0ab7c..11c8e7a4b9d1 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) -@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @@ -69,7 +68,6 @@ def test_models( hf_runner, model: str, backend: str, - dtype: str, max_tokens: int, enforce_eager: bool, enable_prompt_embeds: bool, @@ -97,7 +95,7 @@ def test_models( str(i) for i in range(1024)) + " are:" example_prompts = [prompt] - with hf_runner(model, dtype=dtype) as hf_model: + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) if enable_prompt_embeds: with torch.no_grad(): @@ -106,7 +104,6 @@ def test_models( with VllmRunner(model, max_model_len=8192, - dtype=dtype, enforce_eager=enforce_eager, enable_prompt_embeds=enable_prompt_embeds, gpu_memory_utilization=0.7) as vllm_model: diff --git a/tests/conftest.py b/tests/conftest.py index 26674483f7ae..6336c6c2ce01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -324,7 +324,12 @@ def __init__( trust_remote_code=trust_remote_code, ) self.device = self.get_default_device() - self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype) + self.dtype = torch_dtype = _get_and_verify_dtype( + self.model_name, + self.config, + dtype=dtype, + is_pooling_model=is_sentence_transformer or is_cross_encoder, + ) model_kwargs = model_kwargs if model_kwargs is not None else {} model_kwargs.setdefault("torch_dtype", torch_dtype) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index f4837ae952c3..f45168bc0f1d 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner, vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.model.llm_engine.model_config.dtype - model_dtype = getattr( - vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", - vllm_dtype) - with set_default_torch_dtype(model_dtype) and hf_runner( + with set_default_torch_dtype(vllm_dtype) and hf_runner( model_info.name, is_sentence_transformer=True, - dtype=model_dtype) as hf_model: + dtype=vllm_dtype) as hf_model: if hf_model_callback is not None: hf_model_callback(hf_model) st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) - print("VLLM:", vllm_dtype, vllm_main_score) - print("SentenceTransformer:", model_dtype, st_main_score) + print("VLLM:", vllm_main_score) + print("SentenceTransformers:", st_main_score) print("Difference:", st_main_score - vllm_main_score) assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 44af3df08a86..57b3cb58d88b 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -43,6 +43,6 @@ def test_models( # the tolerance value of 1e-2 is selected based on the # half datatype tests in - # tests/models/embedding/language/test_embedding.py + # tests/models/language/pooling/test_embedding.py assert torch.allclose(hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 306cfdf37707..8f82c8091af3 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -30,13 +30,11 @@ pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) -@pytest.mark.parametrize("dtype", ["half"]) def test_models( hf_runner, vllm_runner, example_prompts, model, - dtype: str, monkeypatch, ) -> None: @@ -58,13 +56,11 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, dtype=dtype, - is_sentence_transformer=True) as hf_model: + with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model, task="embed", - dtype=dtype, max_model_len=None, **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4e48bdbd0428..d0b85842a3d8 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -100,6 +100,7 @@ def run_test( with vllm_runner( model, + dtype="half", max_model_len=448, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 572fa366d332..d7f950c23d95 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -40,7 +40,7 @@ def _test_processing_correctness( tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, seed=0, - dtype="float16", + dtype="auto", revision=None, hf_overrides=model_info.hf_overrides, ) diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 355e3adcf5f3..f9688b4b9b27 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -103,7 +103,7 @@ def setup_method(self, method): add_special_tokens=False)[0] def test_two_token_bad_word(self, vllm_runner): - with vllm_runner(self.MODEL) as llm: + with vllm_runner(self.MODEL, dtype="half") as llm: output_token_ids = self._generate(llm) assert output_token_ids[:2] == [ self.target_token_id1, self.target_token_id2 diff --git a/tests/test_utils.py b/tests/test_utils.py index 0b88d05efeaa..dd8777f06888 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,7 +17,8 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, - bind_kv_cache, deprecate_kwargs, get_open_port, + bind_kv_cache, common_broadcastable_dtype, + deprecate_kwargs, get_open_port, is_lossless_cast, make_zmq_path, make_zmq_socket, memory_profiling, merge_async_iterators, sha256, split_zmq_path, supports_kw, swap_dict_values) @@ -567,12 +568,65 @@ def test_lru_cache(): assert 6 in cache +# yapf: disable +@pytest.mark.parametrize( + ("src_dtype", "tgt_dtype", "expected_result"), + [ + # Different precision_levels + (torch.bool, torch.int8, True), + (torch.bool, torch.float16, True), + (torch.bool, torch.complex32, True), + (torch.int64, torch.bool, False), + (torch.int64, torch.float16, True), + (torch.int64, torch.complex32, True), + (torch.float64, torch.bool, False), + (torch.float64, torch.int8, False), + (torch.float64, torch.complex32, True), + (torch.complex128, torch.bool, False), + (torch.complex128, torch.int8, False), + (torch.complex128, torch.float16, False), + # precision_level=0 + (torch.bool, torch.bool, True), + # precision_level=1 + (torch.int8, torch.int16, True), + (torch.int16, torch.int8, False), + (torch.uint8, torch.int8, False), + (torch.int8, torch.uint8, False), + # precision_level=2 + (torch.float16, torch.float32, True), + (torch.float32, torch.float16, False), + (torch.bfloat16, torch.float32, True), + (torch.float32, torch.bfloat16, False), + # precision_level=3 + (torch.complex32, torch.complex64, True), + (torch.complex64, torch.complex32, False), + ], +) +# yapf: enable +def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): + assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result + + +# yapf: disable +@pytest.mark.parametrize( + ("dtypes", "expected_result"), + [ + ([torch.bool], torch.bool), + ([torch.bool, torch.int8], torch.int8), + ([torch.bool, torch.int8, torch.float16], torch.float16), + ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501 + ], +) +# yapf: enable +def test_common_broadcastable_dtype(dtypes, expected_result): + assert common_broadcastable_dtype(dtypes) == expected_result + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") def build_ctx(): - return pytest.raises(ModuleNotFoundError, - match="No module named") + return pytest.raises(ModuleNotFoundError, match="No module named") with build_ctx(): int(placeholder) @@ -608,6 +662,7 @@ def build_ctx(): _ = placeholder_attr.module +# yapf: disable @pytest.mark.parametrize( "obj,key1,key2", [ @@ -618,6 +673,7 @@ def build_ctx(): # Tests for both keys do not exist ({1: "a", 2: "b"}, 3, 4), ]) +# yapf: enable def test_swap_dict_values(obj, key1, key2): original_obj = obj.copy() swap_dict_values(obj, key1, key2) @@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2): assert key1 not in obj -def test_model_specification(parser_with_config, - cli_config_file, +def test_model_specification(parser_with_config, cli_config_file, cli_config_file_with_model): # Test model in CLI takes precedence over config - args = parser_with_config.parse_args([ - 'serve', 'cli-model', '--config', cli_config_file_with_model - ]) + args = parser_with_config.parse_args( + ['serve', 'cli-model', '--config', cli_config_file_with_model]) assert args.model_tag == 'cli-model' assert args.served_model_name == 'mymodel' # Test model from config file works args = parser_with_config.parse_args([ - 'serve', '--config', cli_config_file_with_model, + 'serve', + '--config', + cli_config_file_with_model, ]) assert args.model == 'config-model' assert args.served_model_name == 'mymodel' @@ -654,17 +710,19 @@ def test_model_specification(parser_with_config, # Test using --model option raises error with pytest.raises( - ValueError, - match=( - "With `vllm serve`, you should provide the model as a positional " - "argument or in a config file instead of via the `--model` option." - ), + ValueError, + match= + ("With `vllm serve`, you should provide the model as a positional " + "argument or in a config file instead of via the `--model` option."), ): parser_with_config.parse_args(['serve', '--model', 'my-model']) # Test other config values are preserved args = parser_with_config.parse_args([ - 'serve', 'cli-model', '--config', cli_config_file_with_model, + 'serve', + 'cli-model', + '--config', + cli_config_file_with_model, ]) assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True @@ -673,7 +731,7 @@ def test_model_specification(parser_with_config, @pytest.mark.parametrize("input", [(), ("abc", ), (None, ), - (None, bool, [1, 2, 3])]) + (None, bool, [1, 2, 3])]) @pytest.mark.parametrize("output", [0, 1, 2]) def test_sha256(input: tuple, output: int): hash = sha256(input) @@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int): assert hash != 0 bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big") + assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), + byteorder="big") # hashing again, returns the same value assert hash == sha256(input) @@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int): ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address ("inproc://some_identifier", ("inproc", "some_identifier", "")), - ] -) + ]) def test_split_zmq_path(path, expected): assert split_zmq_path(path) == expected @@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected): "tcp://127.0.0.1", # Missing port "tcp://[::1]", # Missing port for IPv6 "tcp://:5555", # Missing host - ] -) + ]) def test_split_zmq_path_invalid(invalid_path): with pytest.raises(ValueError): split_zmq_path(invalid_path) @@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6(): zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) # Verify that the IPV6 option is set - assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" + assert zsock.getsockopt( + zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" # Clean up zsock.close() diff --git a/vllm/config.py b/vllm/config.py index 6cec97a5f11b..5776fc5f3531 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -24,6 +24,7 @@ from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator, model_validator) from pydantic.dataclasses import dataclass +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated, runtime_checkable @@ -42,15 +43,16 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - try_get_generation_config, uses_mrope) + try_get_generation_config, try_get_safetensors_metadata, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, - LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, get_open_port, is_torch_equal_or_newer, - random_uuid, resolve_obj_by_qualname) + LayerBlockType, common_broadcastable_dtype, + cuda_device_count_stateless, get_cpu_memory, + get_open_port, is_torch_equal_or_newer, random_uuid, + resolve_obj_by_qualname) if TYPE_CHECKING: from _typeshed import DataclassInstance @@ -540,7 +542,24 @@ def __post_init__(self) -> None: self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision) - self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) + + supported_tasks, task = self._resolve_task(self.task) + self.supported_tasks = supported_tasks + self.task = task + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" + + self.pooler_config = self._init_pooler_config() + + self.dtype = _get_and_verify_dtype( + self.model, + self.hf_config, + self.dtype, + is_pooling_model=self.runner_type == "pooling", + revision=self.revision, + ) # Workaround for Gemma 2 which uses interleaved sliding window # attention, but it's not specified in its config. TODO: remove this @@ -597,16 +616,6 @@ def __post_init__(self) -> None: raise ValueError( "`override_neuron_config` is only supported on Neuron.") - supported_tasks, task = self._resolve_task(self.task) - self.supported_tasks = supported_tasks - self.task = task - if self.task in ("draft", "generate"): - self.truncation_side = "left" - else: - self.truncation_side = "right" - - self.pooler_config = self._init_pooler_config() - self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -692,7 +701,6 @@ def _get_encoder_config(self): self.model, self.revision) def _init_pooler_config(self) -> Optional["PoolerConfig"]: - if self.runner_type == "pooling": if isinstance(self.override_pooler_config, dict): self.override_pooler_config = PoolerConfig( @@ -3064,13 +3072,37 @@ def compute_hash(self) -> str: "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] # +# model_type -> reason +_FLOAT16_NOT_SUPPORTED_MODELS = { + "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", + "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", +} -def _get_and_verify_dtype( +def _is_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 + return False + + return True + + +def _check_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] + raise ValueError(f"The model type {model_type!r} " + f"does not support float16. Reason: {reason}") + + return True + + +def _find_dtype( + model_id: str, config: PretrainedConfig, - dtype: Union[str, torch.dtype], -) -> torch.dtype: + *, + revision: Optional[str], +): # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) @@ -3082,75 +3114,111 @@ def _get_and_verify_dtype( if config_dtype is None and hasattr(config, "vision_config"): config_dtype = getattr(config.vision_config, "torch_dtype", None) + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + if config_dtype is None: config_dtype = torch.float32 - if isinstance(dtype, str): - dtype = dtype.lower() - if dtype == "auto": - # Set default dtype from model config - if config_dtype == torch.float32: - # Following common practice, we use float16 for float32 models - torch_dtype = torch.float16 - else: - torch_dtype = config_dtype + return config_dtype - if config.model_type == "plamo2": - logger.warning( - "For PLaMo2, we cast models to bfloat16 instead of using " - "float16 by default. This is because float16 does not work." - ) - torch_dtype = torch.bfloat16 - # Deal with torch dtype fallback for device compatibility. - from vllm.platforms import current_platform - if torch_dtype not in current_platform.supported_dtypes: - device_name = current_platform.get_device_name() +def _resolve_auto_dtype( + model_type: str, + config_dtype: torch.dtype, + *, + is_pooling_model: bool, +): + from vllm.platforms import current_platform - if ((capability := current_platform.get_device_capability()) - is None): - compute_str = "" - else: - version_str = capability.as_version_str() - compute_str = f" (with compute capability {version_str})" - fallback_dtype = current_platform.supported_dtypes[0] - logger.warning( - "Your %s device%s doesn't support %s. " \ - "Falling back to %s for compatibility.", - device_name, compute_str, torch_dtype, fallback_dtype - ) - torch_dtype = fallback_dtype + supported_dtypes = [ + dtype for dtype in current_platform.supported_dtypes + if _is_valid_dtype(model_type, dtype) + ] - if current_platform.is_hpu() and torch_dtype == torch.float16: - logger.warning( - "For HPU, we cast models to bfloat16 instead of " - "using float16 by default. Please specify `dtype` if you " - "want to use float16.") - torch_dtype = torch.bfloat16 - elif dtype == "float16" and config.model_type == "plamo2": - logger.warning( - "For PLaMo2, using float16 is unstable and might cause " - "unexpected behavior. Please use bfloat16 or float32 instead.") - torch_dtype = torch.float16 + if is_pooling_model and torch.float16 in supported_dtypes: + preferred_dtype = torch.float16 + else: + preferred_dtype = supported_dtypes[0] + + # Downcast for float32 models + if config_dtype == torch.float32: + config_dtype = preferred_dtype + + if config_dtype in supported_dtypes: + return config_dtype + + # Ensure device compatibility + device_name = current_platform.get_device_name() + device_capability = current_platform.get_device_capability() + + if device_capability is None: + device_str = f"{device_name!r}" + else: + version_str = device_capability.as_version_str() + device_str = f"{device_name!r} (with compute capability {version_str})" + + logger.warning( + "Your device %s doesn't support %s. " + "Falling back to %s for compatibility.", + device_str, + config_dtype, + preferred_dtype, + ) + + return preferred_dtype + + +def _get_and_verify_dtype( + model_id: str, + config: PretrainedConfig, + dtype: Union[str, torch.dtype], + *, + is_pooling_model: bool, + revision: Optional[str] = None, +) -> torch.dtype: + config_dtype = _find_dtype(model_id, config, revision=revision) + model_type = config.model_type + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + # Set default dtype from model config + torch_dtype = _resolve_auto_dtype( + model_type, + config_dtype, + is_pooling_model=is_pooling_model, + ) else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: - raise ValueError(f"Unknown dtype: {dtype}") + raise ValueError(f"Unknown dtype: {dtype!r}") torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] elif isinstance(dtype, torch.dtype): torch_dtype = dtype else: raise ValueError(f"Unknown dtype: {dtype}") - # Verify the dtype. + _check_valid_dtype(model_type, torch_dtype) + if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) - pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) - pass else: # Casting between float16 and bfloat16 is allowed with a warning. logger.warning("Casting %s to %s.", config_dtype, torch_dtype) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index c79c603c02eb..eaffaac78cce 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -28,7 +28,7 @@ class CpuPlatform(Platform): dispatch_key: str = "CPU" @property - def supported_dtypes(self) -> list: + def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: return [torch.bfloat16, torch.float32] elif sys.platform.startswith( diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 2ed71a4d334b..8774f95a2f60 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,12 +4,12 @@ import json import os import time -from functools import cache +from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub -from huggingface_hub import hf_hub_download +from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, @@ -93,10 +93,15 @@ class ConfigFormat(str, enum.Enum): MISTRAL = "mistral" -def with_retry(func: Callable[[], Any], - log_msg: str, - max_retries: int = 2, - retry_delay: int = 2): +_R = TypeVar("_R") + + +def with_retry( + func: Callable[[], _R], + log_msg: str, + max_retries: int = 2, + retry_delay: int = 2, +) -> _R: for attempt in range(max_retries): try: return func() @@ -109,6 +114,8 @@ def with_retry(func: Callable[[], Any], time.sleep(retry_delay) retry_delay *= 2 + raise AssertionError("Should not be reached") + # @cache doesn't cache exceptions @cache @@ -840,3 +847,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return resolve_obj_by_qualname(function_name)() else: return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() + + +def try_get_safetensors_metadata( + model: str, + *, + revision: Optional[str] = None, +): + get_safetensors_metadata_partial = partial( + get_safetensors_metadata, + model, + revision=revision, + token=os.getenv('HF_TOKEN', None), + ) + + try: + return with_retry(get_safetensors_metadata_partial, + "Error retrieving safetensors") + except Exception: + return None diff --git a/vllm/utils.py b/vllm/utils.py index c1213d463c21..3b6acaa37ae0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -37,8 +37,8 @@ _ArgumentGroup) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, - Iterable, Iterator, KeysView, Mapping) +from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, + Hashable, Iterable, Iterator, KeysView, Mapping) from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps @@ -979,6 +979,53 @@ def get_dtype_size(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() +# bool = 0, int = 1, float = 2, complex = 3 +def _get_precision_level(dtype: torch.dtype) -> int: + # NOTE: Complex dtypes return `is_floating_point=False` + return ((dtype != torch.bool) + dtype.is_floating_point + + dtype.is_complex * 2) + + +def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): + """ + Test whether it is lossless to cast a tensor from + `src_dtype` to `tgt_dtype`. + """ + if src_dtype == tgt_dtype: + return True + + src_level = _get_precision_level(src_dtype) + tgt_level = _get_precision_level(tgt_dtype) + + if src_level < tgt_level: + return True + if src_level > tgt_level: + return False + + # Compare integral types + if not src_dtype.is_floating_point and not src_dtype.is_complex: + src_info = torch.iinfo(src_dtype) + tgt_info = torch.iinfo(tgt_dtype) + return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + + # Compare floating-point types + src_info = torch.finfo(src_dtype) + tgt_info = torch.finfo(tgt_dtype) + return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution) + + +def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): + """ + Get the common `dtype` where all of the other `dtypes` can be + cast to it without losing any information. + """ + return max( + dtypes, + key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), + ) + + # `collections` helpers def is_list_of( value: object,