Skip to content

Commit 8e07445

Browse files
committed
up
1 parent 4af534b commit 8e07445

File tree

4 files changed

+239
-48
lines changed

4 files changed

+239
-48
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
convert_ltx_vae_checkpoint_to_diffusers,
4343
convert_lumina2_to_diffusers,
4444
convert_mochi_transformer_checkpoint_to_diffusers,
45+
convert_nunchaku_flux_to_diffusers,
4546
convert_sana_transformer_to_diffusers,
4647
convert_sd3_transformer_checkpoint_to_diffusers,
4748
convert_stable_cascade_unet_single_file_to_diffusers,
@@ -341,18 +342,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
341342
user_agent=user_agent,
342343
)
343344
if quantization_config is not None:
344-
# For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`.
345-
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
346-
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
347-
quantization_config, checkpoint
348-
)
349-
if determined_modules_to_not_convert:
350-
original_modules_to_not_convert.extend(determined_modules_to_not_convert)
351-
original_modules_to_not_convert = list(set(original_modules_to_not_convert))
352-
logger.info(
353-
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {original_modules_to_not_convert}."
354-
)
355-
quantization_config.modules_to_not_convert = original_modules_to_not_convert
356345
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
357346
hf_quantizer.validate_environment()
358347
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
@@ -433,9 +422,14 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
433422
model = cls.from_config(diffusers_model_config)
434423

435424
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
436-
if not (
437-
quantization_config is not None and quantization_config.quant_method == "nunchaku"
438-
) and _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
425+
model_state_dict = model.state_dict()
426+
# TODO: Only flux nunchaku checkpoint for now. Unify with how checkpoint mappers are done.
427+
# For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`.
428+
if quantization_config is not None and quantization_config.quant_method == "nunchaku":
429+
diffusers_format_checkpoint = convert_nunchaku_flux_to_diffusers(
430+
checkpoint, model_state_dict=model_state_dict
431+
)
432+
elif _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
439433
diffusers_format_checkpoint = checkpoint_mapping_fn(
440434
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
441435
)
@@ -446,6 +440,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
446440
raise SingleFileComponentError(
447441
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
448442
)
443+
444+
# This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect.
445+
if quantization_config is not None:
446+
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
447+
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
448+
quantization_config, checkpoint
449+
)
450+
if determined_modules_to_not_convert:
451+
determined_modules_to_not_convert.extend(original_modules_to_not_convert)
452+
determined_modules_to_not_convert = list(set(determined_modules_to_not_convert))
453+
logger.info(
454+
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}."
455+
)
456+
quantization_config.modules_to_not_convert = original_modules_to_not_convert
457+
# Update the `quant_config`.
458+
hf_quantizer.quantization_config = quantization_config
459+
449460
# Check if `_keep_in_fp32_modules` is not None
450461
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
451462
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -473,6 +484,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
473484
unexpected_keys = [
474485
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
475486
]
487+
for k in unexpected_keys:
488+
if "single_transformer_blocks.0" in k:
489+
print(f"Unexpected {k=}")
490+
for k in empty_state_dict:
491+
if "single_transformer_blocks.0" in k:
492+
print(f"model {k=}")
476493
device_map = {"": param_device}
477494
load_model_dict_into_meta(
478495
model,

src/diffusers/loaders/single_file_utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,6 +2189,101 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
21892189
return converted_state_dict
21902190

21912191

2192+
# Adapted from https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395
2193+
def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs):
2194+
from .single_file_utils_nunchaku import _unpack_qkv_state_dict
2195+
2196+
_SMOOTH_ORIG_RE = re.compile(r"\.smooth_orig(\.|$)")
2197+
_SMOOTH_RE = re.compile(r"\.smooth(\.|$)")
2198+
2199+
new_state_dict = {}
2200+
model_state_dict = kwargs["model_state_dict"]
2201+
2202+
ckpt_keys = list(checkpoint.keys())
2203+
for k in ckpt_keys:
2204+
if "qweight" in k:
2205+
# only the shape information of this tensor is needed
2206+
v = checkpoint[k]
2207+
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
2208+
for t in ["lora_up", "lora_down"]:
2209+
new_k = k.replace(".qweight", f".{t}")
2210+
if new_k not in ckpt_keys:
2211+
oc, ic = v.shape
2212+
ic = ic * 2 # v is packed into INT8, so we need to double the size
2213+
checkpoint[k.replace(".qweight", f".{t}")] = torch.zeros(
2214+
(0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
2215+
)
2216+
2217+
for k, v in checkpoint.items():
2218+
new_k = k # start with original, then apply independent replacements
2219+
2220+
if k.startswith("single_transformer_blocks."):
2221+
# attention / qkv / norms
2222+
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
2223+
new_k = new_k.replace(".out_proj.", ".attn.to_out.")
2224+
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
2225+
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
2226+
2227+
# mlp heads
2228+
new_k = new_k.replace(".mlp_fc1.", ".proj_mlp.")
2229+
new_k = new_k.replace(".mlp_fc2.", ".proj_out.")
2230+
2231+
# smooth params (use regex to avoid substring collisions)
2232+
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
2233+
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
2234+
2235+
# lora -> proj
2236+
new_k = new_k.replace(".lora_down", ".proj_down")
2237+
new_k = new_k.replace(".lora_up", ".proj_up")
2238+
2239+
elif k.startswith("transformer_blocks."):
2240+
# feed-forward (context & base)
2241+
new_k = new_k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
2242+
new_k = new_k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
2243+
new_k = new_k.replace(".mlp_fc1.", ".ff.net.0.proj.")
2244+
new_k = new_k.replace(".mlp_fc2.", ".ff.net.2.")
2245+
2246+
# attention projections
2247+
new_k = new_k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
2248+
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
2249+
new_k = new_k.replace(".out_proj.", ".attn.to_out.0.")
2250+
new_k = new_k.replace(".out_proj_context.", ".attn.to_add_out.")
2251+
2252+
# norms
2253+
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
2254+
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
2255+
new_k = new_k.replace(".norm_added_k.", ".attn.norm_added_k.")
2256+
new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")
2257+
2258+
# smooth params
2259+
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
2260+
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
2261+
2262+
# lora -> proj
2263+
new_k = new_k.replace(".lora_down", ".proj_down")
2264+
new_k = new_k.replace(".lora_up", ".proj_up")
2265+
2266+
new_state_dict[new_k] = v
2267+
2268+
new_state_dict = _unpack_qkv_state_dict(new_state_dict)
2269+
2270+
# some remnant keys need to be patched
2271+
new_sd_keys = list(new_state_dict.keys())
2272+
for k in new_sd_keys:
2273+
if "qweight" in k:
2274+
no_qweight_k = ".".join(k.split(".qweight")[:-1])
2275+
for unexpected_k in ["wzeros"]:
2276+
unexpected_k = no_qweight_k + f".{unexpected_k}"
2277+
if unexpected_k in new_sd_keys:
2278+
_ = new_state_dict.pop(unexpected_k)
2279+
for k in model_state_dict:
2280+
if k not in new_state_dict:
2281+
# CPU device for now
2282+
new_state_dict[k] = torch.ones_like(k, device="cpu")
2283+
2284+
return new_state_dict
2285+
2286+
21922287
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
21932288
converted_state_dict = {}
21942289
keys = list(checkpoint.keys())
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import re
2+
3+
import torch
4+
5+
6+
_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj")
7+
_ALLOWED_SUFFIXES_NUNCHAKU = {
8+
"bias",
9+
"lora_down",
10+
"lora_up",
11+
"qweight",
12+
"smooth_factor",
13+
"smooth_factor_orig",
14+
"wscales",
15+
}
16+
17+
_QKV_NUNCHAKU_REGEX = re.compile(
18+
rf"^(?P<prefix>.*)\.(?:{'|'.join(map(re.escape, _QKV_ANCHORS_NUNCHAKU))})\.(?P<suffix>.+)$"
19+
)
20+
21+
22+
def _pick_split_dim(t: torch.Tensor, suffix: str) -> int:
23+
"""
24+
Choose which dimension to split by 3. Heuristics:
25+
- 1D -> dim 0
26+
- 2D -> prefer dim=1 for 'qweight' (common layout [*, 3*out_features]),
27+
otherwise prefer dim=0 (common layout [3*out_features, *]).
28+
- If preferred dim isn't divisible by 3, try the other; else error.
29+
"""
30+
shape = list(t.shape)
31+
if len(shape) == 0:
32+
raise ValueError("Cannot split a scalar into Q/K/V.")
33+
34+
if len(shape) == 1:
35+
dim = 0
36+
if shape[dim] % 3 == 0:
37+
return dim
38+
raise ValueError(f"1D tensor of length {shape[0]} not divisible by 3.")
39+
40+
# len(shape) >= 2
41+
preferred = 1 if suffix == "qweight" else 0
42+
other = 0 if preferred == 1 else 1
43+
44+
if shape[preferred] % 3 == 0:
45+
return preferred
46+
if shape[other] % 3 == 0:
47+
return other
48+
49+
# Fall back: any dim divisible by 3
50+
for d, s in enumerate(shape):
51+
if s % 3 == 0:
52+
return d
53+
54+
raise ValueError(f"None of the dims {shape} are divisible by 3 for suffix '{suffix}'.")
55+
56+
57+
def _split_qkv(t: torch.Tensor, dim: int):
58+
return torch.tensor_split(t, 3, dim=dim)
59+
60+
61+
def _unpack_qkv_state_dict(
62+
state_dict: dict, anchors=_QKV_ANCHORS_NUNCHAKU, allowed_suffixes=_ALLOWED_SUFFIXES_NUNCHAKU
63+
):
64+
"""
65+
Convert fused QKV entries (e.g., '...to_qkv.bias', '...qkv_proj.wscales') into separate Q/K/V entries:
66+
'...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales'
67+
Returns a NEW dict; original is not modified.
68+
69+
Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.
70+
"""
71+
anchors = tuple(anchors)
72+
allowed_suffixes = set(allowed_suffixes)
73+
74+
new_sd: dict = {}
75+
for k, v in state_dict.items():
76+
m = _QKV_NUNCHAKU_REGEX.match(k)
77+
if m:
78+
suffix = m.group("suffix")
79+
if suffix not in allowed_suffixes:
80+
# keep as-is if it's not one of the targeted suffixes
81+
new_sd[k] = v
82+
continue
83+
84+
prefix = m.group("prefix") # everything before .to_qkv/.qkv_proj
85+
# Decide split axis
86+
split_dim = _pick_split_dim(v, suffix)
87+
q, k_, vv = _split_qkv(v, dim=split_dim)
88+
89+
# Build new keys
90+
base_q = f"{prefix}.to_q.{suffix}"
91+
base_k = f"{prefix}.to_k.{suffix}"
92+
base_v = f"{prefix}.to_v.{suffix}"
93+
94+
# Write into result dict
95+
new_sd[base_q] = q
96+
new_sd[base_k] = k_
97+
new_sd[base_v] = vv
98+
else:
99+
# not a fused qkv key
100+
new_sd[k] = v
101+
102+
return new_sd

src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,7 @@
22

33
from diffusers.utils.import_utils import is_nunchaku_version
44

5-
from ...utils import (
6-
get_module_from_name,
7-
is_accelerate_available,
8-
is_nunchaku_available,
9-
is_torch_available,
10-
logging,
11-
)
5+
from ...utils import get_module_from_name, is_accelerate_available, is_nunchaku_available, is_torch_available, logging
126
from ...utils.torch_utils import is_fp8_available
137
from ..base import DiffusersQuantizer
148

@@ -20,15 +14,20 @@
2014
if is_torch_available():
2115
import torch
2216

23-
if is_accelerate_available():
24-
pass
25-
2617
if is_nunchaku_available():
2718
from .utils import replace_with_nunchaku_linear
2819

2920
logger = logging.get_logger(__name__)
3021

3122

23+
KEY_MAP = {
24+
"lora_down": "proj_down",
25+
"lora_up": "proj_up",
26+
"smooth_orig": "smooth_factor_orig",
27+
"smooth": "smooth_factor",
28+
}
29+
30+
3231
class NunchakuQuantizer(DiffusersQuantizer):
3332
r"""
3433
Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku)
@@ -98,33 +97,11 @@ def create_quantized_param(
9897
from nunchaku.models.linear import SVDQW4A4Linear
9998

10099
module, tensor_name = get_module_from_name(model, param_name)
101-
state_dict = args[0]
102100
if tensor_name not in module._parameters and tensor_name not in module._buffers:
103101
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
104102

105103
if isinstance(module, SVDQW4A4Linear):
106-
if param_value.ndim == 1:
107-
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(
108-
target_device
109-
)
110-
elif tensor_name == "qweight":
111-
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(
112-
target_device
113-
)
114-
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
115-
for t in ["lora_up", "lora_down"]:
116-
# need to check at the state dict level for this
117-
new_tensor_name = param_name.replace(".qweight", f".{t}")
118-
if new_tensor_name not in state_dict:
119-
oc, ic = param_value.shape
120-
ic = ic * 2 # v is packed into INT8, so we need to double the size
121-
module._parameters[t] = torch.zeros(
122-
(0, ic) if t == "lora_down" else (oc, 0), device=param_value.device, dtype=torch.bfloat16
123-
)
124-
else:
125-
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(
126-
target_device
127-
)
104+
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(target_device)
128105

129106
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
130107
max_memory = {key: val * 0.90 for key, val in max_memory.items()}

0 commit comments

Comments
 (0)