Skip to content

Commit 52792be

Browse files
authored
[Attention] R3 Attention Transform (#485)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5c62cb8 commit 52792be

File tree

7 files changed

+117
-33
lines changed

7 files changed

+117
-33
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
import torch
1919
import torch.nn.utils.parametrize as P
2020
import tqdm
21+
from compressed_tensors.modeling.attention import (
22+
initialize_hooked_attention,
23+
register_query_hook,
24+
)
25+
from compressed_tensors.modeling.kvcache import (
26+
initialize_hooked_kv_cache,
27+
register_key_hook,
28+
)
2129
from compressed_tensors.registry.registry import RegistryMixin, T
2230
from compressed_tensors.transform import (
2331
TransformArgs,
@@ -36,6 +44,7 @@
3644
from compressed_tensors.utils.internal import InternalModule
3745
from torch import Tensor
3846
from torch.nn import Module, Parameter
47+
from transformers import PreTrainedModel
3948

4049

4150
__all__ = ["TransformFactory", "TransformBase"]
@@ -97,12 +106,13 @@ def apply_to_model(self, model: Module, use_tqdm=True):
97106

98107
desc = f"Applying {self.name} transforms"
99108
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
100-
self._apply_to_module(module, arg)
109+
self._apply_to_module(model, module, arg)
101110

102-
def _apply_to_module(self, module: Module, args: TransformArgs):
111+
def _apply_to_module(self, model: Module, module: Module, args: TransformArgs):
103112
"""
104113
Create transforms and apply them to the module
105114
115+
:param model: model which module belongs to
106116
:param module: target module to apply transforms to
107117
:param args: defines how the transform will be applied to the target module
108118
"""
@@ -156,7 +166,28 @@ def output_hook(_, _input, output):
156166

157167
module.register_forward_hook(output_hook)
158168

159-
# other locations such as q_attn and k_attn have not been implemented
169+
# register query hook to attention
170+
elif args.location == TransformLocation.Q_ATTN:
171+
if not isinstance(model, PreTrainedModel):
172+
raise ValueError(f"Cannot hook attention of model: {model}")
173+
174+
def query_hook(_, query_states):
175+
return transform(query_states)
176+
177+
initialize_hooked_attention(model, module)
178+
register_query_hook(module, query_hook)
179+
180+
# register key hook to kvcache
181+
elif args.location == TransformLocation.K_CACHE:
182+
if not isinstance(model, PreTrainedModel):
183+
raise ValueError(f"Cannot hook attention of model: {model}")
184+
185+
def key_hook(_, key_states):
186+
return transform(key_states)
187+
188+
initialize_hooked_kv_cache(model, module)
189+
register_key_hook(module, key_hook)
190+
160191
else:
161192
raise NotImplementedError()
162193

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param module: parent module that transform will be applied to
5252
:param args: defines how the transform will be applied to the module
5353
"""
54-
assert hasattr(module, "weight")
5554
size = get_transform_size(module, args.location, self.scheme.head_dim)
5655
exec_device = get_execution_device(module)
5756
device = get_offloaded_device(module)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def create_transform(self, module: Module, args: TransformArgs):
5050
:param module: parent module that transform will be applied to
5151
:param args: defines how the transform will be applied to the module
5252
"""
53-
assert hasattr(module, "weight")
5453
size = get_transform_size(module, args.location, self.scheme.head_dim)
5554
device = get_offloaded_device(module)
5655
precision = self.scheme.precision if args.is_online() else torch.float64

src/compressed_tensors/transform/transform_args.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ class TransformLocation(str, Enum):
4545
K_CACHE = "k_cache"
4646
Q_ATTN = "q_attn"
4747

48+
def is_online(self) -> bool:
49+
"""
50+
Returns True if the transform location is online
51+
(applied at runtime), False otherwise
52+
"""
53+
return self not in (
54+
TransformLocation.WEIGHT_INPUT,
55+
TransformLocation.WEIGHT_OUTPUT,
56+
)
57+
4858

4959
class TransformArgs(BaseModel, use_enum_values=True):
5060
"""
@@ -70,9 +80,6 @@ def wrap_singleton(cls, value):
7080
return value
7181

7282
def is_online(self) -> bool:
73-
return self.location not in (
74-
TransformLocation.WEIGHT_INPUT,
75-
TransformLocation.WEIGHT_OUTPUT,
76-
)
83+
return TransformLocation(self.location).is_online()
7784

7885
model_config = ConfigDict(extra="forbid")

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def get_transform_size(
3434
:param head_dim: size of head when transform is applied to mha
3535
:return: size of matrix
3636
"""
37+
size = None
38+
3739
if isinstance(module, torch.nn.Linear):
3840
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
3941
size = module.in_features
@@ -44,11 +46,13 @@ def get_transform_size(
4446
size = module.num_embeddings
4547
else:
4648
size = module.embedding_dim
47-
else:
48-
raise NotImplementedError(f"Transforms on {type(module)} are not supported")
49+
elif head_dim is None:
50+
raise NotImplementedError(
51+
f"Transforms on {type(module)} are not supported without head_dim"
52+
)
4953

5054
if head_dim is not None:
51-
if size % head_dim != 0:
55+
if size is not None and size % head_dim != 0:
5256
raise ValueError(
5357
f"{head_dim} must divide {size} for {type(module)} at {location}"
5458
)
@@ -105,38 +109,26 @@ def apply_transform_weight(
105109

106110
assert transform_weight.shape[0] == transform_weight.shape[1]
107111

108-
if module_type == torch.nn.Linear:
109-
if location == TransformLocation.INPUT:
110-
return _multihead_matmul(value, transform_weight)
112+
if TransformLocation(location).is_online():
113+
return _multihead_matmul(value, transform_weight)
111114

112-
elif location == TransformLocation.WEIGHT_INPUT:
115+
if module_type == torch.nn.Linear:
116+
if location == TransformLocation.WEIGHT_INPUT:
113117
# equivalent to (transform_weight @ value.T).T
114118
return _multihead_matmul(value, transform_weight.T)
115119

116120
elif location == TransformLocation.WEIGHT_OUTPUT:
117121
# equivalent to (value.T @ transform_weight).T
118122
return _multihead_matmul(transform_weight.T, value)
119123

120-
elif location == TransformLocation.OUTPUT:
121-
return _multihead_matmul(value, transform_weight)
122-
123124
# similar derivation to torch.nn.Linear, but `y = (x W)`
124125
elif module_type == torch.nn.Embedding:
125-
if location == TransformLocation.INPUT:
126-
return _multihead_matmul(value, transform_weight)
127-
128-
elif location == TransformLocation.WEIGHT_INPUT:
129-
return _multihead_matmul(
130-
transform_weight,
131-
value,
132-
)
126+
if location == TransformLocation.WEIGHT_INPUT:
127+
return _multihead_matmul(transform_weight, value)
133128

134129
elif location == TransformLocation.WEIGHT_OUTPUT:
135130
return _multihead_matmul(value, transform_weight)
136131

137-
elif location == TransformLocation.OUTPUT:
138-
return _multihead_matmul(value, transform_weight)
139-
140132
raise NotImplementedError(
141133
f"Applying transforms to {module_type} {location} is not supported"
142134
)

tests/test_transform/conftest.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,19 @@ def __init__(
6262
num_attention_heads * self.head_dim, hidden_size, bias=False
6363
)
6464

65-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
65+
def forward(
66+
self, hidden_states: torch.Tensor, past_key_values=None
67+
) -> torch.Tensor:
6668
batch_size, seq_len, hidden_size = hidden_states.shape
6769
hidden_shape = (batch_size, seq_len, -1, self.head_dim)
6870

6971
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
7072
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
7173
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
7274

75+
if past_key_values is not None:
76+
past_key_values.update(key_states, value_states, 0, {})
77+
7378
key_states = self.repeat_kv(key_states, self.num_key_value_groups)
7479
value_states = self.repeat_kv(value_states, self.num_key_value_groups)
7580

@@ -97,6 +102,21 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
97102
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
98103

99104

105+
class MockAttentionModel(PreTrainedModel):
106+
config_class = PretrainedConfig
107+
108+
def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
109+
super().__init__(PretrainedConfig())
110+
self.self_attn = MockAttention(
111+
hidden_size=hidden_size,
112+
num_attention_heads=num_attention_heads,
113+
num_key_value_heads=num_key_value_heads,
114+
)
115+
116+
def forward(self, x):
117+
return self.self_attn(x)
118+
119+
100120
@pytest.fixture(scope="function")
101121
def model_apply():
102122
model = TransformableModel(2, 4, 8, 16, 32, 64)

tests/test_transform/factory/test_correctness.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_config,
2323
)
2424
from compressed_tensors.utils import offloaded_dispatch
25-
from tests.test_transform.conftest import MockAttention
25+
from tests.test_transform.conftest import MockAttention, MockAttentionModel
2626
from tests.testing_utils import requires_accelerate, requires_gpu
2727

2828

@@ -147,7 +147,7 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size
147147

148148
config = TransformConfig(
149149
config_groups={
150-
"": TransformScheme(
150+
"R2": TransformScheme(
151151
type=type,
152152
randomize=randomize,
153153
head_dim=head_dim,
@@ -166,6 +166,42 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size
166166
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
167167

168168

169+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
170+
@pytest.mark.parametrize("randomize", (True, False))
171+
@pytest.mark.parametrize("head_dim", (4, 8))
172+
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
173+
def test_correctness_query_key_locations(type, randomize, head_dim, input_batch_size):
174+
hidden_size = 64
175+
num_attention_heads = 8
176+
177+
model = MockAttentionModel(
178+
hidden_size=hidden_size,
179+
num_attention_heads=num_attention_heads,
180+
num_key_value_heads=head_dim,
181+
)
182+
183+
input = torch.rand(input_batch_size, 5, hidden_size)
184+
true_output = model(input)
185+
186+
config = TransformConfig(
187+
config_groups={
188+
"R3": TransformScheme(
189+
type=type,
190+
randomize=randomize,
191+
head_dim=head_dim,
192+
apply=[
193+
TransformArgs(targets="self_attn", location="q_attn"),
194+
TransformArgs(targets="self_attn", location="k_cache"),
195+
],
196+
)
197+
}
198+
)
199+
apply_transform_config(model, config)
200+
201+
output = model(input)
202+
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
203+
204+
169205
@requires_gpu
170206
@pytest.mark.parametrize("cuda_default", (True, False))
171207
def test_random_matrix_device_handling(cuda_default):

0 commit comments

Comments
 (0)