Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
VLLM_SPYRE_GLOO_TIMEOUT_MINUTES: int = 60
VLLM_SPYRE_REQUIRE_PRECOMPILED_DECODERS: bool = False
VLLM_SPYRE_SIMPLE_COMPILE_BACKEND: str = "eager"
VLLM_SPYRE_NUM_CPUS: int = 0

logger = init_logger(__name__)

Expand Down Expand Up @@ -151,6 +152,12 @@ def _backend_backwards_compat() -> str:
# are available.
"VLLM_SPYRE_SIMPLE_COMPILE_BACKEND":
lambda: os.getenv("VLLM_SPYRE_SIMPLE_COMPILE_BACKEND", "eager"),

# Configures the number of CPUs used when determining multi-threading
# configurations
# Set to 0 to have vllm-spyre attempt to detect the CPU count
"VLLM_SPYRE_NUM_CPUS":
lambda: int(os.getenv("VLLM_SPYRE_NUM_CPUS", "0")),
}
# --8<-- [end:env-vars-definition]

Expand Down
83 changes: 52 additions & 31 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,33 +432,54 @@ def _check_threading_config(cls, worker_count: int):
# Try to determine the CPU time/cores that we are allocated
cpu_count: float | None = None
detection_message = ""
try:
# try to query cgroup CPU limits
with open('/sys/fs/cgroup/cpu.max') as f:
quota_str, period_str = f.read().strip().split()

if quota_str != 'max':
quota = int(quota_str)
period = int(period_str)
cpu_count = quota / period
detection_message = f"Detected cgroup CPU limit of {cpu_count}"

except FileNotFoundError:
# file may not exist if not running under cgroups v2
pass
except Exception as e:
logger.debug(
"Error parsing /sys/fs/cgroup/cpu.max to get CPU info",
exc_info=e)

# could try `nproc` here, but it is affected by
# OMP_NUM_THREADS itself

# try os.cpu_count() to get node CPU count
if cpu_count is None and (cpu_count_res := os.cpu_count()) is not None:
cpu_count = float(cpu_count_res)
detection_message = \
f"Detected {cpu_count} CPUs from `os.cpu_count()`"

if (num_cpu := envs_spyre.VLLM_SPYRE_NUM_CPUS) > 0:
cpu_count = num_cpu
detection_message = f"VLLM_SPYRE_NUM_CPUS is set to {cpu_count}"
else:
try:
# try to query cgroup CPU limits
with open('/sys/fs/cgroup/cpu.max') as f:
quota_str, period_str = f.read().strip().split()

if quota_str != 'max':
quota = int(quota_str)
period = int(period_str)
cpu_count = quota / period
detection_message = \
f"Detected cgroup CPU limit of {cpu_count}"

except FileNotFoundError:
# file may not exist if not running under cgroups v2
pass
except Exception as e:
logger.debug(
"Error parsing /sys/fs/cgroup/cpu.max to get CPU info",
exc_info=e)

# try psutil to get physical core count
if cpu_count is None:
try:
import psutil
cpu_count = float(psutil.cpu_count(logical=False))
detection_message = \
f"Detected {cpu_count} physical CPUs from " \
"psutil.cpu_count(logical=False)"
except ImportError:
logger.info("Install psutil to count physical CPU cores")
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we log here saying that psutil can be used if it's installed?

I'm also not opposed to adding it as a dependency, it seems very well maintained

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is well maintained, but is also OS-dependent compiled code. We could list out all of the platform_systems it does support, but 🤷. I wish there was a configuration like "try to install by default but don't stop the install if unable" 😅

I'll add a log message indicating the option to install it though!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, perhaps that's a better option since we have to support non x86 platforms.

except Exception as e:
logger.debug("Error using psutil", exc_info=e)

# could try `nproc` here, but it is affected by
# OMP_NUM_THREADS itself

# try os.cpu_count() to get node CPU count
if cpu_count is None and (cpu_count_res :=
os.cpu_count()) is not None:
cpu_count = float(cpu_count_res)
detection_message = \
f"Detected {cpu_count} CPUs from `os.cpu_count()`"

# NOTE: math.ceil can output a number for each worker that sums
# to a total greater than cpu_count.
Expand All @@ -474,9 +495,9 @@ def _check_threading_config(cls, worker_count: int):
if envs_spyre.VLLM_SPYRE_UPDATE_THREAD_CONFIG:
if cpus_per_worker is None:
raise RuntimeError(
f"{failed_detection_message} Use "
"VLLM_SPYRE_UPDATE_THREAD_CONFIG=0 and configure manually."
)
f"{failed_detection_message} Set VLLM_SPYRE_NUM_CPUS or "
"use VLLM_SPYRE_UPDATE_THREAD_CONFIG=0 and configure "
"manually.")

for env in THREADING_ENVS:
os.environ[env] = str(cpus_per_worker)
Expand Down Expand Up @@ -518,4 +539,4 @@ def get_max_output_tokens(self, prompt_len: int) -> int:
if prompt_len <= shape['prompt_length']:
max_new_tokens = max(max_new_tokens, shape['new_tokens'])

return max_new_tokens
return max_new_tokens