Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solar decoupling update #335

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
timestep_intervals_to_plot: Optional[list[int]] = None,
adapt_batches: Optional[bool] = False,
forecast_minutes_ignore: Optional[int] = 0,
solar_position_config: Optional[dict] = None,
):
"""Neural network which combines information from different sources.

Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(
self.interval_minutes = interval_minutes
self.min_sat_delay_minutes = min_sat_delay_minutes
self.adapt_batches = adapt_batches
self.solar_position_config = solar_position_config

super().__init__(
history_minutes=history_minutes,
Expand Down Expand Up @@ -345,15 +347,27 @@ def forward(self, x):
modes["id"] = id_embedding

if self.include_sun:
sun = torch.cat(
(
x[f"{self._target_key}_solar_azimuth"],
x[f"{self._target_key}_solar_elevation"],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun
# Determine which keys to use
if "solar_azimuth" in x and "solar_elevation" in x:
# Use new standalone keys
azimuth_key = "solar_azimuth"
elevation_key = "solar_elevation"
else:
# Fall back to legacy keys
azimuth_key = f"{self._target_key}_solar_azimuth"
elevation_key = f"{self._target_key}_solar_elevation"

# Process the sun data if either key set is found
if azimuth_key in x and elevation_key in x:
sun = torch.cat(
(
x[azimuth_key],
x[elevation_key],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun

if self.include_time:
time = torch.cat(
Expand Down
20 changes: 15 additions & 5 deletions pvnet/models/multimodal/multimodal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,21 @@ def _adapt_batch(self, batch):
)[:, : self.nwp_encoders_dict[nwp_source].sequence_length]

if self.include_sun:
# Slice off the end of the solar coords data
sun_len = self.forecast_len + self.history_len + 1

solar_position_keys = []
# Slife off end of solar coords
for s in ["solar_azimuth", "solar_elevation"]:
key = f"{self._target_key}_{s}"
if key in batch.keys():
sun_len = self.forecast_len + self.history_len + 1
batch[key] = batch[key][:, :sun_len]
if s in batch.keys():
solar_position_keys.append(s)
batch[s] = batch[s][:, :sun_len]

# Check for legacy keys if new keys aren't found
if not solar_position_keys:
# Slice off the end of the legacy solar coords
for s in ["solar_azimuth", "solar_elevation"]:
key = f"{self._target_key}_{s}"
if key in batch.keys():
batch[key] = batch[key][:, :sun_len]

return batch
30 changes: 21 additions & 9 deletions pvnet/models/multimodal/unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,27 @@ def forward(self, x, return_modes=False):
modes["id"] = id_embedding

if self.include_sun:
sun = torch.cat(
(
x["gsp_solar_azimuth"],
x["gsp_solar_elevation"],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun
# Determine which keys to use
if "solar_azimuth" in x and "solar_elevation" in x:
# Use new standalone keys
azimuth_key = "solar_azimuth"
elevation_key = "solar_elevation"
else:
# Fall back to legacy keys
azimuth_key = f"{self._target_key}_solar_azimuth"
elevation_key = f"{self._target_key}_solar_elevation"

# Process the sun data if either key set is found
if azimuth_key in x and elevation_key in x:
sun = torch.cat(
(
x[azimuth_key],
x[elevation_key],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun

out = self.output_network(modes)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dynamic = ["version", "readme"]
license={file="LICENCE"}

dependencies = [
"ocf_data_sampler==0.1.7",
"ocf-data-sampler>=0.1.13",
"ocf_datapipes>=3.3.34",
"ocf_ml_metrics>=0.0.11",
"numpy",
Expand Down
31 changes: 31 additions & 0 deletions tests/models/multimodal/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,34 @@ def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minute

# Backwards on sum drives sum to zero
y_quantiles.sum().backward()


@pytest.mark.parametrize(
"keys",
[
["solar_azimuth", "solar_elevation"],
["gsp_solar_azimuth", "gsp_solar_elevation"],
],
)
def test_model_with_solar_position_keys(multimodal_model, sample_batch, keys):
"""Test that the model works with both new and legacy solar position keys."""
azimuth_key, elevation_key = keys
batch_copy = sample_batch.copy()

# Clear all solar keys and add just the ones we're testing
for key in ["solar_azimuth", "solar_elevation", "gsp_solar_azimuth", "gsp_solar_elevation"]:
if key in batch_copy:
del batch_copy[key]

# Create solar position data if needed
import torch

batch_size = sample_batch["gsp"].shape[0]
seq_len = multimodal_model.forecast_len + multimodal_model.history_len + 1
batch_copy[azimuth_key] = torch.rand((batch_size, seq_len))
batch_copy[elevation_key] = torch.rand((batch_size, seq_len))

# Test forward and backward passes
y = multimodal_model(batch_copy)
assert tuple(y.shape) == (2, 16), y.shape
y.sum().backward()
29 changes: 29 additions & 0 deletions tests/models/multimodal/test_unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,32 @@ def test_model_conversion(unimodal_model_kwargs, sample_batch):
y_mm = mm_model(sample_batch)

assert (y_um == y_mm).all()


@pytest.mark.parametrize(
"keys",
[
["solar_azimuth", "solar_elevation"],
["gsp_solar_azimuth", "gsp_solar_elevation"],
],
)
def test_unimodal_model_with_solar_position_keys(unimodal_teacher_model, sample_batch, keys):
"""Test that the unimodal teacher model works with both new and legacy solar position keys."""
azimuth_key, elevation_key = keys
batch_copy = sample_batch.copy()

# Clear all solar keys and add just the ones we're testing
for key in ["solar_azimuth", "solar_elevation", "gsp_solar_azimuth", "gsp_solar_elevation"]:
if key in batch_copy:
del batch_copy[key]

# Create solar position data
batch_size = sample_batch["gsp"].shape[0]
seq_len = unimodal_teacher_model.forecast_len + unimodal_teacher_model.history_len + 1
batch_copy[azimuth_key] = torch.rand((batch_size, seq_len))
batch_copy[elevation_key] = torch.rand((batch_size, seq_len))

# Test forward and backward passes
y = unimodal_teacher_model(batch_copy)
assert tuple(y.shape) == (2, 16), y.shape
y.sum().backward()