-
-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
replaced datapipe with ocf-data-sampler #323
base: main
Are you sure you want to change the base?
Changes from 4 commits
457d7d7
4b0c5ff
ac2d8ef
99b4f8a
692446a
0cdbc16
b00cfae
5e38166
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
import xarray as xr | ||
from huggingface_hub import hf_hub_download | ||
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME | ||
from ocf_data_sampler.config import load_yaml_configuration | ||
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset | ||
from omegaconf import DictConfig | ||
from torch.utils.data import DataLoader | ||
|
@@ -161,7 +162,7 @@ def get_sites_ds(config_path: str) -> xr.Dataset: | |
class ModelPipe: | ||
"""A class to conveniently make and process predictions from batches""" | ||
|
||
def __init__(self, model, ds_site: xr.Dataset): | ||
def __init__(self, model, ds_site: xr.Dataset, config_path: str): | ||
"""A class to conveniently make and process predictions from batches | ||
|
||
Args: | ||
|
@@ -170,6 +171,7 @@ def __init__(self, model, ds_site: xr.Dataset): | |
""" | ||
self.model = model | ||
self.ds_site = ds_site | ||
self.config_path = config_path | ||
|
||
def predict_batch(self, sample: dict) -> xr.Dataset: | ||
"""Run the sample through the model and compile the predictions into an xarray DataArray | ||
|
@@ -183,14 +185,20 @@ def predict_batch(self, sample: dict) -> xr.Dataset: | |
# Convert sample to tensor and move to device | ||
sample_tensor = {k: torch.from_numpy(v).to(device) for k, v in sample.items()} | ||
|
||
config = load_yaml_configuration(self.config_path) | ||
|
||
interval_start = np.timedelta64(config.input_data.site.interval_start_minutes, "m") | ||
interval_end = np.timedelta64(config.input_data.site.interval_end_minutes, "m") | ||
time_resolution = np.timedelta64(config.input_data.site.time_resolution_minutes, "m") | ||
|
||
t0 = pd.Timestamp(sample["site_init_time_utc"][0]) | ||
site_id = sample["site_id"][0] | ||
|
||
# Get valid times for this forecast | ||
valid_times = pd.date_range( | ||
start=t0 + pd.Timedelta(minutes=FREQ_MINS), | ||
periods=len(sample["site_target_time_utc"]), | ||
freq=f"{FREQ_MINS}min", | ||
start=t0 + pd.Timedelta(interval_start), | ||
end=t0 + pd.Timedelta(interval_end), | ||
freq=f"{time_resolution.astype(int)}min", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lifting from config is a good way to do it, but the sample ModelPipe is getting already has history, so the first time in it is not t0, it's t0+interval_start. You can get t0 from last time - interval_end for example, or first time - interval_start (you can see how that happens in this function) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me know if the following logic is ok valid_times = pd.date_range( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I should've explained how time windows in our samples work better. When a forecast is made at time t0 (which we also sometimes refer to as the init time), it will create a forecast for x minutes forward, which is what |
||
) | ||
|
||
# Get capacity for this site | ||
|
@@ -280,7 +288,7 @@ def main(config: DictConfig): | |
model = model.eval().to(device) | ||
|
||
# Create object to make predictions | ||
model_pipe = ModelPipe(model, ds_site) | ||
model_pipe = ModelPipe(model, ds_site, config.datamodule.configuration) | ||
|
||
# Loop through the samples | ||
pbar = tqdm(total=len(dataset)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I would probably unpack the config in the init or maybe even move the whole thing to main and pass extracted interval_end or whatever you end up using to avoid doing it again and again for every batch