-
Notifications
You must be signed in to change notification settings - Fork 109
Open
Description
你好,跑inference的时候,运行的时候,报错TypeError: flash_attn_func() got an unexpected keyword argument 'return_attn_probs'。
另外直接删除return_attn_probs这个参数之后,报错hidden_states 是tuple类型,需要修改如下:
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: huggingface/diffusers#12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
## TODO:修改
hidden_states = hidden_states[0]
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
但是最终跑出来的example里面的效果好像不太好,是因为啥
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels