Skip to content

Fused QKV projections incompatible with training #11903

Open
@bghira

Description

@bghira

Describe the bug

I've enabled fused qkv projections in SimpleTuner, but it took quite a bit of investigation and effort.

  1. any PEFT LoRAs become fused as well. we have to adjust the lora_target to include to_qkv instead of the split target layer names.
  2. the fuse_qkv_projections method on the Attention class does not delete the original qkv split layers, wasting VRAM
  • the remaining qkv split layers can be inadvertently used for training, targeted by PEFT LoRA or EMA Model
  1. the unfuse_qkv_projections method on the Attention class actually doesn't do what one would expect, the weights aren't copied back from the fused layer into the splits. it merely just marks the fusion as disabled and swaps attn processors back.
  2. EMAModel actually works perfectly fine with fused QKV projection training, it requires no modification.

Reproduction

The concept is a bit complex for supplying a minimal reproducer, but we can select individual aspects of this issue report to create new, more specific issue reports for, containing all of the relevant info needed to recreate.

Logs

System Info

Latest Diffusers main.

Who can help?

@a-r-r-o-w @DN6 @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions