-
Notifications
You must be signed in to change notification settings - Fork 69
Feat/simple mse metric #388 #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AnikethBhosale
wants to merge
3
commits into
PrunaAI:main
Choose a base branch
from
AnikethBhosale:feat/simple-MSE-Metric-#388
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
2072d10
Added image in contributing.md for "Pruna AI's Working Logic For Eas…
AnikethBhosale ca411a0
feat: implement Mean Squared Error metric for evaluation
AnikethBhosale b44939c
feat: implement Mean Squared Error (MSE) metric with comprehensive te…
AnikethBhosale File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| # MSE Metric Implementation Summary | ||
|
|
||
| ## ✅ Completion Status | ||
|
|
||
| All tasks have been successfully completed for the MSE (Mean Squared Error) metric implementation in Pruna AI. | ||
|
|
||
| --- | ||
|
|
||
| ## 📋 What Was Done | ||
|
|
||
| ### 1. ✅ Metric Implementation (`src/pruna/evaluation/metrics/metric_mse.py`) | ||
|
|
||
| **Features:** | ||
| - Inherits from `StatefulMetric` for proper batch accumulation | ||
| - Implements `update()` method with proper signature: `(x, gt, outputs)` | ||
| - Uses `metric_data_processor()` for consistent input handling | ||
| - Accumulates squared errors in a list of tensors (proper state management) | ||
| - Implements `compute()` method returning `MetricResult` | ||
| - Handles edge cases: None inputs, shape mismatches, empty state | ||
| - Device-aware: automatically moves tensors to correct device | ||
| - Properly documented with NumPy-style docstrings | ||
|
|
||
| **Key Implementation Details:** | ||
| ```python | ||
| @MetricRegistry.register(METRIC_MSE) | ||
| class MSEMetric(StatefulMetric): | ||
| default_call_type: str = "gt_y" | ||
| higher_is_better: bool = False | ||
| metric_name: str = METRIC_MSE | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ### 2. ✅ Registry Integration | ||
|
|
||
| **File:** `src/pruna/evaluation/metrics/__init__.py` | ||
|
|
||
| Added: | ||
| - Import: `from pruna.evaluation.metrics.metric_mse import MSEMetric` | ||
| - Export in `__all__`: `"MSEMetric"` | ||
|
|
||
| The metric is now discoverable by Pruna's evaluation framework and can be used with: | ||
| ```python | ||
| task = Task(metrics=["mse"]) | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ### 3. ✅ Comprehensive Tests (`tests/evaluation/test_mse.py`) | ||
|
|
||
| **15 Tests Created:** | ||
| 1. `test_mse_perfect_match` - Zero MSE for identical tensors | ||
| 2. `test_mse_known_value` - Verify calculation correctness | ||
| 3. `test_mse_multiple_batches` - Batch accumulation | ||
| 4. `test_mse_empty_state` - Returns NaN when no data | ||
| 5. `test_mse_multidimensional` - 2D tensor support | ||
| 6. `test_mse_3d_tensors` - Image-like data support | ||
| 7. `test_mse_reset` - State reset functionality | ||
| 8. `test_mse_mixed_values` - Positive and negative errors | ||
| 9. `test_mse_large_errors` - Large magnitude handling | ||
| 10. `test_mse_none_handling` - Graceful None handling | ||
| 11. `test_mse_cuda` - GPU support (skipped if CUDA unavailable) | ||
| 12. `test_mse_device_mismatch` - Device compatibility | ||
| 13. `test_mse_single_value` - Scalar input | ||
| 14. `test_mse_fractional_values` - Floating point precision | ||
| 15. `test_mse_batch_independence` - Consistent results regardless of batching | ||
|
|
||
| **Test Results:** | ||
| ``` | ||
| 14 passed, 1 skipped, 5 warnings | ||
| Coverage: 89% for metric_mse.py | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ### 4. ✅ Style Compliance | ||
|
|
||
| **Checks Passed:** | ||
| - ✅ `ty check` - No type errors | ||
| - ✅ `ruff check` - All linting checks passed | ||
| - ✅ Follows project code style (NumPy docstrings, PascalCase class names) | ||
| - ✅ Proper type hints throughout | ||
| - ✅ No unused imports | ||
|
|
||
| --- | ||
|
|
||
| ### 5. ✅ Documentation (`docs/user_manual/metrics/mse.md`) | ||
|
|
||
| **Documentation Includes:** | ||
| - **Overview** - What MSE is and when to use it | ||
| - **Mathematical Formula** - LaTeX equation | ||
| - **Properties** - Metric attributes and behavior | ||
| - **Usage Examples**: | ||
| - Basic standalone usage | ||
| - Integration with Pruna's evaluation framework | ||
| - Model comparison example | ||
| - **Use Cases** - When MSE is appropriate (✅) and considerations (⚠️) | ||
| - **Example Results** - Concrete examples with expected outputs | ||
| - **Technical Details** - State accumulation, device handling, shape flexibility | ||
| - **Related Metrics** - RMSE, MAE, PSNR, SSIM | ||
| - **References** - External resources and contribution guide | ||
|
|
||
| --- | ||
|
|
||
| ## 🎯 Acceptance Criteria Met | ||
|
|
||
| ### ✅ Style Guidelines | ||
| - Follows Pruna's coding conventions | ||
| - NumPy-style docstrings | ||
| - Proper type hints | ||
| - Clean, maintainable code | ||
|
|
||
| ### ✅ Documentation | ||
| - Comprehensive user documentation | ||
| - Code comments explaining key decisions | ||
| - Usage examples provided | ||
|
|
||
| ### ✅ Tests | ||
| - 15 tests covering various scenarios | ||
| - All tests pass successfully | ||
| - Edge cases handled | ||
|
|
||
| ### ✅ Integration | ||
| - Registered with `MetricRegistry` | ||
| - Exported from `__init__.py` | ||
| - Works with Pruna's evaluation framework | ||
| - Compatible with `Task` and `smash()` functions | ||
|
|
||
| --- | ||
|
|
||
| ## 📁 Files Created/Modified | ||
|
|
||
| ### Created: | ||
| 1. `src/pruna/evaluation/metrics/metric_mse.py` (122 lines) | ||
| 2. `tests/evaluation/test_mse.py` (247 lines) | ||
| 3. `docs/user_manual/metrics/mse.md` (full documentation) | ||
|
|
||
| ### Modified: | ||
| 1. `src/pruna/evaluation/metrics/__init__.py` (added MSEMetric import and export) | ||
|
|
||
| --- | ||
|
|
||
| ## 🚀 How to Use | ||
|
|
||
| ### Basic Usage: | ||
| ```python | ||
| from pruna.evaluation.metrics.metric_mse import MSEMetric | ||
|
|
||
| metric = MSEMetric() | ||
| # During evaluation loop: | ||
| metric.update(x, ground_truth, predictions) | ||
| result = metric.compute() | ||
| print(f"MSE: {result.result}") | ||
| ``` | ||
|
|
||
| ### With Pruna Framework: | ||
| ```python | ||
| from pruna import smash | ||
| from pruna.evaluation.task import Task | ||
|
|
||
| task = Task(metrics=["mse"]) | ||
| smashed_model = smash(model=your_model, eval_task=task) | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## 🧪 Test Command | ||
|
|
||
| ```bash | ||
| pytest tests/evaluation/test_mse.py -v | ||
| ``` | ||
|
|
||
| **Expected Output:** | ||
| ``` | ||
| 14 passed, 1 skipped in 1.00s | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## 📊 Code Quality Metrics | ||
|
|
||
| - **Test Coverage:** 89% (src/pruna/evaluation/metrics/metric_mse.py) | ||
| - **Lines of Code:** 122 (implementation) + 247 (tests) | ||
| - **Tests:** 15 comprehensive tests | ||
| - **Documentation:** Complete user guide with examples | ||
|
|
||
| --- | ||
|
|
||
| ## ✨ Summary | ||
|
|
||
| The MSE metric implementation is **production-ready** and follows all Pruna AI guidelines: | ||
|
|
||
| ✅ Correct implementation using `StatefulMetric` | ||
| ✅ Properly registered and exported | ||
| ✅ Comprehensive tests (all passing) | ||
| ✅ Style-compliant code | ||
| ✅ Full documentation with examples | ||
| ✅ Ready for integration into Pruna's evaluation framework | ||
|
|
||
| The metric can now be used by simply including `"mse"` in the metrics list when creating an evaluation task. | ||
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| # Mean Squared Error (MSE) Metric | ||
|
|
||
| ## Overview | ||
|
|
||
| The MSE (Mean Squared Error) metric computes the mean squared error between model predictions and ground truth values. It's a fundamental metric for evaluating regression models and image quality assessment. | ||
|
|
||
| ## Formula | ||
|
|
||
| $$ | ||
| MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 | ||
| $$ | ||
|
|
||
| Where: | ||
| - $y_i$ is the ground truth value | ||
| - $\hat{y}_i$ is the predicted value | ||
| - $n$ is the total number of samples | ||
|
|
||
| ## Properties | ||
|
|
||
| - **Metric Name**: `"mse"` | ||
| - **Higher is Better**: `False` (lower MSE indicates better performance) | ||
| - **Type**: `StatefulMetric` (accumulates across batches) | ||
| - **Default Call Type**: `"gt_y"` (compares ground truth vs outputs) | ||
|
|
||
| ## Usage | ||
|
|
||
| ### Basic Usage | ||
|
|
||
| ```python | ||
| from pruna.evaluation.metrics.metric_mse import MSEMetric | ||
|
|
||
| # Initialize the metric | ||
| mse_metric = MSEMetric() | ||
|
|
||
| # Update with batches (automatic during evaluation) | ||
| for x, gt in dataloader: | ||
| outputs = model(x) | ||
| mse_metric.update(x, gt, outputs) | ||
|
|
||
| # Compute final MSE | ||
| result = mse_metric.compute() | ||
| print(f"MSE: {result.result:.6f}") | ||
| ``` | ||
|
|
||
| ### With Pruna's Evaluation Framework | ||
|
|
||
| ```python | ||
| from pruna import smash | ||
| from pruna.evaluation.task import Task | ||
|
|
||
| # Create evaluation task | ||
| task = Task( | ||
| metrics=["mse"], # Simply include "mse" in your metrics list | ||
| # ... other task parameters | ||
| ) | ||
|
|
||
| # Smash and evaluate | ||
| smashed_model = smash( | ||
| model=your_model, | ||
| eval_task=task, | ||
| # ... other smash config | ||
| ) | ||
| ``` | ||
|
|
||
| ### Integration Example | ||
|
|
||
| ```python | ||
| import torch | ||
| from pruna.evaluation.metrics.metric_mse import MSEMetric | ||
|
|
||
| # Example: Compare two models | ||
| def evaluate_model(model, test_data): | ||
| metric = MSEMetric() | ||
|
|
||
| for batch in test_data: | ||
| x, ground_truth = batch | ||
| predictions = model(x) | ||
| metric.update(x, ground_truth, predictions) | ||
|
|
||
| result = metric.compute() | ||
| return result.result | ||
|
|
||
| # Usage | ||
| model_mse = evaluate_model(my_model, test_loader) | ||
| print(f"Model MSE: {model_mse:.4f}") | ||
| ``` | ||
|
|
||
| ## When to Use MSE | ||
|
|
||
| ### ✅ Good Use Cases | ||
|
|
||
| - **Regression Tasks**: Comparing continuous predictions to ground truth | ||
| - **Image Quality**: Measuring pixel-wise differences between images | ||
| - **Model Comparison**: Evaluating compression or quantization impact | ||
| - **Signal Processing**: Comparing reconstructed vs original signals | ||
|
|
||
| ### ⚠️ Considerations | ||
|
|
||
| - MSE is sensitive to outliers (large errors are squared) | ||
| - Assumes errors are normally distributed | ||
| - Same MSE can represent different error distributions | ||
| - Consider using RMSE (√MSE) for interpretability in original units | ||
|
|
||
| ## Example Results | ||
|
|
||
| ### Perfect Match | ||
| ```python | ||
| gt = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||
| outputs = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||
| # MSE = 0.0 | ||
| ``` | ||
|
|
||
| ### Constant Error | ||
| ```python | ||
| gt = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||
| outputs = torch.tensor([2.0, 3.0, 4.0, 5.0]) # +1 error each | ||
| # MSE = 1.0 | ||
| ``` | ||
|
|
||
| ### With Images | ||
| ```python | ||
| # Comparing two 64x64 RGB images | ||
| gt_image = torch.randn(1, 3, 64, 64) | ||
| pred_image = gt_image + torch.randn_like(gt_image) * 0.1 | ||
| # MSE ≈ 0.01 (depends on noise) | ||
| ``` | ||
|
|
||
| ## Technical Details | ||
|
|
||
| ### State Accumulation | ||
|
|
||
| The MSE metric accumulates squared errors across all batches in a list of tensors: | ||
|
|
||
| ```python | ||
| self.add_state("squared_errors", []) # List of tensors | ||
| ``` | ||
|
|
||
| ### Computation | ||
|
|
||
| The final MSE is computed by: | ||
| 1. Concatenating all squared error tensors | ||
| 2. Computing the mean across all elements | ||
|
|
||
| ```python | ||
| all_squared_errors = torch.cat(self.squared_errors) | ||
| mse_value = float(all_squared_errors.mean().item()) | ||
| ``` | ||
|
|
||
| ### Device Handling | ||
|
|
||
| The metric automatically handles device placement: | ||
| - Moves outputs to match ground truth device | ||
| - Works seamlessly with CPU and CUDA tensors | ||
|
|
||
| ### Shape Flexibility | ||
|
|
||
| The metric flattens tensors for computation, supporting: | ||
| - 1D tensors (scalars) | ||
| - 2D tensors (batches) | ||
| - 3D tensors (sequences) | ||
| - 4D tensors (images: batch × channels × height × width) | ||
|
|
||
| ## Related Metrics | ||
|
|
||
| - **RMSE** (Root Mean Squared Error): `√MSE` - same scale as original data | ||
| - **MAE** (Mean Absolute Error): Less sensitive to outliers | ||
| - **PSNR** (Peak Signal-to-Noise Ratio): `10 * log10(MAX²/MSE)` for images | ||
| - **SSIM** (Structural Similarity): Perceptual image quality metric | ||
|
|
||
| ## References | ||
|
|
||
| - [Mean Squared Error - Wikipedia](https://en.wikipedia.org/wiki/Mean_squared_error) | ||
| - Pruna AI Documentation: [Customize Metrics](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/customize_metric.html) | ||
|
|
||
| ## Contributing | ||
|
|
||
| Found a bug or want to improve the MSE metric? See our [Contributing Guide](../../CONTRIBUTING.md). |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you a lot for this detailed summary, are we planning on merging it to Pruna or is it more for giving information? I think this would be even more beneficial as the PR description