Skip to content

Commit 3295c6a

Browse files
committed
up
1 parent 8e07445 commit 3295c6a

File tree

5 files changed

+24
-32
lines changed

5 files changed

+24
-32
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
26+
from ..quantizers.quantization_config import NunchakuConfig
2627
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
2728
from ..utils.torch_utils import empty_device_cache
2829
from .single_file_utils import (
@@ -442,6 +443,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
442443
)
443444

444445
# This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect.
446+
# We can move it to a separate function as well.
445447
if quantization_config is not None:
446448
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
447449
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
@@ -450,12 +452,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
450452
if determined_modules_to_not_convert:
451453
determined_modules_to_not_convert.extend(original_modules_to_not_convert)
452454
determined_modules_to_not_convert = list(set(determined_modules_to_not_convert))
453-
logger.info(
455+
logger.debug(
454456
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}."
455457
)
456-
quantization_config.modules_to_not_convert = original_modules_to_not_convert
457-
# Update the `quant_config`.
458-
hf_quantizer.quantization_config = quantization_config
458+
modified_quant_config = quantization_config.to_dict()
459+
modified_quant_config["modules_to_not_convert"] = determined_modules_to_not_convert
460+
# TODO: figure out a better way.
461+
modified_quant_config = NunchakuConfig.from_dict(modified_quant_config)
462+
setattr(hf_quantizer, "quantization_config", modified_quant_config)
463+
logger.debug("TODO")
459464

460465
# Check if `_keep_in_fp32_modules` is not None
461466
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (

src/diffusers/loaders/single_file_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,7 +2220,7 @@ def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs):
22202220
if k.startswith("single_transformer_blocks."):
22212221
# attention / qkv / norms
22222222
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
2223-
new_k = new_k.replace(".out_proj.", ".attn.to_out.")
2223+
new_k = new_k.replace(".out_proj.", ".proj_out.")
22242224
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
22252225
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
22262226

@@ -2279,7 +2279,11 @@ def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs):
22792279
for k in model_state_dict:
22802280
if k not in new_state_dict:
22812281
# CPU device for now
2282-
new_state_dict[k] = torch.ones_like(k, device="cpu")
2282+
new_state_dict[k] = torch.ones_like(model_state_dict[k], device="cpu")
2283+
2284+
for k in new_state_dict:
2285+
if "single_transformer_blocks.0" in k and k.endswith(".weight"):
2286+
print(f"{k=}")
22832287

22842288
return new_state_dict
22852289

src/diffusers/loaders/single_file_utils_nunchaku.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj")
77
_ALLOWED_SUFFIXES_NUNCHAKU = {
88
"bias",
9-
"lora_down",
10-
"lora_up",
9+
"proj_down",
10+
"proj_up",
1111
"qweight",
1212
"smooth_factor",
1313
"smooth_factor_orig",
@@ -66,14 +66,16 @@ def _unpack_qkv_state_dict(
6666
'...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales'
6767
Returns a NEW dict; original is not modified.
6868
69-
Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.
69+
Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.:
7070
"""
7171
anchors = tuple(anchors)
7272
allowed_suffixes = set(allowed_suffixes)
7373

7474
new_sd: dict = {}
75-
for k, v in state_dict.items():
75+
sd_keys = list(state_dict.keys())
76+
for k in sd_keys:
7677
m = _QKV_NUNCHAKU_REGEX.match(k)
78+
v = state_dict.pop(k)
7779
if m:
7880
suffix = m.group("suffix")
7981
if suffix not in allowed_suffixes:

src/diffusers/quantizers/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656

5757
class DiffusersAutoQuantizer:
5858
"""
59-
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
59+
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
6060
`DiffusersQuantizer` given the `QuantizationConfig`.
6161
"""
6262

src/diffusers/quantizers/quantization_config.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -791,25 +791,6 @@ def post_init(self):
791791

792792
# TODO: should there be a check for rank?
793793

794-
# Copied from diffusers.quantizers.quantization_config.BitsAndBytesConfig.to_diff_dict with BitsAndBytesConfig->NunchakuConfig
795-
def to_diff_dict(self) -> Dict[str, Any]:
796-
"""
797-
Removes all attributes from config which correspond to the default config attributes for better readability and
798-
serializes to a Python dictionary.
799-
800-
Returns:
801-
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
802-
"""
794+
def __repr__(self):
803795
config_dict = self.to_dict()
804-
805-
# get the default config dict
806-
default_config_dict = NunchakuConfig().to_dict()
807-
808-
serializable_config_dict = {}
809-
810-
# only serialize values that differ from the default config
811-
for key, value in config_dict.items():
812-
if value != default_config_dict[key]:
813-
serializable_config_dict[key] = value
814-
815-
return serializable_config_dict
796+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"

0 commit comments

Comments
 (0)