Skip to content

Conversation

@dhruvmalik007
Copy link

@dhruvmalik007 dhruvmalik007 commented May 6, 2025

Author:

Dhruv Malik (@dhruvmalik007 )

Intro

This pull request introduces a new implementation of the Kernel Inception Distance(KID) metric for assessing the quality of generated images and includes comprehensive unit tests to validate its functionality. The most important changes are the addition of the KernelInceptionMetric class, which computes the KID using a polynomial kernel, and the creation of a robust test suite to ensure the metric's correctness, stability, and sensitivity to various scenarios.

Changes :

  • image_metrics.py: Added the KernelInceptionDistanceMetric class, which computes the Kernel Inception Distance (KID) using a polynomial kernel. This includes methods for feature accumulation, MMD computation, and KID calculation with configurable parameters like subsets, subset size, kernel degree, gamma, and coefficient.

  • nnx_metrics.py: defining the wrapper method for the support to the NNX library to import the method directly

  • metrax_test.py: defined the random test case.

  • __init__.py: and adding the import alias for the KID metric.

Unit Tests:

  • src/metrax/image_metrics_test.py: Added a comprehensive test suite for KernelInceptionMetric with the following test cases:
    • Validation of default parameter functionality.
    • Handling of invalid parameters and edge cases.
    • Behavior with small sample sizes and identical feature sets.
    • Sensitivity to mode collapse, outliers, and different subset configurations.
    • Differentiation between similar and dissimilar feature distributions.

- writing initial version of the kernel incpetion disrance metric

- writing the test cases to cover edge case for classification.
@google-cla
Copy link

google-cla bot commented May 6, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@jshin1394
Copy link
Collaborator

Hi Dhruv, thank you so much for your contribution to the Metrax repo! The first iteration of your PR looks very promising! :)

A few things:

  1. Could you please sign the Google Contributor License Agreement(CLA) as suggested by the google-cla bot?
  2. Here is an example PR which adds a new metric to Metrax(https://github.com/google/metrax/pull/79/files). It also adds the metric to metrax_test in order to test its jittability and adds the metrax to metrax/nnx in order to support NNX users.
  3. clu_metric.Metric interface has a list of functions that must be implemented and merge is one of them(https://github.com/google/metrax/pull/79/files).

Thank you so much! :)

@dhruvmalik007 dhruvmalik007 changed the title feat: implementing #65 : feat: implementing Kernel Inception Metric #65 : May 7, 2025
- adding the import modules and name change
- NNX class wrapper description
- implementing merge and empty method as required by the clu_metrics.Metrics method
- and also updating the tests in the metrax_test along with the unittests.
Copy link
Collaborator

@jshin1394 jshin1394 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much Dhruv!



@flax.struct.dataclass
class KernelInceptionDistanceMetric(clu_metrics.Metric):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are your thoughts on having this inherit base.Average from the metrax package? It will keep track of the sum of total KID and the sum of total count and compute will yield the average KID of the data.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed I have done that and during my prompting , I realised to implement it in the following function:

    @classmethod
    def from_model_output(
        cls,
        real_features: jax.Array,
        fake_features: jax.Array,
        subsets: int = KID_DEFAULT_SUBSETS,
        subset_size: int = KID_DEFAULT_SUBSET_SIZE,
        degree: int = KID_DEFAULT_DEGREE,
        gamma: float = KID_DEFAULT_GAMMA,
        coef: float = KID_DEFAULT_COEF,
    ):
        # checks for the valid inputs
        if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0:
            raise ValueError("All parameters must be positive and non-zero.")
        # Compute KID for this batch
        if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size:
            raise ValueError("Subset size must be smaller than the number of samples.")
        master_key = random.PRNGKey(42)
        kid_scores = []
        for i in range(subsets):
            key_real, key_fake = random.split(random.fold_in(master_key, i))
            real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False)
            fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False)
            f_real_subset = real_features[real_indices]
            f_fake_subset = fake_features[fake_indices]
            kid = cls.__compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef)
            kid_scores.append(kid)
        kid_mean = jnp.mean(jnp.array(kid_scores))
        # Accumulate sum and count for averaging
        return cls(
            total=kid_mean,
            count=1.0,
            subsets=subsets,
            subset_size=subset_size,
            degree=degree,
            gamma=gamma,
            coef=coef,
        )

