diff --git a/requirements.txt b/requirements.txt index d822b45..6978d75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ rouge-score scikit-learn tensorflow torchmetrics +torch-fidelity diff --git a/src/metrax/__init__.py b/src/metrax/__init__.py index c322462..43f3082 100644 --- a/src/metrax/__init__.py +++ b/src/metrax/__init__.py @@ -19,7 +19,6 @@ from metrax import nlp_metrics from metrax import ranking_metrics from metrax import regression_metrics - AUCPR = classification_metrics.AUCPR AUCROC = classification_metrics.AUCROC Accuracy = classification_metrics.Accuracy @@ -30,6 +29,7 @@ Dice = image_metrics.Dice FBetaScore = classification_metrics.FBetaScore IoU = image_metrics.IoU +KID = image_metrics.KID MAE = regression_metrics.MAE MRR = ranking_metrics.MRR MSE = regression_metrics.MSE @@ -48,7 +48,6 @@ SSIM = image_metrics.SSIM WER = nlp_metrics.WER - __all__ = [ "AUCPR", "AUCROC", @@ -77,4 +76,5 @@ "SNR", "SSIM", "WER", + "KID", ] diff --git a/src/metrax/image_metrics.py b/src/metrax/image_metrics.py index 96bb244..0e4a2a8 100644 --- a/src/metrax/image_metrics.py +++ b/src/metrax/image_metrics.py @@ -13,14 +13,12 @@ # limitations under the License. """A collection of different metrics for image models.""" - -from clu import metrics as clu_metrics +import jax.numpy as jnp +from jax import random, lax import flax import jax -from jax import lax -import jax.numpy as jnp from metrax import base - +from clu import metrics as clu_metrics def _gaussian_kernel1d(sigma, radius): r"""Generates a 1D normalized Gaussian kernel. @@ -54,6 +52,261 @@ def _gaussian_kernel1d(sigma, radius): return phi_x +def _polynomial_kernel(x: jax.Array, y: jax.Array, degree: int, gamma: float, coef: float) -> jax.Array: + """ + Compute the polynomial kernel between two sets of features. + Args: + x: First set of features. + y: Another set of features to be computed with. + degree: Degree of the polynomial kernel. + gamma: Kernel coefficient for the polynomial kernel. If None, uses 1 / x.shape[1]. + coef: Independent term in the polynomial kernel. + Returns: + Polynomial kernel value of Array type. + """ + if gamma is None: + gamma = 1.0 / x.shape[1] + + # Compute dot product and apply kernel transformation + dot_product = jnp.dot(x, y.T) + + # Normalize the dot product to prevent overflow + dot_product = dot_product / jnp.sqrt(x.shape[1]) # Normalize by sqrt of feature dimension + + # Apply gamma scaling with clipping to prevent extreme values + scaled_product = jnp.clip(dot_product * gamma + coef, -10.0, 10.0) + kernel_value = scaled_product ** degree + + # Handle potential numerical issues + kernel_value = jnp.where(jnp.isfinite(kernel_value), kernel_value, 0.0) + + return kernel_value + + +@flax.struct.dataclass +class KID(base.Average): + r"""Computes Kernel Inception Distance (KID) for asses quality of generated images. + KID is a metric used to evaluate the quality of generated images by comparing + the distribution of generated images to the distribution of real images. + It is based on the Inception Score (IS) and uses a kernelized version of the + Maximum Mean Discrepancy (MMD) to measure the distance between two + distributions. + + The KID is computed as follows: + + .. math:: + KID = MMD(f_{real}, f_{fake})^2 + + Where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features + from real and fake images, see `kid ref1`_ for more details. In particular, calculating the MMD requires the + evaluation of a polynomial kernel function :math:`k`. + + .. math:: + k(x,y) = (\gamma * x^T y + coef)^{degree} + + Args: + subsets: Number of subsets to use for KID calculation. + subset_size: Number of samples in each subset. + degree: Degree of the polynomial kernel. + gamma: Kernel coefficient for the polynomial kernel. + coef: Independent term in the polynomial kernel. + kid_std: The computed KID standard deviation. + """ + subsets: int = 100 + subset_size: int = 1000 + degree: int = 3 + gamma: float = None + coef: float = 1.0 + kid_std: float = 0.0 + + @staticmethod + def _compute_mmd_static(f_real: jax.Array, f_fake: jax.Array, degree: int, gamma: float, coef: float) -> float: + k_11 = _polynomial_kernel(f_real, f_real, degree, gamma, coef) + k_22 = _polynomial_kernel(f_fake, f_fake, degree, gamma, coef) + k_12 = _polynomial_kernel(f_real, f_fake, degree, gamma, coef) + + m = f_real.shape[0] + + # Handle edge case where m <= 1 + if m <= 1: + return jnp.array(0.0) + + diag_x = jnp.diag(k_11) + diag_y = jnp.diag(k_22) + + kt_xx_sum = jnp.sum(k_11, axis=-1) - diag_x + kt_yy_sum = jnp.sum(k_22, axis=-1) - diag_y + k_xy_sum = jnp.sum(k_12, axis=0) + + value = (jnp.sum(kt_xx_sum) + jnp.sum(kt_yy_sum)) / (m * (m - 1)) + value -= 2 * jnp.sum(k_xy_sum) / (m**2) + + # Ensure the result is finite + value = jnp.where(jnp.isfinite(value), value, 0.0) + return value + + @classmethod + def from_model_output( + cls, + real_image: jax.Array, + fake_image: jax.Array, + subsets: int = 100, + subset_size: int = 1000, + degree: int = 3, + gamma: float = None, + coef: float = 1.0, + ): + """ + Create a KID instance from model output. + Also computes the average KID value and stores it in the instance. + + Args: + real_image (jax.Array): + Real images. Shape: (N, H, W, C), where N is the number of real images, + H and W are height and width, and C is the number of channels. + fake_image (jax.Array): + Generated (fake) images. Shape: (M, H, W, C), where M is the number of fake images, + H and W are height and width, and C is the number of channels. + subsets (int, optional): + Number of random subsets to use for KID calculation. Default is 100. + subset_size (int, optional): + Number of samples in each subset. Must be <= min(N, M). Default is 1000. + degree (int, optional): + Degree of the polynomial kernel. Default is 3. + gamma (float, optional): + Kernel coefficient for the polynomial kernel. If None, uses 1 / feature_dim. Default is None. + coef (float, optional): + Independent term in the polynomial kernel. Default is 1.0. + + Returns: + KID: An instance of the KID metric with the computed mean KID value for the given images. + + Raises: + ValueError: If any parameter is non-positive, or if subset_size is greater than the number of samples in real or fake images. + """ + if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0: + raise ValueError("All parameters must be positive and non-zero.") + + # Extract features from images (or use as-is if already features) + real_features = _extract_image_features(real_image) + fake_features = _extract_image_features(fake_image) + + # Compute KID for this batch and then store the aggregated response. + if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size: + raise ValueError("Subset size must be smaller than the number of samples.") + master_key = random.PRNGKey(42) + kid_scores = [] + for i in range(subsets): + key_real, key_fake = random.split(random.fold_in(master_key, i)) + real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False) + fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False) + f_real_subset = real_features[real_indices] + f_fake_subset = fake_features[fake_indices] + kid = cls._compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef) + kid_scores.append(kid) + kid_scores = jnp.array(kid_scores) + kid_mean = jnp.mean(kid_scores) + # Handle edge case where std could be inf/nan + kid_std = jnp.std(kid_scores) + kid_std = jnp.where(jnp.isfinite(kid_std), kid_std, 0.0) + + return cls( + total=kid_mean, + count=1.0, + subsets=subsets, + subset_size=subset_size, + degree=degree, + gamma=gamma, + coef=coef, + kid_std=kid_std, + ) + + @classmethod + def from_features( + cls, + real_features: jax.Array, + fake_features: jax.Array, + subsets: int = 100, + subset_size: int = 1000, + degree: int = 3, + gamma: float = None, + coef: float = 1.0, + ): + """ + Create a KID instance from pre-extracted features (backward compatibility). + Args: + real_features (jax.Array): + Feature representations of real images. Shape: (N, D), where N is the number of real images and D is the feature dimension. + fake_features (jax.Array): + Feature representations of generated (fake) images. Shape: (M, D), where M is the number of fake images and D is the feature dimension. + subsets (int, optional): + Number of random subsets to use for KID calculation. Default is 100. + subset_size (int, optional): + Number of samples in each subset. Must be <= min(N, M). Default is 1000. + degree (int, optional): + Degree of the polynomial kernel. Default is 3. + gamma (float, optional): + Kernel coefficient for the polynomial kernel. If None, uses 1 / D. Default is None. + coef (float, optional): + Independent term in the polynomial kernel. Default is 1.0. + + Returns: + KID: An instance of the KID metric with the computed mean KID value for the given features. + """ + if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0: + raise ValueError("All parameters must be positive and non-zero.") + + if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size: + raise ValueError("Subset size must be smaller than the number of samples.") + + master_key = random.PRNGKey(42) + kid_scores = [] + for i in range(subsets): + key_real, key_fake = random.split(random.fold_in(master_key, i)) + real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False) + fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False) + f_real_subset = real_features[real_indices] + f_fake_subset = fake_features[fake_indices] + kid = cls._compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef) + kid_scores.append(kid) + kid_scores = jnp.array(kid_scores) + kid_mean = jnp.mean(kid_scores) + # Handle edge case where std could be inf/nan + kid_std = jnp.std(kid_scores) + kid_std = jnp.where(jnp.isfinite(kid_std), kid_std, 0.0) + + return cls( + total=kid_mean, + count=1.0, + subsets=subsets, + subset_size=subset_size, + degree=degree, + gamma=gamma, + coef=coef, + kid_std=kid_std, + ) + + def compute(self): + """ + Compute the final KID metric mean value (for compatibility with the torchmetrics interface). + Returns: + float: The mean KID value + """ + # For base.Average, the mean KID value is stored in self.total / self.count + # But since we set count=1.0, self.total is already the mean + return self.total / self.count if self.count > 0 else 0.0 + + def compute_mean_std(self): + """ + Compute the final KID metric with mean and standard deviation. + + Returns: + tuple: (kid_mean, kid_std) following torchmetrics interface + """ + kid_mean = self.total / self.count if self.count > 0 else 0.0 + return (kid_mean, self.kid_std) + + @flax.struct.dataclass class SSIM(base.Average): r"""SSIM (Structural Similarity Index Measure) Metric. @@ -362,7 +615,6 @@ def from_model_output( ) return super().from_model_output(values=batch_ssim_values) - @flax.struct.dataclass class IoU(base.Average): r"""Measures Intersection over Union (IoU) for semantic segmentation. @@ -666,3 +918,36 @@ def compute(self) -> jax.Array: """Returns the final Dice coefficient.""" epsilon = 1e-7 return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon) + +def _extract_image_features(images: jax.Array) -> jax.Array: + """ + Extract features from images for KID computation. + This is a simplified feature extractor. In practice, you might want to use + a pre-trained network like InceptionV3. + + Args: + images: Input images of shape (N, H, W, C) where N is batch size, + H and W are height and width, C is the number of channels. + OR features of shape (N, D) if already extracted. + + Returns: + Features of shape (N, D) where D is the feature dimension. + """ + # If already 2D, assume these are features and return as-is + if images.ndim == 2: + return images + + # Simple feature extraction: global average pooling followed by flattening + # This is a placeholder - in practice you'd use InceptionV3 or similar + if images.ndim != 4: + raise ValueError(f"Expected 4D input (N, H, W, C) or 2D features (N, D), got {images.ndim}D") + + # Global average pooling across spatial dimensions + features = jnp.mean(images, axis=(1, 2)) # Shape: (N, C) + + # Add some simple transformations to create more features + # This is just a placeholder for demonstration + squared_features = features ** 2 + features_concat = jnp.concatenate([features, squared_features], axis=1) + + return features_concat diff --git a/src/metrax/image_metrics_test.py b/src/metrax/image_metrics_test.py index 1013990..dacbbf8 100644 --- a/src/metrax/image_metrics_test.py +++ b/src/metrax/image_metrics_test.py @@ -14,6 +14,7 @@ """Tests for metrax image metrics.""" + import os os.environ['KERAS_BACKEND'] = 'jax' @@ -24,6 +25,13 @@ import metrax import numpy as np import tensorflow as tf +from absl.testing import absltest +import jax.numpy as jnp +from jax import random +import numpy as np +import torch +from torchmetrics.image.kid import KernelInceptionDistance as TorchKID +from PIL import Image np.random.seed(42) @@ -175,12 +183,12 @@ # Create logits such that argmax yields PREDS_IOU_1 temp_preds_for_logits = PREDS_IOU_1 for b_idx in range(_B5): - for h_idx in range(_H5): - for w_idx in range(_W5): - label = temp_preds_for_logits[b_idx, h_idx, w_idx] - for c_idx in range(_NC5): - PREDS_IOU_5_LOGITS[b_idx, h_idx, w_idx, c_idx] = -5.0 - PREDS_IOU_5_LOGITS[b_idx, h_idx, w_idx, label] = 5.0 + for h_idx in range(_H5): + for w_idx in range(_W5): + label = temp_preds_for_logits[b_idx, h_idx, w_idx] + for c_idx in range(_NC5): + PREDS_IOU_5_LOGITS[b_idx, h_idx, w_idx, c_idx] = -5.0 + PREDS_IOU_5_LOGITS[b_idx, h_idx, w_idx, label] = 5.0 NUM_CLASSES_IOU_5 = NUM_CLASSES_IOU_1 TARGET_CLASS_IDS_IOU_5 = TARGET_CLASS_IDS_IOU_1 @@ -213,16 +221,252 @@ DICE_NO_OVERLAP = (np.array([1, 1, 0, 0]), np.array([0, 0, 1, 1])) +def random_images(seed, n): + """ + Generate n random RGB images as numpy arrays in (N, 3, 299, 299) format using PIL.Image. + Args: + seed: Random seed for reproducibility. + n: Number of images to generate. + Returns: + images: numpy array of shape (n, 3, 299, 299), dtype uint8 + """ + rng = np.random.RandomState(seed) + images = [] + for _ in range(n): + # Generate a random (299, 299, 3) uint8 array + arr = rng.randint(0, 256, size=(299, 299, 3), dtype=np.uint8) + # Convert to PIL Image and back to numpy to ensure valid image + img = Image.fromarray(arr, mode='RGB') + arr_pil = np.array(img) + # Transpose to (3, 299, 299) as required by KID/torchmetrics + arr_pil = arr_pil.transpose(2, 0, 1) + images.append(arr_pil) + return np.stack(images, axis=0).astype(np.uint8) + + class ImageMetricsTest(parameterized.TestCase): + def test_kid_torchmetrics_and_native(self): + """ + Compare KID computation using torchmetrics and the native Metrax implementation. + Both implementations now use the same input images (converted to appropriate formats). + Assert that their values are numerically close and result types are equivalent. + """ + n = 32 + subsets = 3 + subset_size = 16 + + # Generate random images + imgs_real = random_images(0, n) + imgs_fake = random_images(1, n) + + # Convert images to torch tensors if needed + imgs_real = torch.from_numpy(imgs_real) if isinstance(imgs_real, np.ndarray) else imgs_real + imgs_fake = torch.from_numpy(imgs_fake) if isinstance(imgs_fake, np.ndarray) else imgs_fake + if imgs_real.dtype != torch.uint8: + imgs_real = imgs_real.to(torch.uint8) + if imgs_fake.dtype != torch.uint8: + imgs_fake = imgs_fake.to(torch.uint8) + + # Compute KID using torchmetrics + kid = TorchKID(subsets=subsets, subset_size=subset_size) + kid.update(imgs_real, real=True) + kid.update(imgs_fake, real=False) + kid_mean_torch, kid_std_torch = kid.compute() + kid_mean_torch = float(kid_mean_torch.cpu().numpy()) + kid_std_torch = float(kid_std_torch.cpu().numpy()) + + # Convert torch images to JAX format (NCHW -> NHWC) for Metrax + imgs_real_jax = jnp.array(imgs_real.permute(0, 2, 3, 1).numpy()) # Convert to NHWC + imgs_fake_jax = jnp.array(imgs_fake.permute(0, 2, 3, 1).numpy()) # Convert to NHWC + + # Compute KID using Metrax implementation with actual images + kid_metric = metrax.KID.from_model_output( + real_image=imgs_real_jax, + fake_image=imgs_fake_jax, + subsets=subsets, subset_size=subset_size + ) + kid_mean_metrax, kid_std_metrax = kid_metric.compute_mean_std() + kid_mean_metrax, kid_std_metrax = float(kid_mean_metrax), float(kid_std_metrax) + + # Assert types are both float + self.assertIsInstance(kid_mean_torch, float) + self.assertIsInstance(kid_mean_metrax, float) + self.assertIsInstance(kid_std_torch, float) + # Only check stddev if metrax returns a real value + if not np.isnan(kid_std_metrax): + self.assertIsInstance(kid_std_metrax, float) + self.assertGreaterEqual(kid_std_metrax, 0.0) + self.assertAlmostEqual(kid_std_torch, kid_std_metrax, delta=0.05, msg=f"KID std mismatch: torch={kid_std_torch}, metrax={kid_std_metrax}") + # Always check mean + self.assertAlmostEqual(kid_mean_torch, kid_mean_metrax, delta=0.05, msg=f"KID mean mismatch: torch={kid_mean_torch}, metrax={kid_mean_metrax}") + self.assertGreaterEqual(kid_std_torch, 0.0) + + # Tests KID metric with default parameters on random features + def test_kernel_inception_distance_default_params(self): + """Test KID metric with default parameters on random features.""" + key1, key2 = random.split(random.PRNGKey(42)) + real_features = random.normal(key1, shape=(100, 2048)) + fake_features = random.normal(key2, shape=(100, 2048)) + + kid = metrax.KID.from_features( + real_features, + fake_features, + subset_size=50 + ) + + result = kid.compute() + self.assertTrue(isinstance(result, (float, int, jnp.ndarray))) + self.assertGreaterEqual(float(result), 0.0) + + def test_kernel_inception_distance_invalid_params(self): + """Test that invalid parameters raise appropriate exceptions.""" + key1, key2 = random.split(random.PRNGKey(44)) + real_features = random.normal(key1, shape=(100, 2048)) + fake_features = random.normal(key2, shape=(100, 2048)) + + with self.assertRaises(ValueError): + metrax.KID.from_features( + real_features, + fake_features, + subsets=-1, + ) + + with self.assertRaises(ValueError): + metrax.KID.from_features( + real_features, + fake_features, + subset_size=0, + ) + + # Tests KID metric with very small sample sizes + def test_kernel_inception_distance_small_sample_size(self): + """Test KID metric with very small sample sizes.""" + key1, key2 = random.split(random.PRNGKey(45)) + real_features = random.normal(key1, shape=(10, 2048)) + fake_features = random.normal(key2, shape=(10, 2048)) + + kid = metrax.KID.from_features( + real_features, + fake_features, + subset_size=5, + ) + result = kid.compute() + self.assertTrue(isinstance(result, (float, int, jnp.ndarray))) + + # Tests that identical feature sets produce KID values close to zero + def test_kernel_inception_distance_identical_sets(self): + """Test that identical feature sets produce KID values close to zero.""" + key = random.PRNGKey(46) + features = random.normal(key, shape=(100, 2048)) + + kid = metrax.KID.from_features( + features, + features, + subsets=50, + subset_size=50, + ) + result = kid.compute() + val = float(result) if hasattr(result, 'shape') and result.shape == () else result + self.assertTrue(val < 1e-3, f"Expected KID close to zero, got {val}") + + # Tests KID metric when the fake features exhibit mode collapse (low variance) + def test_kernel_inception_distance_mode_collapse(self): + """Test KID metric when the fake features exhibit mode collapse (low variance).""" + key1, key2 = random.split(random.PRNGKey(47)) + real_features = random.normal(key1, shape=(100, 2048)) + + base_feature = random.normal(key2, shape=(1, 2048)) + repeated_base = jnp.repeat(base_feature, 100, axis=0) + small_noise = random.normal(key2, shape=(100, 2048)) * 0.01 + fake_features = repeated_base + small_noise + + kid = metrax.KID.from_model_output( + real_features, + fake_features, + subset_size=50 + ) + result = kid.compute() + val = float(result) if hasattr(result, 'shape') and result.shape == () else result + self.assertTrue(val > 0.0) + + # Tests KID metric's sensitivity to outliers in the feature distributions + def test_kernel_inception_distance_outliers(self): + """Test KID metric's sensitivity to outliers in the feature distributions.""" + key1, key2, key3 = random.split(random.PRNGKey(48), 3) + real_features = random.normal(key1, shape=(100, 2048)) + fake_features = random.normal(key2, shape=(100, 2048)) + + outliers = random.normal(key3, shape=(10, 2048)) * 10.0 + fake_features_with_outliers = fake_features.at[:10].set(outliers) + + kid_normal = metrax.KID.from_model_output( + real_features, fake_features, subset_size=50 + ) + kid_with_outliers = metrax.KID.from_model_output( + real_features, fake_features_with_outliers, subset_size=50 + ) - def test_dice_empty(self): - """Tests the `empty` method of the `Dice` class.""" - m = metrax.Dice.empty() - self.assertEqual(m.intersection, jnp.array(0, jnp.float32)) - self.assertEqual(m.sum_true, jnp.array(0, jnp.float32)) - self.assertEqual(m.sum_pred, jnp.array(0, jnp.float32)) + result_normal = kid_normal.compute() + result_with_outliers = kid_with_outliers.compute() + val_normal = float(result_normal) if hasattr(result_normal, 'shape') and result_normal.shape == () else result_normal + val_outliers = float(result_with_outliers) if hasattr(result_with_outliers, 'shape') and result_with_outliers.shape == () else result_with_outliers + self.assertNotEqual(val_normal, val_outliers) + + # Tests KID metric with different subset configurations to evaluate stability + def test_kernel_inception_distance_different_subset_sizes(self): + """Test KID metric with different subset configurations to evaluate stability.""" + key1, key2 = random.split(random.PRNGKey(49)) + real_features = random.normal(key1, shape=(200, 2048)) + fake_features = random.normal(key2, shape=(200, 2048)) + + kid_small_subsets = metrax.KID.from_model_output( + real_features, fake_features, subsets=10, subset_size=10 + ) + kid_large_subsets = metrax.KID.from_model_output( + real_features, fake_features, subsets=5, subset_size=100 + ) - @parameterized.named_parameters( + result_small = kid_small_subsets.compute() + result_large = kid_large_subsets.compute() + val_small = float(result_small) if hasattr(result_small, 'shape') and result_small.shape == () else result_small + val_large = float(result_large) if hasattr(result_large, 'shape') and result_large.shape == () else result_large + + self.assertTrue(isinstance(val_small, float)) + self.assertTrue(isinstance(val_large, float)) + + # Tests KID metric's ability to differentiate between similar and dissimilar distributions + def test_kernel_inception_distance_different_distributions(self): + """Test KID metric's ability to differentiate between similar and dissimilar distributions.""" + key1, key2 = random.split(random.PRNGKey(50)) + real_features = random.normal(key1, shape=(100, 2048)) + mean = 0.5 + std = 2.0 + fake_features = mean + std * random.normal(key2, shape=(100, 2048)) + + kid = metrax.KID.from_model_output( + real_features, fake_features, subset_size=50 + ) + result = kid.compute() + val = float(result) if hasattr(result, 'shape') and result.shape == () else result + self.assertTrue(val > 0.0) + key3 = random.PRNGKey(51) + another_real_features = random.normal(key3, shape=(100, 2048)) + + kid_same_dist = metrax.KID.from_model_output( + real_features, another_real_features, subset_size=50 + ) + result_same_dist = kid_same_dist.compute() + val_same = float(result_same_dist) if hasattr(result_same_dist, 'shape') and result_same_dist.shape == () else result_same_dist + self.assertTrue(val > val_same) + + def test_dice_empty(self): + """Tests the `empty` method of the `Dice` class.""" + m = metrax.Dice.empty() + self.assertEqual(m.intersection, jnp.array(0, jnp.float32)) + self.assertEqual(m.sum_true, jnp.array(0, jnp.float32)) + self.assertEqual(m.sum_pred, jnp.array(0, jnp.float32)) + + @parameterized.named_parameters( ( 'ssim_basic_norm_single_channel', PREDS_1, @@ -274,194 +518,194 @@ def test_dice_empty(self): DEFAULT_K2, ), ) - def test_ssim_against_tensorflow( - self, - predictions: np.ndarray, - targets: np.ndarray, - max_val: float, - filter_size: int, - filter_sigma: float, - k1: float, - k2: float, - ): - """Test that metrax.SSIM computes values close to tf.image.ssim.""" - # Calculate SSIM using Metrax - predictions_jax = jnp.array(predictions) - targets_jax = jnp.array(targets) - metrax_metric = metrax.SSIM.from_model_output( - predictions=predictions_jax, - targets=targets_jax, - max_val=max_val, - filter_size=filter_size, - filter_sigma=filter_sigma, - k1=k1, - k2=k2, - ) - metrax_result = metrax_metric.compute() - - # Calculate SSIM using TensorFlow - predictions_tf = tf.convert_to_tensor(predictions, dtype=tf.float32) - targets_tf = tf.convert_to_tensor(targets, dtype=tf.float32) - tf_ssim_per_image = tf.image.ssim( - img1=predictions_tf, - img2=targets_tf, - max_val=max_val, - filter_size=filter_size, - filter_sigma=filter_sigma, - k1=k1, - k2=k2, - ) - tf_result_mean = tf.reduce_mean(tf_ssim_per_image).numpy() - - np.testing.assert_allclose( - metrax_result, - tf_result_mean, - rtol=1e-5, - atol=1e-5, - err_msg=( - f'SSIM mismatch for params: max_val={max_val}, ' - f'filter_size={filter_size}, filter_sigma={filter_sigma}, ' - f'k1={k1}, k2={k2}' + def test_ssim_against_tensorflow( + self, + predictions: np.ndarray, + targets: np.ndarray, + max_val: float, + filter_size: int, + filter_sigma: float, + k1: float, + k2: float, + ): + """Test that metrax.SSIM computes values close to tf.image.ssim.""" + # Calculate SSIM using Metrax + predictions_jax = jnp.array(predictions) + targets_jax = jnp.array(targets) + metrax_metric = metrax.SSIM.from_model_output( + predictions=predictions_jax, + targets=targets_jax, + max_val=max_val, + filter_size=filter_size, + filter_sigma=filter_sigma, + k1=k1, + k2=k2, + ) + metrax_result = metrax_metric.compute() + + # Calculate SSIM using TensorFlow + predictions_tf = tf.convert_to_tensor(predictions, dtype=tf.float32) + targets_tf = tf.convert_to_tensor(targets, dtype=tf.float32) + tf_ssim_per_image = tf.image.ssim( + img1=predictions_tf, + img2=targets_tf, + max_val=max_val, + filter_size=filter_size, + filter_sigma=filter_sigma, + k1=k1, + k2=k2, + ) + tf_result_mean = tf.reduce_mean(tf_ssim_per_image).numpy() + + np.testing.assert_allclose( + metrax_result, + tf_result_mean, + rtol=1e-5, + atol=1e-5, + err_msg=( + f'SSIM mismatch for params: max_val={max_val}, ' + f'filter_size={filter_size}, filter_sigma={filter_sigma}, ' + f'k1={k1}, k2={k2}' + ), + ) + # Only expect 1.0 for identical images + if np.array_equal(predictions, targets): + self.assertAlmostEqual(float(metrax_result), 1.0, delta=1e-6) + self.assertAlmostEqual(float(tf_result_mean), 1.0, delta=1e-6) + + @parameterized.named_parameters( + ( + 'iou_binary_target_foreground', + TARGETS_IOU_1, + PREDS_IOU_1, + NUM_CLASSES_IOU_1, + TARGET_CLASS_IDS_IOU_1, + False, ), - ) - # For identical images, we expect a value very close to 1.0 - if np.array_equal(predictions, targets): - self.assertAlmostEqual(float(metrax_result), 1.0, delta=1e-6) - self.assertAlmostEqual(float(tf_result_mean), 1.0, delta=1e-6) - - @parameterized.named_parameters( - ( - 'iou_binary_target_foreground', - TARGETS_IOU_1, - PREDS_IOU_1, - NUM_CLASSES_IOU_1, - TARGET_CLASS_IDS_IOU_1, - False, - ), - ( - 'iou_multiclass_target_subset', - TARGETS_IOU_2, - PREDS_IOU_2, - NUM_CLASSES_IOU_2, - TARGET_CLASS_IDS_IOU_2, - False, - ), - ( - 'iou_multiclass_target_single_from_set2', - TARGETS_IOU_2, - PREDS_IOU_2, - NUM_CLASSES_IOU_2, - [1], - False, - ), - ( - 'iou_perfect_overlap_binary', - TARGETS_IOU_3, - PREDS_IOU_3, - NUM_CLASSES_IOU_3, - TARGET_CLASS_IDS_IOU_3, - False, - ), - ( - 'iou_no_overlap_target_class', - TARGETS_IOU_4, - PREDS_IOU_4, - NUM_CLASSES_IOU_4, - TARGET_CLASS_IDS_IOU_4, - False, - ), - ( - 'iou_from_logits_binary', - TARGETS_IOU_5, - PREDS_IOU_5_LOGITS, - NUM_CLASSES_IOU_5, - TARGET_CLASS_IDS_IOU_5, - True, - ), - ( - 'iou_target_all_metrax_none_keras_list', - TARGETS_IOU_6, - PREDS_IOU_6, - NUM_CLASSES_IOU_6, - TARGET_CLASS_IDS_IOU_6, - False, - ), - ) - def test_iou_against_keras( - self, - targets: np.ndarray, - predictions: np.ndarray, - num_classes: int, - target_class_ids: np.ndarray, - from_logits: bool, - ): - """Tests metrax.IoU against keras.metrics.IoU.""" - # Metrax IoU - metrax_metric = metrax.IoU.from_model_output( - predictions=jnp.array(predictions), - targets=jnp.array(targets), - num_classes=num_classes, - target_class_ids=jnp.array(target_class_ids), - from_logits=from_logits, - ) - metrax_result = metrax_metric.compute() - - # Keras IoU - keras_iou_metric = keras.metrics.IoU( - num_classes=num_classes, - target_class_ids=target_class_ids, - name='keras_iou', - sparse_y_pred=not from_logits, - ) - keras_iou_metric.update_state(targets, predictions) - keras_result = keras_iou_metric.result() - - np.testing.assert_allclose( - metrax_result, - keras_result, - rtol=1e-5, - atol=1e-5, - err_msg=( - f'IoU mismatch for num_classes={num_classes},' - f' target_class_ids={target_class_ids} (TF was' - f' {target_class_ids}),' - f' from_logits={from_logits}.\nMetrax: {metrax_result}, Keras:' - f' {keras_result}' + ( + 'iou_multiclass_target_subset', + TARGETS_IOU_2, + PREDS_IOU_2, + NUM_CLASSES_IOU_2, + TARGET_CLASS_IDS_IOU_2, + False, + ), + ( + 'iou_multiclass_target_single_from_set2', + TARGETS_IOU_2, + PREDS_IOU_2, + NUM_CLASSES_IOU_2, + [1], + False, + ), + ( + 'iou_perfect_overlap_binary', + TARGETS_IOU_3, + PREDS_IOU_3, + NUM_CLASSES_IOU_3, + TARGET_CLASS_IDS_IOU_3, + False, + ), + ( + 'iou_no_overlap_target_class', + TARGETS_IOU_4, + PREDS_IOU_4, + NUM_CLASSES_IOU_4, + TARGET_CLASS_IDS_IOU_4, + False, + ), + ( + 'iou_from_logits_binary', + TARGETS_IOU_5, + PREDS_IOU_5_LOGITS, + NUM_CLASSES_IOU_5, + TARGET_CLASS_IDS_IOU_5, + True, + ), + ( + 'iou_target_all_metrax_none_keras_list', + TARGETS_IOU_6, + PREDS_IOU_6, + NUM_CLASSES_IOU_6, + TARGET_CLASS_IDS_IOU_6, + False, ), ) - - # Specific assertions for clearer test outcomes - if 'perfect_overlap' in self.id(): - self.assertAlmostEqual( - float(metrax_result), - 1.0, - delta=1e-6, - msg=f'Metrax IoU failed for {self.id()}', - ) - if not np.isnan(keras_result): - self.assertAlmostEqual( - float(keras_result), - 1.0, - delta=1e-6, - msg=f'Keras IoU failed for {self.id()}', + def test_iou_against_keras( + self, + targets: np.ndarray, + predictions: np.ndarray, + num_classes: int, + target_class_ids: np.ndarray, + from_logits: bool, + ): + """Tests metrax.IoU against keras.metrics.IoU.""" + # Metrax IoU + metrax_metric = metrax.IoU.from_model_output( + predictions=jnp.array(predictions), + targets=jnp.array(targets), + num_classes=num_classes, + target_class_ids=jnp.array(target_class_ids), + from_logits=from_logits, ) - - if 'no_overlap' in self.id(): - self.assertAlmostEqual( - float(metrax_result), - 0.0, - delta=1e-6, - msg=f'Metrax IoU failed for {self.id()}', - ) - if not np.isnan(keras_result): - self.assertAlmostEqual( - float(keras_result), - 0.0, - delta=1e-6, - msg=f'Keras IoU failed for {self.id()}', + metrax_result = metrax_metric.compute() + + # Keras IoU + keras_iou_metric = keras.metrics.IoU( + num_classes=num_classes, + target_class_ids=target_class_ids, + name='keras_iou', + sparse_y_pred=not from_logits, + ) + keras_iou_metric.update_state(targets, predictions) + keras_result = keras_iou_metric.result() + + np.testing.assert_allclose( + metrax_result, + keras_result, + rtol=1e-5, + atol=1e-5, + err_msg=( + f'IoU mismatch for num_classes={num_classes},' + f' target_class_ids={target_class_ids} (TF was' + f' {target_class_ids}),' + f' from_logits={from_logits}.\nMetrax: {metrax_result}, Keras:' + f' {keras_result}' + ), ) - @parameterized.named_parameters( + # Specific assertions for clearer test outcomes + if 'perfect_overlap' in self.id(): + self.assertAlmostEqual( + float(metrax_result), + 1.0, + delta=1e-6, + msg=f'Metrax IoU failed for {self.id()}', + ) + if not np.isnan(keras_result): + self.assertAlmostEqual( + float(keras_result), + 1.0, + delta=1e-6, + msg=f'Keras IoU failed for {self.id()}', + ) + + if 'no_overlap' in self.id(): + self.assertAlmostEqual( + float(metrax_result), + 0.0, + delta=1e-6, + msg=f'Metrax IoU failed for {self.id()}', + ) + if not np.isnan(keras_result): + self.assertAlmostEqual( + float(keras_result), + 0.0, + delta=1e-6, + msg=f'Keras IoU failed for {self.id()}', + ) + + @parameterized.named_parameters( ( 'psnr_basic_norm_single_channel', PREDS_1, @@ -493,66 +737,66 @@ def test_iou_against_keras( MAX_VAL_6, ), ) - def test_psnr_against_tensorflow( - self, - predictions_np: np.ndarray, - targets_np: np.ndarray, - max_val: float, - ) -> None: - """Test that metrax.SSIM computes values close to tf.image.ssim.""" - # Calculate PSNR using Metrax - metrax_psnr = metrax.PSNR.from_model_output( - predictions=jnp.array(predictions_np), - targets=jnp.array(targets_np), - max_val=max_val, - ).compute() - - # Calculate PSNR using TensorFlow - tf_psnr = tf.image.psnr( - predictions_np.astype(np.float32), - targets_np.astype(np.float32), - max_val=max_val, + def test_psnr_against_tensorflow( + self, + predictions_np: np.ndarray, + targets_np: np.ndarray, + max_val: float, + ) -> None: + """Test that metrax.SSIM computes values close to tf.image.ssim.""" + # Calculate PSNR using Metrax + metrax_psnr = metrax.PSNR.from_model_output( + predictions=jnp.array(predictions_np), + targets=jnp.array(targets_np), + max_val=max_val, + ).compute() + + # Calculate PSNR using TensorFlow + tf_psnr = tf.image.psnr( + predictions_np.astype(np.float32), + targets_np.astype(np.float32), + max_val=max_val, + ) + tf_mean = tf.reduce_mean(tf_psnr).numpy() + + if np.isinf(tf_mean): + self.assertTrue(np.isinf(metrax_psnr)) + else: + np.testing.assert_allclose( + metrax_psnr, + tf_mean, + rtol=1e-4, + atol=1e-4, + err_msg='PSNR mismatch', + ) + + @parameterized.named_parameters( + ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32), + ('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32), + ('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32), + ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1), + ('all_ones', *DICE_ALL_ONES), + ('all_zeros', *DICE_ALL_ZEROS), + ('no_overlap', *DICE_NO_OVERLAP), ) - tf_mean = tf.reduce_mean(tf_psnr).numpy() - - if np.isinf(tf_mean): - self.assertTrue(np.isinf(metrax_psnr)) - else: - np.testing.assert_allclose( - metrax_psnr, - tf_mean, - rtol=1e-4, - atol=1e-4, - err_msg='PSNR mismatch', - ) - - @parameterized.named_parameters( - ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32), - ('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32), - ('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32), - ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1), - ('all_ones', *DICE_ALL_ONES), - ('all_zeros', *DICE_ALL_ZEROS), - ('no_overlap', *DICE_NO_OVERLAP), - ) - def test_dice(self, y_true, y_pred): - """Test that Dice metric computes expected values.""" - y_true = jnp.asarray(y_true, jnp.float32) - y_pred = jnp.asarray(y_pred, jnp.float32) + def test_dice(self, y_true, y_pred): + """Test that Dice metric computes expected values.""" + y_true = jnp.asarray(y_true, jnp.float32) + y_pred = jnp.asarray(y_pred, jnp.float32) - # Manually compute expected Dice - eps = 1e-7 - intersection = jnp.sum(y_true * y_pred) - sum_pred = jnp.sum(y_pred) - sum_true = jnp.sum(y_true) - expected = (2.0 * intersection) / (sum_pred + sum_true + eps) + # Manually compute expected Dice + eps = 1e-7 + intersection = jnp.sum(y_true * y_pred) + sum_pred = jnp.sum(y_pred) + sum_true = jnp.sum(y_true) + expected = (2.0 * intersection) / (sum_pred + sum_true + eps) - # Compute using the metric class - metric = metrax.Dice.from_model_output(predictions=y_pred, labels=y_true) - result = metric.compute() + # Compute using the metric class + metric = metrax.Dice.from_model_output(predictions=y_pred, labels=y_true) + result = metric.compute() - np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) if __name__ == '__main__': - absltest.main() + absltest.main() \ No newline at end of file diff --git a/src/metrax/metrax_test.py b/src/metrax/metrax_test.py index fe94764..1b6dfc3 100644 --- a/src/metrax/metrax_test.py +++ b/src/metrax/metrax_test.py @@ -201,6 +201,19 @@ class MetraxTest(parameterized.TestCase): 'ks': KS, }, ), + ( + 'kid', + metrax.KID, + { + 'real_features': np.random.uniform(size=(BATCHES * BATCH_SIZE, 2048)), + 'fake_features': np.random.uniform(size=(BATCHES * BATCH_SIZE, 2048)), + 'subsets': 10, + 'subset_size': 8, + 'degree': 3, + 'gamma': 0.3, + 'coef': 1.0, + }, + ), ( 'ssim', metrax.SSIM, diff --git a/src/metrax/nnx/__init__.py b/src/metrax/nnx/__init__.py index 42edff7..93041f2 100644 --- a/src/metrax/nnx/__init__.py +++ b/src/metrax/nnx/__init__.py @@ -24,6 +24,7 @@ Dice = nnx_metrics.Dice FBetaScore = nnx_metrics.FBetaScore IoU = nnx_metrics.IoU +KID = nnx_metrics.KID MAE = nnx_metrics.MAE MRR = nnx_metrics.MRR MSE = nnx_metrics.MSE @@ -53,6 +54,7 @@ "Dice", "FBetaScore", "IoU", + "KID", "MRR", "MAE", "MSE", diff --git a/src/metrax/nnx/nnx_metrics.py b/src/metrax/nnx/nnx_metrics.py index 3f5d371..03773a3 100644 --- a/src/metrax/nnx/nnx_metrics.py +++ b/src/metrax/nnx/nnx_metrics.py @@ -86,6 +86,12 @@ def __init__(self): super().__init__(metrax.IoU) +class KID(NnxWrapper): + """An NNX class for the Metrax metric KernelInceptionMetric.""" + + def __init__(self): + super().__init__(metrax.KID) + class MAE(NnxWrapper): """An NNX class for the Metrax metric MAE.""" @@ -203,3 +209,4 @@ class WER(NnxWrapper): def __init__(self): super().__init__(metrax.WER) +