Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2af982a
feat: implementing #65 :
dhruvmalik007 May 6, 2025
f23efa4
feat: add Kernel Inception Distance Metric (KIDM) to metrax :
dhruvmalik007 May 7, 2025
c4bea1b
oops: forgot the contribution of the merge and empty methods
dhruvmalik007 May 7, 2025
85f88a8
refactor: adapted the PR changes asked by @jshin1394
dhruvmalik007 May 8, 2025
87b736a
removing the logging (added to test the errors before)
dhruvmalik007 May 8, 2025
f13d26e
Merge branch 'main' into dhruvmalik007/add-kid-metric
jshin1394 May 8, 2025
c1c30be
stashing local changes before pulling the origin branch changesinto…
dhruvmalik007 May 8, 2025
24b16d0
Merge branch 'dhruvmalik007/add-kid-metric' of https://github.com/dhr…
dhruvmalik007 May 8, 2025
88f9ec6
refactor:
dhruvmalik007 May 8, 2025
6bdd189
all test passing
dhruvmalik007 May 8, 2025
4cb3d65
refactor: test uncluding ersthwhile empty and merge functions removed
dhruvmalik007 May 9, 2025
f9d5954
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 May 9, 2025
965eab5
add the change suggestions @jshin1394 .
dhruvmalik007 May 12, 2025
075e68d
minor refactor:
dhruvmalik007 May 15, 2025
878876c
merging from the remote to the local
dhruvmalik007 May 15, 2025
d40ee0c
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 May 17, 2025
d3df741
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 Jun 2, 2025
8969ce9
feat: replacing the KID metrics computation on actual and fake derive…
dhruvmalik007 Jun 3, 2025
74cf343
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 Oct 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ keras-rs
pytest
rouge-score
scikit-learn
tensorflow
tensorflow
torchmetrics
torch-fidelity
4 changes: 2 additions & 2 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from metrax import nlp_metrics
from metrax import ranking_metrics
from metrax import regression_metrics

AUCPR = classification_metrics.AUCPR
AUCROC = classification_metrics.AUCROC
Accuracy = classification_metrics.Accuracy
Expand All @@ -42,7 +41,7 @@
RougeN = nlp_metrics.RougeN
SSIM = image_metrics.SSIM
WER = nlp_metrics.WER

KID = image_metrics.KID
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's keep alphabetical order.


__all__ = [
"AUCPR",
Expand All @@ -68,4 +67,5 @@
"RougeN",
"SSIM",
"WER",
"KID",
]
124 changes: 122 additions & 2 deletions src/metrax/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@

"""A collection of different metrics for image models."""

import jax.numpy as jnp
from jax import random, lax
import flax
import jax
from jax import lax
import jax.numpy as jnp
from clu import metrics as clu_metrics
from metrax import base

KID_DEFAULT_SUBSETS = 100
KID_DEFAULT_SUBSET_SIZE = 1000
KID_DEFAULT_DEGREE = 3
KID_DEFAULT_GAMMA = None
KID_DEFAULT_COEF = 1.0


def _gaussian_kernel1d(sigma, radius):
r"""Generates a 1D normalized Gaussian kernel.
Expand Down Expand Up @@ -53,6 +60,119 @@ def _gaussian_kernel1d(sigma, radius):
return phi_x


def _polynomial_kernel(x: jax.Array, y: jax.Array, degree: int, gamma: float, coef: float) -> jax.Array:
"""
Compute the polynomial kernel between two sets of features.
Args:
x: First set of features.
y: Another set of features to be computed with.
degree: Degree of the polynomial kernel.
gamma: Kernel coefficient for the polynomial kernel. If None, uses 1 / x.shape[1].
coef: Independent term in the polynomial kernel.
Returns:
Polynomial kernel value of Array type.
"""
if gamma is None:
gamma = 1.0 / x.shape[1]
return (jnp.dot(x, y.T) * gamma + coef) ** degree


@flax.struct.dataclass
class KID(base.Average):
r"""Computes Kernel Inception Distance (KID) for asses quality of generated images.
KID is a metric used to evaluate the quality of generated images by comparing
the distribution of generated images to the distribution of real images.
It is based on the Inception Score (IS) and uses a kernelized version of the
Maximum Mean Discrepancy (MMD) to measure the distance between two
distributions.

