|  | 
| 18 | 18 | import torch | 
| 19 | 19 | import torch.nn.utils.parametrize as P | 
| 20 | 20 | import tqdm | 
|  | 21 | +from compressed_tensors.modeling.attention import ( | 
|  | 22 | +    initialize_hooked_attention, | 
|  | 23 | +    register_query_hook, | 
|  | 24 | +) | 
|  | 25 | +from compressed_tensors.modeling.kvcache import ( | 
|  | 26 | +    initialize_hooked_kv_cache, | 
|  | 27 | +    register_key_hook, | 
|  | 28 | +) | 
| 21 | 29 | from compressed_tensors.registry.registry import RegistryMixin, T | 
| 22 | 30 | from compressed_tensors.transform import ( | 
| 23 | 31 |     TransformArgs, | 
|  | 
| 36 | 44 | from compressed_tensors.utils.internal import InternalModule | 
| 37 | 45 | from torch import Tensor | 
| 38 | 46 | from torch.nn import Module, Parameter | 
|  | 47 | +from transformers import PreTrainedModel | 
| 39 | 48 | 
 | 
| 40 | 49 | 
 | 
| 41 | 50 | __all__ = ["TransformFactory", "TransformBase"] | 
| @@ -97,12 +106,13 @@ def apply_to_model(self, model: Module, use_tqdm=True): | 
| 97 | 106 | 
 | 
| 98 | 107 |         desc = f"Applying {self.name} transforms" | 
| 99 | 108 |         for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): | 
| 100 |  | -            self._apply_to_module(module, arg) | 
|  | 109 | +            self._apply_to_module(model, module, arg) | 
| 101 | 110 | 
 | 
| 102 |  | -    def _apply_to_module(self, module: Module, args: TransformArgs): | 
|  | 111 | +    def _apply_to_module(self, model: Module, module: Module, args: TransformArgs): | 
| 103 | 112 |         """ | 
| 104 | 113 |         Create transforms and apply them to the module | 
| 105 | 114 | 
 | 
|  | 115 | +        :param model: model which module belongs to | 
| 106 | 116 |         :param module: target module to apply transforms to | 
| 107 | 117 |         :param args: defines how the transform will be applied to the target module | 
| 108 | 118 |         """ | 
| @@ -156,7 +166,28 @@ def output_hook(_, _input, output): | 
| 156 | 166 | 
 | 
| 157 | 167 |             module.register_forward_hook(output_hook) | 
| 158 | 168 | 
 | 
| 159 |  | -        # other locations such as q_attn and k_attn have not been implemented | 
|  | 169 | +        # register query hook to attention | 
|  | 170 | +        elif args.location == TransformLocation.Q_ATTN: | 
|  | 171 | +            if not isinstance(model, PreTrainedModel): | 
|  | 172 | +                raise ValueError(f"Cannot hook attention of model: {model}") | 
|  | 173 | + | 
|  | 174 | +            def query_hook(_, query_states): | 
|  | 175 | +                return transform(query_states) | 
|  | 176 | + | 
|  | 177 | +            initialize_hooked_attention(model, module) | 
|  | 178 | +            register_query_hook(module, query_hook) | 
|  | 179 | + | 
|  | 180 | +        # register key hook to kvcache | 
|  | 181 | +        elif args.location == TransformLocation.K_CACHE: | 
|  | 182 | +            if not isinstance(model, PreTrainedModel): | 
|  | 183 | +                raise ValueError(f"Cannot hook attention of model: {model}") | 
|  | 184 | + | 
|  | 185 | +            def key_hook(_, key_states): | 
|  | 186 | +                return transform(key_states) | 
|  | 187 | + | 
|  | 188 | +            initialize_hooked_kv_cache(model, module) | 
|  | 189 | +            register_key_hook(module, key_hook) | 
|  | 190 | + | 
| 160 | 191 |         else: | 
| 161 | 192 |             raise NotImplementedError() | 
| 162 | 193 | 
 | 
|  | 
0 commit comments