diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 34d609e74..96f15c9da 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -18,6 +18,14 @@ import torch import torch.nn.utils.parametrize as P import tqdm +from compressed_tensors.modeling.attention import ( + initialize_hooked_attention, + register_query_hook, +) +from compressed_tensors.modeling.kvcache import ( + initialize_hooked_kv_cache, + register_key_hook, +) from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -36,6 +44,7 @@ from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.nn import Module, Parameter +from transformers import PreTrainedModel __all__ = ["TransformFactory", "TransformBase"] @@ -97,12 +106,13 @@ def apply_to_model(self, model: Module, use_tqdm=True): desc = f"Applying {self.name} transforms" for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): - self._apply_to_module(module, arg) + self._apply_to_module(model, module, arg) - def _apply_to_module(self, module: Module, args: TransformArgs): + def _apply_to_module(self, model: Module, module: Module, args: TransformArgs): """ Create transforms and apply them to the module + :param model: model which module belongs to :param module: target module to apply transforms to :param args: defines how the transform will be applied to the target module """ @@ -156,7 +166,28 @@ def output_hook(_, _input, output): module.register_forward_hook(output_hook) - # other locations such as q_attn and k_attn have not been implemented + # register query hook to attention + elif args.location == TransformLocation.Q_ATTN: + if not isinstance(model, PreTrainedModel): + raise ValueError(f"Cannot hook attention of model: {model}") + + def query_hook(_, query_states): + return transform(query_states) + + initialize_hooked_attention(model, module) + register_query_hook(module, query_hook) + + # register key hook to kvcache + elif args.location == TransformLocation.K_CACHE: + if not isinstance(model, PreTrainedModel): + raise ValueError(f"Cannot hook attention of model: {model}") + + def key_hook(_, key_states): + return transform(key_states) + + initialize_hooked_kv_cache(model, module) + register_key_hook(module, key_hook) + else: raise NotImplementedError() diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index a843e2728..3b78dd25e 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -51,7 +51,6 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) exec_device = get_execution_device(module) device = get_offloaded_device(module) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 4f1b52762..a135dfab0 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -50,7 +50,6 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) device = get_offloaded_device(module) precision = self.scheme.precision if args.is_online() else torch.float64 diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index d3f469579..3967d4616 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -45,6 +45,16 @@ class TransformLocation(str, Enum): K_CACHE = "k_cache" Q_ATTN = "q_attn" + def is_online(self) -> bool: + """ + Returns True if the transform location is online + (applied at runtime), False otherwise + """ + return self not in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ) + class TransformArgs(BaseModel, use_enum_values=True): """ @@ -70,9 +80,6 @@ def wrap_singleton(cls, value): return value def is_online(self) -> bool: - return self.location not in ( - TransformLocation.WEIGHT_INPUT, - TransformLocation.WEIGHT_OUTPUT, - ) + return TransformLocation(self.location).is_online() model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index 920728571..0414e3f69 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -34,6 +34,8 @@ def get_transform_size( :param head_dim: size of head when transform is applied to mha :return: size of matrix """ + size = None + if isinstance(module, torch.nn.Linear): if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): size = module.in_features @@ -44,11 +46,13 @@ def get_transform_size( size = module.num_embeddings else: size = module.embedding_dim - else: - raise NotImplementedError(f"Transforms on {type(module)} are not supported") + elif head_dim is None: + raise NotImplementedError( + f"Transforms on {type(module)} are not supported without head_dim" + ) if head_dim is not None: - if size % head_dim != 0: + if size is not None and size % head_dim != 0: raise ValueError( f"{head_dim} must divide {size} for {type(module)} at {location}" ) @@ -105,11 +109,11 @@ def apply_transform_weight( assert transform_weight.shape[0] == transform_weight.shape[1] - if module_type == torch.nn.Linear: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) + if TransformLocation(location).is_online(): + return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.WEIGHT_INPUT: + if module_type == torch.nn.Linear: + if location == TransformLocation.WEIGHT_INPUT: # equivalent to (transform_weight @ value.T).T return _multihead_matmul(value, transform_weight.T) @@ -117,26 +121,14 @@ def apply_transform_weight( # equivalent to (value.T @ transform_weight).T return _multihead_matmul(transform_weight.T, value) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - # similar derivation to torch.nn.Linear, but `y = (x W)` elif module_type == torch.nn.Embedding: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) - - elif location == TransformLocation.WEIGHT_INPUT: - return _multihead_matmul( - transform_weight, - value, - ) + if location == TransformLocation.WEIGHT_INPUT: + return _multihead_matmul(transform_weight, value) elif location == TransformLocation.WEIGHT_OUTPUT: return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - raise NotImplementedError( f"Applying transforms to {module_type} {location} is not supported" ) diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index 824c06bd3..0ab5093c6 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -62,7 +62,9 @@ def __init__( num_attention_heads * self.head_dim, hidden_size, bias=False ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, past_key_values=None + ) -> torch.Tensor: batch_size, seq_len, hidden_size = hidden_states.shape hidden_shape = (batch_size, seq_len, -1, self.head_dim) @@ -70,6 +72,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if past_key_values is not None: + past_key_values.update(key_states, value_states, 0, {}) + key_states = self.repeat_kv(key_states, self.num_key_value_groups) value_states = self.repeat_kv(value_states, self.num_key_value_groups) @@ -97,6 +102,21 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +class MockAttentionModel(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, hidden_size, num_attention_heads, num_key_value_heads): + super().__init__(PretrainedConfig()) + self.self_attn = MockAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + ) + + def forward(self, x): + return self.self_attn(x) + + @pytest.fixture(scope="function") def model_apply(): model = TransformableModel(2, 4, 8, 16, 32, 64) diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index ae4bf12e8..1fdbc3a00 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -22,7 +22,7 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch -from tests.test_transform.conftest import MockAttention +from tests.test_transform.conftest import MockAttention, MockAttentionModel from tests.testing_utils import requires_accelerate, requires_gpu @@ -147,7 +147,7 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size config = TransformConfig( config_groups={ - "": TransformScheme( + "R2": TransformScheme( type=type, randomize=randomize, head_dim=head_dim, @@ -166,6 +166,42 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomize", (True, False)) +@pytest.mark.parametrize("head_dim", (4, 8)) +@pytest.mark.parametrize("input_batch_size", (1, 5, 17)) +def test_correctness_query_key_locations(type, randomize, head_dim, input_batch_size): + hidden_size = 64 + num_attention_heads = 8 + + model = MockAttentionModel( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=head_dim, + ) + + input = torch.rand(input_batch_size, 5, hidden_size) + true_output = model(input) + + config = TransformConfig( + config_groups={ + "R3": TransformScheme( + type=type, + randomize=randomize, + head_dim=head_dim, + apply=[ + TransformArgs(targets="self_attn", location="q_attn"), + TransformArgs(targets="self_attn", location="k_cache"), + ], + ) + } + ) + apply_transform_config(model, config) + + output = model(input) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) + + @requires_gpu @pytest.mark.parametrize("cuda_default", (True, False)) def test_random_matrix_device_handling(cuda_default):