Skip to content

[Misc] Split monolithic config.py into domain-specific modules #18830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, _get_and_verify_dtype
from vllm.config.model_config import TaskOption, _get_and_verify_dtype
from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
Expand Down
4,610 changes: 0 additions & 4,610 deletions vllm/config.py

This file was deleted.

201 changes: 201 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, TypeVar

from vllm.config.cache_config import (BlockSize, CacheConfig, CacheDType,
PrefixCachingHashAlgo)
from vllm.config.compilation_config import CompilationConfig, CompilationLevel
from vllm.config.decoding_config import (DecodingConfig, GuidedDecodingBackend,
GuidedDecodingBackendV1)
from vllm.config.device_config import Device, DeviceConfig
from vllm.config.kvevents_config import KVEventsConfig
from vllm.config.kvtransformer_config import KVTransferConfig
from vllm.config.load_config import LoadConfig, LoadFormat
from vllm.config.lora_config import LoRAConfig
from vllm.config.model_config import (ConfigFormat, HfOverrides, ModelConfig,

Check failure on line 17 in vllm/config/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F401)

vllm/config/__init__.py:17:39: F401 `vllm.config.model_config.ConfigFormat` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
ModelDType, ModelImpl, TaskOption,
TokenizerMode)
from vllm.config.multimodal_config import MultiModalConfig
from vllm.config.obervability_config import (DetailedTraceModules,
ObservabilityConfig)
from vllm.config.parallel_config import (DistributedExecutorBackend,
ParallelConfig)
from vllm.config.pass_config import PassConfig
from vllm.config.pooler_config import PoolerConfig
from vllm.config.promptadapter_config import PromptAdapterConfig
from vllm.config.scheduler_config import SchedulerConfig, SchedulerPolicy
from vllm.config.speculative_config import SpeculativeConfig
from vllm.config.tokenizerpool_config import TokenizerPoolConfig
from vllm.config.utils import get_field # noqa
from vllm.config.utils import config, get_attr_docs, is_init_field # noqa
from vllm.config.vllm_config import SupportsMetricsInfo, VllmConfig
from vllm.logger import init_logger

if TYPE_CHECKING:
from _typeshed import DataclassInstance

from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
ConfigType = type[DataclassInstance]
else:
QuantizationConfig = Any
ConfigType = type

logger = init_logger(__name__)

_current_vllm_config: Optional[VllmConfig] = None


@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
"""
Temporarily set the current vLLM config.
Used during model initialization.
We save the current vLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the vLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
except Exception:
raise
else:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if check_compile and \
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
" if you want it to be supported.",
vllm_config.model_config.model)
finally:
_current_vllm_config = old_vllm_config


def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current vLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config


T = TypeVar("T")


def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}


__all__ = [
# Cache config
"CacheConfig",
"BlockSize",
"CacheDType",
"PrefixCachingHashAlgo",

# Compilation config
"CompilationConfig",
"CompilationLevel",

# Decoding config
"DecodingConfig",
"GuidedDecodingBackend",
"GuidedDecodingBackendV1",

# Device config
"DeviceConfig",
"Device",

# KV events config
"KVEventsConfig",

# KV transformer config
"KVTransferConfig",

# Load config
"LoadConfig",
"LoadFormat",

# LoRA config
"LoRAConfig",

# Model config
"ModelConfig",
"TaskOption",
"TokenizerMode",
"ModelDType",
"ModelImpl",
"HfOverrides",

# Multimodal config
"MultiModalConfig",

# Pooler config
"PoolerConfig",

# Observability config
"ObservabilityConfig",
"DetailedTraceModules",

# Parallel config
"ParallelConfig",
"DistributedExecutorBackend",

# Pass config
"PassConfig",

# Prompt adapter config
"PromptAdapterConfig",

# Scheduler config
"SchedulerConfig",
"SchedulerPolicy",

# Speculative config
"SpeculativeConfig",

#Tokenizerpool config
"TokenizerPoolConfig",

# vLLM config
"VllmConfig",
"SupportsMetricsInfo",

# Others
"set_current_vllm_config",
"get_current_vllm_config",
"get_layers_from_vllm_config",
"get_field",
"config",
"get_attr_docs"
"is_init_field",
"ConfigFormat"
"ConfigType",
"QuantizationConfig"
]
172 changes: 172 additions & 0 deletions vllm/config/cache_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal, Optional, get_args

from pydantic import SkipValidation
from pydantic.dataclasses import dataclass

import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, get_cpu_memory

if TYPE_CHECKING:
from vllm.config.parallel_config import ParallelConfig

logger = init_logger(__name__)

BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"]


@config
@dataclass
class CacheConfig:
"""Configuration for the KV cache."""

block_size: SkipValidation[BlockSize] = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.

This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_configs()` based on the current
platform."""
gpu_memory_utilization: float = 0.9
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
utilization. If unspecified, will use the default value of 0.9. This is a
per-instance limit, and only applies to the current vLLM instance. It does
not matter if you have another vLLM instance running on the same GPU. For
example, if you have two vLLM instances running on the same GPU, you can
set the GPU memory utilization to 0.5 for each instance."""
swap_space: float = 4
"""Size of the CPU swap space per GPU (in GiB)."""
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3)."""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
num_gpu_blocks_override: Optional[int] = None
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
if specified. Does nothing if `None`. Used for testing preemption."""
sliding_window: Optional[int] = None
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: Optional[bool] = None
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
default for V1."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n
- "sha256" is collision resistant but with certain overheads."""
cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
increase the GPU memory size. For example, if you have one 24 GB GPU and
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
"""
calculate_kv_scales: bool = False
"""This enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
checkpoint if available. Otherwise, the scales will default to 1.0."""

# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for GPU memory."""
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.

Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = []
factors.append(self.cache_dtype)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str

def __post_init__(self) -> None:
self.swap_space_bytes = self.swap_space * GiB_bytes

self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()

def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}

def _verify_args(self) -> None:
if self.cpu_offload_gb < 0:
raise ValueError("CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")

if self.gpu_memory_utilization > 1.0:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")

def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in get_args(CacheDType):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

def _verify_prefix_caching(self) -> None:
if not self.enable_prefix_caching:
return

if self.sliding_window is not None and not envs.VLLM_USE_V1:
raise NotImplementedError(
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")

if (self.enable_prefix_caching and self.prefix_caching_hash_algo
not in get_args(PrefixCachingHashAlgo)):
raise ValueError(
"Unknown prefix caching hash algorithm: "
f"{self.prefix_caching_hash_algo}. Must be one of "
f"{get_args(PrefixCachingHashAlgo)}.")

def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_cpu_memory = get_cpu_memory()
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
# group are in the same node. However, the GPUs may span multiple nodes.
num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
"is allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. %s", msg)
Loading
Loading