Skip to content

Commit

Permalink
Fix KL-Divergence similarity function calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Nov 30, 2023
1 parent e76110a commit df6d359
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 4 additions & 1 deletion model_compression_toolkit/core/common/similarity_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,7 @@ def compute_kl_divergence(float_tensor: np.ndarray, fxp_tensor: np.ndarray, batc
non_zero_fxp_tensor = fxp_flat.copy()
non_zero_fxp_tensor[non_zero_fxp_tensor == 0] = EPS

return np.mean(np.sum(np.where(float_flat != 0, float_flat * np.log(float_flat / non_zero_fxp_tensor), 0), axis=-1))
prob_distance = np.where(float_flat != 0, float_flat * np.log(float_flat / non_zero_fxp_tensor), 0)
# The sum is part of the KL-Divergance function.
# The mean is to aggregate the distance between each output probability vectors.
return np.mean(np.sum(prob_distance, axis=-1), axis=-1)
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def softmax_model(input_shape):
return model



class TestSensitivityMetricInterestPoints(unittest.TestCase):

def test_filtered_interest_points_set(self):
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_softmax_interest_point(self):
distance_per_softmax_axis = distance_fn(t1, t2, batch=True, axis=axis)
distance_global = distance_fn(t1, t2, batch=True, axis=None)

self.assertFalse(np.isclose(distance_per_softmax_axis, distance_global),
self.assertFalse(np.isclose(np.mean(distance_per_softmax_axis), distance_global),
f"Computing distance for softmax node on softmax activation axis should be different than "
f"on than computing on the entire tensor.")

Expand Down

0 comments on commit df6d359

Please sign in to comment.