-
Notifications
You must be signed in to change notification settings - Fork 11
feat: implementing Kernel Inception Metric #65 : #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
2af982a
f23efa4
c4bea1b
85f88a8
87b736a
f13d26e
c1c30be
24b16d0
88f9ec6
6bdd189
4cb3d65
f9d5954
965eab5
075e68d
878876c
d40ee0c
d3df741
8969ce9
74cf343
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,189 @@ | ||
| # Copyright 2024 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| ## credits to the https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/image/kid.py for the refernece implementation. | ||
|
|
||
|
|
||
| import jax.numpy as jnp | ||
| from jax import random | ||
| import flax | ||
| import jax | ||
| from clu import metrics as clu_metrics | ||
|
|
||
| KID_DEFAULT_SUBSETS = 100 | ||
| KID_DEFAULT_SUBSET_SIZE = 1000 | ||
| KID_DEFAULT_DEGREE = 3 | ||
| KID_DEFAULT_GAMMA = None | ||
| KID_DEFAULT_COEF = 1.0 | ||
|
|
||
|
|
||
| @flax.struct.dataclass | ||
| class KernelInceptionDistanceMetric(clu_metrics.Metric): | ||
dhruvmalik007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| r"""Computes Kernel Inception Distance (KID) for asses quality of generated images. | ||
| KID is a metric used to evaluate the quality of generated images by comparing | ||
| the distribution of generated images to the distribution of real images. | ||
| It is based on the Inception Score (IS) and uses a kernelized version of the | ||
| Maximum Mean Discrepancy (MMD) to measure the distance between two | ||
| distributions. | ||
| The KID is computed as follows: | ||
| .. math:: | ||
| KID = MMD(f_{real}, f_{fake})^2 | ||
| Where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features | ||
| from real and fake images, see `kid ref1`_ for more details. In particular, calculating the MMD requires the | ||
| evaluation of a polynomial kernel function :math:`k`. | ||
| .. math:: | ||
| k(x,y) = (\gamma * x^T y + coef)^{degree} | ||
| Args: | ||
| subsets: Number of subsets to use for KID calculation. | ||
| subset_size: Number of samples in each subset. | ||
| degree: Degree of the polynomial kernel. | ||
| gamma: Kernel coefficient for the polynomial kernel. | ||
| coef: Independent term in the polynomial kernel. | ||
| """ | ||
|
|
||
| subsets: int | ||
| subset_size: int | ||
| degree: int | ||
| gamma: float | ||
| coef: float | ||
|
|
||
| real_features: jax.Array = flax.struct.field(default_factory=lambda: jnp.array([], dtype=jnp.float32)) | ||
| fake_features: jax.Array = flax.struct.field(default_factory=lambda: jnp.array([], dtype=jnp.float32)) | ||
|
|
||
|
|
||
| @classmethod | ||
| def from_model_output( | ||
| cls, | ||
| real_features: jax.Array, | ||
| fake_features: jax.Array, | ||
| subsets: int = KID_DEFAULT_SUBSETS, | ||
|
||
| subset_size: int = KID_DEFAULT_SUBSET_SIZE, | ||
| degree: int = KID_DEFAULT_DEGREE, | ||
| gamma: float = KID_DEFAULT_GAMMA, | ||
| coef: float = KID_DEFAULT_COEF, | ||
| ): | ||
| # checks for the valid inputs | ||
| if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0: | ||
| raise ValueError("All parameters must be positive and non-zero.") | ||
| return cls( | ||
| subsets=subsets, | ||
| subset_size=subset_size, | ||
| degree=degree, | ||
| gamma=gamma, | ||
| coef=coef, | ||
| real_features=real_features, | ||
| fake_features=fake_features, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def empty(cls) -> "KernelInceptionDistanceMetric": | ||
| """ | ||
| Create an empty instance of KernelInceptionDistanceMetric. | ||
| """ | ||
| return cls( | ||
| subsets=KID_DEFAULT_SUBSETS, | ||
| subset_size=KID_DEFAULT_SUBSET_SIZE, | ||
| degree=KID_DEFAULT_DEGREE, | ||
| gamma=KID_DEFAULT_GAMMA, | ||
| coef=KID_DEFAULT_COEF, | ||
| real_features=jnp.empty((0, 2048), dtype=jnp.float32), | ||
| fake_features=jnp.empty((0, 2048), dtype=jnp.float32), | ||
| ) | ||
|
|
||
| def compute_mmd(self, f_real: jax.Array, f_fake: jax.Array) -> float: | ||
dhruvmalik007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Compute the Maximum Mean Discrepancy (MMD) using a polynomial kernel. | ||
| Args: | ||
| f_real: Features from real images. | ||
| f_fake: Features from fake images. | ||
| Returns: | ||
| MMD value in order to compute KID | ||
| """ | ||
| k_11 = self.polynomial_kernel(f_real, f_real) | ||
| k_22 = self.polynomial_kernel(f_fake, f_fake) | ||
| k_12 = self.polynomial_kernel(f_real, f_fake) | ||
|
|
||
| m = f_real.shape[0] | ||
| diag_x = jnp.diag(k_11) | ||
| diag_y = jnp.diag(k_22) | ||
|
|
||
| kt_xx_sum = jnp.sum(k_11, axis=-1) - diag_x | ||
| kt_yy_sum = jnp.sum(k_22, axis=-1) - diag_y | ||
| k_xy_sum = jnp.sum(k_12, axis=0) | ||
|
|
||
| value = (jnp.sum(kt_xx_sum) + jnp.sum(kt_yy_sum)) / (m * (m - 1)) | ||
| value -= 2 * jnp.sum(k_xy_sum) / (m**2) | ||
| return value | ||
|
|
||
| def polynomial_kernel(self, x: jax.Array, y: jax.Array) -> jax.Array: | ||
dhruvmalik007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Compute the polynomial kernel between two sets of features. | ||
| Args: | ||
| x: First set of features. | ||
| y: another set of features to be computed with. | ||
| Returns: | ||
| Polynomial kernel value of Array type . | ||
| """ | ||
| gamma = self.gamma if self.gamma is not None else 1.0 / x.shape[1] | ||
| return (jnp.dot(x, y.T) * gamma + self.coef) ** self.degree | ||
|
|
||
| def compute(self) -> jax.Array: | ||
dhruvmalik007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Compute the KID mean and standard deviation from accumulated features. | ||
| """ | ||
| if self.real_features.shape[0] < self.subset_size or self.fake_features.shape[0] < self.subset_size: | ||
| raise ValueError("Subset size must be smaller than the number of samples.") | ||
|
|
||
| master_key = random.PRNGKey(42) | ||
|
|
||
| kid_scores = [] | ||
| for i in range(self.subsets): | ||
| # Split the key for each iteration to ensure different random samples | ||
| key_real, key_fake = random.split(random.fold_in(master_key, i)) | ||
| real_indices = random.choice(key_real, self.real_features.shape[0], (self.subset_size,), replace=False) | ||
| fake_indices = random.choice(key_fake, self.fake_features.shape[0], (self.subset_size,), replace=False) | ||
|
|
||
| f_real_subset = self.real_features[real_indices] | ||
| f_fake_subset = self.fake_features[fake_indices] | ||
|
|
||
| kid = self.compute_mmd(f_real_subset, f_fake_subset) | ||
| kid_scores.append(kid) | ||
|
|
||
| kid_mean = jnp.mean(jnp.array(kid_scores)) | ||
| kid_std = jnp.std(jnp.array(kid_scores)) | ||
| return jnp.array([kid_mean, kid_std]) | ||
|
|
||
|
|
||
| def merge(self, other: "KernelInceptionDistanceMetric") -> "KernelInceptionDistanceMetric": | ||
| """ | ||
| Merge two KernelInceptionDistanceMetric instances. | ||
| Args: | ||
| other: Another instance of KernelInceptionDistanceMetric. | ||
| Returns: | ||
| A new instance of KernelInceptionDistanceMetric with combined features. | ||
| """ | ||
| return type(self)( | ||
| subsets=self.subsets, | ||
| subset_size=self.subset_size, | ||
| degree=self.degree, | ||
| gamma=self.gamma, | ||
| coef=self.coef, | ||
| real_features=jnp.concatenate([self.real_features, other.real_features]), | ||
| fake_features=jnp.concatenate([self.fake_features, other.fake_features]), | ||
| ) | ||
dhruvmalik007 marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| from absl.testing import absltest | ||
| import jax.numpy as jnp | ||
| from jax import random | ||
| from . import image_metrics | ||
|
|
||
|
|
||
| class KernelImageMetricsTest(absltest.TestCase): | ||
dhruvmalik007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Tests empty instantiation and merge of KID metric | ||
| def test_kernel_inception_distance_empty_and_merge(self): | ||
| empty1 = image_metrics.KernelInceptionDistanceMetric.empty() | ||
dhruvmalik007 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| empty2 = image_metrics.KernelInceptionDistanceMetric.empty() | ||
| merged = empty1.merge(empty2) | ||
| # Should still be empty and not error | ||
| self.assertEqual(merged.real_features.shape[0], 0) | ||
| self.assertEqual(merged.fake_features.shape[0], 0) | ||
|
|
||
| # Now merge with non-empty | ||
| key1, key2 = random.split(random.PRNGKey(99)) | ||
| real_features = random.normal(key1, shape=(10, 2048)) | ||
| fake_features = random.normal(key2, shape=(10, 2048)) | ||
| kid_nonempty = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, fake_features, subset_size=5 | ||
| ) | ||
| merged2 = kid_nonempty.merge(empty1) | ||
| self.assertEqual(merged2.real_features.shape[0], 10) | ||
| self.assertEqual(merged2.fake_features.shape[0], 10) | ||
|
|
||
| # Tests KID metric with default parameters on random features | ||
| def test_kernel_inception_distance_default_params(self): | ||
| key1, key2 = random.split(random.PRNGKey(42)) | ||
| real_features = random.normal(key1, shape=(100, 2048)) | ||
| fake_features = random.normal(key2, shape=(100, 2048)) | ||
|
|
||
| kid = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, | ||
| fake_features, | ||
| subset_size=50 # Using smaller subset size for testing | ||
| ) | ||
|
|
||
| result = kid.compute() | ||
| self.assertEqual(result.shape, (2,)) | ||
| self.assertTrue(result[0] >= 0) | ||
|
|
||
| # Tests that invalid parameters raise appropriate exceptions | ||
| def test_kernel_inception_distance_invalid_params(self): | ||
| key1, key2 = random.split(random.PRNGKey(44)) | ||
| real_features = random.normal(key1, shape=(100, 2048)) | ||
| fake_features = random.normal(key2, shape=(100, 2048)) | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, | ||
| fake_features, | ||
| subsets=-1, # Invalid | ||
| ) | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, | ||
| fake_features, | ||
| subset_size=0, # Invalid | ||
| ) | ||
|
|
||
| # Tests KID metric with very small sample sizes | ||
| def test_kernel_inception_distance_small_sample_size(self): | ||
| key1, key2 = random.split(random.PRNGKey(45)) | ||
| real_features = random.normal(key1, shape=(10, 2048)) | ||
| fake_features = random.normal(key2, shape=(10, 2048)) | ||
|
|
||
| kid = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, | ||
| fake_features, | ||
| subset_size=5, | ||
| ) | ||
| result = kid.compute() | ||
| self.assertEqual(result.shape, (2,)) | ||
|
|
||
| # Tests that identical feature sets produce KID values close to zero | ||
| def test_kernel_inception_distance_identical_sets(self): | ||
| key = random.PRNGKey(46) | ||
| features = random.normal(key, shape=(100, 2048)) | ||
|
|
||
| kid = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| features, | ||
| features, | ||
| subsets=50, | ||
| subset_size=50, | ||
| ) | ||
| result = kid.compute() | ||
| self.assertTrue(result[0] < 1e-3, f"Expected KID close to zero, got {result[0]}") | ||
|
|
||
| # Tests KID metric when the fake features exhibit mode collapse (low variance) | ||
| def test_kernel_inception_distance_mode_collapse(self): | ||
| key1, key2 = random.split(random.PRNGKey(47)) | ||
| real_features = random.normal(key1, shape=(100, 2048)) | ||
|
|
||
| base_feature = random.normal(key2, shape=(1, 2048)) | ||
| repeated_base = jnp.repeat(base_feature, 100, axis=0) | ||
| small_noise = random.normal(key2, shape=(100, 2048)) * 0.01 | ||
| fake_features = repeated_base + small_noise | ||
|
|
||
| kid = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, | ||
| fake_features, | ||
| subset_size=50 | ||
| ) | ||
| result = kid.compute() | ||
| self.assertTrue(result[0] > 0.0) | ||
|
|
||
| # Tests KID metric's sensitivity to outliers in the feature distributions | ||
| def test_kernel_inception_distance_outliers(self): | ||
| key1, key2, key3 = random.split(random.PRNGKey(48), 3) | ||
| real_features = random.normal(key1, shape=(100, 2048)) | ||
| fake_features = random.normal(key2, shape=(100, 2048)) | ||
|
|
||
| outliers = random.normal(key3, shape=(10, 2048)) * 10.0 | ||
| fake_features_with_outliers = fake_features.at[:10].set(outliers) | ||
|
|
||
| kid_normal = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, fake_features, subset_size=50 # Using smaller subset size for testing | ||
| ) | ||
| kid_with_outliers = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, fake_features_with_outliers, subset_size=50 # Using smaller subset size for testing | ||
| ) | ||
|
|
||
| result_normal = kid_normal.compute() | ||
| result_with_outliers = kid_with_outliers.compute() | ||
|
|
||
| self.assertNotEqual(result_normal[0], result_with_outliers[0]) | ||
|
|
||
| # Tests KID metric with different subset configurations to evaluate stability | ||
| def test_kernel_inception_distance_different_subset_sizes(self): | ||
| key1, key2 = random.split(random.PRNGKey(49)) | ||
| real_features = random.normal(key1, shape=(200, 2048)) | ||
| fake_features = random.normal(key2, shape=(200, 2048)) | ||
|
|
||
| kid_small_subsets = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, fake_features, subsets=10, subset_size=10 | ||
| ) | ||
| kid_large_subsets = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, fake_features, subsets=5, subset_size=100 | ||
| ) | ||
|
|
||
| result_small = kid_small_subsets.compute() | ||
| result_large = kid_large_subsets.compute() | ||
|
|
||
| self.assertEqual(result_small.shape, (2,)) | ||
| self.assertEqual(result_large.shape, (2,)) | ||
|
|
||
| # Tests KID metric's ability to differentiate between similar and dissimilar distributions | ||
| def test_kernel_inception_distance_different_distributions(self): | ||
| key1, key2 = random.split(random.PRNGKey(50)) | ||
|
|
||
| real_features = random.normal(key1, shape=(100, 2048)) | ||
|
|
||
| mean = 0.5 | ||
| std = 2.0 | ||
| fake_features = mean + std * random.normal(key2, shape=(100, 2048)) | ||
|
|
||
| kid = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, fake_features, subset_size=50 # Using smaller subset size for testing | ||
| ) | ||
| result = kid.compute() | ||
|
|
||
| self.assertTrue(result[0] > 0.0) | ||
| key3 = random.PRNGKey(51) | ||
| another_real_features = random.normal(key3, shape=(100, 2048)) | ||
|
|
||
| kid_same_dist = image_metrics.KernelInceptionDistanceMetric.from_model_output( | ||
| real_features, another_real_features, subset_size=50 # Using smaller subset size for testing | ||
| ) | ||
| result_same_dist = kid_same_dist.compute() | ||
|
|
||
| self.assertTrue(result[0] > result_same_dist[0]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove this line since image_metrics was added as part of previous PR in line 17.