Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/diffusers/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi

::: mindone.diffusers.loaders.lora_pipeline.QwenImageLoraLoaderMixin

::: mindone.diffusers.loaders.lora_pipeline.KandinskyLoraLoaderMixin

::: mindone.diffusers.loaders.lora_base.LoraBaseMixin
4 changes: 4 additions & 0 deletions mindone/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
Expand Down Expand Up @@ -234,6 +235,7 @@
"KandinskyV22PriorPipeline",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"Kandinsky5T2VPipeline",
"KolorsPAGPipeline",
"KolorsPipeline",
"KolorsImg2ImgPipeline",
Expand Down Expand Up @@ -475,6 +477,7 @@
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
Expand Down Expand Up @@ -609,6 +612,7 @@
IFSuperResolutionPipeline,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
Kandinsky5T2VPipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
Expand Down
2 changes: 2 additions & 0 deletions mindone/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
],
Expand All @@ -97,6 +98,7 @@ def text_encoder_attn_modules(text_encoder):
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Expand Down
262 changes: 262 additions & 0 deletions mindone/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4667,6 +4667,268 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)


class KandinskyLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
"""

_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME

@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, ms.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.

Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted on the Hub.
- A path to a *directory* containing the model weights.
- A [mindspore state dict].

cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
use_safetensors (`bool`, *optional*):
Whether to use safetensors for loading.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata.
"""
# Load the main state dict first which has the LoRA layers
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." # noqa
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

out = (state_dict, metadata) if return_lora_metadata else state_dict
return out

def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, ms.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`

Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
hotswap (`bool`, *optional*):
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
kwargs (`dict`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
"""
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

# Load LoRA into transformer
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
hotswap=hotswap,
)

@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
hotswap: bool = False,
metadata=None,
):
"""
Load the LoRA layers specified in `state_dict` into `transformer`.

Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters.
transformer (`Kandinsky5Transformer3DModel`):
The transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
hotswap (`bool`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata.
"""
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
hotswap=hotswap,
)

@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[ms.nn.Cell, ms.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the transformer and text encoders.

Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to.
transformer_lora_layers (`Dict[str, ms.nn.Cell]` or `Dict[str, ms.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process.
save_function (`Callable`):
The function to use to save the state dictionary.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer.
"""
lora_layers = {}
lora_metadata = {}

if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata

if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers`")

cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.

Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing.

Example:
```py
from mindone.diffusers import Kandinsky5T2VPipeline

pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
pipeline.load_lora_weights("path/to/lora.safetensors")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of [`pipe.fuse_lora()`].

Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components, **kwargs)


class WanLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
Expand Down
2 changes: 2 additions & 0 deletions mindone/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"transformers.transformer_hidream_image": ["HiDreamImageTransformer2DModel"],
"transformers.transformer_hunyuan_video": ["HunyuanVideoTransformer3DModel"],
"transformers.transformer_hunyuan_video_framepack": ["HunyuanVideoFramepackTransformer3DModel"],
"transformers.transformer_kandinsky": ["Kandinsky5Transformer3DModel"],
"transformers.transformer_ltx": ["LTXVideoTransformer3DModel"],
"transformers.transformer_lumina2": ["Lumina2Transformer2DModel"],
"transformers.transformer_mochi": ["MochiTransformer3DModel"],
Expand Down Expand Up @@ -159,6 +160,7 @@
HunyuanDiT2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
Expand Down
Loading