2020if is_torch_available ():
2121 import torch
2222
23- if is_accelerate_available ():
24- pass
25-
26- if is_nunchaku_available ():
27- from .utils import replace_with_nunchaku_linear
2823
2924logger = logging .get_logger (__name__ )
3025
@@ -79,13 +74,14 @@ def check_if_quantized_param(
7974 state_dict : Dict [str , Any ],
8075 ** kwargs ,
8176 ):
82- from nunchaku .models .linear import SVDQW4A4Linear
83-
84- module , tensor_name = get_module_from_name (model , param_name )
85- if self .pre_quantized and isinstance (module , SVDQW4A4Linear ):
86- return True
87-
88- return False
77+ # TODO: revisit
78+ # Check if the param_name is not in self.modules_to_not_convert
79+ if any ((key + "." in param_name ) or (key == param_name ) for key in self .modules_to_not_convert ):
80+ return False
81+ else :
82+ # We only quantize the weight of nn.Linear
83+ module , _ = get_module_from_name (model , param_name )
84+ return isinstance (module , torch .nn .Linear )
8985
9086 def create_quantized_param (
9187 self ,
@@ -112,13 +108,32 @@ def create_quantized_param(
112108 module ._buffers [tensor_name ] = torch .nn .Parameter (param_value .to (target_device ))
113109
114110 elif isinstance (module , torch .nn .Linear ):
115- if tensor_name in module ._parameters :
116- module ._parameters [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
117- if tensor_name in module ._buffers :
118- module ._buffers [tensor_name ] = torch .nn .Parameter (param_value ).to (target_device )
119-
120- new_module = SVDQW4A4Linear .from_linear (module )
121- setattr (model , param_name , new_module )
111+ # TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
112+ # But we need to have a utility that can take a pretrained param value and quantize it. Not sure
113+ # how to do that yet.
114+ # Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
115+ # way to do it?
116+ is_param = tensor_name in module ._parameters
117+ is_buffer = tensor_name in module ._buffers
118+ new_module = SVDQW4A4Linear .from_linear (
119+ module , precision = self .quantization_config .precision , rank = self .quantization_config .rank
120+ )
121+ module_name = "." .join (param_name .split ("." )[:- 1 ])
122+ if "." in module_name :
123+ parent_name , leaf = module_name .rsplit ("." , 1 )
124+ parent = model .get_submodule (parent_name )
125+ else :
126+ parent , leaf = model , module_name
127+
128+ # rebind
129+ # this will result into
130+ # AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
131+ if is_param :
132+ new_module ._parameters [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
133+ elif is_buffer :
134+ new_module ._buffers [tensor_name ] = torch .nn .Parameter (param_value ).to (device = target_device )
135+
136+ setattr (parent , leaf , new_module )
122137
123138 def adjust_max_memory (self , max_memory : Dict [str , Union [int , str ]]) -> Dict [str , Union [int , str ]]:
124139 max_memory = {key : val * 0.90 for key , val in max_memory .items ()}
@@ -157,24 +172,25 @@ def _process_model_before_weight_loading(
157172 keep_in_fp32_modules : List [str ] = [],
158173 ** kwargs ,
159174 ):
160- # TODO: deal with `device_map`
161175 self .modules_to_not_convert = self .quantization_config .modules_to_not_convert
162176
163177 if not isinstance (self .modules_to_not_convert , list ):
164178 self .modules_to_not_convert = [self .modules_to_not_convert ]
165179
166180 self .modules_to_not_convert .extend (keep_in_fp32_modules )
181+
182+ # TODO: revisit
183+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
184+ # if isinstance(device_map, dict) and len(device_map.keys()) > 1:
185+ # keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
186+ # self.modules_to_not_convert.extend(keys_on_cpu)
187+
167188 # Purge `None`.
168189 # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
169190 # in case of diffusion transformer models. For language models and others alike, `lm_head`
170191 # and tied modules are usually kept in FP32.
171192 self .modules_to_not_convert = [module for module in self .modules_to_not_convert if module is not None ]
172193
173- model = replace_with_nunchaku_linear (
174- model ,
175- modules_to_not_convert = self .modules_to_not_convert ,
176- quantization_config = self .quantization_config ,
177- )
178194 model .config .quantization_config = self .quantization_config
179195
180196 def _process_model_after_weight_loading (self , model , ** kwargs ):
0 commit comments