diff --git a/src/metrax/__init__.py b/src/metrax/__init__.py index e43e827..e1a2fe3 100644 --- a/src/metrax/__init__.py +++ b/src/metrax/__init__.py @@ -45,6 +45,7 @@ RougeN = nlp_metrics.RougeN SNR = audio_metrics.SNR SSIM = image_metrics.SSIM +TotalVariation = image_metrics.TotalVariation WER = nlp_metrics.WER diff --git a/src/metrax/image_metrics.py b/src/metrax/image_metrics.py index 96bb244..527bb2d 100644 --- a/src/metrax/image_metrics.py +++ b/src/metrax/image_metrics.py @@ -666,3 +666,69 @@ def compute(self) -> jax.Array: """Returns the final Dice coefficient.""" epsilon = 1e-7 return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon) + +@flax.struct.dataclass +class TotalVariation(base.Average): + r"""Calculates and returns the Total Variation (TV) for one or more images. + + The total variation is the sum of the absolute differences for neighboring + pixel-values in the input images. This measures how much noise is in the + images. + + This implements the anisotropic 2-D version of the formula described here: + + https://en.wikipedia.org/wiki/Total_variation_denoising + """ + + @staticmethod + def _calculate_total_variation( + images: jax.Array, + ) -> jax.Array: + """Computes Total Variation values. + + Args: + images: 4-D Array of shape ``(batch, height, width, channels)`` or + 3-D Array of shape ``(height, width, channels)``. + + Returns: + Total variation of 'images'. + + If `images` was 4-D, return a 1-D float Array of shape `[batch]` with the + total variation for each image in the batch. + If `images` was 3-D, return a scalar float with the total variation for + that image. + """ + ndims = images.ndim + if ndims == 3: # (height, width, channels) + # Shift images by one pixel along the height and width. + pixel_dif1 = jnp.abs(images[1:, :, :] - images[:-1, :, :]) + pixel_dif2 = jnp.abs(images[:, 1:, :] - images[:, :-1, :]) + sum_axis = None + elif ndims == 4: # (batch, height, width, channels) + # Shift images by one pixel along the height and width. + pixel_dif1 = jnp.abs(images[:, 1:, :, :] - images[:, :-1, :, :]) + pixel_dif2 = jnp.abs(images[:, :, 1:, :] - images[:, :, :-1, :]) + sum_axis = [1, 2, 3] + else: + raise ValueError( + f'Input images must be either 3 or 4-dimensional, got {ndims} dimensions instead.' + ) + + return jnp.sum(pixel_dif1, axis=sum_axis) + jnp.sum(pixel_dif2, axis=sum_axis) + + + @classmethod + def from_model_output( + cls, + predictions: jax.Array + ) -> 'TotalVariation': + """Computes the Total Variation for a batch of images and creates a TotalVariation metric instance. + + Args: + predictions: A JAX array of predicted images, with shape ``(batch, H, W, C)``. + + Returns: + A ``TotalVariation`` instance containing per‑image total variation values. + """ + total_variation = cls._calculate_total_variation(predictions) + return super().from_model_output(values=total_variation) \ No newline at end of file diff --git a/src/metrax/image_metrics_test.py b/src/metrax/image_metrics_test.py index 1013990..f1d2db5 100644 --- a/src/metrax/image_metrics_test.py +++ b/src/metrax/image_metrics_test.py @@ -212,6 +212,27 @@ DICE_ALL_ZEROS = (np.array([0, 0, 0, 0]), np.array([0, 0, 0, 0])) DICE_NO_OVERLAP = (np.array([1, 1, 0, 0]), np.array([0, 0, 1, 1])) +# Test data for TotalVariation +# Case 1: Basic, float normalized [0,1], single channel (3D) +TV_IMG_SHAPE_1 = (16, 16, 1) # height, width, channels +TV_IMG_1 = np.random.rand(*TV_IMG_SHAPE_1).astype(np.float32) + +# Case 2: Multi-channel (3), float normalized [0,1] (3D) +TV_IMG_SHAPE_2 = (32, 32, 3) +TV_IMG_2 = np.random.rand(*TV_IMG_SHAPE_2).astype(np.float32) + +# Case 3: Batch of single channel images (4D) +TV_IMG_SHAPE_3 = (4, 16, 16, 1) # batch, height, width, channels +TV_IMG_3 = np.random.rand(*TV_IMG_SHAPE_3).astype(np.float32) + +# Case 4: Batch of multi-channel images (4D) +TV_IMG_SHAPE_4 = (4, 32, 32, 3) # batch, height, width, channels +TV_IMG_4 = np.random.rand(*TV_IMG_SHAPE_4).astype(np.float32) + +# Case 5: Constant image (should have zero variation) +TV_IMG_SHAPE_5 = (16, 16, 1) +TV_IMG_5 = np.ones(TV_IMG_SHAPE_5, dtype=np.float32) + class ImageMetricsTest(parameterized.TestCase): @@ -553,6 +574,62 @@ def test_dice(self, y_true, y_pred): np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + @parameterized.named_parameters( + ( + 'tv_single_channel_3d', + TV_IMG_1, + ), + ( + 'tv_multichannel_3d', + TV_IMG_2, + ), + ( + 'tv_batch_single_channel_4d', + TV_IMG_3, + ), + ( + 'tv_batch_multichannel_4d', + TV_IMG_4, + ), + ( + 'tv_constant_image', + TV_IMG_5, + ), + ) + def test_total_variation_against_tensorflow( + self, + images_np: np.ndarray, + ) -> None: + """Test that TotalVariation metric computes values close to tf.image.total_variation.""" + + # Calculate TV using Metrax + # convert to uniform [B, H, W, C] otherwise `for image in images_np` will be 2D + # if input is 3D + images_np = images_np if images_np.ndim == 4 else np.expand_dims(images_np, axis=0) + metric = None + for image in images_np: + update = metrax.TotalVariation.from_model_output( + predictions=jnp.array(image) + ) + metric = update if metric is None else metric.merge(update) + metrax_tv = metric.compute() + + # Calculate TV using TensorFlow + tf_tv = tf.image.total_variation(tf.convert_to_tensor(images_np)) + tf_mean = tf.reduce_mean(tf_tv).numpy() + + # For constant image, TV should be 0 + if np.array_equal(images_np, TV_IMG_5): + np.testing.assert_allclose(metrax_tv, 0.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(tf_mean, 0.0, rtol=1e-5, atol=1e-5) + else: + np.testing.assert_allclose( + metrax_tv, + tf_mean, + rtol=1e-5, + atol=1e-5, + err_msg='Total Variation mismatch', + ) if __name__ == '__main__': absltest.main() diff --git a/src/metrax/metrax_test.py b/src/metrax/metrax_test.py index 0328a46..7882b90 100644 --- a/src/metrax/metrax_test.py +++ b/src/metrax/metrax_test.py @@ -41,7 +41,7 @@ 'the quick brown fox jumps over the lazy dog', 'hello beautiful world', ] -# For image_metrics.SSIM and image_metrics.PSNR. +# For image_metrics.SSIM, image_metrics.PSNR and image_metrics.TotalVariation. IMG_SHAPE = (4, 32, 32, 3) PRED_IMGS = np.random.rand(*IMG_SHAPE).astype(np.float32) TARGET_IMGS = np.random.rand(*IMG_SHAPE).astype(np.float32) @@ -214,6 +214,13 @@ class MetraxTest(parameterized.TestCase): 'zero_mean': False, }, ), + ( + 'total_variation', + metrax.TotalVariation, + { + 'predictions': PRED_IMGS + } + ) ) def test_metrics_jittable(self, metric, kwargs): """Tests that jitted metrax metric yields the same result as non-jitted metric.""" diff --git a/src/metrax/nnx/__init__.py b/src/metrax/nnx/__init__.py index 495a9ce..6a5fd11 100644 --- a/src/metrax/nnx/__init__.py +++ b/src/metrax/nnx/__init__.py @@ -39,6 +39,7 @@ RougeN = nnx_metrics.RougeN SNR = nnx_metrics.SNR SSIM = nnx_metrics.SSIM +TotalVariation = nnx_metrics.TotalVariation WER = nnx_metrics.WER diff --git a/src/metrax/nnx/nnx_metrics.py b/src/metrax/nnx/nnx_metrics.py index 842794a..115c6a6 100644 --- a/src/metrax/nnx/nnx_metrics.py +++ b/src/metrax/nnx/nnx_metrics.py @@ -193,6 +193,13 @@ def __init__(self): super().__init__(metrax.SSIM) +class TotalVariation(NnxWrapper): + """An NNX class for the Metrax metric TotalVariation.""" + + def __init__(self): + super().__init__(metrax.TotalVariation) + + class WER(NnxWrapper): """An NNX class for the Metrax metric WER."""