| 
18 | 18 | import torch  | 
19 | 19 | import torch.nn.utils.parametrize as P  | 
20 | 20 | import tqdm  | 
 | 21 | +from compressed_tensors.modeling.attention import (  | 
 | 22 | +    initialize_hooked_attention,  | 
 | 23 | +    register_query_hook,  | 
 | 24 | +)  | 
 | 25 | +from compressed_tensors.modeling.kvcache import (  | 
 | 26 | +    initialize_hooked_kv_cache,  | 
 | 27 | +    register_key_hook,  | 
 | 28 | +)  | 
21 | 29 | from compressed_tensors.registry.registry import RegistryMixin, T  | 
22 | 30 | from compressed_tensors.transform import (  | 
23 | 31 |     TransformArgs,  | 
 | 
36 | 44 | from compressed_tensors.utils.internal import InternalModule  | 
37 | 45 | from torch import Tensor  | 
38 | 46 | from torch.nn import Module, Parameter  | 
 | 47 | +from transformers import PreTrainedModel  | 
39 | 48 | 
 
  | 
40 | 49 | 
 
  | 
41 | 50 | __all__ = ["TransformFactory", "TransformBase"]  | 
@@ -97,12 +106,13 @@ def apply_to_model(self, model: Module, use_tqdm=True):  | 
97 | 106 | 
 
  | 
98 | 107 |         desc = f"Applying {self.name} transforms"  | 
99 | 108 |         for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):  | 
100 |  | -            self._apply_to_module(module, arg)  | 
 | 109 | +            self._apply_to_module(model, module, arg)  | 
101 | 110 | 
 
  | 
102 |  | -    def _apply_to_module(self, module: Module, args: TransformArgs):  | 
 | 111 | +    def _apply_to_module(self, model: Module, module: Module, args: TransformArgs):  | 
103 | 112 |         """  | 
104 | 113 |         Create transforms and apply them to the module  | 
105 | 114 | 
  | 
 | 115 | +        :param model: model which module belongs to  | 
106 | 116 |         :param module: target module to apply transforms to  | 
107 | 117 |         :param args: defines how the transform will be applied to the target module  | 
108 | 118 |         """  | 
@@ -156,7 +166,28 @@ def output_hook(_, _input, output):  | 
156 | 166 | 
 
  | 
157 | 167 |             module.register_forward_hook(output_hook)  | 
158 | 168 | 
 
  | 
159 |  | -        # other locations such as q_attn and k_attn have not been implemented  | 
 | 169 | +        # register query hook to attention  | 
 | 170 | +        elif args.location == TransformLocation.Q_ATTN:  | 
 | 171 | +            if not isinstance(model, PreTrainedModel):  | 
 | 172 | +                raise ValueError(f"Cannot hook attention of model: {model}")  | 
 | 173 | + | 
 | 174 | +            def query_hook(_, query_states):  | 
 | 175 | +                return transform(query_states)  | 
 | 176 | + | 
 | 177 | +            initialize_hooked_attention(model, module, quantize=False)  | 
 | 178 | +            register_query_hook(module, query_hook)  | 
 | 179 | + | 
 | 180 | +        # register key hook to kvcache  | 
 | 181 | +        elif args.location == TransformLocation.K_CACHE:  | 
 | 182 | +            if not isinstance(model, PreTrainedModel):  | 
 | 183 | +                raise ValueError(f"Cannot hook attention of model: {model}")  | 
 | 184 | + | 
 | 185 | +            def key_hook(_, key_states):  | 
 | 186 | +                return transform(key_states)  | 
 | 187 | + | 
 | 188 | +            initialize_hooked_kv_cache(model, module, quantize=False)  | 
 | 189 | +            register_key_hook(module, key_hook)  | 
 | 190 | + | 
160 | 191 |         else:  | 
161 | 192 |             raise NotImplementedError()  | 
162 | 193 | 
 
  | 
 | 
0 commit comments