Skip to content

推理阶段问题 #27

@suhuijia

Description

@suhuijia

你好,跑inference的时候,运行的时候,报错TypeError: flash_attn_func() got an unexpected keyword argument 'return_attn_probs'。

另外直接删除return_attn_probs这个参数之后,报错hidden_states 是tuple类型,需要修改如下:

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: huggingface/diffusers#12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
## TODO:修改
hidden_states = hidden_states[0]

    hidden_states = hidden_states.flatten(2, 3)
    hidden_states = hidden_states.type_as(query)

但是最终跑出来的example里面的效果好像不太好,是因为啥

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions