Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2af982a
feat: implementing #65 :
dhruvmalik007 May 6, 2025
f23efa4
feat: add Kernel Inception Distance Metric (KIDM) to metrax :
dhruvmalik007 May 7, 2025
c4bea1b
oops: forgot the contribution of the merge and empty methods
dhruvmalik007 May 7, 2025
85f88a8
refactor: adapted the PR changes asked by @jshin1394
dhruvmalik007 May 8, 2025
87b736a
removing the logging (added to test the errors before)
dhruvmalik007 May 8, 2025
f13d26e
Merge branch 'main' into dhruvmalik007/add-kid-metric
jshin1394 May 8, 2025
c1c30be
stashing local changes before pulling the origin branch changesinto…
dhruvmalik007 May 8, 2025
24b16d0
Merge branch 'dhruvmalik007/add-kid-metric' of https://github.com/dhr…
dhruvmalik007 May 8, 2025
88f9ec6
refactor:
dhruvmalik007 May 8, 2025
6bdd189
all test passing
dhruvmalik007 May 8, 2025
4cb3d65
refactor: test uncluding ersthwhile empty and merge functions removed
dhruvmalik007 May 9, 2025
f9d5954
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 May 9, 2025
965eab5
add the change suggestions @jshin1394 .
dhruvmalik007 May 12, 2025
075e68d
minor refactor:
dhruvmalik007 May 15, 2025
878876c
merging from the remote to the local
dhruvmalik007 May 15, 2025
d40ee0c
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 May 17, 2025
d3df741
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 Jun 2, 2025
8969ce9
feat: replacing the KID metrics computation on actual and fake derive…
dhruvmalik007 Jun 3, 2025
74cf343
Merge branch 'main' into dhruvmalik007/add-kid-metric
dhruvmalik007 Oct 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ rouge-score
scikit-learn
tensorflow
torchmetrics
torch-fidelity
4 changes: 2 additions & 2 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from metrax import nlp_metrics
from metrax import ranking_metrics
from metrax import regression_metrics

AUCPR = classification_metrics.AUCPR
AUCROC = classification_metrics.AUCROC
Accuracy = classification_metrics.Accuracy
Expand All @@ -30,6 +29,7 @@
Dice = image_metrics.Dice
FBetaScore = classification_metrics.FBetaScore
IoU = image_metrics.IoU
KID = image_metrics.KID
MAE = regression_metrics.MAE
MRR = ranking_metrics.MRR
MSE = regression_metrics.MSE
Expand All @@ -48,7 +48,6 @@
SSIM = image_metrics.SSIM
WER = nlp_metrics.WER


__all__ = [
"AUCPR",
"AUCROC",
Expand Down Expand Up @@ -77,4 +76,5 @@
"SNR",
"SSIM",
"WER",
"KID",
]
297 changes: 291 additions & 6 deletions src/metrax/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

"""A collection of different metrics for image models."""

from clu import metrics as clu_metrics
import jax.numpy as jnp
from jax import random, lax
import flax
import jax
from jax import lax
import jax.numpy as jnp
from metrax import base

from clu import metrics as clu_metrics

def _gaussian_kernel1d(sigma, radius):
r"""Generates a 1D normalized Gaussian kernel.
Expand Down Expand Up @@ -54,6 +52,261 @@ def _gaussian_kernel1d(sigma, radius):
return phi_x


def _polynomial_kernel(x: jax.Array, y: jax.Array, degree: int, gamma: float, coef: float) -> jax.Array:
"""
Compute the polynomial kernel between two sets of features.
Args:
x: First set of features.
y: Another set of features to be computed with.
degree: Degree of the polynomial kernel.
gamma: Kernel coefficient for the polynomial kernel. If None, uses 1 / x.shape[1].
coef: Independent term in the polynomial kernel.
Returns:
Polynomial kernel value of Array type.
"""
if gamma is None:
gamma = 1.0 / x.shape[1]

# Compute dot product and apply kernel transformation
dot_product = jnp.dot(x, y.T)

# Normalize the dot product to prevent overflow
dot_product = dot_product / jnp.sqrt(x.shape[1]) # Normalize by sqrt of feature dimension

# Apply gamma scaling with clipping to prevent extreme values
scaled_product = jnp.clip(dot_product * gamma + coef, -10.0, 10.0)
kernel_value = scaled_product ** degree

# Handle potential numerical issues
kernel_value = jnp.where(jnp.isfinite(kernel_value), kernel_value, 0.0)

return kernel_value


@flax.struct.dataclass
class KID(base.Average):
r"""Computes Kernel Inception Distance (KID) for asses quality of generated images.
KID is a metric used to evaluate the quality of generated images by comparing
the distribution of generated images to the distribution of real images.
It is based on the Inception Score (IS) and uses a kernelized version of the
Maximum Mean Discrepancy (MMD) to measure the distance between two
distributions.

The KID is computed as follows:

.. math::
KID = MMD(f_{real}, f_{fake})^2

Where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features
from real and fake images, see `kid ref1`_ for more details. In particular, calculating the MMD requires the
evaluation of a polynomial kernel function :math:`k`.

.. math::
k(x,y) = (\gamma * x^T y + coef)^{degree}

Args:
subsets: Number of subsets to use for KID calculation.
subset_size: Number of samples in each subset.
degree: Degree of the polynomial kernel.
gamma: Kernel coefficient for the polynomial kernel.
coef: Independent term in the polynomial kernel.
kid_std: The computed KID standard deviation.
"""
subsets: int = 100
subset_size: int = 1000
degree: int = 3
gamma: float = None
coef: float = 1.0
kid_std: float = 0.0

@staticmethod
def _compute_mmd_static(f_real: jax.Array, f_fake: jax.Array, degree: int, gamma: float, coef: float) -> float:
k_11 = _polynomial_kernel(f_real, f_real, degree, gamma, coef)
k_22 = _polynomial_kernel(f_fake, f_fake, degree, gamma, coef)
k_12 = _polynomial_kernel(f_real, f_fake, degree, gamma, coef)

m = f_real.shape[0]

# Handle edge case where m <= 1
if m <= 1:
return jnp.array(0.0)

diag_x = jnp.diag(k_11)
diag_y = jnp.diag(k_22)

kt_xx_sum = jnp.sum(k_11, axis=-1) - diag_x
kt_yy_sum = jnp.sum(k_22, axis=-1) - diag_y
k_xy_sum = jnp.sum(k_12, axis=0)

value = (jnp.sum(kt_xx_sum) + jnp.sum(kt_yy_sum)) / (m * (m - 1))
value -= 2 * jnp.sum(k_xy_sum) / (m**2)

# Ensure the result is finite
value = jnp.where(jnp.isfinite(value), value, 0.0)
return value

@classmethod
def from_model_output(
cls,
real_image: jax.Array,
fake_image: jax.Array,
subsets: int = 100,
subset_size: int = 1000,
degree: int = 3,
gamma: float = None,
coef: float = 1.0,
):
"""
Create a KID instance from model output.
Also computes the average KID value and stores it in the instance.

Args:
real_image (jax.Array):
Real images. Shape: (N, H, W, C), where N is the number of real images,
H and W are height and width, and C is the number of channels.
fake_image (jax.Array):
Generated (fake) images. Shape: (M, H, W, C), where M is the number of fake images,
H and W are height and width, and C is the number of channels.
subsets (int, optional):
Number of random subsets to use for KID calculation. Default is 100.
subset_size (int, optional):
Number of samples in each subset. Must be <= min(N, M). Default is 1000.
degree (int, optional):
Degree of the polynomial kernel. Default is 3.
gamma (float, optional):
Kernel coefficient for the polynomial kernel. If None, uses 1 / feature_dim. Default is None.
coef (float, optional):
Independent term in the polynomial kernel. Default is 1.0.

Returns:
KID: An instance of the KID metric with the computed mean KID value for the given images.

Raises:
ValueError: If any parameter is non-positive, or if subset_size is greater than the number of samples in real or fake images.
"""
if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0:
raise ValueError("All parameters must be positive and non-zero.")

# Extract features from images (or use as-is if already features)
real_features = _extract_image_features(real_image)
fake_features = _extract_image_features(fake_image)

# Compute KID for this batch and then store the aggregated response.
if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size:
raise ValueError("Subset size must be smaller than the number of samples.")
master_key = random.PRNGKey(42)
kid_scores = []
for i in range(subsets):
key_real, key_fake = random.split(random.fold_in(master_key, i))
real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False)
fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False)
f_real_subset = real_features[real_indices]
f_fake_subset = fake_features[fake_indices]
kid = cls._compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef)
kid_scores.append(kid)
kid_scores = jnp.array(kid_scores)
kid_mean = jnp.mean(kid_scores)
# Handle edge case where std could be inf/nan
kid_std = jnp.std(kid_scores)
kid_std = jnp.where(jnp.isfinite(kid_std), kid_std, 0.0)

return cls(
total=kid_mean,
count=1.0,
subsets=subsets,
subset_size=subset_size,
degree=degree,
gamma=gamma,
coef=coef,
kid_std=kid_std,
)

@classmethod
def from_features(
cls,
real_features: jax.Array,
fake_features: jax.Array,
subsets: int = 100,
subset_size: int = 1000,
degree: int = 3,
gamma: float = None,
coef: float = 1.0,
):
"""
Create a KID instance from pre-extracted features (backward compatibility).
Args:
real_features (jax.Array):
Feature representations of real images. Shape: (N, D), where N is the number of real images and D is the feature dimension.
fake_features (jax.Array):
Feature representations of generated (fake) images. Shape: (M, D), where M is the number of fake images and D is the feature dimension.
subsets (int, optional):
Number of random subsets to use for KID calculation. Default is 100.
subset_size (int, optional):
Number of samples in each subset. Must be <= min(N, M). Default is 1000.
degree (int, optional):
Degree of the polynomial kernel. Default is 3.
gamma (float, optional):
Kernel coefficient for the polynomial kernel. If None, uses 1 / D. Default is None.
coef (float, optional):
Independent term in the polynomial kernel. Default is 1.0.

Returns:
KID: An instance of the KID metric with the computed mean KID value for the given features.
"""
if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0:
raise ValueError("All parameters must be positive and non-zero.")

if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size:
raise ValueError("Subset size must be smaller than the number of samples.")

master_key = random.PRNGKey(42)
kid_scores = []
for i in range(subsets):
key_real, key_fake = random.split(random.fold_in(master_key, i))
real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False)
fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False)
f_real_subset = real_features[real_indices]
f_fake_subset = fake_features[fake_indices]
kid = cls._compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef)
kid_scores.append(kid)
kid_scores = jnp.array(kid_scores)
kid_mean = jnp.mean(kid_scores)
# Handle edge case where std could be inf/nan
kid_std = jnp.std(kid_scores)
kid_std = jnp.where(jnp.isfinite(kid_std), kid_std, 0.0)

return cls(
total=kid_mean,
count=1.0,
subsets=subsets,
subset_size=subset_size,
degree=degree,
gamma=gamma,
coef=coef,
kid_std=kid_std,
)

def compute(self):
"""
Compute the final KID metric mean value (for compatibility with the torchmetrics interface).
Returns:
float: The mean KID value
"""
# For base.Average, the mean KID value is stored in self.total / self.count
# But since we set count=1.0, self.total is already the mean
return self.total / self.count if self.count > 0 else 0.0

def compute_mean_std(self):
"""
Compute the final KID metric with mean and standard deviation.

Returns:
tuple: (kid_mean, kid_std) following torchmetrics interface
"""
kid_mean = self.total / self.count if self.count > 0 else 0.0
return (kid_mean, self.kid_std)


@flax.struct.dataclass
class SSIM(base.Average):
r"""SSIM (Structural Similarity Index Measure) Metric.
Expand Down Expand Up @@ -362,7 +615,6 @@ def from_model_output(
)
return super().from_model_output(values=batch_ssim_values)


@flax.struct.dataclass
class IoU(base.Average):
r"""Measures Intersection over Union (IoU) for semantic segmentation.
Expand Down Expand Up @@ -666,3 +918,36 @@ def compute(self) -> jax.Array:
"""Returns the final Dice coefficient."""
epsilon = 1e-7
return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon)

def _extract_image_features(images: jax.Array) -> jax.Array:
"""
Extract features from images for KID computation.
This is a simplified feature extractor. In practice, you might want to use
a pre-trained network like InceptionV3.

Args:
images: Input images of shape (N, H, W, C) where N is batch size,
H and W are height and width, C is the number of channels.
OR features of shape (N, D) if already extracted.

Returns:
Features of shape (N, D) where D is the feature dimension.
"""
# If already 2D, assume these are features and return as-is
if images.ndim == 2:
return images

# Simple feature extraction: global average pooling followed by flattening
# This is a placeholder - in practice you'd use InceptionV3 or similar
if images.ndim != 4:
raise ValueError(f"Expected 4D input (N, H, W, C) or 2D features (N, D), got {images.ndim}D")

# Global average pooling across spatial dimensions
features = jnp.mean(images, axis=(1, 2)) # Shape: (N, C)

# Add some simple transformations to create more features
# This is just a placeholder for demonstration
squared_features = features ** 2
features_concat = jnp.concatenate([features, squared_features], axis=1)

return features_concat
Loading