diff --git a/docs/user_guide/supported_models.md b/docs/user_guide/supported_models.md index 13bdaaa91..cf7f4655c 100644 --- a/docs/user_guide/supported_models.md +++ b/docs/user_guide/supported_models.md @@ -38,7 +38,7 @@ configurations. | [Multilingual-E5-large][] | 1 | 512 | 64 | [Granite-3.3-8b]: https://huggingface.co/ibm-granite/granite-3.3-8b-instruct -[Granite-3.3-8b (FP8)]: https://huggingface.co/ibm-granite/granite-3.3-8b-instruct +[Granite-3.3-8b (FP8)]: https://huggingface.co/ibm-granite/granite-3.3-8b-instruct-FP8 [Granite-Embedding-125m (English)]: https://huggingface.co/ibm-granite/granite-embedding-125m-english [Granite-Embedding-278m (Multilingual)]: https://huggingface.co/ibm-granite/granite-embedding-278m-multilingual [BAAI/BGE-Reranker (v2-m3)]: https://huggingface.co/BAAI/bge-reranker-v2-m3 diff --git a/tests/download_model_configs.py b/tests/download_model_configs.py index a90d140cd..5973a95af 100755 --- a/tests/download_model_configs.py +++ b/tests/download_model_configs.py @@ -42,9 +42,6 @@ def download_model_config_from_hf(hf_model_id: str, revision: str = "main"): if __name__ == '__main__': model_ids = get_supported_models_list() for model_id in model_ids: - # TODO: get the actual FP8 model config - if "-FP8" in model_id: - continue config = download_hf_model_config(model_id) # download_model_config_from_hf(model_id) print(f"model_id: {model_id}") diff --git a/tests/fixtures/model_configs/ibm-granite/granite-3.3-8b-instruct-FP8/config.json b/tests/fixtures/model_configs/ibm-granite/granite-3.3-8b-instruct-FP8/config.json new file mode 100644 index 000000000..021a8e78c --- /dev/null +++ b/tests/fixtures/model_configs/ibm-granite/granite-3.3-8b-instruct-FP8/config.json @@ -0,0 +1,74 @@ +{ + "architectures": [ + "GraniteForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "attention_multiplier": 0.0078125, + "bos_token_id": 0, + "embedding_multiplier": 12.0, + "eos_token_id": 0, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12800, + "logits_scaling": 16.0, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "granite", + "num_attention_heads": 32, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "pad_token_id": 0, + "quantization_config": { + "config_groups": { + "group_0": { + "input_activations": { + "actorder": null, + "block_structure": null, + "dynamic": true, + "group_size": null, + "num_bits": 8, + "observer": null, + "observer_kwargs": {}, + "strategy": "token", + "symmetric": true, + "type": "float" + }, + "output_activations": null, + "targets": [ + "Linear" + ], + "weights": { + "actorder": null, + "block_structure": null, + "dynamic": false, + "group_size": null, + "num_bits": 8, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "channel", + "symmetric": true, + "type": "float" + } + } + }, + "format": "float-quantized", + "global_compression_ratio": null, + "ignore": [ + "lm_head" + ], + "kv_cache_scheme": null, + "quant_method": "compressed-tensors", + "quantization_status": "compressed" + }, + "residual_multiplier": 0.22, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000000.0, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.55.2", + "use_cache": true, + "vocab_size": 49159 +} diff --git a/tests/spyre_util.py b/tests/spyre_util.py index 788c5670f..a48c3af28 100644 --- a/tests/spyre_util.py +++ b/tests/spyre_util.py @@ -284,12 +284,17 @@ def _default_test_models(isEmbeddings=False, ] return params - # Full sized decoders - # The granite 8b fp8 model is not publicly available yet + # Full-size decoders granite = ModelInfo(name="ibm-granite/granite-3.3-8b-instruct", revision="51dd4bc2ade4059a6bd87649d68aa11e4fb2529b") + granite_fp8 = ModelInfo( + name="ibm-granite/granite-3.3-8b-instruct-FP8", + revision="4b5990b8d402a75febe0086abbf1e490af494e3d") params = [ pytest.param(granite, marks=[pytest.mark.decoder], id=granite.name), + pytest.param(granite_fp8, + marks=[pytest.mark.decoder, pytest.mark.quantized], + id=granite_fp8.name) ] return params diff --git a/tests/utils/test_model_config_validator.py b/tests/utils/test_model_config_validator.py index f86d8f0d6..bc201ad20 100644 --- a/tests/utils/test_model_config_validator.py +++ b/tests/utils/test_model_config_validator.py @@ -186,10 +186,6 @@ def test_find_model_by_config(monkeypatch, caplog): for model_id in get_supported_models_list(): - # TODO: get the actual FP8 model config - if "-FP8" in model_id: - continue - model_config_dir = model_configs_dir / model_id model_config_file = model_config_dir / "config.json" @@ -198,7 +194,7 @@ def test_find_model_by_config(monkeypatch, caplog): f" Use download_model_configs.py to download it.") if env.get("HF_HUB_OFFLINE", "0") == "0": - # it takes about 3 sec per model to load config from HF: + # it takes up to 3 sec per model to load config from HF: # vllm.config.ModelConfig.__post_init__(): # model_info, arch = self.registry.inspect_model_cls(...) model_config = ModelConfig(model=str(model_config_dir)) @@ -212,10 +208,15 @@ def test_find_model_by_config(monkeypatch, caplog): assert model_config.model != model_id models_found = find_known_models_by_model_config(model_config) - assert len(models_found) == 1, \ - (f"More than one model found. Need to add more distinguishing" + assert len(models_found) > 0, \ + (f"Could not find any known models that match the ModelConfig" + f" for model `{model_id}`. Update the entry for `{model_id}`" + f" in `vllm_spyre/config/known_model_configs.json` so that its" + f" parameters are a subset of those in `{model_config_file}`.") + assert len(models_found) < 2, \ + (f"More than one model found. Add more distinguishing" f" parameters for models `{models_found}` in file" - f" `vllm_spyre/config/known_model_configs.json`") + f" `vllm_spyre/config/known_model_configs.json`!") assert models_found[0] == model_id validate(model_config) diff --git a/vllm_spyre/config/known_model_configs.json b/vllm_spyre/config/known_model_configs.json index 5feb84ee6..7400d1f8d 100644 --- a/vllm_spyre/config/known_model_configs.json +++ b/vllm_spyre/config/known_model_configs.json @@ -14,13 +14,11 @@ }, "ibm-granite/granite-3.3-8b-instruct": { "model_type": "granite", - "attention_dropout": 0.0, "vocab_size": 49159 }, "ibm-granite/granite-3.3-8b-instruct-FP8": { "model_type": "granite", - "attention_dropout": 0.1, - "vocab_size": 49155, + "vocab_size": 49159, "quantization_config": { "format": "float-quantized" } diff --git a/vllm_spyre/config/runtime_config_validator.py b/vllm_spyre/config/runtime_config_validator.py index cd8288d56..6645712f4 100644 --- a/vllm_spyre/config/runtime_config_validator.py +++ b/vllm_spyre/config/runtime_config_validator.py @@ -4,6 +4,7 @@ from typing import Any import yaml +from pandas.io.json._normalize import nested_to_record as flatten from vllm.config import ModelConfig from vllm.logger import init_logger @@ -137,16 +138,31 @@ def is_power_of_2(n: int) -> bool: def find_known_models_by_model_config(model_config: ModelConfig) -> list[str]: + """ + Try to find a supported model by comparing the requested model config to + the known model configurations. The known model configurations file only + contains a minimal subset of model config parameters to distinguish + between the supported models. + """ if known_model_configs is None: initialize_known_model_configurations_from_file() requested_config = model_config.hf_config.__dict__ \ if model_config.hf_config else {} + # remove sub-dicts with integers as keys so we can flatten dictionaries + requested_config.pop("id2label", None) + + # don't return quantized models if the requested config doesn't have it + def is_quantized(config: dict) -> bool: + return "quantization_config" in config + matching_models = [ model for model, config in (known_model_configs or {}).items() - if config.items() <= requested_config.items() + if flatten(config).items() <= flatten(requested_config).items() and ( + is_quantized(config) == is_quantized(requested_config)) ] + return matching_models