|
15 | 15 | import logging |
16 | 16 | from collections import OrderedDict |
17 | 17 | from copy import deepcopy |
18 | | -from typing import Dict, List, Optional |
| 18 | +from typing import List |
19 | 19 | from typing import OrderedDict as OrderedDictType |
20 | 20 | from typing import Union |
21 | 21 |
|
|
36 | 36 | ) |
37 | 37 | from compressed_tensors.utils.helpers import replace_module |
38 | 38 | from compressed_tensors.utils.match import match_named_modules, match_targets |
39 | | -from compressed_tensors.utils.offload import update_parameter_data |
40 | | -from safetensors import safe_open |
41 | 39 | from torch.nn import Module |
42 | 40 |
|
43 | 41 |
|
@@ -149,50 +147,6 @@ def process_kv_cache_config( |
149 | 147 | return config |
150 | 148 |
|
151 | 149 |
|
152 | | -def _load_quant_args_from_mapping( |
153 | | - base_name: str, module_name: str, module: Module, mapping: Dict |
154 | | -): |
155 | | - # TODO: skip update and just register here, don't do it in initialize |
156 | | - """ |
157 | | - Loads scale and zero point from a state_dict into the specified module |
158 | | -
|
159 | | - :param base_name: quantization target, one of: weights, input_activations or |
160 | | - output_activations |
161 | | - :param module_name: pytorch module name to look up in state_dict |
162 | | - :module: pytorch module associated with module_name |
163 | | - :mapping: mapping to search fetch paths on disk for a given parameter |
164 | | - """ |
165 | | - scale_name = f"{base_name}_scale" |
166 | | - zp_name = f"{base_name}_zero_point" |
167 | | - g_idx_name = f"{base_name}_g_idx" |
168 | | - |
169 | | - state_dict_scale_path = mapping.get(f"{module_name}.{scale_name}", None) |
170 | | - state_dict_zp_path = mapping.get(f"{module_name}.{zp_name}", None) |
171 | | - state_dict_g_idx_path = mapping.get(f"{module_name}.{g_idx_name}", None) |
172 | | - |
173 | | - if state_dict_g_idx_path is not None: |
174 | | - with safe_open(state_dict_g_idx_path, framework="pt", device="cpu") as f: |
175 | | - state_dict_g_idx = f.get_tensor(f"{module_name}.{g_idx_name}") |
176 | | - |
177 | | - update_parameter_data(module, state_dict_g_idx, g_idx_name) |
178 | | - |
179 | | - if state_dict_scale_path is not None: |
180 | | - # module is quantized |
181 | | - with safe_open(state_dict_scale_path, framework="pt", device="cpu") as f: |
182 | | - state_dict_scale = f.get_tensor(f"{module_name}.{scale_name}") |
183 | | - |
184 | | - update_parameter_data(module, state_dict_scale, scale_name) |
185 | | - |
186 | | - if state_dict_zp_path is None: |
187 | | - # fill in zero point for symmetric quantization |
188 | | - state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu") |
189 | | - else: |
190 | | - with safe_open(state_dict_zp_path, framework="pt", device="cpu") as f: |
191 | | - state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") |
192 | | - |
193 | | - update_parameter_data(module, state_dict_zp, zp_name) |
194 | | - |
195 | | - |
196 | 150 | def _scheme_from_targets( |
197 | 151 | target_to_scheme: OrderedDictType[str, QuantizationScheme], |
198 | 152 | targets: List[str], |
|
0 commit comments