Skip to content

Commit

Permalink
rename smoothquant private vars
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Mar 10, 2025
1 parent b03124a commit 4488a8c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
36 changes: 18 additions & 18 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ class SmoothQuantModifier(Modifier):
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None

resolved_mappings_: Optional[List[SmoothQuantMapping]] = None
scales_: Optional[Dict] = None
_resolved_mappings: Optional[List[SmoothQuantMapping]] = None
_scales: Optional[Dict] = None

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand All @@ -132,8 +132,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:

self.ignore = [] if not self.ignore else self.ignore
self.mappings = self._infer_mappings_from_model(state.model)
self.resolved_mappings_ = self._resolve_mappings(state.model)
self.scales_ = {}
self._resolved_mappings = self._resolve_mappings(state.model)
self._scales = {}

calibration_dataloader = state.data.calib

Expand All @@ -150,10 +150,10 @@ def on_finalize(self, state: State, **kwargs) -> bool:
:param state: unused
:return: True
"""
if self.scales_ is not None:
self.scales_.clear()
if self.resolved_mappings_ is not None:
self.resolved_mappings_.clear()
if self._scales is not None:
self._scales.clear()
if self._resolved_mappings is not None:
self._resolved_mappings.clear()

return True

Expand Down Expand Up @@ -219,21 +219,21 @@ def hook_fn(module, inp, out):
latest_mins = torch.min(out, dim=0)[0]
latest_maxes = torch.max(out, dim=0)[0]

if layer_name in self.scales_:
self.scales_[layer_name].min_channel_vals = torch.minimum(
self.scales_[layer_name].min_channel_vals, latest_mins
if layer_name in self._scales:
self._scales[layer_name].min_channel_vals = torch.minimum(
self._scales[layer_name].min_channel_vals, latest_mins
)
self.scales_[layer_name].max_channel_vals = torch.maximum(
self.scales_[layer_name].max_channel_vals, latest_maxes
self._scales[layer_name].max_channel_vals = torch.maximum(
self._scales[layer_name].max_channel_vals, latest_maxes
)
else:
self.scales_[layer_name] = SmoothQuantScale(
self._scales[layer_name] = SmoothQuantScale(
min_channel_vals=latest_mins, max_channel_vals=latest_maxes
)

return hook_fn

for mapping in self.resolved_mappings_:
for mapping in self._resolved_mappings:
name = mapping.smooth_name
layer = mapping.smooth_layer
self.register_hook(layer, create_hook_fn(name), "forward")
Expand Down Expand Up @@ -278,10 +278,10 @@ def _apply_smoothing(self, model: Module):
This modifies the weights of the model in-place.
"""
logger.info("Smoothing activation scales...")
for mapping in self.resolved_mappings_:
for mapping in self._resolved_mappings:
activation_scales = ( # get dynamic range for each activation channel
self.scales_[mapping.smooth_name].max_channel_vals
- self.scales_[mapping.smooth_name].min_channel_vals
self._scales[mapping.smooth_name].max_channel_vals
- self._scales[mapping.smooth_name].min_channel_vals
)
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def test_successful_map(self):
modifier = LogarithmicEqualizationModifier(mappings=mappings)

modifier.ignore = []
modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model)
modifier._resolved_mappings = modifier._resolve_mappings(self.state.model)

self.assertEqual(len(modifier.resolved_mappings_), len(mappings))
self.assertEqual(len(modifier._resolved_mappings), len(mappings))

mapping = modifier.resolved_mappings_[0]
mapping = modifier._resolved_mappings[0]
self.assertEqual(mapping.smooth_name, mappings[0][1])
self.assertIsInstance(mapping.smooth_layer, Linear)
self.assertIsInstance(mapping.balance_layers[0], Linear)
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def test_successful_map(self):
modifier = SmoothQuantModifier(mappings=mappings)

modifier.ignore = []
modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model)
modifier._resolved_mappings = modifier._resolve_mappings(self.state.model)

self.assertEqual(len(modifier.resolved_mappings_), len(mappings))
self.assertEqual(len(modifier._resolved_mappings), len(mappings))

mapping = modifier.resolved_mappings_[0]
mapping = modifier._resolved_mappings[0]
self.assertEqual(mapping.smooth_name, mappings[0][1])
self.assertIsInstance(mapping.smooth_layer, Linear)
self.assertIsInstance(mapping.balance_layers[0], Linear)

0 comments on commit 4488a8c

Please sign in to comment.