Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2af982a
feat: implementing #65 :
dhruvmalik007 May 6, 2025
f23efa4
feat: add Kernel Inception Distance Metric (KIDM) to metrax :
dhruvmalik007 May 7, 2025
c4bea1b
oops: forgot the contribution of the merge and empty methods
dhruvmalik007 May 7, 2025
85f88a8
refactor: adapted the PR changes asked by @jshin1394
dhruvmalik007 May 8, 2025
87b736a
removing the logging (added to test the errors before)
dhruvmalik007 May 8, 2025
f13d26e
Merge branch 'main' into dhruvmalik007/add-kid-metric
jshin1394 May 8, 2025
c1c30be
stashing local changes before pulling the origin branch changesinto…
dhruvmalik007 May 8, 2025
24b16d0
Merge branch 'dhruvmalik007/add-kid-metric' of https://github.com/dhr…
dhruvmalik007 May 8, 2025
88f9ec6
refactor:
dhruvmalik007 May 8, 2025
6bdd189
all test passing
dhruvmalik007 May 8, 2025
4cb3d65
refactor: test uncluding ersthwhile empty and merge functions removed
dhruvmalik007 May 9, 2025
f9d5954
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 May 9, 2025
965eab5
add the change suggestions @jshin1394 .
dhruvmalik007 May 12, 2025
075e68d
minor refactor:
dhruvmalik007 May 15, 2025
878876c
merging from the remote to the local
dhruvmalik007 May 15, 2025
d40ee0c
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 May 17, 2025
d3df741
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 Jun 2, 2025
8969ce9
feat: replacing the KID metrics computation on actual and fake derive…
dhruvmalik007 Jun 3, 2025
74cf343
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 Oct 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from metrax import nlp_metrics
from metrax import ranking_metrics
from metrax import regression_metrics

from metrax import image_metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this line since image_metrics was added as part of previous PR in line 17.

AUCPR = classification_metrics.AUCPR
AUCROC = classification_metrics.AUCROC
Accuracy = classification_metrics.Accuracy
Expand All @@ -39,6 +39,7 @@
RougeL = nlp_metrics.RougeL
RougeN = nlp_metrics.RougeN
WER = nlp_metrics.WER
KIDM = image_metrics.KernelInceptionDistanceMetric


__all__ = [
Expand All @@ -63,4 +64,5 @@
"RougeL",
"RougeN",
"WER",
"KIDM",
]
189 changes: 189 additions & 0 deletions src/metrax/image_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
## credits to the https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/image/kid.py for the refernece implementation.


import jax.numpy as jnp
from jax import random
import flax
import jax
from clu import metrics as clu_metrics

KID_DEFAULT_SUBSETS = 100
KID_DEFAULT_SUBSET_SIZE = 1000
KID_DEFAULT_DEGREE = 3
KID_DEFAULT_GAMMA = None
KID_DEFAULT_COEF = 1.0


@flax.struct.dataclass
class KernelInceptionDistanceMetric(clu_metrics.Metric):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are your thoughts on having this inherit base.Average from the metrax package? It will keep track of the sum of total KID and the sum of total count and compute will yield the average KID of the data.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed I have done that and during my prompting , I realised to implement it in the following function:

    @classmethod
    def from_model_output(
        cls,
        real_features: jax.Array,
        fake_features: jax.Array,
        subsets: int = KID_DEFAULT_SUBSETS,
        subset_size: int = KID_DEFAULT_SUBSET_SIZE,
        degree: int = KID_DEFAULT_DEGREE,
        gamma: float = KID_DEFAULT_GAMMA,
        coef: float = KID_DEFAULT_COEF,
    ):
        # checks for the valid inputs
        if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0:
            raise ValueError("All parameters must be positive and non-zero.")
        # Compute KID for this batch
        if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size:
            raise ValueError("Subset size must be smaller than the number of samples.")
        master_key = random.PRNGKey(42)
        kid_scores = []
        for i in range(subsets):
            key_real, key_fake = random.split(random.fold_in(master_key, i))
            real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False)
            fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False)
            f_real_subset = real_features[real_indices]
            f_fake_subset = fake_features[fake_indices]
            kid = cls.__compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef)
            kid_scores.append(kid)
        kid_mean = jnp.mean(jnp.array(kid_scores))
        # Accumulate sum and count for averaging
        return cls(
            total=kid_mean,
            count=1.0,
            subsets=subsets,
            subset_size=subset_size,
            degree=degree,
            gamma=gamma,
            coef=coef,
        )

