Skip to content

Commit

Permalink
Fix Hessian batch sample method to use remaining samples correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Jun 6, 2024
1 parent d13319f commit a44a060
Showing 1 changed file with 42 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
from functools import partial
from tqdm import tqdm
from typing import Callable, List, Dict, Any, Tuple

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self,
self.trace_hessian_request_to_score_list = {}

def _sample_batch_representative_dataset(self,
representative_dataset: Iterable,
representative_dataset: Any,
num_hessian_samples: int,
num_inputs: int,
last_iter_remain_samples: List[List[np.ndarray]] = None
Expand All @@ -86,46 +87,59 @@ def _sample_batch_representative_dataset(self,
f"but given {num_inputs}.")

all_inp_hessian_samples = [[] for _ in range(num_inputs)]
all_inp_remaining_samples = [[] for _ in range(num_inputs)]

# Collect the requested number of samples from the representative dataset
for batch in representative_dataset:
if not isinstance(batch, list):
# In case there are samples left from previous iterations, we use them first
# otherwise, we take a batch from the representative dataset generator
while len(all_inp_hessian_samples[0]) < num_hessian_samples:
batch = None
sampling_from_repr = True
if last_iter_remain_samples is not None and len(last_iter_remain_samples[0]) >= num_hessian_samples:
batch = last_iter_remain_samples
sampling_from_repr = False
else:
try:
batch = next(representative_dataset)
except StopIteration:
Logger.critical(
f"Not enough samples in the provided representative dataset to compute Hessian approximation on "
f"{num_hessian_samples} samples.")

if batch is not None and not isinstance(batch, list):
Logger.critical(f'Expected batch to be a list; found type: {type(batch)}.') # pragma: no cover
all_inp_remaining_samples = [[] for _ in range(num_inputs)]
for inp_idx in range(len(batch)):
inp_batch = batch[inp_idx]

if last_iter_remain_samples is not None and len(last_iter_remain_samples[inp_idx]):
# some samples remained from last batch of last computation iteration -
# include them in the current batch
inp_batch = np.concatenate((inp_batch, last_iter_remain_samples[inp_idx]))
for inp_idx in range(len(batch)):
inp_batch = batch[inp_idx] if sampling_from_repr else np.stack(batch[inp_idx], axis=0)
if not sampling_from_repr:
last_iter_remain_samples[inp_idx] = []

# Compute number of missing samples to get to the requested amount from the current batch
num_missing = min(num_hessian_samples - len(all_inp_hessian_samples[inp_idx]), inp_batch.shape[0])

# Append each sample separately
samples = [s for s in inp_batch[0:num_missing, ...]]
remaining_samples = [s for s in inp_batch[num_missing:, ...]]

all_inp_hessian_samples[inp_idx] += [sample.reshape(1, *sample.shape) for sample in samples]
# This list would can only get filled on the last batch iteration
all_inp_remaining_samples[inp_idx] += (remaining_samples)

# This list can only get filled on the last batch iteration
all_inp_remaining_samples[inp_idx] += remaining_samples

if len(all_inp_hessian_samples[0]) > num_hessian_samples:
Logger.critical(f"Requested {num_hessian_samples} samples for computing Hessian approximation but "
f"{len(all_inp_hessian_samples[0])} were collected.") # pragma: no cover
elif len(all_inp_hessian_samples[0]) == num_hessian_samples:
# Collected enough samples, constructing a dataset with the requested batch size
hessian_samples_for_input = []
for inp_samples in all_inp_hessian_samples:
inp_samples = np.concatenate(inp_samples, axis=0)
num_collected_samples = inp_samples.shape[0]
inp_samples = np.split(inp_samples,
num_collected_samples // min(num_collected_samples, num_hessian_samples))
hessian_samples_for_input.append(inp_samples[0])

return hessian_samples_for_input, all_inp_remaining_samples
Logger.critical(
f"Not enough samples in the provided representative dataset to compute Hessian approximation on "
f"{num_hessian_samples} samples.")

# Collected enough samples, constructing a dataset with the requested batch size
hessian_samples_for_input = []
for inp_samples in all_inp_hessian_samples:
inp_samples = np.concatenate(inp_samples, axis=0)
num_collected_samples = inp_samples.shape[0]
inp_samples = np.split(inp_samples,
num_collected_samples // min(num_collected_samples, num_hessian_samples))
hessian_samples_for_input.append(inp_samples[0])

return hessian_samples_for_input, all_inp_remaining_samples

def _clear_saved_hessian_info(self):
"""Clears the saved info approximations."""
Expand Down Expand Up @@ -302,9 +316,11 @@ def _populate_saved_info_to_size(self,
representative_dataset=self.representative_dataset_gen())

next_iter_remaining_samples = None
pbar = tqdm(desc="Computing Hessian approximations...", total=None)
while max_remaining_hessians > 0:
# If batch_size < max_remaining_hessians then we run each computation on a batch_size of images.
# This way, we always run a computation for a single batch.
pbar.update(1)
size_to_compute = min(max_remaining_hessians, batch_size)
next_iter_remaining_samples = (
self.compute(trace_hessian_request, hessian_representative_dataset, size_to_compute,
Expand Down

0 comments on commit a44a060

Please sign in to comment.