Skip to content

Commit

Permalink
Initial updates
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Mar 11, 2025
1 parent 2f0d9ac commit b842dad
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 33 deletions.
14 changes: 8 additions & 6 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,15 +347,17 @@ def forward(self, x):
modes["id"] = id_embedding

if self.include_sun:

azimuth_key = f"solar_position_{self._target_key}_azimuth"
elevation_key = f"solar_position_{self._target_key}_elevation"

# Fall back to legacy keys
if azimuth_key not in x:
# Determine which keys to use
if "solar_azimuth" in x and "solar_elevation" in x:
# Use new standalone keys
azimuth_key = "solar_azimuth"
elevation_key = "solar_elevation"
else:
# Fall back to legacy keys
azimuth_key = f"{self._target_key}_solar_azimuth"
elevation_key = f"{self._target_key}_solar_elevation"

# Process the sun data if either key set is found
if azimuth_key in x and elevation_key in x:
sun = torch.cat(
(
Expand Down
24 changes: 7 additions & 17 deletions pvnet/models/multimodal/multimodal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,19 @@ def _adapt_batch(self, batch):
output_size=self.nwp_encoders_dict[nwp_source].image_size_pixels,
)[:, : self.nwp_encoders_dict[nwp_source].sequence_length]

if self.include_sun:
# Slice off the end of the solar coords data
for s in ["solar_azimuth", "solar_elevation"]:
key = f"{self._target_key}_{s}"
if key in batch.keys():
sun_len = self.forecast_len + self.history_len + 1
batch[key] = batch[key][:, :sun_len]

if self.include_sun:
sun_len = self.forecast_len + self.history_len + 1

# Check for solar position keys first
# Check for new standalone solar position keys first
solar_position_keys = []
# Slice off the end of the solar coords data
for s in ["azimuth", "elevation"]:
key = f"solar_position_{self._target_key}_{s}"
if key in batch.keys():
solar_position_keys.append(key)
batch[key] = batch[key][:, :sun_len]
for s in ["solar_azimuth", "solar_elevation"]:
if s in batch.keys():
solar_position_keys.append(s)
batch[s] = batch[s][:, :sun_len]

# Check for legacy keys
# Check for legacy keys if new keys aren't found
if not solar_position_keys:
# Slice off the end of the solar coords data
# Slice off the end of the legacy solar coords data
for s in ["solar_azimuth", "solar_elevation"]:
key = f"{self._target_key}_{s}"
if key in batch.keys():
Expand Down
30 changes: 21 additions & 9 deletions pvnet/models/multimodal/unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,27 @@ def forward(self, x, return_modes=False):
modes["id"] = id_embedding

if self.include_sun:
sun = torch.cat(
(
x["gsp_solar_azimuth"],
x["gsp_solar_elevation"],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun
# Determine which keys to use
if "solar_azimuth" in x and "solar_elevation" in x:
# Use new standalone keys
azimuth_key = "solar_azimuth"
elevation_key = "solar_elevation"
else:
# Fall back to legacy keys
azimuth_key = f"{self._target_key}_solar_azimuth"
elevation_key = f"{self._target_key}_solar_elevation"

# Process the sun data if either key set is found
if azimuth_key in x and elevation_key in x:
sun = torch.cat(
(
x[azimuth_key],
x[elevation_key],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun

out = self.output_network(modes)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dynamic = ["version", "readme"]
license={file="LICENCE"}

dependencies = [
"ocf_data_sampler==0.1.7",
"ocf-data-sampler>=0.1.13",
"ocf_datapipes>=3.3.34",
"ocf_ml_metrics>=0.0.11",
"numpy",
Expand Down
33 changes: 33 additions & 0 deletions tests/models/multimodal/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,36 @@ def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minute

# Backwards on sum drives sum to zero
y_quantiles.sum().backward()


@pytest.mark.parametrize(
"keys",
[
["solar_azimuth", "solar_elevation"],
["gsp_solar_azimuth", "gsp_solar_elevation"],
],
)


def test_model_with_solar_position_keys(multimodal_model, sample_batch, keys):
"""Test that the model works with both new and legacy solar position keys."""
azimuth_key, elevation_key = keys
batch_copy = sample_batch.copy()

# Clear all solar keys and add just the ones we're testing
for key in ["solar_azimuth", "solar_elevation",
"gsp_solar_azimuth", "gsp_solar_elevation"]:
if key in batch_copy:
del batch_copy[key]

# Create solar position data if needed
import torch
batch_size = sample_batch["gsp"].shape[0]
seq_len = multimodal_model.forecast_len + multimodal_model.history_len + 1
batch_copy[azimuth_key] = torch.rand((batch_size, seq_len))
batch_copy[elevation_key] = torch.rand((batch_size, seq_len))

# Test forward and backward passes
y = multimodal_model(batch_copy)
assert tuple(y.shape) == (2, 16), y.shape
y.sum().backward()
32 changes: 32 additions & 0 deletions tests/models/multimodal/test_unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,35 @@ def test_model_conversion(unimodal_model_kwargs, sample_batch):
y_mm = mm_model(sample_batch)

assert (y_um == y_mm).all()


@pytest.mark.parametrize(
"keys",
[
["solar_azimuth", "solar_elevation"],
["gsp_solar_azimuth", "gsp_solar_elevation"],
],
)


def test_unimodal_model_with_solar_position_keys(unimodal_teacher_model, sample_batch, keys):
"""Test that the unimodal teacher model works with both new and legacy solar position keys."""
azimuth_key, elevation_key = keys
batch_copy = sample_batch.copy()

# Clear all solar keys and add just the ones we're testing
for key in ["solar_azimuth", "solar_elevation",
"gsp_solar_azimuth", "gsp_solar_elevation"]:
if key in batch_copy:
del batch_copy[key]

# Create solar position data
batch_size = sample_batch["gsp"].shape[0]
seq_len = unimodal_teacher_model.forecast_len + unimodal_teacher_model.history_len + 1
batch_copy[azimuth_key] = torch.rand((batch_size, seq_len))
batch_copy[elevation_key] = torch.rand((batch_size, seq_len))

# Test forward and backward passes
y = unimodal_teacher_model(batch_copy)
assert tuple(y.shape) == (2, 16), y.shape
y.sum().backward()

0 comments on commit b842dad

Please sign in to comment.