Skip to content

Commit

Permalink
PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Jun 5, 2024
1 parent ff21edf commit a9c83e4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def _sample_batch_representative_dataset(self,
(2) A list of remaining samples - for each input layer.
"""

if num_inputs < 0: # pragma: no cover
Logger.critical(f"Number of images to compute Hessian approximation must be positive, "
f"but given {num_inputs}.")

all_inp_hessian_samples = [[] for _ in range(num_inputs)]
# Collect the requested number of samples from the representative dataset
for batch in representative_dataset:
Expand All @@ -99,7 +103,6 @@ def _sample_batch_representative_dataset(self,
num_missing = min(num_hessian_samples - len(all_inp_hessian_samples[inp_idx]), inp_batch.shape[0])
# Append each sample separately
samples = [s for s in inp_batch[0:num_missing, ...]]
# hessian_samples += [sample.reshape(1, *sample.shape) for sample in samples]
remaining_samples = [s for s in inp_batch[num_missing:, ...]]

all_inp_hessian_samples[inp_idx] += [sample.reshape(1, *sample.shape) for sample in samples]
Expand Down Expand Up @@ -341,7 +344,7 @@ def _collect_saved_hessians_for_request(self, trace_hessian_request: TraceHessia
return collected_results

@staticmethod
def _construct_single_node_request(mode: HessianMode, granularity: HessianInfoGranularity, target_nodes
def _construct_single_node_request(mode: HessianMode, granularity: HessianInfoGranularity, target_nodes: List
) -> TraceHessianRequest:
"""
Constructs a Hessian request with for a single node. Used for retrieving and maintaining cached results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compute(self) -> List[np.ndarray]:

# Loop through each interest point activation tensor
prev_mean_results = None
for j in tqdm(range(20)): # Approximation iterations
for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
# Getting a random vector with normal distribution
v = tf.random.normal(shape=output.shape, dtype=output.dtype)
f_v = tf.reduce_sum(v * output)
Expand Down

0 comments on commit a9c83e4

Please sign in to comment.