r"""Computes Kernel Inception Distance (KID) for asses quality of generated images.
KID is a metric used to evaluate the quality of generated images by comparing
the distribution of generated images to the distribution of real images.
It is based on the Inception Score (IS) and uses a kernelized version of the
Maximum Mean Discrepancy (MMD) to measure the distance between two
distributions.
The KID is computed as follows:
.. math::
KID = MMD(f_{real}, f_{fake})^2
Where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features
from real and fake images, see `kid ref1`_ for more details. In particular, calculating the MMD requires the
evaluation of a polynomial kernel function :math:`k`.
.. math::
k(x,y) = (\gamma * x^T y + coef)^{degree}
Args:
subsets: Number of subsets to use for KID calculation.
subset_size: Number of samples in each subset.
degree: Degree of the polynomial kernel.
gamma: Kernel coefficient for the polynomial kernel.
coef: Independent term in the polynomial kernel.
"""

subsets: int
subset_size: int
degree: int
gamma: float
coef: float

real_features: jax.Array = flax.struct.field(default_factory=lambda: jnp.array([], dtype=jnp.float32))
fake_features: jax.Array = flax.struct.field(default_factory=lambda: jnp.array([], dtype=jnp.float32))


@classmethod
def from_model_output(
cls,
real_features: jax.Array,
fake_features: jax.Array,
subsets: int = KID_DEFAULT_SUBSETS,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we not define member variables and just replace with default values here?
subsets: int = 1000,
subset_size: int = 100,
degree: int = 3,
gamma: float = 1.0,
coef: float = 1.0,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. It was just in case if some model_eval.py evaluation script might import this parameter .

subset_size: int = KID_DEFAULT_SUBSET_SIZE,
degree: int = KID_DEFAULT_DEGREE,
gamma: float = KID_DEFAULT_GAMMA,
coef: float = KID_DEFAULT_COEF,
):
# checks for the valid inputs
if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0:
raise ValueError("All parameters must be positive and non-zero.")
return cls(
subsets=subsets,
subset_size=subset_size,
degree=degree,
gamma=gamma,
coef=coef,
real_features=real_features,
fake_features=fake_features,
)

@classmethod
def empty(cls) -> "KernelInceptionDistanceMetric":
"""
Create an empty instance of KernelInceptionDistanceMetric.
"""
return cls(
subsets=KID_DEFAULT_SUBSETS,
subset_size=KID_DEFAULT_SUBSET_SIZE,
degree=KID_DEFAULT_DEGREE,
gamma=KID_DEFAULT_GAMMA,
coef=KID_DEFAULT_COEF,
real_features=jnp.empty((0, 2048), dtype=jnp.float32),
fake_features=jnp.empty((0, 2048), dtype=jnp.float32),
)

def compute_mmd(self, f_real: jax.Array, f_fake: jax.Array) -> float:
"""
Compute the Maximum Mean Discrepancy (MMD) using a polynomial kernel.
Args:
f_real: Features from real images.
f_fake: Features from fake images.
Returns:
MMD value in order to compute KID
"""
k_11 = self.polynomial_kernel(f_real, f_real)
k_22 = self.polynomial_kernel(f_fake, f_fake)
k_12 = self.polynomial_kernel(f_real, f_fake)

m = f_real.shape[0]
diag_x = jnp.diag(k_11)
diag_y = jnp.diag(k_22)

kt_xx_sum = jnp.sum(k_11, axis=-1) - diag_x
kt_yy_sum = jnp.sum(k_22, axis=-1) - diag_y
k_xy_sum = jnp.sum(k_12, axis=0)