The KID is computed as follows:

.. math::
KID = MMD(f_{real}, f_{fake})^2

Where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features
from real and fake images, see `kid ref1`_ for more details. In particular, calculating the MMD requires the
evaluation of a polynomial kernel function :math:`k`.

.. math::
k(x,y) = (\gamma * x^T y + coef)^{degree}

Args:
subsets: Number of subsets to use for KID calculation.
subset_size: Number of samples in each subset.
degree: Degree of the polynomial kernel.
gamma: Kernel coefficient for the polynomial kernel.
coef: Independent term in the polynomial kernel.
"""

subsets: int = KID_DEFAULT_SUBSETS
subset_size: int = KID_DEFAULT_SUBSET_SIZE
degree: int = KID_DEFAULT_DEGREE
gamma: float = KID_DEFAULT_GAMMA
coef: float = KID_DEFAULT_COEF

@classmethod
def from_model_output(
cls,
real_features: jax.Array,
fake_features: jax.Array,
subsets: int = KID_DEFAULT_SUBSETS,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we not define member variables and just replace with default values here?
subsets: int = 1000,
subset_size: int = 100,
degree: int = 3,
gamma: float = 1.0,
coef: float = 1.0,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. It was just in case if some model_eval.py evaluation script might import this parameter .

subset_size: int = KID_DEFAULT_SUBSET_SIZE,
degree: int = KID_DEFAULT_DEGREE,
gamma: float = KID_DEFAULT_GAMMA,
coef: float = KID_DEFAULT_COEF,
):
"""
Create a KID instance from model output.
also it computes average output and then store it in the instance.
"""
if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0:
raise ValueError("All parameters must be positive and non-zero.")
# Compute KID for this batch and then store the aggregated response.
if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size:
raise ValueError("Subset size must be smaller than the number of samples.")
master_key = random.PRNGKey(42)
kid_scores = []
for i in range(subsets):
key_real, key_fake = random.split(random.fold_in(master_key, i))
real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False)
fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False)
f_real_subset = real_features[real_indices]
f_fake_subset = fake_features[fake_indices]
kid = cls._compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef)
kid_scores.append(kid)
kid_mean = jnp.mean(jnp.array(kid_scores))

return cls(
total=kid_mean,
count=1.0,
subsets=subsets,
subset_size=subset_size,
degree=degree,
gamma=gamma,
coef=coef,
)


@staticmethod
def _compute_mmd_static(f_real: jax.Array, f_fake: jax.Array, degree: int, gamma: float, coef: float) -> float:
k_11 = _polynomial_kernel(f_real, f_real, degree, gamma, coef)
k_22 = _polynomial_kernel(f_fake, f_fake, degree, gamma, coef)
k_12 = _polynomial_kernel(f_real, f_fake, degree, gamma, coef)

m = f_real.shape[0]
diag_x = jnp.diag(k_11)
diag_y = jnp.diag(k_22)

kt_xx_sum = jnp.sum(k_11, axis=-1) - diag_x
kt_yy_sum = jnp.sum(k_22, axis=-1) - diag_y
k_xy_sum = jnp.sum(k_12, axis=0)

value = (jnp.sum(kt_xx_sum) + jnp.sum(kt_yy_sum)) / (m * (m - 1))
value -= 2 * jnp.sum(k_xy_sum) / (m**2)
return value

@flax.struct.dataclass
class SSIM(base.Average):
r"""SSIM (Structural Similarity Index Measure) Metric.
Expand Down
Loading