From 7c820a091da4dfabd505e20c6ac07fc7d67edf89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 16:05:55 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/multimodal/multimodal.py | 2 +- pvnet/models/multimodal/multimodal_base.py | 4 ++-- pvnet/models/multimodal/unimodal_teacher.py | 2 +- tests/models/multimodal/test_multimodal.py | 16 +++++++--------- tests/models/multimodal/test_unimodal_teacher.py | 15 ++++++--------- 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index aea95f0c..1a7668f5 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -356,7 +356,7 @@ def forward(self, x): # Fall back to legacy keys azimuth_key = f"{self._target_key}_solar_azimuth" elevation_key = f"{self._target_key}_solar_elevation" - + # Process the sun data if either key set is found if azimuth_key in x and elevation_key in x: sun = torch.cat( diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index 22b9a2a5..24941a9f 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -42,14 +42,14 @@ def _adapt_batch(self, batch): if self.include_sun: sun_len = self.forecast_len + self.history_len + 1 - + solar_position_keys = [] # Slife off end of solar coords for s in ["solar_azimuth", "solar_elevation"]: if s in batch.keys(): solar_position_keys.append(s) batch[s] = batch[s][:, :sun_len] - + # Check for legacy keys if new keys aren't found if not solar_position_keys: # Slice off the end of the legacy solar coords diff --git a/pvnet/models/multimodal/unimodal_teacher.py b/pvnet/models/multimodal/unimodal_teacher.py index 1141ff5b..ff697b9f 100644 --- a/pvnet/models/multimodal/unimodal_teacher.py +++ b/pvnet/models/multimodal/unimodal_teacher.py @@ -314,7 +314,7 @@ def forward(self, x, return_modes=False): # Fall back to legacy keys azimuth_key = f"{self._target_key}_solar_azimuth" elevation_key = f"{self._target_key}_solar_elevation" - + # Process the sun data if either key set is found if azimuth_key in x and elevation_key in x: sun = torch.cat( diff --git a/tests/models/multimodal/test_multimodal.py b/tests/models/multimodal/test_multimodal.py index 2a41c333..62892e4f 100644 --- a/tests/models/multimodal/test_multimodal.py +++ b/tests/models/multimodal/test_multimodal.py @@ -54,27 +54,25 @@ def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minute ["gsp_solar_azimuth", "gsp_solar_elevation"], ], ) - - def test_model_with_solar_position_keys(multimodal_model, sample_batch, keys): """Test that the model works with both new and legacy solar position keys.""" - azimuth_key, elevation_key = keys + azimuth_key, elevation_key = keys batch_copy = sample_batch.copy() - + # Clear all solar keys and add just the ones we're testing - for key in ["solar_azimuth", "solar_elevation", - "gsp_solar_azimuth", "gsp_solar_elevation"]: + for key in ["solar_azimuth", "solar_elevation", "gsp_solar_azimuth", "gsp_solar_elevation"]: if key in batch_copy: del batch_copy[key] - + # Create solar position data if needed import torch + batch_size = sample_batch["gsp"].shape[0] seq_len = multimodal_model.forecast_len + multimodal_model.history_len + 1 batch_copy[azimuth_key] = torch.rand((batch_size, seq_len)) batch_copy[elevation_key] = torch.rand((batch_size, seq_len)) - + # Test forward and backward passes - y = multimodal_model(batch_copy) + y = multimodal_model(batch_copy) assert tuple(y.shape) == (2, 16), y.shape y.sum().backward() diff --git a/tests/models/multimodal/test_unimodal_teacher.py b/tests/models/multimodal/test_unimodal_teacher.py index a7530de6..689f3dfd 100644 --- a/tests/models/multimodal/test_unimodal_teacher.py +++ b/tests/models/multimodal/test_unimodal_teacher.py @@ -113,26 +113,23 @@ def test_model_conversion(unimodal_model_kwargs, sample_batch): ["gsp_solar_azimuth", "gsp_solar_elevation"], ], ) - - def test_unimodal_model_with_solar_position_keys(unimodal_teacher_model, sample_batch, keys): """Test that the unimodal teacher model works with both new and legacy solar position keys.""" - azimuth_key, elevation_key = keys + azimuth_key, elevation_key = keys batch_copy = sample_batch.copy() - + # Clear all solar keys and add just the ones we're testing - for key in ["solar_azimuth", "solar_elevation", - "gsp_solar_azimuth", "gsp_solar_elevation"]: + for key in ["solar_azimuth", "solar_elevation", "gsp_solar_azimuth", "gsp_solar_elevation"]: if key in batch_copy: del batch_copy[key] - + # Create solar position data batch_size = sample_batch["gsp"].shape[0] seq_len = unimodal_teacher_model.forecast_len + unimodal_teacher_model.history_len + 1 batch_copy[azimuth_key] = torch.rand((batch_size, seq_len)) batch_copy[elevation_key] = torch.rand((batch_size, seq_len)) - + # Test forward and backward passes - y = unimodal_teacher_model(batch_copy) + y = unimodal_teacher_model(batch_copy) assert tuple(y.shape) == (2, 16), y.shape y.sum().backward()