diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 2ab3fc68..aea95f0c 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -347,15 +347,17 @@ def forward(self, x): modes["id"] = id_embedding if self.include_sun: - - azimuth_key = f"solar_position_{self._target_key}_azimuth" - elevation_key = f"solar_position_{self._target_key}_elevation" - - # Fall back to legacy keys - if azimuth_key not in x: + # Determine which keys to use + if "solar_azimuth" in x and "solar_elevation" in x: + # Use new standalone keys + azimuth_key = "solar_azimuth" + elevation_key = "solar_elevation" + else: + # Fall back to legacy keys azimuth_key = f"{self._target_key}_solar_azimuth" elevation_key = f"{self._target_key}_solar_elevation" + # Process the sun data if either key set is found if azimuth_key in x and elevation_key in x: sun = torch.cat( ( diff --git a/pvnet/models/multimodal/multimodal_base.py b/pvnet/models/multimodal/multimodal_base.py index 22b68343..7d9d172f 100644 --- a/pvnet/models/multimodal/multimodal_base.py +++ b/pvnet/models/multimodal/multimodal_base.py @@ -40,29 +40,19 @@ def _adapt_batch(self, batch): output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels, )[:, : self.nwp_encoders_dict[nwp_source].sequence_length] - if self.include_sun: - # Slice off the end of the solar coords data - for s in ["solar_azimuth", "solar_elevation"]: - key = f"{self._target_key}_{s}" - if key in batch.keys(): - sun_len = self.forecast_len + self.history_len + 1 - batch[key] = batch[key][:, :sun_len] - if self.include_sun: sun_len = self.forecast_len + self.history_len + 1 - # Check for solar position keys first + # Check for new standalone solar position keys first solar_position_keys = [] - # Slice off the end of the solar coords data - for s in ["azimuth", "elevation"]: - key = f"solar_position_{self._target_key}_{s}" - if key in batch.keys(): - solar_position_keys.append(key) - batch[key] = batch[key][:, :sun_len] + for s in ["solar_azimuth", "solar_elevation"]: + if s in batch.keys(): + solar_position_keys.append(s) + batch[s] = batch[s][:, :sun_len] - # Check for legacy keys + # Check for legacy keys if new keys aren't found if not solar_position_keys: - # Slice off the end of the solar coords data + # Slice off the end of the legacy solar coords data for s in ["solar_azimuth", "solar_elevation"]: key = f"{self._target_key}_{s}" if key in batch.keys(): diff --git a/pvnet/models/multimodal/unimodal_teacher.py b/pvnet/models/multimodal/unimodal_teacher.py index b7aa25fb..1141ff5b 100644 --- a/pvnet/models/multimodal/unimodal_teacher.py +++ b/pvnet/models/multimodal/unimodal_teacher.py @@ -305,15 +305,27 @@ def forward(self, x, return_modes=False): modes["id"] = id_embedding if self.include_sun: - sun = torch.cat( - ( - x["gsp_solar_azimuth"], - x["gsp_solar_elevation"], - ), - dim=1, - ).float() - sun = self.sun_fc1(sun) - modes["sun"] = sun + # Determine which keys to use + if "solar_azimuth" in x and "solar_elevation" in x: + # Use new standalone keys + azimuth_key = "solar_azimuth" + elevation_key = "solar_elevation" + else: + # Fall back to legacy keys + azimuth_key = f"{self._target_key}_solar_azimuth" + elevation_key = f"{self._target_key}_solar_elevation" + + # Process the sun data if either key set is found + if azimuth_key in x and elevation_key in x: + sun = torch.cat( + ( + x[azimuth_key], + x[elevation_key], + ), + dim=1, + ).float() + sun = self.sun_fc1(sun) + modes["sun"] = sun out = self.output_network(modes) diff --git a/pyproject.toml b/pyproject.toml index 1a4990dc..c96b9011 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dynamic = ["version", "readme"] license={file="LICENCE"} dependencies = [ - "ocf_data_sampler==0.1.7", + "ocf-data-sampler>=0.1.13", "ocf_datapipes>=3.3.34", "ocf_ml_metrics>=0.0.11", "numpy", diff --git a/tests/models/multimodal/test_multimodal.py b/tests/models/multimodal/test_multimodal.py index b7c28f73..2a41c333 100644 --- a/tests/models/multimodal/test_multimodal.py +++ b/tests/models/multimodal/test_multimodal.py @@ -45,3 +45,36 @@ def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minute # Backwards on sum drives sum to zero y_quantiles.sum().backward() + + +@pytest.mark.parametrize( + "keys", + [ + ["solar_azimuth", "solar_elevation"], + ["gsp_solar_azimuth", "gsp_solar_elevation"], + ], +) + + +def test_model_with_solar_position_keys(multimodal_model, sample_batch, keys): + """Test that the model works with both new and legacy solar position keys.""" + azimuth_key, elevation_key = keys + batch_copy = sample_batch.copy() + + # Clear all solar keys and add just the ones we're testing + for key in ["solar_azimuth", "solar_elevation", + "gsp_solar_azimuth", "gsp_solar_elevation"]: + if key in batch_copy: + del batch_copy[key] + + # Create solar position data if needed + import torch + batch_size = sample_batch["gsp"].shape[0] + seq_len = multimodal_model.forecast_len + multimodal_model.history_len + 1 + batch_copy[azimuth_key] = torch.rand((batch_size, seq_len)) + batch_copy[elevation_key] = torch.rand((batch_size, seq_len)) + + # Test forward and backward passes + y = multimodal_model(batch_copy) + assert tuple(y.shape) == (2, 16), y.shape + y.sum().backward() diff --git a/tests/models/multimodal/test_unimodal_teacher.py b/tests/models/multimodal/test_unimodal_teacher.py index fbed5e92..a7530de6 100644 --- a/tests/models/multimodal/test_unimodal_teacher.py +++ b/tests/models/multimodal/test_unimodal_teacher.py @@ -104,3 +104,35 @@ def test_model_conversion(unimodal_model_kwargs, sample_batch): y_mm = mm_model(sample_batch) assert (y_um == y_mm).all() + + +@pytest.mark.parametrize( + "keys", + [ + ["solar_azimuth", "solar_elevation"], + ["gsp_solar_azimuth", "gsp_solar_elevation"], + ], +) + + +def test_unimodal_model_with_solar_position_keys(unimodal_teacher_model, sample_batch, keys): + """Test that the unimodal teacher model works with both new and legacy solar position keys.""" + azimuth_key, elevation_key = keys + batch_copy = sample_batch.copy() + + # Clear all solar keys and add just the ones we're testing + for key in ["solar_azimuth", "solar_elevation", + "gsp_solar_azimuth", "gsp_solar_elevation"]: + if key in batch_copy: + del batch_copy[key] + + # Create solar position data + batch_size = sample_batch["gsp"].shape[0] + seq_len = unimodal_teacher_model.forecast_len + unimodal_teacher_model.history_len + 1 + batch_copy[azimuth_key] = torch.rand((batch_size, seq_len)) + batch_copy[elevation_key] = torch.rand((batch_size, seq_len)) + + # Test forward and backward passes + y = unimodal_teacher_model(batch_copy) + assert tuple(y.shape) == (2, 16), y.shape + y.sum().backward()