diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 15516ed2ed60..bdb9c4861a76 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1106,6 +1106,51 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _npu_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + out = npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + return out + + +# Not implemented yet. +def _npu_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") + + # ===== Context parallel ===== @@ -2131,6 +2176,7 @@ def _native_math_attention( @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_NPU, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _native_npu_attention( query: torch.Tensor, @@ -2146,22 +2192,36 @@ def _native_npu_attention( raise ValueError("`attn_mask` is not supported for NPU attention") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) - out = npu_fusion_attention( - query, - key, - value, - query.size(1), # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0 - dropout_p, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() + if _parallel_config is None: + out = npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + None, + scale, + None, + return_lse, + forward_op=_npu_attention_forward_op, + backward_op=_npu_attention_backward_op, + _parallel_config=_parallel_config, + ) return out