value = (jnp.sum(kt_xx_sum) + jnp.sum(kt_yy_sum)) / (m * (m - 1))
value -= 2 * jnp.sum(k_xy_sum) / (m**2)
return value

def polynomial_kernel(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""
Compute the polynomial kernel between two sets of features.
Args:
x: First set of features.
y: another set of features to be computed with.
Returns:
Polynomial kernel value of Array type .
"""
gamma = self.gamma if self.gamma is not None else 1.0 / x.shape[1]
return (jnp.dot(x, y.T) * gamma + self.coef) ** self.degree

def compute(self) -> jax.Array:
"""
Compute the KID mean and standard deviation from accumulated features.
"""
if self.real_features.shape[0] < self.subset_size or self.fake_features.shape[0] < self.subset_size:
raise ValueError("Subset size must be smaller than the number of samples.")

master_key = random.PRNGKey(42)

kid_scores = []
for i in range(self.subsets):
# Split the key for each iteration to ensure different random samples
key_real, key_fake = random.split(random.fold_in(master_key, i))
real_indices = random.choice(key_real, self.real_features.shape[0], (self.subset_size,), replace=False)
fake_indices = random.choice(key_fake, self.fake_features.shape[0], (self.subset_size,), replace=False)

f_real_subset = self.real_features[real_indices]
f_fake_subset = self.fake_features[fake_indices]

kid = self.compute_mmd(f_real_subset, f_fake_subset)
kid_scores.append(kid)

kid_mean = jnp.mean(jnp.array(kid_scores))
kid_std = jnp.std(jnp.array(kid_scores))
return jnp.array([kid_mean, kid_std])


def merge(self, other: "KernelInceptionDistanceMetric") -> "KernelInceptionDistanceMetric":
"""
Merge two KernelInceptionDistanceMetric instances.
Args:
other: Another instance of KernelInceptionDistanceMetric.
Returns:
A new instance of KernelInceptionDistanceMetric with combined features.
"""
return type(self)(
subsets=self.subsets,
subset_size=self.subset_size,
degree=self.degree,
gamma=self.gamma,
coef=self.coef,
real_features=jnp.concatenate([self.real_features, other.real_features]),
fake_features=jnp.concatenate([self.fake_features, other.fake_features]),
)
175 changes: 175 additions & 0 deletions src/metrax/image_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from absl.testing import absltest
import jax.numpy as jnp
from jax import random
from . import image_metrics


class KernelImageMetricsTest(absltest.TestCase):

# Tests empty instantiation and merge of KID metric
def test_kernel_inception_distance_empty_and_merge(self):
empty1 = image_metrics.KernelInceptionDistanceMetric.empty()
empty2 = image_metrics.KernelInceptionDistanceMetric.empty()
merged = empty1.merge(empty2)
# Should still be empty and not error
self.assertEqual(merged.real_features.shape[0], 0)
self.assertEqual(merged.fake_features.shape[0], 0)

# Now merge with non-empty
key1, key2 = random.split(random.PRNGKey(99))
real_features = random.normal(key1, shape=(10, 2048))
fake_features = random.normal(key2, shape=(10, 2048))
kid_nonempty = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features, fake_features, subset_size=5
)
merged2 = kid_nonempty.merge(empty1)
self.assertEqual(merged2.real_features.shape[0], 10)
self.assertEqual(merged2.fake_features.shape[0], 10)

# Tests KID metric with default parameters on random features
def test_kernel_inception_distance_default_params(self):
key1, key2 = random.split(random.PRNGKey(42))
real_features = random.normal(key1, shape=(100, 2048))
fake_features = random.normal(key2, shape=(100, 2048))

kid = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features,
fake_features,
subset_size=50 # Using smaller subset size for testing
)

result = kid.compute()
self.assertEqual(result.shape, (2,))
self.assertTrue(result[0] >= 0)

# Tests that invalid parameters raise appropriate exceptions
def test_kernel_inception_distance_invalid_params(self):
key1, key2 = random.split(random.PRNGKey(44))
real_features = random.normal(key1, shape=(100, 2048))
fake_features = random.normal(key2, shape=(100, 2048))

with self.assertRaises(ValueError):
image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features,
fake_features,
subsets=-1, # Invalid
)

with self.assertRaises(ValueError):
image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features,
fake_features,
subset_size=0, # Invalid
)

# Tests KID metric with very small sample sizes
def test_kernel_inception_distance_small_sample_size(self):
key1, key2 = random.split(random.PRNGKey(45))
real_features = random.normal(key1, shape=(10, 2048))
fake_features = random.normal(key2, shape=(10, 2048))

kid = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features,
fake_features,
subset_size=5,
)
result = kid.compute()
self.assertEqual(result.shape, (2,))

# Tests that identical feature sets produce KID values close to zero
def test_kernel_inception_distance_identical_sets(self):
key = random.PRNGKey(46)
features = random.normal(key, shape=(100, 2048))

kid = image_metrics.KernelInceptionDistanceMetric.from_model_output(
features,
features,
subsets=50,
subset_size=50,
)
result = kid.compute()
self.assertTrue(result[0] < 1e-3, f"Expected KID close to zero, got {result[0]}")

# Tests KID metric when the fake features exhibit mode collapse (low variance)
def test_kernel_inception_distance_mode_collapse(self):
key1, key2 = random.split(random.PRNGKey(47))
real_features = random.normal(key1, shape=(100, 2048))

base_feature = random.normal(key2, shape=(1, 2048))
repeated_base = jnp.repeat(base_feature, 100, axis=0)
small_noise = random.normal(key2, shape=(100, 2048)) * 0.01
fake_features = repeated_base + small_noise

kid = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features,
fake_features,
subset_size=50
)
result = kid.compute()
self.assertTrue(result[0] > 0.0)

# Tests KID metric's sensitivity to outliers in the feature distributions
def test_kernel_inception_distance_outliers(self):
key1, key2, key3 = random.split(random.PRNGKey(48), 3)
real_features = random.normal(key1, shape=(100, 2048))
fake_features = random.normal(key2, shape=(100, 2048))

outliers = random.normal(key3, shape=(10, 2048)) * 10.0
fake_features_with_outliers = fake_features.at[:10].set(outliers)

kid_normal = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features, fake_features, subset_size=50 # Using smaller subset size for testing
)
kid_with_outliers = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features, fake_features_with_outliers, subset_size=50 # Using smaller subset size for testing
)

result_normal = kid_normal.compute()
result_with_outliers = kid_with_outliers.compute()

self.assertNotEqual(result_normal[0], result_with_outliers[0])

# Tests KID metric with different subset configurations to evaluate stability
def test_kernel_inception_distance_different_subset_sizes(self):
key1, key2 = random.split(random.PRNGKey(49))
real_features = random.normal(key1, shape=(200, 2048))
fake_features = random.normal(key2, shape=(200, 2048))

kid_small_subsets = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features, fake_features, subsets=10, subset_size=10
)
kid_large_subsets = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features, fake_features, subsets=5, subset_size=100
)

result_small = kid_small_subsets.compute()
result_large = kid_large_subsets.compute()

self.assertEqual(result_small.shape, (2,))
self.assertEqual(result_large.shape, (2,))

# Tests KID metric's ability to differentiate between similar and dissimilar distributions
def test_kernel_inception_distance_different_distributions(self):
key1, key2 = random.split(random.PRNGKey(50))

real_features = random.normal(key1, shape=(100, 2048))

mean = 0.5
std = 2.0
fake_features = mean + std * random.normal(key2, shape=(100, 2048))

kid = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features, fake_features, subset_size=50 # Using smaller subset size for testing
)
result = kid.compute()

self.assertTrue(result[0] > 0.0)
key3 = random.PRNGKey(51)
another_real_features = random.normal(key3, shape=(100, 2048))

kid_same_dist = image_metrics.KernelInceptionDistanceMetric.from_model_output(
real_features, another_real_features, subset_size=50 # Using smaller subset size for testing
)
result_same_dist = kid_same_dist.compute()

self.assertTrue(result[0] > result_same_dist[0])
Loading