|  | 
| 15 | 15 | import logging | 
| 16 | 16 | from collections import OrderedDict | 
| 17 | 17 | from copy import deepcopy | 
| 18 |  | -from typing import Dict, Iterable, List, Optional | 
|  | 18 | +from typing import Iterable, List | 
| 19 | 19 | from typing import OrderedDict as OrderedDictType | 
| 20 | 20 | from typing import Union | 
| 21 | 21 | 
 | 
|  | 
| 40 | 40 | ) | 
| 41 | 41 | from compressed_tensors.utils.helpers import deprecated, replace_module | 
| 42 | 42 | from compressed_tensors.utils.match import match_named_modules, match_targets | 
| 43 |  | -from compressed_tensors.utils.offload import update_parameter_data | 
| 44 |  | -from safetensors import safe_open | 
| 45 | 43 | from torch.nn import Module | 
| 46 | 44 | 
 | 
| 47 | 45 | 
 | 
| @@ -196,58 +194,6 @@ def find_name_or_class_matches( | 
| 196 | 194 |     return match_targets(name, module, targets) | 
| 197 | 195 | 
 | 
| 198 | 196 | 
 | 
| 199 |  | -def _infer_status(model: Module) -> Optional[QuantizationStatus]: | 
| 200 |  | -    for module in model.modules(): | 
| 201 |  | -        status = getattr(module, "quantization_status", None) | 
| 202 |  | -        if status is not None: | 
| 203 |  | -            return status | 
| 204 |  | -    return None | 
| 205 |  | - | 
| 206 |  | - | 
| 207 |  | -def _load_quant_args_from_mapping( | 
| 208 |  | -    base_name: str, module_name: str, module: Module, mapping: Dict | 
| 209 |  | -): | 
| 210 |  | -    # TODO: skip update and just register here, don't do it in initialize | 
| 211 |  | -    """ | 
| 212 |  | -    Loads scale and zero point from a state_dict into the specified module | 
| 213 |  | -
 | 
| 214 |  | -    :param base_name: quantization target, one of: weights, input_activations or | 
| 215 |  | -    output_activations | 
| 216 |  | -    :param module_name: pytorch module name to look up in state_dict | 
| 217 |  | -    :module: pytorch module associated with module_name | 
| 218 |  | -    :mapping: mapping to search fetch paths on disk for a given parameter | 
| 219 |  | -    """ | 
| 220 |  | -    scale_name = f"{base_name}_scale" | 
| 221 |  | -    zp_name = f"{base_name}_zero_point" | 
| 222 |  | -    g_idx_name = f"{base_name}_g_idx" | 
| 223 |  | - | 
| 224 |  | -    state_dict_scale_path = mapping.get(f"{module_name}.{scale_name}", None) | 
| 225 |  | -    state_dict_zp_path = mapping.get(f"{module_name}.{zp_name}", None) | 
| 226 |  | -    state_dict_g_idx_path = mapping.get(f"{module_name}.{g_idx_name}", None) | 
| 227 |  | - | 
| 228 |  | -    if state_dict_g_idx_path is not None: | 
| 229 |  | -        with safe_open(state_dict_g_idx_path, framework="pt", device="cpu") as f: | 
| 230 |  | -            state_dict_g_idx = f.get_tensor(f"{module_name}.{g_idx_name}") | 
| 231 |  | - | 
| 232 |  | -        update_parameter_data(module, state_dict_g_idx, g_idx_name) | 
| 233 |  | - | 
| 234 |  | -    if state_dict_scale_path is not None: | 
| 235 |  | -        # module is quantized | 
| 236 |  | -        with safe_open(state_dict_scale_path, framework="pt", device="cpu") as f: | 
| 237 |  | -            state_dict_scale = f.get_tensor(f"{module_name}.{scale_name}") | 
| 238 |  | - | 
| 239 |  | -        update_parameter_data(module, state_dict_scale, scale_name) | 
| 240 |  | - | 
| 241 |  | -        if state_dict_zp_path is None: | 
| 242 |  | -            # fill in zero point for symmetric quantization | 
| 243 |  | -            state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu") | 
| 244 |  | -        else: | 
| 245 |  | -            with safe_open(state_dict_zp_path, framework="pt", device="cpu") as f: | 
| 246 |  | -                state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") | 
| 247 |  | - | 
| 248 |  | -        update_parameter_data(module, state_dict_zp, zp_name) | 
| 249 |  | - | 
| 250 |  | - | 
| 251 | 197 | def _scheme_from_targets( | 
| 252 | 198 |     target_to_scheme: OrderedDictType[str, QuantizationScheme], | 
| 253 | 199 |     targets: List[str], | 
|  | 
0 commit comments