Skip to content

Commit

Permalink
implement sequential pipeline hack
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Feb 7, 2025
1 parent b9bea3c commit 4402121
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 11 deletions.
7 changes: 4 additions & 3 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
match_layers_params,
)

Expand Down Expand Up @@ -129,7 +130,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
)

# register hooks
#target_modules = match_layers_params(self.targets, model)
target_modules = match_layers_params(self.targets, model)
for index, (layer_name, layer) in enumerate(layers.items()):
if isinstance(self.sparsity, dict):
layer_sparsity = self.sparsity[layer_name]
Expand All @@ -139,9 +140,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
layer_sparsity = self.sparsity

# TODO: match module or param
for name, module in layer.named_modules(prefix=layer_name):
for name, module in get_prunable_layers(layer).items():
if module in target_modules.values():
self._module_names[module] = name
self._module_names[module] = f"{layer_name}.{name}"
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr, Field
from pydantic import Field, PrivateAttr

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/pipelines/basic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def run_pipeline(
batch = tensors_to_device(batch, model_device)
model(**batch)

# TODO: replace with a lifecycle event
if callback_modifier:
callback_modifier.on_sequential_batch_end()
# TODO: replace with a lifecycle event
if callback_modifier:
callback_modifier.on_sequential_batch_end()
23 changes: 22 additions & 1 deletion src/llmcompressor/pipelines/layer_sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from llmcompressor.pytorch.utils.helpers import tensors_to_device
from llmcompressor.utils.helpers import calibration_forward_context

__all__ = ["match_modules", "capture_first_layer_intermediates", "to_next_layer_kwargs"]
__all__ = [
"match_modules",
"capture_first_layer_intermediates",
"to_next_layer_kwargs",
"maybe_inject_pos_embeddings",
]


def match_modules(model: Module, target_names: List[str]) -> List[Module]:
Expand Down Expand Up @@ -126,3 +131,19 @@ class EarlyStopException(Exception):

_args: Tuple[Any, ...]
_kwargs: Dict[str, Any]


def maybe_inject_pos_embeddings(
output: Dict[str, Any],
next_layer: Module,
inputs: Dict[str, Any],
) -> Dict[str, Any]:
signature = inspect.signature(next_layer.forward)
if (
"position_embeddings" in signature.parameters.keys()
and "position_embeddings" in inputs
and "position_embeddings" not in output
):
output["position_embeddings"] = inputs["position_embeddings"]

return output
7 changes: 6 additions & 1 deletion src/llmcompressor/pipelines/layer_sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from llmcompressor.pipelines.layer_sequential.helpers import (
capture_first_layer_intermediates,
match_modules,
maybe_inject_pos_embeddings,
to_next_layer_kwargs,
)
from llmcompressor.utils.helpers import calibration_forward_context
Expand Down Expand Up @@ -79,6 +80,10 @@ def run_pipeline(
output = layer(**inputs)

if layer_index < num_layers - 1:
output = to_next_layer_kwargs(output, layers[layer_index + 1])
next_layer = layers[layer_index + 1]
output = to_next_layer_kwargs(output, next_layer)
# HACK: accommodate models which pass position embeddings
output = maybe_inject_pos_embeddings(output, next_layer, inputs)

intermediates.delete(batch_index)
intermediates.update(batch_index, output)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ test_stage:
SparseGPTModifier:
sparsity: 0.5
block_size: 128
sequential_update: False
targets: [
're:model.layers.3.mlp.gate_proj.weight'
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ recipe: |
SparseGPTModifier:
sparsity: 0.5
block_size: 128
sequential_update: False
targets: [
're:model.layers.3.mlp.gate_proj.weight'
]

0 comments on commit 4402121

Please sign in to comment.