Skip to content

Commit

Permalink
foxes per code review
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Oct 6, 2024
1 parent ba6433c commit a329151
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def compute_trackable_per_sample_hessian(self,

hessian_score_by_image_hash = {}

if not isinstance(inputs_batch, list):
raise TypeError('Expected a list of inputs') # pragma: no cover
if not inputs_batch or not isinstance(inputs_batch, list):
raise TypeError('Expected a non-empty list of inputs') # pragma: no cover
if len(inputs_batch) > 1:
raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover

Expand All @@ -261,17 +261,27 @@ def compute_trackable_per_sample_hessian(self,
hessian_scores_request=hessian_scores_request,
num_iterations_for_approximation=self.num_iterations_for_approximation)
hessian_scores = fw_hessian_calculator.compute()
for b in range(inputs_batch[0].shape[0]):
img_hash = self.calc_image_hash(inputs_batch[0][b])
for i in range(inputs_batch[0].shape[0]):
img_hash = self.calc_image_hash(inputs_batch[0][i])
hessian_score_by_image_hash[img_hash] = {
node: score[b] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
}

return hessian_score_by_image_hash

@staticmethod
def calc_image_hash(image):
if len(image.shape) != 3: # pragma: no cover
"""
Calculates hash for an input image.
Args:
image: input 3d image (without batch).
Returns:
Image hash.
"""
if not len(image.shape) == 3: # pragma: no cover
raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}')
image_bytes = image.astype(np.float32).tobytes()
return hashlib.md5(image_bytes).hexdigest()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def forward_pass(self):
output = self.concat_tensors(output_tensors)
return output, target_activation_tensors

def _generate_random_vectors_batch(self, shape, distribution: HessianEstimationDistribution, device) -> torch.Tensor:
def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution,
device: torch.device) -> torch.Tensor:
"""
Generate a batch of random vectors for Hutchinson estimation
Expand Down
17 changes: 10 additions & 7 deletions model_compression_toolkit/gptq/pytorch/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def _get_total_grad_steps():
self.hessian_score_per_layer = None # for fixed layer weights
self.hessian_score_per_image_per_layer = None # for sample-layer attention
if self.use_sample_layer_attention:
assert (hessian_cfg.norm_scores is False and hessian_cfg.log_norm is False and
hessian_cfg.scale_log_norm is False), hessian_cfg
# normalization is currently not supported, make sure the config reflects it.
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
raise NotImplementedError()
# Per sample hessian scores are calculated on-demand during the training loop
self.hessian_score_per_image_per_layer = {}
else:
Expand Down Expand Up @@ -308,18 +309,20 @@ def _get_loss_weights(self, input_tensors: List[torch.Tensor]) -> Tuple[torch.Te
if len(input_tensors) > 1:
raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs')

scores = []
image_scores = []
batch = input_tensors[0]
img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch]
for img_hash in img_hashes:
# If sample-layer attention score for the image is not found, compute and store it for the whole batch.
if img_hash not in self.hessian_score_per_image_per_layer:
score_per_image_layer_per = self._compute_sample_layer_attention_scores(input_tensors)
self.hessian_score_per_image_per_layer.update(score_per_image_layer_per)
score_per_image_per_layer = self._compute_sample_layer_attention_scores(input_tensors)
self.hessian_score_per_image_per_layer.update(score_per_image_per_layer)
img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash]
# fetch image scores for all layers and combine them into a single tensor
img_scores = np.stack(list(img_scores_per_layer.values()), axis=0)
scores.append(img_scores)
image_scores.append(img_scores)

layer_sample_weights = np.stack(scores, axis=1) # layers X images
layer_sample_weights = np.stack(image_scores, axis=1) # layers X images
layer_weights = layer_sample_weights.mean(axis=1)
return layer_sample_weights, layer_weights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,19 @@ def __init__(self, beta_scheduler: Callable[[int], float]):

self.count_iter = 0

def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor = None):
def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor):
"""
Returns the soft quantizer regularization value for SoftRounding.
Args:
model: A model to be quantized with SoftRounding.
entropy_reg: Entropy value to scale the quantizer regularization.
layer_weights: a vector of layer weights. If None, each layers has a weight of 1.
layer_weights: a vector of layer weights.
Returns: Regularization value.
"""
layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)]

if layer_weights is None:
layer_weights = torch.ones((len(layers),))
if len(layer_weights.shape) != 1 or layer_weights.shape[0] != len(layers):
raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') # pragma: no cover
max_w = layer_weights.max()
Expand Down

0 comments on commit a329151

Please sign in to comment.