diff --git a/vllm_spyre/envs.py b/vllm_spyre/envs.py index f2fb5fc9..fea970a2 100644 --- a/vllm_spyre/envs.py +++ b/vllm_spyre/envs.py @@ -24,6 +24,7 @@ VLLM_SPYRE_GLOO_TIMEOUT_MINUTES: int = 60 VLLM_SPYRE_REQUIRE_PRECOMPILED_DECODERS: bool = False VLLM_SPYRE_SIMPLE_COMPILE_BACKEND: str = "eager" + VLLM_SPYRE_NUM_CPUS: int = 0 logger = init_logger(__name__) @@ -151,6 +152,12 @@ def _backend_backwards_compat() -> str: # are available. "VLLM_SPYRE_SIMPLE_COMPILE_BACKEND": lambda: os.getenv("VLLM_SPYRE_SIMPLE_COMPILE_BACKEND", "eager"), + + # Configures the number of CPUs used when determining multi-threading + # configurations + # Set to 0 to have vllm-spyre attempt to detect the CPU count + "VLLM_SPYRE_NUM_CPUS": + lambda: int(os.getenv("VLLM_SPYRE_NUM_CPUS", "0")), } # --8<-- [end:env-vars-definition] diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 59a2f9f3..ec986f10 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -432,33 +432,54 @@ def _check_threading_config(cls, worker_count: int): # Try to determine the CPU time/cores that we are allocated cpu_count: float | None = None detection_message = "" - try: - # try to query cgroup CPU limits - with open('/sys/fs/cgroup/cpu.max') as f: - quota_str, period_str = f.read().strip().split() - - if quota_str != 'max': - quota = int(quota_str) - period = int(period_str) - cpu_count = quota / period - detection_message = f"Detected cgroup CPU limit of {cpu_count}" - - except FileNotFoundError: - # file may not exist if not running under cgroups v2 - pass - except Exception as e: - logger.debug( - "Error parsing /sys/fs/cgroup/cpu.max to get CPU info", - exc_info=e) - - # could try `nproc` here, but it is affected by - # OMP_NUM_THREADS itself - - # try os.cpu_count() to get node CPU count - if cpu_count is None and (cpu_count_res := os.cpu_count()) is not None: - cpu_count = float(cpu_count_res) - detection_message = \ - f"Detected {cpu_count} CPUs from `os.cpu_count()`" + + if (num_cpu := envs_spyre.VLLM_SPYRE_NUM_CPUS) > 0: + cpu_count = num_cpu + detection_message = f"VLLM_SPYRE_NUM_CPUS is set to {cpu_count}" + else: + try: + # try to query cgroup CPU limits + with open('/sys/fs/cgroup/cpu.max') as f: + quota_str, period_str = f.read().strip().split() + + if quota_str != 'max': + quota = int(quota_str) + period = int(period_str) + cpu_count = quota / period + detection_message = \ + f"Detected cgroup CPU limit of {cpu_count}" + + except FileNotFoundError: + # file may not exist if not running under cgroups v2 + pass + except Exception as e: + logger.debug( + "Error parsing /sys/fs/cgroup/cpu.max to get CPU info", + exc_info=e) + + # try psutil to get physical core count + if cpu_count is None: + try: + import psutil + cpu_count = float(psutil.cpu_count(logical=False)) + detection_message = \ + f"Detected {cpu_count} physical CPUs from " \ + "psutil.cpu_count(logical=False)" + except ImportError: + logger.info("Install psutil to count physical CPU cores") + pass + except Exception as e: + logger.debug("Error using psutil", exc_info=e) + + # could try `nproc` here, but it is affected by + # OMP_NUM_THREADS itself + + # try os.cpu_count() to get node CPU count + if cpu_count is None and (cpu_count_res := + os.cpu_count()) is not None: + cpu_count = float(cpu_count_res) + detection_message = \ + f"Detected {cpu_count} CPUs from `os.cpu_count()`" # NOTE: math.ceil can output a number for each worker that sums # to a total greater than cpu_count. @@ -474,9 +495,9 @@ def _check_threading_config(cls, worker_count: int): if envs_spyre.VLLM_SPYRE_UPDATE_THREAD_CONFIG: if cpus_per_worker is None: raise RuntimeError( - f"{failed_detection_message} Use " - "VLLM_SPYRE_UPDATE_THREAD_CONFIG=0 and configure manually." - ) + f"{failed_detection_message} Set VLLM_SPYRE_NUM_CPUS or " + "use VLLM_SPYRE_UPDATE_THREAD_CONFIG=0 and configure " + "manually.") for env in THREADING_ENVS: os.environ[env] = str(cpus_per_worker) @@ -518,4 +539,4 @@ def get_max_output_tokens(self, prompt_len: int) -> int: if prompt_len <= shape['prompt_length']: max_new_tokens = max(max_new_tokens, shape['new_tokens']) - return max_new_tokens \ No newline at end of file + return max_new_tokens