@@ -37,14 +37,38 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3737 dtype : torch .dtype , kv_cache_dtype : Optional [str ],
3838 block_size : int , use_v1 : bool , use_mla : bool ,
3939 has_sink : bool ) -> str :
40- if selected_backend is not None and selected_backend != _Backend .IPEX :
41- logger .info ("Cannot use %s backend on XPU." , selected_backend )
4240 use_v1 = envs .VLLM_USE_V1
4341 if not use_v1 :
4442 raise ValueError ("XPU backend only supports V1." )
43+ TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
44+ FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
45+ if selected_backend == _Backend .TRITON_ATTN_VLLM_V1 :
46+ logger .info_once ("Using Triton backend on V1 engine." )
47+ return TRITON_ATTN_VLLM_V1
48+ elif selected_backend == _Backend .FLASH_ATTN :
49+ logger .info_once ("Using Flash Attention backend on V1 engine." )
50+ return FLASH_ATTN_V1
51+ elif selected_backend :
52+ raise ValueError (
53+ f"Invalid attention backend for { cls .device_name } , "
54+ f"with use_v1: { use_v1 } use_mla: { use_mla } " )
55+
4556 logger .info ("Using Flash Attention backend on V1 engine." )
4657 return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
4758
59+ @classmethod
60+ def is_kv_cache_dtype_supported (cls , kv_cache_dtype : str ,
61+ model_config : "ModelConfig" ) -> bool :
62+ """
63+ Check if the kv_cache_dtype is supported.
64+ XPU only support fp8 kv cache with triton backend.
65+ """
66+ if envs .is_set ("VLLM_ATTENTION_BACKEND" ) and \
67+ envs .VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1" :
68+ return kv_cache_dtype in ["fp8_e4m3" , "fp8_e5m2" , "fp8" ]
69+
70+ return False
71+
4872 @classmethod
4973 def set_device (cls , device : torch .device ) -> None :
5074 """
0 commit comments