Skip to content

Support FLUX OneTrainer LoRA formats (incl. DoRA) #7590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8b4f411
Add a test state dict for the OneTrainer DoRA format.
RyanJDick Jan 20, 2025
5bd6428
Add is_state_dict_likely_in_flux_onetrainer_format() util function.
RyanJDick Jan 21, 2025
faa4fa0
Expand unit tests to test for confusion between FLUX LoRA formats.
RyanJDick Jan 21, 2025
4f369e3
First draft of DoRALayer. Not tested yet.
RyanJDick Jan 21, 2025
dfa253e
Add utils for working with Kohya LoRA keys.
RyanJDick Jan 21, 2025
908976a
Add support for LyCoris-style LoRA keys in lora_model_from_flux_diffu…
RyanJDick Jan 22, 2025
7eee4da
Further updates to lora_model_from_flux_diffusers_state_dict() so tha…
RyanJDick Jan 22, 2025
206f261
Add utils for loading FLUX OneTrainer DoRA models.
RyanJDick Jan 22, 2025
409b69e
Fix typo in DoRALayer.
RyanJDick Jan 23, 2025
f4a0b78
Update FLUX invocations to support LoRAs that modify the T5 text enco…
RyanJDick Jan 23, 2025
1054283
Fix bug in FLUX T5 Koyha-style LoRA key parsing.
RyanJDick Jan 23, 2025
b8eed2b
Relax lora_layers_from_flux_diffusers_grouped_state_dict(...) so that…
RyanJDick Jan 23, 2025
0db6639
Add FLUX OneTrainer model probing.
RyanJDick Jan 23, 2025
5ea7953
Update GGMLTensor with ops necessary to work with ConcatenatedLoRALayer.
RyanJDick Jan 24, 2025
28514ba
Update ConcatenatedLoRALayer to work with all sub-layer types.
RyanJDick Jan 24, 2025
5d472ac
Move quantized weight handling for patch layers up from ConcatenatedL…
RyanJDick Jan 24, 2025
e7fb435
Update DoRALayer with a custom get_parameters() override that 1) appl…
RyanJDick Jan 24, 2025
7fef569
Update frontend graph building logic to support FLUX LoRAs that modif…
RyanJDick Jan 24, 2025
5357d6e
Rename ConcatenatedLoRALayer to MergedLayerPatch. And other minor cle…
RyanJDick Jan 24, 2025
6c919e1
Handle DoRA layer device casting when model is partially-loaded.
RyanJDick Jan 24, 2025
229834a
Performance optimizations for LoRAs applied on top of GGML-quantized …
RyanJDick Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions invokeai/app/invocations/flux_lora_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType

Expand All @@ -21,14 +21,17 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
t5_encoder: Optional[T5EncoderField] = OutputField(
default=None, description=FieldDescriptions.t5_encoder, title="T5 Encoder"
)


@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
Expand All @@ -50,6 +53,12 @@ class FluxLoRALoaderInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
Expand All @@ -62,6 +71,8 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.')
if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.')

output = FluxLoRALoaderOutput()

Expand All @@ -82,6 +93,14 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
weight=self.weight,
)
)
if self.t5_encoder is not None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)

return output

Expand All @@ -91,7 +110,7 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
Expand All @@ -113,6 +132,12 @@ class FLUXLoRACollectionLoader(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField | None = InputField(
default=None,
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
Expand Down Expand Up @@ -140,4 +165,9 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)

if self.t5_encoder is not None:
if output.t5_encoder is None:
output.t5_encoder = self.t5_encoder.model_copy(deep=True)
output.t5_encoder.loras.append(lora)

return output
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
version="1.0.5",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
Expand Down Expand Up @@ -87,7 +87,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
43 changes: 41 additions & 2 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo

Expand Down Expand Up @@ -71,13 +71,45 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
prompt = [self.prompt]

t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_encoder_config = t5_encoder_info.config
assert t5_encoder_config is not None

with (
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
t5_encoder_info.model_on_device() as (cached_weights, t5_text_encoder),
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))

# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
if t5_encoder_config.format in [ModelFormat.T5Encoder, ModelFormat.Diffusers]:
model_is_quantized = False
elif t5_encoder_config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
model_is_quantized = True
else:
raise ValueError(f"Unsupported model format: {t5_encoder_config.format}")

# Apply LoRA models to the T5 encoder.
# Note: We apply the LoRA after the encoder has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=t5_text_encoder,
patches=self._t5_lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX,
dtype=t5_text_encoder.dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)

t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)

context.util.signal_progress("Running T5 encoder")
Expand Down Expand Up @@ -132,3 +164,10 @@ def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[Mode
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

def _t5_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.t5_encoder.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
1 change: 1 addition & 0 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class CLIPField(BaseModel):
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")


class VAEField(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
CustomModuleMixin,
)
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer

Expand All @@ -22,25 +21,6 @@ def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight:
return x


def concatenated_lora_forward(
input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
) -> torch.Tensor:
"""An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer."""
x_chunks: list[torch.Tensor] = []
for lora_layer in concatenated_lora_layer.lora_layers:
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
if lora_layer.mid is not None:
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
x_chunk *= lora_weight * lora_layer.scale()
x_chunks.append(x_chunk)

# TODO(ryand): Generalize to support concat_axis != 0.
assert concatenated_lora_layer.concat_axis == 0
x = torch.cat(x_chunks, dim=-1)
return x


def autocast_linear_forward_sidecar_patches(
orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]]
) -> torch.Tensor:
Expand All @@ -66,8 +46,6 @@ def autocast_linear_forward_sidecar_patches(
output += linear_lora_forward(orig_input, patch, patch_weight)
elif isinstance(patch, LoRALayer):
output += linear_lora_forward(input, patch, patch_weight)
elif isinstance(patch, ConcatenatedLoRALayer):
output += concatenated_lora_forward(input, patch, patch_weight)
else:
unprocessed_patches_and_weights.append((patch, patch_weight))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


class CustomModuleMixin:
Expand Down Expand Up @@ -42,6 +44,20 @@ def _aggregate_patch_parameters(
device: torch.device | None = None,
):
"""Helper function that aggregates the parameters from all patches into a single dict."""
# HACK(ryand): If the original parameters are in a quantized format whose weights can't be accessed, we replace
# them with dummy tensors on the 'meta' device. This allows patch layers to access the shapes of the original
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_params.keys():
param = orig_params[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /
# dequantizations.
orig_params[param_name] = param.to(device=device).get_dequantized_tensor()
else:
orig_params[param_name] = torch.empty(get_param_shape(param), device="meta")

params: dict[str, torch.Tensor] = {}

for patch, patch_weight in patches_and_weights:
Expand Down
8 changes: 8 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format

Expand Down Expand Up @@ -84,8 +88,12 @@ def _load_model(
elif config.format == ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
Expand Down
6 changes: 5 additions & 1 deletion invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
Expand Down Expand Up @@ -283,7 +286,7 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return ModelType.LoRA
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
Expand Down Expand Up @@ -632,6 +635,7 @@ def get_format(self) -> ModelFormat:
def get_base_type(self) -> BaseModelType:
if (
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint)
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
or is_state_dict_likely_flux_control(self.checkpoint)
):
Expand Down
55 changes: 0 additions & 55 deletions invokeai/backend/patches/layers/concatenated_lora_layer.py

This file was deleted.

Loading