13
13
from vllm .model_executor .layers .quantization .base_config import (
14
14
QuantizationConfig , QuantizeMethodBase )
15
15
from vllm .model_executor .utils import set_weight_attrs
16
+ from vllm .utils import is_torch_equal_or_newer
16
17
17
18
logger = init_logger (__name__ )
18
19
@@ -22,13 +23,18 @@ class TorchAOConfig(QuantizationConfig):
22
23
23
24
def __init__ (self , torchao_config ) -> None :
24
25
self .torchao_config = torchao_config
25
- # TODO (jerryzh168): enable after 2.8.0
26
26
# TorchAO quantization relies on tensor subclasses. In order,
27
27
# to enable proper caching this needs standalone compile
28
- # os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
29
- # logger.info("Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
30
- os .environ ["VLLM_DISABLE_COMPILE_CACHE" ] = "1"
31
- logger .info ("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1" )
28
+ if is_torch_equal_or_newer ("2.8.0" ):
29
+ os .environ ["VLLM_TEST_STANDALONE_COMPILE" ] = "1"
30
+ logger .info (
31
+ "Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1" )
32
+
33
+ # TODO: remove after the torch dependency is updated to 2.8
34
+ if is_torch_equal_or_newer (
35
+ "2.7.0" ) and not is_torch_equal_or_newer ("2.8.0" ):
36
+ os .environ ["VLLM_DISABLE_COMPILE_CACHE" ] = "1"
37
+ logger .info ("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1" )
32
38
33
39
def __repr__ (self ) -> str :
34
40
return f"TorchAOConfig({ self .torchao_config } )"
0 commit comments