Skip to content

Commit

Permalink
Draft update - pending on data-sampler solar coords update confirmation
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Mar 9, 2025
1 parent e9837bb commit 2f0d9ac
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
30 changes: 21 additions & 9 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
timestep_intervals_to_plot: Optional[list[int]] = None,
adapt_batches: Optional[bool] = False,
forecast_minutes_ignore: Optional[int] = 0,
solar_position_config: Optional[dict] = None,
):
"""Neural network which combines information from different sources.
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(
self.interval_minutes = interval_minutes
self.min_sat_delay_minutes = min_sat_delay_minutes
self.adapt_batches = adapt_batches
self.solar_position_config = solar_position_config

super().__init__(
history_minutes=history_minutes,
Expand Down Expand Up @@ -345,15 +347,25 @@ def forward(self, x):
modes["id"] = id_embedding

if self.include_sun:
sun = torch.cat(
(
x[f"{self._target_key}_solar_azimuth"],
x[f"{self._target_key}_solar_elevation"],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun

azimuth_key = f"solar_position_{self._target_key}_azimuth"
elevation_key = f"solar_position_{self._target_key}_elevation"

# Fall back to legacy keys
if azimuth_key not in x:
azimuth_key = f"{self._target_key}_solar_azimuth"
elevation_key = f"{self._target_key}_solar_elevation"

if azimuth_key in x and elevation_key in x:
sun = torch.cat(
(
x[azimuth_key],
x[elevation_key],
),
dim=1,
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun

if self.include_time:
time = torch.cat(
Expand Down
20 changes: 20 additions & 0 deletions pvnet/models/multimodal/multimodal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,24 @@ def _adapt_batch(self, batch):
sun_len = self.forecast_len + self.history_len + 1
batch[key] = batch[key][:, :sun_len]

if self.include_sun:
sun_len = self.forecast_len + self.history_len + 1

# Check for solar position keys first
solar_position_keys = []
# Slice off the end of the solar coords data
for s in ["azimuth", "elevation"]:
key = f"solar_position_{self._target_key}_{s}"
if key in batch.keys():
solar_position_keys.append(key)
batch[key] = batch[key][:, :sun_len]

# Check for legacy keys
if not solar_position_keys:
# Slice off the end of the solar coords data
for s in ["solar_azimuth", "solar_elevation"]:
key = f"{self._target_key}_{s}"
if key in batch.keys():
batch[key] = batch[key][:, :sun_len]

return batch

0 comments on commit 2f0d9ac

Please sign in to comment.