Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add collate function + test #101

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions ocf_data_sampler/numpy_batch/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from ocf_data_sampler.numpy_batch import NWPBatchKey

import numpy as np
import logging
from typing import Union

logger = logging.getLogger(__name__)


def stack_np_examples_into_batch(dict_list):
"""
Stacks Numpy examples into a batch

See also: `unstack_np_batch_into_examples()` for opposite

Args:
dict_list: A list of dict-like Numpy examples to stack

Returns:
The stacked NumpyBatch object
"""
batch = {}

batch_keys = list(dict_list[0].keys())

for batch_key in batch_keys:
# NWP is nested so treat separately
if batch_key == "nwp":
nwp_batch: dict[str, NWPBatchKey] = {}

# Unpack source keys
nwp_sources = list(dict_list[0]["nwp"].keys())

for nwp_source in nwp_sources:
# Keys can be different for different NWPs
nwp_batch_keys = list(dict_list[0]["nwp"][nwp_source].keys())

nwp_source_batch = {}
for nwp_batch_key in nwp_batch_keys:
nwp_source_batch[nwp_batch_key] = stack_data_list(
[d["nwp"][nwp_source][nwp_batch_key] for d in dict_list],
nwp_batch_key,
)

nwp_batch[nwp_source] = nwp_source_batch

batch["nwp"] = nwp_batch

else:
batch[batch_key] = stack_data_list(
[d[batch_key] for d in dict_list],
batch_key,
)

return batch


def _key_is_constant(batch_key):
is_constant = batch_key.endswith("t0_idx") or batch_key == NWPBatchKey.channel_names
return is_constant


def stack_data_list(
data_list: list,
batch_key: Union[str, NWPBatchKey],
):
"""How to combine data entries for each key
"""
if _key_is_constant(batch_key):
# These are always the same for all examples.
return data_list[0]
try:
return np.stack(data_list)
except Exception as e:
logger.debug(f"Could not stack the following shapes together, ({batch_key})")
shapes = [example.shape for example in data_list]
logger.debug(shapes)
logger.error(e)
raise e
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile

from ocf_data_sampler.config.model import Site
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration

_top_test_directory = os.path.dirname(os.path.realpath(__file__))

Expand Down Expand Up @@ -269,3 +270,18 @@ def uk_gsp_zarr_path(ds_uk_gsp):
ds_uk_gsp.to_zarr(filename)
yield filename


@pytest.fixture()
def pvnet_config_filename(
tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path
):

# adjust config to point to the zarr file
config = load_yaml_configuration(config_filename)
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
config.input_data.satellite.zarr_path = sat_zarr_path
config.input_data.gsp.zarr_path = uk_gsp_zarr_path

filename = f"{tmp_path}/configuration.yaml"
save_yaml_configuration(config, filename)
return filename
26 changes: 26 additions & 0 deletions tests/numpy_batch/test_collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ocf_data_sampler.numpy_batch import GSPBatchKey, SatelliteBatchKey
from ocf_data_sampler.numpy_batch.collate import stack_np_examples_into_batch
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset


def test_pvnet(pvnet_config_filename):

# Create dataset object
dataset = PVNetUKRegionalDataset(pvnet_config_filename)

assert len(dataset.locations) == 317
assert len(dataset.valid_t0_times) == 39
assert len(dataset) == 317 * 39

# Generate 2 samples
sample1 = dataset[0]
sample2 = dataset[1]

batch = stack_np_examples_into_batch([sample1, sample2])

assert isinstance(batch, dict)
assert "nwp" in batch
assert isinstance(batch["nwp"], dict)
assert "ukv" in batch["nwp"]
assert GSPBatchKey.gsp in batch
assert SatelliteBatchKey.satellite_actual in batch
13 changes: 0 additions & 13 deletions tests/torch_datasets/test_pvnet_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,6 @@
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey


@pytest.fixture()
def pvnet_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path):

# adjust config to point to the zarr file
config = load_yaml_configuration(config_filename)
config.input_data.nwp['ukv'].zarr_path = nwp_ukv_zarr_path
config.input_data.satellite.zarr_path = sat_zarr_path
config.input_data.gsp.zarr_path = uk_gsp_zarr_path

filename = f"{tmp_path}/configuration.yaml"
save_yaml_configuration(config, filename)
return filename


def test_pvnet(pvnet_config_filename):

Expand Down
Loading