Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
import torch
import torch.nn.utils.parametrize as P
import tqdm
from compressed_tensors.modeling.attention import (
initialize_hooked_attention,
register_query_hook,
)
from compressed_tensors.modeling.kvcache import (
initialize_hooked_kv_cache,
register_key_hook,
)
from compressed_tensors.registry.registry import RegistryMixin, T
from compressed_tensors.transform import (
TransformArgs,
Expand All @@ -36,6 +44,7 @@
from compressed_tensors.utils.internal import InternalModule
from torch import Tensor
from torch.nn import Module, Parameter
from transformers import PreTrainedModel


__all__ = ["TransformFactory", "TransformBase"]
Expand Down Expand Up @@ -97,12 +106,13 @@ def apply_to_model(self, model: Module, use_tqdm=True):

desc = f"Applying {self.name} transforms"
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
self._apply_to_module(module, arg)
self._apply_to_module(model, module, arg)

def _apply_to_module(self, module: Module, args: TransformArgs):
def _apply_to_module(self, model: Module, module: Module, args: TransformArgs):
"""
Create transforms and apply them to the module

:param model: model which module belongs to
:param module: target module to apply transforms to
:param args: defines how the transform will be applied to the target module
"""
Expand Down Expand Up @@ -156,7 +166,28 @@ def output_hook(_, _input, output):

module.register_forward_hook(output_hook)

# other locations such as q_attn and k_attn have not been implemented
# register query hook to attention
elif args.location == TransformLocation.Q_ATTN:
if not isinstance(model, PreTrainedModel):
raise ValueError(f"Cannot hook attention of model: {model}")

def query_hook(_, query_states):
return transform(query_states)

initialize_hooked_attention(model, module)
register_query_hook(module, query_hook)

# register key hook to kvcache
elif args.location == TransformLocation.K_CACHE:
if not isinstance(model, PreTrainedModel):
raise ValueError(f"Cannot hook attention of model: {model}")

def key_hook(_, key_states):
return transform(key_states)

initialize_hooked_kv_cache(model, module)
register_key_hook(module, key_hook)

else:
raise NotImplementedError()

Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def create_transform(self, module: Module, args: TransformArgs):
:param module: parent module that transform will be applied to
:param args: defines how the transform will be applied to the module
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
exec_device = get_execution_device(module)
device = get_offloaded_device(module)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def create_transform(self, module: Module, args: TransformArgs):
:param module: parent module that transform will be applied to
:param args: defines how the transform will be applied to the module
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
device = get_offloaded_device(module)
precision = self.scheme.precision if args.is_online() else torch.float64
Expand Down
15 changes: 11 additions & 4 deletions src/compressed_tensors/transform/transform_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class TransformLocation(str, Enum):
K_CACHE = "k_cache"
Q_ATTN = "q_attn"

def is_online(self) -> bool:
"""
Returns True if the transform location is online
(applied at runtime), False otherwise
"""
return self not in (
TransformLocation.WEIGHT_INPUT,
TransformLocation.WEIGHT_OUTPUT,
)


class TransformArgs(BaseModel, use_enum_values=True):
"""
Expand All @@ -70,9 +80,6 @@ def wrap_singleton(cls, value):
return value

def is_online(self) -> bool:
return self.location not in (
TransformLocation.WEIGHT_INPUT,
TransformLocation.WEIGHT_OUTPUT,
)
return TransformLocation(self.location).is_online()

model_config = ConfigDict(extra="forbid")
34 changes: 13 additions & 21 deletions src/compressed_tensors/transform/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def get_transform_size(
:param head_dim: size of head when transform is applied to mha
:return: size of matrix
"""
size = None

if isinstance(module, torch.nn.Linear):
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
size = module.in_features
Expand All @@ -44,11 +46,13 @@ def get_transform_size(
size = module.num_embeddings
else:
size = module.embedding_dim
else:
raise NotImplementedError(f"Transforms on {type(module)} are not supported")
elif head_dim is None:
raise NotImplementedError(
f"Transforms on {type(module)} are not supported without head_dim"
)

if head_dim is not None:
if size % head_dim != 0:
if size is not None and size % head_dim != 0:
raise ValueError(
f"{head_dim} must divide {size} for {type(module)} at {location}"
)
Expand Down Expand Up @@ -105,38 +109,26 @@ def apply_transform_weight(

assert transform_weight.shape[0] == transform_weight.shape[1]

if module_type == torch.nn.Linear:
if location == TransformLocation.INPUT:
return _multihead_matmul(value, transform_weight)
if TransformLocation(location).is_online():
return _multihead_matmul(value, transform_weight)

elif location == TransformLocation.WEIGHT_INPUT:
if module_type == torch.nn.Linear:
if location == TransformLocation.WEIGHT_INPUT:
# equivalent to (transform_weight @ value.T).T
return _multihead_matmul(value, transform_weight.T)

elif location == TransformLocation.WEIGHT_OUTPUT:
# equivalent to (value.T @ transform_weight).T
return _multihead_matmul(transform_weight.T, value)

elif location == TransformLocation.OUTPUT:
return _multihead_matmul(value, transform_weight)

# similar derivation to torch.nn.Linear, but `y = (x W)`
elif module_type == torch.nn.Embedding:
if location == TransformLocation.INPUT:
return _multihead_matmul(value, transform_weight)

elif location == TransformLocation.WEIGHT_INPUT:
return _multihead_matmul(
transform_weight,
value,
)
if location == TransformLocation.WEIGHT_INPUT:
return _multihead_matmul(transform_weight, value)

elif location == TransformLocation.WEIGHT_OUTPUT:
return _multihead_matmul(value, transform_weight)

elif location == TransformLocation.OUTPUT:
return _multihead_matmul(value, transform_weight)

raise NotImplementedError(
f"Applying transforms to {module_type} {location} is not supported"
)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_transform/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,19 @@ def __init__(
num_attention_heads * self.head_dim, hidden_size, bias=False
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self, hidden_states: torch.Tensor, past_key_values=None
) -> torch.Tensor:
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_shape = (batch_size, seq_len, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

if past_key_values is not None:
past_key_values.update(key_states, value_states, 0, {})

key_states = self.repeat_kv(key_states, self.num_key_value_groups)
value_states = self.repeat_kv(value_states, self.num_key_value_groups)

Expand Down Expand Up @@ -97,6 +102,21 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class MockAttentionModel(PreTrainedModel):
config_class = PretrainedConfig

def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
super().__init__(PretrainedConfig())
self.self_attn = MockAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
)

def forward(self, x):
return self.self_attn(x)


@pytest.fixture(scope="function")
def model_apply():
model = TransformableModel(2, 4, 8, 16, 32, 64)
Expand Down
40 changes: 38 additions & 2 deletions tests/test_transform/factory/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
apply_transform_config,
)
from compressed_tensors.utils import offloaded_dispatch
from tests.test_transform.conftest import MockAttention
from tests.test_transform.conftest import MockAttention, MockAttentionModel
from tests.testing_utils import requires_accelerate, requires_gpu


Expand Down Expand Up @@ -147,7 +147,7 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size

config = TransformConfig(
config_groups={
"": TransformScheme(
"R2": TransformScheme(
type=type,
randomize=randomize,
head_dim=head_dim,
Expand All @@ -166,6 +166,42 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)


@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
@pytest.mark.parametrize("randomize", (True, False))
@pytest.mark.parametrize("head_dim", (4, 8))
@pytest.mark.parametrize("input_batch_size", (1, 5, 17))
def test_correctness_query_key_locations(type, randomize, head_dim, input_batch_size):
hidden_size = 64
num_attention_heads = 8

model = MockAttentionModel(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=head_dim,
)

input = torch.rand(input_batch_size, 5, hidden_size)
true_output = model(input)

config = TransformConfig(
config_groups={
"R3": TransformScheme(
type=type,
randomize=randomize,
head_dim=head_dim,
apply=[
TransformArgs(targets="self_attn", location="q_attn"),
TransformArgs(targets="self_attn", location="k_cache"),
],
)
}
)
apply_transform_config(model, config)

output = model(input)
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)


@requires_gpu
@pytest.mark.parametrize("cuda_default", (True, False))
def test_random_matrix_device_handling(cuda_default):
Expand Down