2020if is_torch_available ():
2121 import torch
2222
23+ if is_accelerate_available ():
24+ pass
25+
26+ if is_nunchaku_available ():
27+ from .utils import replace_with_nunchaku_linear
2328
2429logger = logging .get_logger (__name__ )
2530
@@ -35,10 +40,6 @@ class NunchakuQuantizer(DiffusersQuantizer):
3540
3641 def __init__ (self , quantization_config , ** kwargs ):
3742 super ().__init__ (quantization_config , ** kwargs )
38- dtype_map = {"int4" : torch .int8 }
39- if is_fp8_available ():
40- dtype_map = {"nvfp4" : torch .float8_e4m3fn }
41- self .dtype_map = dtype_map
4243
4344 def validate_environment (self , * args , ** kwargs ):
4445 if not torch .cuda .is_available ():
@@ -74,14 +75,13 @@ def check_if_quantized_param(
7475 state_dict : Dict [str , Any ],
7576 ** kwargs ,
7677 ):
77- # TODO: revisit
78- # Check if the param_name is not in self.modules_to_not_convert
79- if any ((key + "." in param_name ) or (key == param_name ) for key in self .modules_to_not_convert ):
80- return False
81- else :
82- # We only quantize the weight of nn.Linear
83- module , _ = get_module_from_name (model , param_name )
84- return isinstance (module , torch .nn .Linear )
78+ from nunchaku .models .linear import SVDQW4A4Linear
79+
80+ module , _ = get_module_from_name (model , param_name )
81+ if self .pre_quantized and isinstance (module , SVDQW4A4Linear ):
82+ return True
83+
84+ return False
8585
8686 def create_quantized_param (
8787 self ,
@@ -98,42 +98,33 @@ def create_quantized_param(
9898 from nunchaku .models .linear import SVDQW4A4Linear
9999
100100 module , tensor_name = get_module_from_name (model , param_name )
101+ state_dict = args [0 ]
101102 if tensor_name not in module ._parameters and tensor_name not in module ._buffers :
102103 raise ValueError (f"{ module } does not have a parameter or a buffer named { tensor_name } ." )
103104
104- if self .pre_quantized :
105- if tensor_name in module ._parameters :
106- module ._parameters [tensor_name ] = torch .nn .Parameter (param_value .to (device = target_device ))
107- if tensor_name in module ._buffers :
108- module ._buffers [tensor_name ] = torch .nn .Parameter (param_value .to (target_device ))
109-
110- elif isinstance (module , torch .nn .Linear ):
111- # TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
112- # But we need to have a utility that can take a pretrained param value and quantize it. Not sure
113- # how to do that yet.
114- # Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
115- # way to do it?
116- is_param = tensor_name in module ._parameters
117- is_buffer = tensor_name in module ._buffers
118- new_module = SVDQW4A4Linear .from_linear (
119- module , precision = self .quantization_config .precision , rank = self .quantization_config .rank
120- )
121- module_name = "." .join (param_name .split ("." )[:- 1 ])
122- if "." in module_name :
123- parent_name , leaf = module_name .rsplit ("." , 1 )
124- parent = model .get_submodule (parent_name )
105+ if isinstance (module , SVDQW4A4Linear ):
106+ if param_value .ndim == 1 :
107+ module ._parameters [tensor_name ] = torch .nn .Parameter (param_value , requires_grad = False ).to (
108+ target_device
109+ )
110+ elif tensor_name == "qweight" :
111+ module ._parameters [tensor_name ] = torch .nn .Parameter (param_value , requires_grad = False ).to (
112+ target_device
113+ )
114+ # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
115+ for t in ["lora_up" , "lora_down" ]:
116+ # need to check at the state dict level for this
117+ new_tensor_name = param_name .replace (".qweight" , f".{ t } " )
118+ if new_tensor_name not in state_dict :
119+ oc , ic = param_value .shape
120+ ic = ic * 2 # v is packed into INT8, so we need to double the size
121+ module ._parameters [t ] = torch .zeros (
122+ (0 , ic ) if t == "lora_down" else (oc , 0 ), device = param_value .device , dtype = torch .bfloat16
123+ )
125124 else :
126- parent , leaf = model , module_name
127-
128- # rebind
129- # this will result into
130- # AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
131- if is_param :
132- new_module ._parameters [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
133- elif is_buffer :
134- new_module ._buffers [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
135-
136- setattr (parent , leaf , new_module )
125+ module ._parameters [tensor_name ] = torch .nn .Parameter (param_value , requires_grad = False ).to (
126+ target_device
127+ )
137128
138129 def adjust_max_memory (self , max_memory : Dict [str , Union [int , str ]]) -> Dict [str , Union [int , str ]]:
139130 max_memory = {key : val * 0.90 for key , val in max_memory .items ()}
@@ -173,24 +164,25 @@ def _process_model_before_weight_loading(
173164 ** kwargs ,
174165 ):
175166 self .modules_to_not_convert = self .quantization_config .modules_to_not_convert
176-
177167 if not isinstance (self .modules_to_not_convert , list ):
178168 self .modules_to_not_convert = [self .modules_to_not_convert ]
179-
180169 self .modules_to_not_convert .extend (keep_in_fp32_modules )
181-
182- # TODO: revisit
183- # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
184- # if isinstance(device_map, dict) and len(device_map.keys()) > 1:
185- # keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
186- # self.modules_to_not_convert.extend(keys_on_cpu)
187-
188170 # Purge `None`.
189171 # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
190172 # in case of diffusion transformer models. For language models and others alike, `lm_head`
191173 # and tied modules are usually kept in FP32.
192174 self .modules_to_not_convert = [module for module in self .modules_to_not_convert if module is not None ]
193175
176+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
177+ if isinstance (device_map , dict ) and len (device_map .keys ()) > 1 :
178+ keys_on_cpu = [key for key , value in device_map .items () if value in ["disk" , "cpu" ]]
179+ self .modules_to_not_convert .extend (keys_on_cpu )
180+
181+ model = replace_with_nunchaku_linear (
182+ model ,
183+ modules_to_not_convert = self .modules_to_not_convert ,
184+ quantization_config = self .quantization_config ,
185+ )
194186 model .config .quantization_config = self .quantization_config
195187
196188 def _process_model_after_weight_loading (self , model , ** kwargs ):
0 commit comments