From 440212113905eac9a5ed36530740b0c0a204b4ec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 20:19:08 +0000 Subject: [PATCH] implement sequential pipeline hack Signed-off-by: Kyle Sayers --- .../modifiers/obcq/sgpt_mixin.py | 7 +++--- .../modifiers/pruning/wanda/base.py | 2 +- src/llmcompressor/pipelines/basic/pipeline.py | 6 ++--- .../pipelines/layer_sequential/helpers.py | 23 ++++++++++++++++++- .../pipelines/layer_sequential/pipeline.py | 7 +++++- .../oneshot_configs/recipes/recipe.yaml | 1 - .../oneshot_configs/tiny_stories_conf1.yaml | 1 - 7 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 721ef3fd4..6d002cb29 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -20,6 +20,7 @@ from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, + get_prunable_layers, match_layers_params, ) @@ -129,7 +130,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) # register hooks - #target_modules = match_layers_params(self.targets, model) + target_modules = match_layers_params(self.targets, model) for index, (layer_name, layer) in enumerate(layers.items()): if isinstance(self.sparsity, dict): layer_sparsity = self.sparsity[layer_name] @@ -139,9 +140,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: layer_sparsity = self.sparsity # TODO: match module or param - for name, module in layer.named_modules(prefix=layer_name): + for name, module in get_prunable_layers(layer).items(): if module in target_modules.values(): - self._module_names[module] = name + self._module_names[module] = f"{layer_name}.{name}" self._module_sparsities[module] = layer_sparsity self.register_hook(module, self.calibrate_module, "forward") diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index 291e5ae48..f4ef5e224 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -7,7 +7,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import PrivateAttr, Field +from pydantic import Field, PrivateAttr from llmcompressor.core import State from llmcompressor.modifiers import Modifier diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 61d6e28ce..13a1c9454 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -40,6 +40,6 @@ def run_pipeline( batch = tensors_to_device(batch, model_device) model(**batch) - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() + # TODO: replace with a lifecycle event + if callback_modifier: + callback_modifier.on_sequential_batch_end() diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py index 06e7e5b3b..2bf943fcc 100644 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -15,7 +15,12 @@ from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context -__all__ = ["match_modules", "capture_first_layer_intermediates", "to_next_layer_kwargs"] +__all__ = [ + "match_modules", + "capture_first_layer_intermediates", + "to_next_layer_kwargs", + "maybe_inject_pos_embeddings", +] def match_modules(model: Module, target_names: List[str]) -> List[Module]: @@ -126,3 +131,19 @@ class EarlyStopException(Exception): _args: Tuple[Any, ...] _kwargs: Dict[str, Any] + + +def maybe_inject_pos_embeddings( + output: Dict[str, Any], + next_layer: Module, + inputs: Dict[str, Any], +) -> Dict[str, Any]: + signature = inspect.signature(next_layer.forward) + if ( + "position_embeddings" in signature.parameters.keys() + and "position_embeddings" in inputs + and "position_embeddings" not in output + ): + output["position_embeddings"] = inputs["position_embeddings"] + + return output diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index cef100e2f..c0ae0b620 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -9,6 +9,7 @@ from llmcompressor.pipelines.layer_sequential.helpers import ( capture_first_layer_intermediates, match_modules, + maybe_inject_pos_embeddings, to_next_layer_kwargs, ) from llmcompressor.utils.helpers import calibration_forward_context @@ -79,6 +80,10 @@ def run_pipeline( output = layer(**inputs) if layer_index < num_layers - 1: - output = to_next_layer_kwargs(output, layers[layer_index + 1]) + next_layer = layers[layer_index + 1] + output = to_next_layer_kwargs(output, next_layer) + # HACK: accommodate models which pass position embeddings + output = maybe_inject_pos_embeddings(output, next_layer, inputs) + intermediates.delete(batch_index) intermediates.update(batch_index, output) diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml index c5bf782d4..54239b3b4 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml @@ -3,7 +3,6 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml index 39f9d6576..7b795ba8e 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -9,7 +9,6 @@ recipe: | SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file