- `image_metrics.py`/ `metrax_test.py`: changing name along with defining the private functions and other details.
- test cases added to compare with the torchmetrics implementation.

- installing the requisite packages to support torchmetrics comparison
@dhruvmalik007
Copy link
Author

Thank you so much Dhruv!

your welcome and appreciated your help for resolving issues.

Copy link
Collaborator

@jshin1394 jshin1394 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you :)

from metrax import ranking_metrics
from metrax import regression_metrics

from metrax import image_metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this line since image_metrics was added as part of previous PR in line 17.

return result


def merge(self, other: "KernelInceptionDistance") -> "KernelInceptionDistance":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we are inheriting base.Average, we shouldn't need to implement compute and merge.

)

@classmethod
def empty(cls) -> "KernelInceptionDistance":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we are inheriting base.Average, we shouldn't need to implement empty

…ocal:

- again renamed to KID
- restructuring the test util function to image_metrics.py .
- minor function name changes .
- removing the redundant methods implemented for clu_metrics
- also merging the @jshin1394 changes and then merging our works.
Copy link
Collaborator

@jshin1394 jshin1394 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much :)

SSIM = image_metrics.SSIM
WER = nlp_metrics.WER

KID = image_metrics.KID
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's keep alphabetical order.

Copy link
Collaborator

@jshin1394 jshin1394 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much :)

cls,
real_features: jax.Array,
fake_features: jax.Array,
subsets: int = KID_DEFAULT_SUBSETS,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we not define member variables and just replace with default values here?
subsets: int = 1000,
subset_size: int = 100,
degree: int = 3,
gamma: float = 1.0,
coef: float = 1.0,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. It was just in case if some model_eval.py evaluation script might import this parameter .

@dhruvmalik007
Copy link
Author

hi @jshin1394 , apologies of not notifying on this branch, let me know if are there any remaining issues to be resolved .

@jshin1394
Copy link
Collaborator

hi @jshin1394 , apologies of not notifying on this branch, let me know if are there any remaining issues to be resolved .

Hi Dhruv, are we assuming that users will call metrax.KID with extracted features only? It seems like torchmetrics.KID takes in as input the real image and the generate image but metrax.KID takes in real_features and generated_features.

@dhruvmalik007
Copy link
Author

dhruvmalik007 commented Jun 4, 2025

hi @jshin1394 , apologies of not notifying on this branch, let me know if are there any remaining issues to be resolved .

Hi Dhruv, are we assuming that users will call metrax.KID with extracted features only? It seems like torchmetrics.KID takes in as input the real image and the generate image but metrax.KID takes in real_features and generated_features.

Yes you're right . initially I have taken the I{Real} and I{Fake} as in the reference implementation here does indeed take the generated features . I have done the changes accordingly .

@jshin1394
Copy link
Collaborator

hi @jshin1394 , apologies of not notifying on this branch, let me know if are there any remaining issues to be resolved .

Hi Dhruv, are we assuming that users will call metrax.KID with extracted features only? It seems like torchmetrics.KID takes in as input the real image and the generate image but metrax.KID takes in real_features and generated_features.

Yes you're right . initially I have taken the I{Real} and I{Fake} as in the reference implementation here does indeed take the generated features . I have done the changes accordingly .

I think the original implementation takes in images as input and calls this model(https://github.com/Lightning-AI/torchmetrics/blob/74bdb26b35515c36334f59d654f0f4d8a12d5b59/src/torchmetrics/image/kid.py#L211) to get the features. If the implementation does indeed require pretrained model, I will go ahead and come up with a design for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants