Skip to content

Commit 4af534b

Browse files
committed
up
1 parent df58c80 commit 4af534b

File tree

5 files changed

+200
-66
lines changed

5 files changed

+200
-66
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,23 @@ def _get_mapping_function_kwargs(mapping_fn, **kwargs):
190190
return mapping_kwargs
191191

192192

193+
def _maybe_determine_modules_to_not_convert(quantization_config, state_dict):
194+
if quantization_config is None:
195+
return None
196+
else:
197+
is_nunchaku = quantization_config.quant_method == "nunchaku"
198+
if not is_nunchaku:
199+
return None
200+
else:
201+
no_qweight = set()
202+
for key in state_dict:
203+
if key.endswith(".weight"):
204+
# module name is everything except the last piece after "."
205+
module_name = ".".join(key.split(".")[:-1])
206+
no_qweight.add(module_name)
207+
return sorted(no_qweight)
208+
209+
193210
class FromOriginalModelMixin:
194211
"""
195212
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
@@ -324,6 +341,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
324341
user_agent=user_agent,
325342
)
326343
if quantization_config is not None:
344+
# For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`.
345+
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
346+
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
347+
quantization_config, checkpoint
348+
)
349+
if determined_modules_to_not_convert:
350+
original_modules_to_not_convert.extend(determined_modules_to_not_convert)
351+
original_modules_to_not_convert = list(set(original_modules_to_not_convert))
352+
logger.info(
353+
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {original_modules_to_not_convert}."
354+
)
355+
quantization_config.modules_to_not_convert = original_modules_to_not_convert
327356
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
328357
hf_quantizer.validate_environment()
329358
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
@@ -404,8 +433,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
404433
model = cls.from_config(diffusers_model_config)
405434

406435
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
407-
408-
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
436+
if not (
437+
quantization_config is not None and quantization_config.quant_method == "nunchaku"
438+
) and _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
409439
diffusers_format_checkpoint = checkpoint_mapping_fn(
410440
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
411441
)

src/diffusers/models/model_loading_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ def load_model_dict_into_meta(
297297
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
298298
elif param_device == "cpu" and state_dict_index is not None:
299299
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
300+
# This check below might be a bit counter-intuitive in nature. This is because we're checking if the param
301+
# or its module is quantized and if so, we're proceeding with creating a quantized param. This is because
302+
# of the way pre-trained models are loaded. They're initialized under "meta" device, where
303+
# quantization layers are first injected. Hence, for a model that is either pre-quantized or supplemented
304+
# with a `quantization_config` during `from_pretrained`, we expect `check_if_quantized_param` to return True.
305+
# Then depending on the quantization backend being used, we run the actual quantization step under
306+
# `create_quantized_param`.
300307
elif is_quantized and (
301308
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
302309
):

src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
if is_torch_available():
2121
import torch
2222

23+
if is_accelerate_available():
24+
pass
25+
26+
if is_nunchaku_available():
27+
from .utils import replace_with_nunchaku_linear
2328

2429
logger = logging.get_logger(__name__)
2530

@@ -35,10 +40,6 @@ class NunchakuQuantizer(DiffusersQuantizer):
3540

3641
def __init__(self, quantization_config, **kwargs):
3742
super().__init__(quantization_config, **kwargs)
38-
dtype_map = {"int4": torch.int8}
39-
if is_fp8_available():
40-
dtype_map = {"nvfp4": torch.float8_e4m3fn}
41-
self.dtype_map = dtype_map
4243

4344
def validate_environment(self, *args, **kwargs):
4445
if not torch.cuda.is_available():
@@ -74,14 +75,13 @@ def check_if_quantized_param(
7475
state_dict: Dict[str, Any],
7576
**kwargs,
7677
):
77-
# TODO: revisit
78-
# Check if the param_name is not in self.modules_to_not_convert
79-
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
80-
return False
81-
else:
82-
# We only quantize the weight of nn.Linear
83-
module, _ = get_module_from_name(model, param_name)
84-
return isinstance(module, torch.nn.Linear)
78+
from nunchaku.models.linear import SVDQW4A4Linear
79+
80+
module, _ = get_module_from_name(model, param_name)
81+
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
82+
return True
83+
84+
return False
8585

8686
def create_quantized_param(
8787
self,
@@ -98,42 +98,33 @@ def create_quantized_param(
9898
from nunchaku.models.linear import SVDQW4A4Linear
9999

100100
module, tensor_name = get_module_from_name(model, param_name)
101+
state_dict = args[0]
101102
if tensor_name not in module._parameters and tensor_name not in module._buffers:
102103
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
103104

104-
if self.pre_quantized:
105-
if tensor_name in module._parameters:
106-
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
107-
if tensor_name in module._buffers:
108-
module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device))
109-
110-
elif isinstance(module, torch.nn.Linear):
111-
# TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
112-
# But we need to have a utility that can take a pretrained param value and quantize it. Not sure
113-
# how to do that yet.
114-
# Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
115-
# way to do it?
116-
is_param = tensor_name in module._parameters
117-
is_buffer = tensor_name in module._buffers
118-
new_module = SVDQW4A4Linear.from_linear(
119-
module, precision=self.quantization_config.precision, rank=self.quantization_config.rank
120-
)
121-
module_name = ".".join(param_name.split(".")[:-1])
122-
if "." in module_name:
123-
parent_name, leaf = module_name.rsplit(".", 1)
124-
parent = model.get_submodule(parent_name)
105+
if isinstance(module, SVDQW4A4Linear):
106+
if param_value.ndim == 1:
107+
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(
108+
target_device
109+
)
110+
elif tensor_name == "qweight":
111+
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(
112+
target_device
113+
)
114+
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
115+
for t in ["lora_up", "lora_down"]:
116+
# need to check at the state dict level for this
117+
new_tensor_name = param_name.replace(".qweight", f".{t}")
118+
if new_tensor_name not in state_dict:
119+
oc, ic = param_value.shape
120+
ic = ic * 2 # v is packed into INT8, so we need to double the size
121+
module._parameters[t] = torch.zeros(
122+
(0, ic) if t == "lora_down" else (oc, 0), device=param_value.device, dtype=torch.bfloat16
123+
)
125124
else:
126-
parent, leaf = model, module_name
127-
128-
# rebind
129-
# this will result into
130-
# AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
131-
if is_param:
132-
new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
133-
elif is_buffer:
134-
new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
135-
136-
setattr(parent, leaf, new_module)
125+
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(
126+
target_device
127+
)
137128

138129
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
139130
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
@@ -173,24 +164,25 @@ def _process_model_before_weight_loading(
173164
**kwargs,
174165
):
175166
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
176-
177167
if not isinstance(self.modules_to_not_convert, list):
178168
self.modules_to_not_convert = [self.modules_to_not_convert]
179-
180169
self.modules_to_not_convert.extend(keep_in_fp32_modules)
181-
182-
# TODO: revisit
183-
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
184-
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
185-
# keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
186-
# self.modules_to_not_convert.extend(keys_on_cpu)
187-
188170
# Purge `None`.
189171
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
190172
# in case of diffusion transformer models. For language models and others alike, `lm_head`
191173
# and tied modules are usually kept in FP32.
192174
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
193175

176+
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
177+
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
178+
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
179+
self.modules_to_not_convert.extend(keys_on_cpu)
180+
181+
model = replace_with_nunchaku_linear(
182+
model,
183+
modules_to_not_convert=self.modules_to_not_convert,
184+
quantization_config=self.quantization_config,
185+
)
194186
model.config.quantization_config = self.quantization_config
195187

196188
def _process_model_after_weight_loading(self, model, **kwargs):
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch.nn as nn
2+
3+
from ...utils import is_accelerate_available, is_nunchaku_available, logging
4+
5+
6+
if is_accelerate_available():
7+
from accelerate import init_empty_weights
8+
9+
10+
logger = logging.get_logger(__name__)
11+
12+
13+
def _replace_with_nunchaku_linear(
14+
model,
15+
svdq_linear_cls,
16+
modules_to_not_convert=None,
17+
current_key_name=None,
18+
quantization_config=None,
19+
has_been_replaced=False,
20+
):
21+
for name, module in model.named_children():
22+
if current_key_name is None:
23+
current_key_name = []
24+
current_key_name.append(name)
25+
26+
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
27+
# Check if the current key is not in the `modules_to_not_convert`
28+
current_key_name_str = ".".join(current_key_name)
29+
if not any(
30+
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
31+
):
32+
with init_empty_weights():
33+
in_features = module.in_features
34+
out_features = module.out_features
35+
36+
model._modules[name] = svdq_linear_cls(
37+
in_features,
38+
out_features,
39+
rank=quantization_config.rank,
40+
bias=module.bias is not None,
41+
torch_dtype=module.weight.dtype,
42+
)
43+
has_been_replaced = True
44+
# Store the module class in case we need to transpose the weight later
45+
model._modules[name].source_cls = type(module)
46+
# Force requires grad to False to avoid unexpected errors
47+
model._modules[name].requires_grad_(False)
48+
if len(list(module.children())) > 0:
49+
_, has_been_replaced = _replace_with_nunchaku_linear(
50+
module,
51+
svdq_linear_cls,
52+
modules_to_not_convert,
53+
current_key_name,
54+
quantization_config,
55+
has_been_replaced=has_been_replaced,
56+
)
57+
# Remove the last key for recursion
58+
current_key_name.pop(-1)
59+
return model, has_been_replaced
60+
61+
62+
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
63+
if is_nunchaku_available():
64+
from nunchaku.models.linear import SVDQW4A4Linear
65+
66+
model, _ = _replace_with_nunchaku_linear(
67+
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
68+
)
69+
70+
has_been_replaced = any(
71+
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
72+
)
73+
if not has_been_replaced:
74+
logger.warning(
75+
"You are loading your model in the SVDQuant method but no linear modules were found in your model."
76+
" Please double check your model architecture, or submit an issue on github if you think this is"
77+
" a bug."
78+
)
79+
80+
return model

src/diffusers/quantizers/quantization_config.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -733,36 +733,61 @@ class NunchakuConfig(QuantizationConfigMixin):
733733
loaded using `nunchaku`.
734734
735735
Args:
736-
TODO
736+
TODO
737737
modules_to_not_convert (`list`, *optional*, default to `None`):
738738
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
739-
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
739+
modules left in their original precision (e.g. `norm` layers in Qwen-Image).
740740
"""
741741

