Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 11, 2025
1 parent 26ce76c commit 7c820a0
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def forward(self, x):
# Fall back to legacy keys
azimuth_key = f"{self._target_key}_solar_azimuth"
elevation_key = f"{self._target_key}_solar_elevation"

# Process the sun data if either key set is found
if azimuth_key in x and elevation_key in x:
sun = torch.cat(
Expand Down
4 changes: 2 additions & 2 deletions pvnet/models/multimodal/multimodal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def _adapt_batch(self, batch):

if self.include_sun:
sun_len = self.forecast_len + self.history_len + 1

solar_position_keys = []
# Slife off end of solar coords
for s in ["solar_azimuth", "solar_elevation"]:
if s in batch.keys():
solar_position_keys.append(s)
batch[s] = batch[s][:, :sun_len]

# Check for legacy keys if new keys aren't found
if not solar_position_keys:
# Slice off the end of the legacy solar coords
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/multimodal/unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def forward(self, x, return_modes=False):
# Fall back to legacy keys
azimuth_key = f"{self._target_key}_solar_azimuth"
elevation_key = f"{self._target_key}_solar_elevation"

# Process the sun data if either key set is found
if azimuth_key in x and elevation_key in x:
sun = torch.cat(
Expand Down
16 changes: 7 additions & 9 deletions tests/models/multimodal/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,25 @@ def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minute
["gsp_solar_azimuth", "gsp_solar_elevation"],
],
)


def test_model_with_solar_position_keys(multimodal_model, sample_batch, keys):
"""Test that the model works with both new and legacy solar position keys."""
azimuth_key, elevation_key = keys
azimuth_key, elevation_key = keys
batch_copy = sample_batch.copy()

# Clear all solar keys and add just the ones we're testing
for key in ["solar_azimuth", "solar_elevation",
"gsp_solar_azimuth", "gsp_solar_elevation"]:
for key in ["solar_azimuth", "solar_elevation", "gsp_solar_azimuth", "gsp_solar_elevation"]:
if key in batch_copy:
del batch_copy[key]

# Create solar position data if needed
import torch

batch_size = sample_batch["gsp"].shape[0]
seq_len = multimodal_model.forecast_len + multimodal_model.history_len + 1
batch_copy[azimuth_key] = torch.rand((batch_size, seq_len))
batch_copy[elevation_key] = torch.rand((batch_size, seq_len))

# Test forward and backward passes
y = multimodal_model(batch_copy)
y = multimodal_model(batch_copy)
assert tuple(y.shape) == (2, 16), y.shape
y.sum().backward()
15 changes: 6 additions & 9 deletions tests/models/multimodal/test_unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,23 @@ def test_model_conversion(unimodal_model_kwargs, sample_batch):
["gsp_solar_azimuth", "gsp_solar_elevation"],
],
)


def test_unimodal_model_with_solar_position_keys(unimodal_teacher_model, sample_batch, keys):
"""Test that the unimodal teacher model works with both new and legacy solar position keys."""
azimuth_key, elevation_key = keys
azimuth_key, elevation_key = keys
batch_copy = sample_batch.copy()

# Clear all solar keys and add just the ones we're testing
for key in ["solar_azimuth", "solar_elevation",
"gsp_solar_azimuth", "gsp_solar_elevation"]:
for key in ["solar_azimuth", "solar_elevation", "gsp_solar_azimuth", "gsp_solar_elevation"]:
if key in batch_copy:
del batch_copy[key]

# Create solar position data
batch_size = sample_batch["gsp"].shape[0]
seq_len = unimodal_teacher_model.forecast_len + unimodal_teacher_model.history_len + 1
batch_copy[azimuth_key] = torch.rand((batch_size, seq_len))
batch_copy[elevation_key] = torch.rand((batch_size, seq_len))

# Test forward and backward passes
y = unimodal_teacher_model(batch_copy)
y = unimodal_teacher_model(batch_copy)
assert tuple(y.shape) == (2, 16), y.shape
y.sum().backward()

0 comments on commit 7c820a0

Please sign in to comment.