Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replaced datapipe with ocf-data-sampler #323

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
18 changes: 13 additions & 5 deletions scripts/backtest_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import xarray as xr
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from ocf_data_sampler.config import load_yaml_configuration
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
from omegaconf import DictConfig
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -161,7 +162,7 @@ def get_sites_ds(config_path: str) -> xr.Dataset:
class ModelPipe:
"""A class to conveniently make and process predictions from batches"""

def __init__(self, model, ds_site: xr.Dataset):
def __init__(self, model, ds_site: xr.Dataset, config_path: str):
"""A class to conveniently make and process predictions from batches

Args:
Expand All @@ -170,6 +171,7 @@ def __init__(self, model, ds_site: xr.Dataset):
"""
self.model = model
self.ds_site = ds_site
self.config_path = config_path

def predict_batch(self, sample: dict) -> xr.Dataset:
"""Run the sample through the model and compile the predictions into an xarray DataArray
Expand All @@ -183,14 +185,20 @@ def predict_batch(self, sample: dict) -> xr.Dataset:
# Convert sample to tensor and move to device
sample_tensor = {k: torch.from_numpy(v).to(device) for k, v in sample.items()}

config = load_yaml_configuration(self.config_path)

interval_start = np.timedelta64(config.input_data.site.interval_start_minutes, "m")
interval_end = np.timedelta64(config.input_data.site.interval_end_minutes, "m")
time_resolution = np.timedelta64(config.input_data.site.time_resolution_minutes, "m")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would probably unpack the config in the init or maybe even move the whole thing to main and pass extracted interval_end or whatever you end up using to avoid doing it again and again for every batch


t0 = pd.Timestamp(sample["site_init_time_utc"][0])
site_id = sample["site_id"][0]

# Get valid times for this forecast
valid_times = pd.date_range(
start=t0 + pd.Timedelta(minutes=FREQ_MINS),
periods=len(sample["site_target_time_utc"]),
freq=f"{FREQ_MINS}min",
start=t0 + pd.Timedelta(interval_start),
end=t0 + pd.Timedelta(interval_end),
freq=f"{time_resolution.astype(int)}min",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lifting from config is a good way to do it, but the sample ModelPipe is getting already has history, so the first time in it is not t0, it's t0+interval_start. You can get t0 from last time - interval_end for example, or first time - interval_start (you can see how that happens in this function)

Copy link
Author

@zaryab-ali zaryab-ali Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if the following logic is ok
`last_time = pd.Timestamp(sample["site_target_time_utc"][-1])
t0 = last_time - pd.Timedelta(interval_end)

valid_times = pd.date_range(
start=t0 + pd.Timedelta(interval_start),
end=t0 + pd.Timedelta(interval_end),
freq=f"{time_resolution.astype(int)}min"
)`

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I should've explained how time windows in our samples work better. When a forecast is made at time t0 (which we also sometimes refer to as the init time), it will create a forecast for x minutes forward, which is what interval_end governs (in that case, interval_end = x). To do that, the model also needs information on what the generation was just before, so we supply to it y minutes of history, which is what interval_start governs (interval_start = -y in this example). So any sample will contain y+x minutes of generation data, some for history and some to check the forecast against at training time. When running inference, we are only interested in [t0, t0+interval_end] period, so these are the dates we need to extract from the sample.

)

# Get capacity for this site
Expand Down Expand Up @@ -280,7 +288,7 @@ def main(config: DictConfig):
model = model.eval().to(device)

# Create object to make predictions
model_pipe = ModelPipe(model, ds_site)
model_pipe = ModelPipe(model, ds_site, config.datamodule.configuration)

# Loop through the samples
pbar = tqdm(total=len(dataset))
Expand Down