Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iteratively update quantization parameters in GPTQ #178

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,24 @@ def compress(

if actorder == ActivationOrdering.GROUP:
# permute by activation order first, then update groups
W, self.H, perm = self._apply_activation_ordering(W, self.H)
W, self.H, perm, invperm = self._apply_activation_ordering(W, self.H)
self._update_quantization_parameters(weight_quant_args, W)

# use identity g_idx (invert permutation later)

elif actorder == ActivationOrdering.WEIGHT:
# update groups first, then permute by activation order
self._update_quantization_parameters(weight_quant_args, W)
W, self.H, perm = self._apply_activation_ordering(W, self.H)
W, self.H, perm, invperm = self._apply_activation_ordering(W, self.H)

# permute g_idx to maintain identity mapping after unpermutation
g_idx = g_idx[perm]

else:
# ensure that quantization parameters are calculated using the same
# floating point data type, regardless of quantization strategy
self._update_quantization_parameters(weight_quant_args, W)

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

Expand Down Expand Up @@ -235,6 +240,14 @@ def compress(
W1[:, i:] -= w1_err
Err1[:, i] = err1

# because weight values have been updated, so should q parameters
if actorder == ActivationOrdering.WEIGHT:
W = W[:, invperm]
self._update_quantization_parameters(weight_quant_args, W)
W = W[:, perm]
else:
self._update_quantization_parameters(weight_quant_args, W)

# propagate block error
W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2
Expand All @@ -251,7 +264,6 @@ def compress(
if strategy == QuantizationStrategy.GROUP:
if actorder == ActivationOrdering.WEIGHT:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]

elif actorder == ActivationOrdering.GROUP:
Expand Down Expand Up @@ -298,14 +310,15 @@ def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tenso

def _apply_activation_ordering(
self, W: torch.Tensor, H: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Permute weight and hessian in order of greatest outupt activations

:param W: weight to permute
"""
perm = torch.argsort(torch.diag(H), descending=True)
return W[:, perm], H[perm][:, perm], perm
invperm = torch.argsort(perm)
return W[:, perm], H[perm][:, perm], perm, invperm

def _log_metrics(self, start_tick: float, losses: torch.Tensor):
"""
Expand Down
Loading