Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
RougeN = nlp_metrics.RougeN
SNR = audio_metrics.SNR
SSIM = image_metrics.SSIM
TotalVariation = image_metrics.TotalVariation
WER = nlp_metrics.WER


Expand Down
66 changes: 66 additions & 0 deletions src/metrax/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,69 @@ def compute(self) -> jax.Array:
"""Returns the final Dice coefficient."""
epsilon = 1e-7
return (2.0 * self.intersection) / (self.sum_pred + self.sum_true + epsilon)

@flax.struct.dataclass
class TotalVariation(base.Average):
r"""Calculates and returns the Total Variation (TV) for one or more images.

The total variation is the sum of the absolute differences for neighboring
pixel-values in the input images. This measures how much noise is in the
images.

This implements the anisotropic 2-D version of the formula described here:

https://en.wikipedia.org/wiki/Total_variation_denoising
"""

@staticmethod
def _calculate_total_variation(
images: jax.Array,
) -> jax.Array:
"""Computes Total Variation values.

Args:
images: 4-D Array of shape ``(batch, height, width, channels)`` or
3-D Array of shape ``(height, width, channels)``.

Returns:
Total variation of 'images'.

If `images` was 4-D, return a 1-D float Array of shape `[batch]` with the
total variation for each image in the batch.
If `images` was 3-D, return a scalar float with the total variation for
that image.
"""
ndims = images.ndim
if ndims == 3: # (height, width, channels)
# Shift images by one pixel along the height and width.
pixel_dif1 = jnp.abs(images[1:, :, :] - images[:-1, :, :])
pixel_dif2 = jnp.abs(images[:, 1:, :] - images[:, :-1, :])
sum_axis = None
elif ndims == 4: # (batch, height, width, channels)
# Shift images by one pixel along the height and width.
pixel_dif1 = jnp.abs(images[:, 1:, :, :] - images[:, :-1, :, :])
pixel_dif2 = jnp.abs(images[:, :, 1:, :] - images[:, :, :-1, :])
sum_axis = [1, 2, 3]
else:
raise ValueError(
f'Input images must be either 3 or 4-dimensional, got {ndims} dimensions instead.'
)

return jnp.sum(pixel_dif1, axis=sum_axis) + jnp.sum(pixel_dif2, axis=sum_axis)


@classmethod
def from_model_output(
cls,
predictions: jax.Array
) -> 'TotalVariation':
"""Computes the Total Variation for a batch of images and creates a TotalVariation metric instance.

Args:
predictions: A JAX array of predicted images, with shape ``(batch, H, W, C)``.

Returns:
A ``TotalVariation`` instance containing per‑image total variation values.
"""
total_variation = cls._calculate_total_variation(predictions)
return super().from_model_output(values=total_variation)
75 changes: 75 additions & 0 deletions src/metrax/image_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,27 @@
DICE_ALL_ZEROS = (np.array([0, 0, 0, 0]), np.array([0, 0, 0, 0]))
DICE_NO_OVERLAP = (np.array([1, 1, 0, 0]), np.array([0, 0, 1, 1]))

# Test data for TotalVariation
# Case 1: Basic, float normalized [0,1], single channel (3D)
TV_IMG_SHAPE_1 = (16, 16, 1) # height, width, channels
TV_IMG_1 = np.random.rand(*TV_IMG_SHAPE_1).astype(np.float32)

# Case 2: Multi-channel (3), float normalized [0,1] (3D)
TV_IMG_SHAPE_2 = (32, 32, 3)
TV_IMG_2 = np.random.rand(*TV_IMG_SHAPE_2).astype(np.float32)

# Case 3: Batch of single channel images (4D)
TV_IMG_SHAPE_3 = (4, 16, 16, 1) # batch, height, width, channels
TV_IMG_3 = np.random.rand(*TV_IMG_SHAPE_3).astype(np.float32)

# Case 4: Batch of multi-channel images (4D)
TV_IMG_SHAPE_4 = (4, 32, 32, 3) # batch, height, width, channels
TV_IMG_4 = np.random.rand(*TV_IMG_SHAPE_4).astype(np.float32)

# Case 5: Constant image (should have zero variation)
TV_IMG_SHAPE_5 = (16, 16, 1)
TV_IMG_5 = np.ones(TV_IMG_SHAPE_5, dtype=np.float32)


class ImageMetricsTest(parameterized.TestCase):

Expand Down Expand Up @@ -553,6 +574,60 @@ def test_dice(self, y_true, y_pred):

np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)

@parameterized.named_parameters(
(
'tv_single_channel_3d',
TV_IMG_1,
False, # is_batch
),
(
'tv_multichannel_3d',
TV_IMG_2,
False, # is_batch
),
(
'tv_batch_single_channel_4d',
TV_IMG_3,
True, # is_batch
),
(
'tv_batch_multichannel_4d',
TV_IMG_4,
True, # is_batch
),
(
'tv_constant_image',
TV_IMG_5,
False, # is_batch
),
)
def test_total_variation_against_tensorflow(
self,
images_np: np.ndarray,
is_batch: bool,
):
"""Test that TotalVariation metric computes values close to tf.image.total_variation."""
# Calculate TV using Metrax
metrax_tv = metrax.TotalVariation.from_model_output(
predictions=jnp.array(images_np)
).compute()

# Calculate TV using TensorFlow
tf_tv = tf.image.total_variation(tf.convert_to_tensor(images_np))
tf_mean = tf.reduce_mean(tf_tv).numpy()

# For constant image, TV should be 0
if np.array_equal(images_np, TV_IMG_5):
np.testing.assert_allclose(metrax_tv, 0.0, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(tf_mean, 0.0, rtol=1e-5, atol=1e-5)
else:
np.testing.assert_allclose(
metrax_tv,
tf_mean,
rtol=1e-5,
atol=1e-5,
err_msg='Total Variation mismatch',
)

if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions src/metrax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
RougeN = nnx_metrics.RougeN
SNR = nnx_metrics.SNR
SSIM = nnx_metrics.SSIM
TotalVariation = nnx_metrics.TotalVariation
WER = nnx_metrics.WER


Expand Down
7 changes: 7 additions & 0 deletions src/metrax/nnx/nnx_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ def __init__(self):
super().__init__(metrax.SSIM)


class TotalVariation(NnxWrapper):
"""An NNX class for the Metrax metric TotalVariation."""

def __init__(self):
super().__init__(metrax.TotalVariation)


class WER(NnxWrapper):
"""An NNX class for the Metrax metric WER."""

Expand Down