Skip to content

Commit 9ead292

Browse files
committed
attention quant
Signed-off-by: Kyle Sayers <[email protected]>
1 parent df6fd15 commit 9ead292

File tree

12 files changed

+821
-211
lines changed

12 files changed

+821
-211
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
# isort: off
17+
from .kvcache import *
18+
from .attention import *
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Callable, Optional
17+
from weakref import ref
18+
19+
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
20+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
21+
from compressed_tensors.utils import getattr_chain
22+
from compressed_tensors.utils.internal import InternalModule
23+
from torch import Tensor
24+
from torch.nn import Module
25+
from torch.utils.hooks import RemovableHandle
26+
from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel
27+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
28+
29+
30+
__all__ = [
31+
"QuantizedAttentionImpl",
32+
"initialize_hooked_attention",
33+
"register_query_hook",
34+
"IMPL_ATTR",
35+
]
36+
37+
38+
IMPL_ATTR = "impl"
39+
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
40+
41+
42+
class QuantizedAttentionImpl(InternalModule):
43+
"""
44+
QuantizedAttentionImpl module which wraps the functionality of the original
45+
attention implementation. Unlike the original attention function, this
46+
implementation is a `torch.nn.Module` which can be hooked to trigger
47+
transforms and calibration hooks.
48+
49+
This module works by being registered as a submodule to attention modules via
50+
`initialize_hooked_attention`, registering a new attention implementation function
51+
which calls this module, then setting the model attention implementation to the new
52+
function. After triggering hooks and quantization, this module calls the original
53+
attention implementation function.
54+
55+
:param attn_module: parent attention module
56+
"""
57+
58+
_original_impl = "eager"
59+
60+
def __init__(self, config: PretrainedConfig, attn_module: Module):
61+
super().__init__()
62+
self.config = config
63+
self.attn_module = ref(attn_module) # avoid circular references
64+
65+
def forward(
66+
self,
67+
module: Module,
68+
query: Tensor,
69+
key: Tensor,
70+
value: Tensor,
71+
*args,
72+
**kwargs,
73+
):
74+
# quantization
75+
quant_args_attr = "quantization_scheme.input_activations"
76+
quant_args = getattr_chain(module, quant_args_attr, None)
77+
quant_enabled = getattr(module, "quantization_enabled", True)
78+
if quant_args is not None and quant_enabled:
79+
query = forward_quantize(module, query, "q", quant_args)
80+
81+
# original attention
82+
return ALL_ATTENTION_FUNCTIONS[_original_impl](
83+
module,
84+
query,
85+
key,
86+
value,
87+
*args,
88+
**kwargs,
89+
)
90+
91+
92+
# ----- initialize ----- #
93+
94+
95+
def _ct_hooked_attention(module: Module, *args, **kwargs):
96+
if hasattr(module, IMPL_ATTR):
97+
return module.impl(module, *args, **kwargs)
98+
else:
99+
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs)
100+
101+
102+
def initialize_hooked_attention(model: PreTrainedModel, module: Module):
103+
"""
104+
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
105+
attached to attention
106+
107+
:param model: parent model of attention module
108+
:param module: attention module to initialize with
109+
"""
110+
if not hasattr(module, IMPL_ATTR):
111+
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module))
112+
if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
113+
# assumes only one model at a time
114+
global _original_impl
115+
_original_impl = model.config._attn_implementation
116+
117+
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
118+
model.config._attn_implementation = HOOKED_ATTENTION_NAME
119+
120+
initialize_hooked_kv_cache(model, module)
121+
122+
123+
# ----- hooks ----- #
124+
125+
126+
def register_query_hook(
127+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
128+
) -> RemovableHandle:
129+
"""
130+
Register a hook which takes post-rope query states as an argument and
131+
returns the modified query states or `None`
132+
133+
:param module: attention module to add hook to
134+
:param hook: query hook function
135+
"""
136+
impl = getattr(module, IMPL_ATTR)
137+
138+
def _hook(impl: QuantizedAttentionImpl, args, kwargs):
139+
bound = inspect.signature(impl.forward).bind(*args, **kwargs)
140+
value = hook(module, bound.arguments["query"])
141+
if value is not None:
142+
bound.arguments["query"] = value
143+
144+
return bound.args, bound.kwargs
145+
146+
return impl.register_forward_pre_hook(_hook, with_kwargs=True)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Callable, Optional, Tuple
17+
from weakref import ref
18+
19+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
20+
from compressed_tensors.utils import getattr_chain
21+
from compressed_tensors.utils.internal import InternalModule
22+
from torch import Tensor
23+
from torch.nn import Module
24+
from torch.utils.hooks import RemovableHandle
25+
from transformers import Cache, PretrainedConfig, PreTrainedModel
26+
27+
28+
__all__ = [
29+
"QuantizedKVCache",
30+
"initialize_hooked_kv_cache",
31+
"register_key_hook",
32+
"register_value_hook",
33+
"KV_CACHE_ATTR",
34+
]
35+
36+
37+
KV_CACHE_ATTR = "kv_cache"
38+
39+
40+
class QuantizedKVCache(InternalModule):
41+
"""
42+
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
43+
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
44+
hooked to trigger transforms and calibration hooks.
45+
46+
This module works by being registered as a submodule to attention modules via
47+
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
48+
kwargs with this module. This module adopts the functionality of the replaced cache,
49+
preserving caching functionality such as sliding window attention, ect.
50+
51+
:param attn_module: parent attention module
52+
"""
53+
54+
def __init__(self, config: PretrainedConfig, attn_module: Module):
55+
super().__init__()
56+
self.config = config
57+
self.attn_module = ref(attn_module) # avoid circular reference
58+
self.past_key_values: Optional[Cache] = None
59+
60+
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
61+
return self(*args, **kwargs)
62+
63+
def forward(
64+
self,
65+
key_states: Tensor,
66+
value_states: Tensor,
67+
*args,
68+
**kwargs,
69+
) -> Tuple[Tensor, Tensor]:
70+
# quantization
71+
module = self.attn_module()
72+
quant_args_attr = "quantization_scheme.input_activations"
73+
quant_args = getattr_chain(module, quant_args_attr, None)
74+
quant_enabled = getattr(module, "quantization_enabled", True)
75+
if quant_args is not None and quant_enabled:
76+
key_states = forward_quantize(module, key_states, "k", quant_args)
77+
value_states = forward_quantize(module, value_states, "v", quant_args)
78+
79+
# original cache
80+
if self.past_key_values is not None:
81+
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs)
82+
else:
83+
ret = (key_states, value_states)
84+
85+
self.past_key_values = None
86+
return ret
87+
88+
89+
# ----- initialize ----- #
90+
91+
92+
def _kv_cache_attention_hook(module: Module, args, kwargs):
93+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
94+
_past_kv_name = (
95+
"past_key_values" # transformers#39956
96+
if "past_key_values" in inspect.signature(module.forward).parameters
97+
else "past_key_value"
98+
)
99+
kv_cache.past_key_values = kwargs.get(_past_kv_name, None)
100+
kwargs[_past_kv_name] = kv_cache
101+
102+
return args, kwargs
103+
104+
105+
def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
106+
"""
107+
Initialize a `QuantizedKVCache` instance attached to attention
108+
109+
:param model: parent model of attention module
110+
:param module: attention module to initialize with
111+
"""
112+
if not hasattr(module, KV_CACHE_ATTR):
113+
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
114+
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
115+
116+
117+
# ----- hooks ----- #
118+
119+
120+
def register_key_hook(
121+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
122+
) -> RemovableHandle:
123+
"""
124+
Register a hook which takes post-rope key states as an argument and
125+
returns the modified key states or `None`
126+
127+
:param module: attention module to add hook to
128+
:param hook: key hook function
129+
"""
130+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
131+
132+
def _hook(cache: QuantizedKVCache, args, kwargs):
133+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
134+
value = hook(module, bound.arguments["key_states"])
135+
if value is not None:
136+
bound.arguments["key_states"] = value
137+
138+
return bound.args, bound.kwargs
139+
140+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
141+
142+
143+
def register_value_hook(
144+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
145+
) -> RemovableHandle:
146+
"""
147+
Register a hook which takes value states as an argument and
148+
returns the modified value states or `None`
149+
150+
:param module: attention module to add hook to
151+
:param hook: value hook function
152+
"""
153+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
154+
155+
def _hook(cache: QuantizedKVCache, args, kwargs):
156+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
157+
value = hook(module, bound.arguments["value_states"])
158+
if value is not None:
159+
bound.arguments["value_states"] = value
160+
161+
return bound.args, bound.kwargs
162+
163+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)

0 commit comments

Comments
 (0)