Skip to content

Commit

Permalink
GPTQ Activation Ordering (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs authored Aug 28, 2024
1 parent e64c74d commit 6ad6e05
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class GPTQModifier(Modifier):
| symmetric: true
| strategy: "tensor"
| group_size: 128
| actorder: False
:param sequential_update: Whether or not to update weights sequentially by layer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,23 @@ def compress(

tick = time.time()

# update quantization parameters for activation ordering
observer = MemorylessObserver(weight_quant_args)
scale, zero_point = observer(W)
update_parameter_data(self.layer, scale, "weight_scale")
update_parameter_data(self.layer, zero_point, "weight_zero_point")
# consider activation ordering
if weight_quant_args.actorder:
# use hessian to create a permutation of weights
perm = torch.argsort(torch.diag(self.H), descending=True)

# permute weight and hessian
W = W[:, perm]
self.H = self.H[perm][:, perm]

# update quantization parameters for activation ordering
observer = MemorylessObserver(weight_quant_args)
_scale, _zero_point = observer(W)
update_parameter_data(self.layer, _scale, "weight_scale")
update_parameter_data(self.layer, _zero_point, "weight_zero_point")

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

# mask dead hessian values
dead = torch.diag(self.H) == 0
Expand All @@ -135,6 +147,7 @@ def compress(

Losses = torch.zeros(self.rows, device=self.dev)

# compute inverse hessian in place to save memory
damp = percdamp * torch.mean(torch.diag(self.H))
diag = torch.arange(self.columns, device=self.dev)
self.H[diag, diag] += damp
Expand Down Expand Up @@ -224,12 +237,26 @@ def compress(
if "METRIC" in logger._core.levels.keys():
self.log_metrics(tick, Losses)

if weight_quant_args.actorder:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]

# g_idx describes the group index of the permuted weight
g_idx = torch.tensor(
[i // weight_quant_args.group_size for i in range(self.columns)],
dtype=torch.int,
).to(device=invperm.device)

# invert to get the group index of the unpermuted weight
update_parameter_data(self.layer, g_idx[invperm], "weight_g_idx")

if isinstance(self.layer, transformers.Conv1D):
W.transpose_(0, 1)
W = W.reshape(final_shape).to(final_dtype)

# This is a bit hacky, but FSDP updates only work if we change the weight in
# place, clone() or direct assignment won't work
# This is a bit hacky, but FSDP updates only work if we change
# the weight in place, clone() or direct assignment won't work
self.layer.weight -= self.layer.weight
self.layer.weight += W

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_actorder.yaml"
ppl_threshold: 20
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
test_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head", "model.layers.0.mlp.down_proj"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: False
strategy: "group"
group_size: 128
actorder: True
input_activations: null
output_activations: null
targets: ["Linear"]
GPTQModifier:
block_size: 128
sequential_update: False

0 comments on commit 6ad6e05

Please sign in to comment.