diff --git a/vllm_spyre/envs.py b/vllm_spyre/envs.py index 62b4185a..bed01251 100644 --- a/vllm_spyre/envs.py +++ b/vllm_spyre/envs.py @@ -23,6 +23,7 @@ VLLM_SPYRE_WORKER_LOG_REDIRECT_DIR: str = "" VLLM_SPYRE_GLOO_TIMEOUT_MINUTES: int = 60 VLLM_SPYRE_REQUIRE_PRECOMPILED_DECODERS: bool = False + VLLM_SPYRE_SIMPLE_COMPILE_BACKEND: str = "eager" logger = init_logger(__name__) @@ -143,7 +144,14 @@ def _backend_backwards_compat() -> str: # disable compilation for decoders "VLLM_SPYRE_REQUIRE_PRECOMPILED_DECODERS": lambda: bool(int(os.getenv("VLLM_SPYRE_REQUIRE_PRECOMPILED_DECODERS", "0")) - ) + ), + + # Simple compile backend for some dynamically compiled operations, like + # gathering logprobs in the sampler. + # Defaults to eager, iductor can be used if python headers and a compiler + # are available. + "VLLM_SPYRE_SIMPLE_COMPILE_BACKEND": + lambda: os.getenv("VLLM_SPYRE_SIMPLE_COMPILE_BACKEND", "eager"), } # --8<-- [end:env-vars-definition] diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 801cb2d0..3950f832 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -65,9 +65,9 @@ class SpyrePlatform(Platform): _num_spyre_blocks_override: int = -1 # override num of KV cache blocks _config: VllmConfig = None - # TODO: see if this needs to be set + # Backend for dynamic compilation ops # See vllm batched_count_greater_than method - # simple_compile_backend: str = "eager" + simple_compile_backend: str = envs_spyre.VLLM_SPYRE_SIMPLE_COMPILE_BACKEND # Needed by vllm/model_executor/layers/pooler.py:562 current_stream = lambda _: _StreamPlaceholder()