742-
group_size_map = {"int4": 64, "nvfp4": 16}
743-
744742
def __init__(
745743
self,
746-
precision: str = "int4",
744+
method: str = "svdquant",
745+
weight_dtype: str = "int4",
746+
weight_scale_dtype: str = None,
747+
weight_group_size: int = 64,
748+
activation_dtype: str = "int4",
749+
activation_scale_dtype: str = None,
750+
activation_group_size: int = 64,
747751
rank: int = 32,
748752
modules_to_not_convert: Optional[List[str]] = None,
749753
**kwargs,
750754
):
751755
self.quant_method = QuantizationMethod.NUNCHAKU
752-
self.precision = precision
756+
self.method = method
757+
self.weight_dtype = weight_dtype
758+
self.weight_scale_dtype = weight_scale_dtype
759+
self.weight_group_size = weight_group_size
760+
self.activation_dtype = activation_dtype
761+
self.activation_scale_dtype = activation_scale_dtype
762+
self.activation_group_size = activation_group_size
753763
self.rank = rank
754-
self.group_size = self.group_size_map[precision]
755764
self.modules_to_not_convert = modules_to_not_convert
756765

757766
self.post_init()
758767

759768
def post_init(self):
760769
r"""
761-
Safety checker that arguments are correct
770+
Safety checker that arguments are correct. Hardware checks were largely adapted from the official `nunchaku`
771+
codebase.
762772
"""
763-
accpeted_precision = ["int4", "nvfp4"]
764-
if self.precision not in accpeted_precision:
765-
raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}")
773+
from ..utils.torch_utils import get_device
774+
775+
device = get_device()
776+
if isinstance(device, str):
777+
device = torch.device(device)
778+
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
779+
sm = f"{capability[0]}{capability[1]}"
780+
if sm == "120": # you can only use the fp4 models
781+
if self.weight_dtype != "fp4_e2m1_all":
782+
raise ValueError('Please use "fp4" quantization for Blackwell GPUs.')
783+
elif sm in ["75", "80", "86", "89"]:
784+
if self.weight_dtype != "int4":
785+
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs.')
786+
else:
787+
raise ValueError(
788+
f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. "
789+
"Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration."
790+
)
766791

767792
# TODO: should there be a check for rank?
768793

0 commit comments

Comments
 (0)