- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10
feat: implementing Kernel Inception Metric #65 : #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: implementing Kernel Inception Metric #65 : #81
Conversation
- writing initial version of the kernel incpetion disrance metric - writing the test cases to cover edge case for classification.
| Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. | 
| Hi Dhruv, thank you so much for your contribution to the Metrax repo! The first iteration of your PR looks very promising! :) A few things: 
 Thank you so much! :) | 
- adding the import modules and name change - NNX class wrapper description - implementing merge and empty method as required by the clu_metrics.Metrics method - and also updating the tests in the metrax_test along with the unittests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much Dhruv!
        
          
                src/metrax/image_metrics.py
              
                Outdated
          
        
      |  | ||
|  | ||
| @flax.struct.dataclass | ||
| class KernelInceptionDistanceMetric(clu_metrics.Metric): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are your thoughts on having this inherit base.Average from the metrax package? It will keep track of the sum of total KID and the sum of total count and compute will yield the average KID of the data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed I have done that and during my prompting , I realised to implement it in the following function:
    @classmethod
    def from_model_output(
        cls,
        real_features: jax.Array,
        fake_features: jax.Array,
        subsets: int = KID_DEFAULT_SUBSETS,
        subset_size: int = KID_DEFAULT_SUBSET_SIZE,
        degree: int = KID_DEFAULT_DEGREE,
        gamma: float = KID_DEFAULT_GAMMA,
        coef: float = KID_DEFAULT_COEF,
    ):
        # checks for the valid inputs
        if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0:
            raise ValueError("All parameters must be positive and non-zero.")
        # Compute KID for this batch
        if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size:
            raise ValueError("Subset size must be smaller than the number of samples.")
        master_key = random.PRNGKey(42)
        kid_scores = []
        for i in range(subsets):
            key_real, key_fake = random.split(random.fold_in(master_key, i))
            real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False)
            fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False)
            f_real_subset = real_features[real_indices]
            f_fake_subset = fake_features[fake_indices]
            kid = cls.__compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef)
            kid_scores.append(kid)
        kid_mean = jnp.mean(jnp.array(kid_scores))
        # Accumulate sum and count for averaging
        return cls(
            total=kid_mean,
            count=1.0,
            subsets=subsets,
            subset_size=subset_size,
            degree=degree,
            gamma=gamma,
            coef=coef,
        )- `image_metrics.py`/ `metrax_test.py`: changing name along with defining the private functions and other details. - test cases added to compare with the torchmetrics implementation. - installing the requisite packages to support torchmetrics comparison
| 
 your welcome and appreciated your help for resolving issues. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you :)
        
          
                src/metrax/__init__.py
              
                Outdated
          
        
      | from metrax import ranking_metrics | ||
| from metrax import regression_metrics | ||
|  | ||
| from metrax import image_metrics | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove this line since image_metrics was added as part of previous PR in line 17.
        
          
                src/metrax/image_metrics.py
              
                Outdated
          
        
      | return result | ||
|  | ||
|  | ||
| def merge(self, other: "KernelInceptionDistance") -> "KernelInceptionDistance": | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we are inheriting base.Average, we shouldn't need to implement compute and merge.
        
          
                src/metrax/image_metrics.py
              
                Outdated
          
        
      | ) | ||
|  | ||
| @classmethod | ||
| def empty(cls) -> "KernelInceptionDistance": | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we are inheriting base.Average, we shouldn't need to implement empty
…ocal: - again renamed to KID - restructuring the test util function to image_metrics.py .
- minor function name changes . - removing the redundant methods implemented for clu_metrics - also merging the @jshin1394 changes and then merging our works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much :)
        
          
                src/metrax/__init__.py
              
                Outdated
          
        
      | SSIM = image_metrics.SSIM | ||
| WER = nlp_metrics.WER | ||
|  | ||
| KID = image_metrics.KID | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's keep alphabetical order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much :)
        
          
                src/metrax/image_metrics.py
              
                Outdated
          
        
      | cls, | ||
| real_features: jax.Array, | ||
| fake_features: jax.Array, | ||
| subsets: int = KID_DEFAULT_SUBSETS, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we not define member variables and just replace with default values here?
subsets: int = 1000,
subset_size: int = 100,
degree: int = 3,
gamma: float = 1.0,
coef: float = 1.0,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. It was just in case if some model_eval.py evaluation script might import this parameter .
- Resolve KID metric position in nnx - Removing the constants - Adding doc-string for model_output with description.
| hi @jshin1394 , apologies of not notifying on this branch, let me know if are there any remaining issues to be resolved . | 
| 
 Hi Dhruv, are we assuming that users will call metrax.KID with extracted features only? It seems like torchmetrics.KID takes in as input the real image and the generate image but metrax.KID takes in real_features and generated_features. | 
| 
 Yes you're right . initially I have taken the I{Real} and I{Fake} as in the reference implementation here does indeed take the generated features . I have done the changes accordingly . | 
| 
 I think the original implementation takes in images as input and calls this model(https://github.com/Lightning-AI/torchmetrics/blob/74bdb26b35515c36334f59d654f0f4d8a12d5b59/src/torchmetrics/image/kid.py#L211) to get the features. If the implementation does indeed require pretrained model, I will go ahead and come up with a design for this. | 
Author:
Dhruv Malik (@dhruvmalik007 )
Intro
This pull request introduces a new implementation of the Kernel Inception Distance(KID) metric for assessing the quality of generated images and includes comprehensive unit tests to validate its functionality. The most important changes are the addition of the
KernelInceptionMetricclass, which computes the KID using a polynomial kernel, and the creation of a robust test suite to ensure the metric's correctness, stability, and sensitivity to various scenarios.Changes :
image_metrics.py: Added theKernelInceptionDistanceMetricclass, which computes the Kernel Inception Distance (KID) using a polynomial kernel. This includes methods for feature accumulation, MMD computation, and KID calculation with configurable parameters like subsets, subset size, kernel degree, gamma, and coefficient.nnx_metrics.py: defining the wrapper method for the support to the NNX library to import the method directlymetrax_test.py: defined the random test case.__init__.py: and adding the import alias for the KID metric.Unit Tests:
src/metrax/image_metrics_test.py: Added a comprehensive test suite forKernelInceptionMetricwith the following test cases: