Skip to content

Commit

Permalink
Fix BatchKeys
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukh-P committed Jan 20, 2025
1 parent bdd8732 commit 9b9ff27
Show file tree
Hide file tree
Showing 35 changed files with 313 additions and 106 deletions.
61 changes: 8 additions & 53 deletions pvnet/data/uk_regional_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@
from glob import glob

import torch
from lightning.pytorch import LightningDataModule
from pvnet.data.base_datamodule import BaseDataModule
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_datapipes.batch import (
NumpyBatch,
TensorBatch,
batch_to_tensor,
stack_np_examples_into_batch,
)
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset


class NumpybatchPremadeSamplesDataset(Dataset):
Expand All @@ -31,12 +25,7 @@ def __getitem__(self, idx):
return torch.load(self.sample_paths[idx])


def collate_fn(samples: list[NumpyBatch]) -> TensorBatch:
"""Convert a list of NumpyBatch samples to a tensor batch"""
return batch_to_tensor(stack_np_examples_into_batch(samples))


class DataModule(LightningDataModule):
class DataModule(BaseDataModule):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""

def __init__(
Expand Down Expand Up @@ -64,32 +53,14 @@ def __init__(
val_period: Date range filter for val dataloader.
"""
super().__init__()

if not ((sample_dir is not None) ^ (configuration is not None)):
raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")

if sample_dir is not None:
if any([period != [None, None] for period in [train_period, val_period]]):
raise ValueError("Cannot set `(train/val)_period` with presaved samples")

self.configuration = configuration
self.sample_dir = sample_dir
self.train_period = train_period
self.val_period = val_period

self._common_dataloader_kwargs = dict(
super().__init__(
configuration=configuration,
sample_dir=sample_dir,
batch_size=batch_size,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
train_period=train_period,
val_period=val_period,
)

def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
Expand All @@ -98,19 +69,3 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
def _get_premade_samples_dataset(self, subdir) -> Dataset:
split_dir = f"{self.sample_dir}/{subdir}"
return NumpybatchPremadeSamplesDataset(split_dir)

def train_dataloader(self) -> DataLoader:
"""Construct train dataloader"""
if self.sample_dir is not None:
dataset = self._get_premade_samples_dataset("train")
else:
dataset = self._get_streamed_samples_dataset(*self.train_period)
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self) -> DataLoader:
"""Construct val dataloader"""
if self.sample_dir is not None:
dataset = self._get_premade_samples_dataset("val")
else:
dataset = self._get_streamed_samples_dataset(*self.val_period)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
2 changes: 1 addition & 1 deletion pvnet/models/baseline/last_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
def forward(self, x: dict):
"""Run model forward on dict batch of data"""
# Shape: batch_size, seq_length, n_sites
gsp_yield = x[BatchKey.gsp]
gsp_yield = x["gsp"]

# take the last value non forecaster value and the first in the pv yeild
# (this is the pv site we are preditcting for)
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/baseline/single_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ def __init__(
def forward(self, x: dict):
"""Run model forward on dict batch of data"""
# Returns a single value at all steps
y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len]) + self._value
y_hat = torch.zeros_like(x["gsp"][:, : self.forecast_len]) + self._value
return y_hat
34 changes: 17 additions & 17 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
output_quantiles: Optional[list[float]] = None,
nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None,
sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None,
site_encoder: Optional[AbstractSitesEncoder] = None,
pv_encoder: Optional[AbstractSitesEncoder] = None,
sensor_encoder: Optional[AbstractSitesEncoder] = None,
add_image_embedding_channel: bool = False,
include_gsp_yield_history: bool = True,
Expand All @@ -55,14 +55,14 @@ def __init__(
min_sat_delay_minutes: Optional[int] = 30,
nwp_forecast_minutes: Optional[DictConfig] = None,
nwp_history_minutes: Optional[DictConfig] = None,
site_history_minutes: Optional[int] = None,
pv_history_minutes: Optional[int] = None,
sensor_history_minutes: Optional[int] = None,
sensor_forecast_minutes: Optional[int] = None,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
target_key: str = "gsp",
interval_minutes: int = 30,
nwp_interval_minutes: Optional[DictConfig] = None,
site_interval_minutes: int = 5,
pv_interval_minutes: int = 5,
sat_interval_minutes: int = 5,
sensor_interval_minutes: int = 30,
num_embeddings: Optional[int] = 318,
Expand All @@ -88,7 +88,7 @@ def __init__(
encode the NWP data from 4D into a 1D feature vector from different sources.
sat_encoder: A partially instantiated pytorch Module class used to encode the satellite
data from 4D into a 1D feature vector.
site_encoder: A partially instantiated pytorch Module class used to encode the site-level
pv_encoder: A partially instantiated pytorch Module class used to encode the site-level
PV data from 2D into a 1D feature vector.
add_image_embedding_channel: Add a channel to the NWP and satellite data with the
embedding of the GSP ID.
Expand All @@ -107,14 +107,14 @@ def __init__(
`forecast_minutes` if not provided.
nwp_history_minutes: Period of historical NWP forecast used as input. Defaults to
`history_minutes` if not provided.
site_history_minutes: Length of recent site-level PV data used as
pv_history_minutes: Length of recent site-level PV data used as
input. Defaults to `history_minutes` if not provided.
optimizer: Optimizer factory function used for network.
target_key: The key of the target variable in the batch.
interval_minutes: The interval between each sample of the target data
nwp_interval_minutes: Dictionary of the intervals between each sample of the NWP
data for each source
site_interval_minutes: The interval between each sample of the PV data
pv_interval_minutes: The interval between each sample of the PV data
sat_interval_minutes: The interval between each sample of the satellite data
sensor_interval_minutes: The interval between each sample of the sensor data
num_embeddings: The number of dimensions to use for the image embedding
Expand All @@ -133,7 +133,7 @@ def __init__(
self.include_gsp_yield_history = include_gsp_yield_history
self.include_sat = sat_encoder is not None
self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
self.include_site = site_encoder is not None
self.include_pv = pv_encoder is not None
self.include_sun = include_sun
self.include_time = include_time
self.include_sensor = sensor_encoder is not None
Expand Down Expand Up @@ -218,17 +218,17 @@ def __init__(
# Update num features
fusion_input_features += self.nwp_encoders_dict[nwp_source].out_features

if self.include_site:
assert site_history_minutes is not None
if self.include_pv:
assert pv_history_minutes is not None

self.site_encoder = site_encoder(
sequence_length=site_history_minutes // site_interval_minutes - 1,
self.pv_encoder = pv_encoder(
sequence_length=pv_history_minutes // pv_interval_minutes - 1,
target_key_to_use=self._target_key_name,
input_key_to_use="site",
)

# Update num features
fusion_input_features += self.site_encoder.out_features
fusion_input_features += self.pv_encoder.out_features

if self.include_sensor:
if sensor_history_minutes is None:
Expand Down Expand Up @@ -294,7 +294,7 @@ def forward(self, x):
# ******************* Satellite imagery *************************
if self.include_sat:
# Shape: batch_size, seq_length, channel, height, width
sat_data = x[BatchKey.satellite_actual][:, : self.sat_sequence_len]
sat_data = x["satellite_actual"][:, : self.sat_sequence_len]
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels

if self.add_image_embedding_channel:
Expand Down Expand Up @@ -322,20 +322,20 @@ def forward(self, x):

# *********************** Site Data *************************************
# Add site-level PV yield
if self.include_site:
if self.include_pv:
if self._target_key_name != "site":
modes["site"] = self.site_encoder(x)
modes["site"] = self.pv_encoder(x)
else:
# Target is PV, so only take the history
# Copy batch
x_tmp = x.copy()
x_tmp["site"] = x_tmp["site"][:, : self.history_len + 1]
modes["site"] = self.site_encoder(x_tmp)
modes["site"] = self.pv_encoder(x_tmp)

# *********************** GSP Data ************************************
# add gsp yield history
if self.include_gsp_yield_history:
gsp_history = x[BatchKey.gsp][:, : self.history_len].float()
gsp_history = x["gsp"][:, : self.history_len].float()
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
modes["gsp"] = gsp_history

Expand Down
5 changes: 3 additions & 2 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
kv_res_block_layers: int = 2,
use_id_in_value: bool = False,
target_id_dim: int = 318,
target_key_to_use: str = "site",
target_key_to_use: str = "gsp",
input_key_to_use: str = "site",
num_channels: int = 1,
num_sites_in_inference: int = 1,
Expand Down Expand Up @@ -209,7 +209,8 @@ def _encode_inputs(self, x):
# Shape: [batch size, sequence length, number of sites] -> [8, 197, 1]
# Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
input_data = x[f"{self.input_key_to_use}"]
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
if len(input_data.shape) == 2: # one site per sample
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
if len(input_data.shape) == 4: # Has multiple channels
input_data = input_data[:, :, : self.sequence_length]
input_data = einops.rearrange(input_data, "b id s c -> b (s c) id")
Expand Down
28 changes: 14 additions & 14 deletions pvnet/models/multimodal/unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def teacher_forward(self, x):
# ******************* Satellite imagery *************************
if mode == "sat":
# Shape: batch_size, seq_length, channel, height, width
sat_data = x[BatchKey.satellite_actual][:, : teacher_model.sat_sequence_len]
sat_data = x["satellite_actual"][:, : teacher_model.sat_sequence_len]
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels

if self.add_image_embedding_channel:
id = x[BatchKey.gsp_id].int()
id = x["gsp_id"].int()
sat_data = teacher_model.sat_embed(sat_data, id)

modes[mode] = teacher_model.sat_encoder(sat_data)
Expand All @@ -229,11 +229,11 @@ def teacher_forward(self, x):
nwp_source = mode.removeprefix("nwp/")

# shape: batch_size, seq_len, n_chans, height, width
nwp_data = x[BatchKey.nwp][nwp_source][NWPBatchKey.nwp].float()
nwp_data = x["nwp"][nwp_source]["nwp"].float()
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
nwp_data = torch.clip(nwp_data, min=-50, max=50)
if teacher_model.add_image_embedding_channel:
id = x[BatchKey.gsp_id].int()
id = x["gsp_id"].int()
nwp_data = teacher_model.nwp_embed_dict[nwp_source](nwp_data, id)

nwp_out = teacher_model.nwp_encoders_dict[nwp_source](nwp_data)
Expand All @@ -256,11 +256,11 @@ def forward(self, x, return_modes=False):
# ******************* Satellite imagery *************************
if self.include_sat:
# Shape: batch_size, seq_length, channel, height, width
sat_data = x[BatchKey.satellite_actual][:, : self.sat_sequence_len]
sat_data = x["satellite_actual"][:, : self.sat_sequence_len]
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels

if self.add_image_embedding_channel:
id = x[BatchKey.gsp_id].int()
id = x["gsp_id"].int()
sat_data = self.sat_embed(sat_data, id)
modes["sat"] = self.sat_encoder(sat_data)

Expand All @@ -269,47 +269,47 @@ def forward(self, x, return_modes=False):
# Loop through potentially many NMPs
for nwp_source in self.nwp_encoders_dict:
# shape: batch_size, seq_len, n_chans, height, width
nwp_data = x[BatchKey.nwp][nwp_source][NWPBatchKey.nwp].float()
nwp_data = x["nwp"][nwp_source]["nwp"].float()
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
# Some NWP variables can overflow into NaNs when normalised if they have extreme
# tails
nwp_data = torch.clip(nwp_data, min=-50, max=50)

if self.add_image_embedding_channel:
id = x[BatchKey.gsp_id].int()
id = x["gsp_id"].int()
nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)

nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
modes[f"nwp/{nwp_source}"] = nwp_out

# *********************** PV Data *************************************
# Add site-level PV yield
if self.include_site:
if self.include_pv:
if self._target_key_name != "site":
modes["site"] = self.site_encoder(x)
else:
# Target is PV, so only take the history
pv_history = x[BatchKey.pv][:, : self.history_len].float()
pv_history = x["pv"][:, : self.history_len].float()
modes["site"] = self.site_encoder(pv_history)

# *********************** GSP Data ************************************
# add gsp yield history
if self.include_gsp_yield_history:
gsp_history = x[BatchKey.gsp][:, : self.history_len].float()
gsp_history = x["gsp"][:, : self.history_len].float()
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
modes["gsp"] = gsp_history

# ********************** Embedding of GSP ID ********************
if self.embedding_dim:
id = x[BatchKey.gsp_id].int()
id = x["gsp_id"].int()
id_embedding = self.embed(id)
modes["id"] = id_embedding

if self.include_sun:
sun = torch.cat(
(
x[BatchKey.gsp_solar_azimuth],
x[BatchKey.gsp_solar_elevation],
x["gsp_solar_azimuth"],
x["gsp_solar_elevation"],
),
dim=1,
).float()
Expand Down
Loading

0 comments on commit 9b9ff27

Please sign in to comment.