Skip to content

Commit

Permalink
Merge pull request #101 from openclimatefix/issue/100-colatte-function
Browse files Browse the repository at this point in the history
add collate function + test
  • Loading branch information
peterdudfield authored Dec 20, 2024
2 parents aae2af3 + 51d714e commit 7d67873
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 13 deletions.
79 changes: 79 additions & 0 deletions ocf_data_sampler/numpy_batch/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from ocf_data_sampler.numpy_batch import NWPBatchKey

import numpy as np
import logging
from typing import Union

logger = logging.getLogger(__name__)


def stack_np_examples_into_batch(dict_list):
"""
Stacks Numpy examples into a batch
See also: `unstack_np_batch_into_examples()` for opposite
Args:
dict_list: A list of dict-like Numpy examples to stack
Returns:
The stacked NumpyBatch object
"""
batch = {}

batch_keys = list(dict_list[0].keys())

for batch_key in batch_keys:
# NWP is nested so treat separately
if batch_key == "nwp":
nwp_batch: dict[str, NWPBatchKey] = {}

# Unpack source keys
nwp_sources = list(dict_list[0]["nwp"].keys())

for nwp_source in nwp_sources:
# Keys can be different for different NWPs
nwp_batch_keys = list(dict_list[0]["nwp"][nwp_source].keys())

nwp_source_batch = {}
for nwp_batch_key in nwp_batch_keys:
nwp_source_batch[nwp_batch_key] = stack_data_list(
[d["nwp"][nwp_source][nwp_batch_key] for d in dict_list],
nwp_batch_key,
)

nwp_batch[nwp_source] = nwp_source_batch

batch["nwp"] = nwp_batch

else:
batch[batch_key] = stack_data_list(
[d[batch_key] for d in dict_list],
batch_key,
)

return batch


def _key_is_constant(batch_key):
is_constant = batch_key.endswith("t0_idx") or batch_key == NWPBatchKey.channel_names
return is_constant


def stack_data_list(
data_list: list,
batch_key: Union[str, NWPBatchKey],
):
"""How to combine data entries for each key
"""
if _key_is_constant(batch_key):
# These are always the same for all examples.
return data_list[0]
try:
return np.stack(data_list)
except Exception as e:
logger.debug(f"Could not stack the following shapes together, ({batch_key})")
shapes = [example.shape for example in data_list]
logger.debug(shapes)
logger.error(e)
raise e
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile

from ocf_data_sampler.config.model import Site
from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration

_top_test_directory = os.path.dirname(os.path.realpath(__file__))

Expand Down Expand Up @@ -269,3 +270,18 @@ def uk_gsp_zarr_path(ds_uk_gsp):
ds_uk_gsp.to_zarr(filename)
yield filename


@pytest.fixture()
def pvnet_config_filename(
tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path
):

# adjust config to point to the zarr file
config = load_yaml_configuration(config_filename)
config.input_data.nwp["ukv"].zarr_path = nwp_ukv_zarr_path
config.input_data.satellite.zarr_path = sat_zarr_path
config.input_data.gsp.zarr_path = uk_gsp_zarr_path

filename = f"{tmp_path}/configuration.yaml"
save_yaml_configuration(config, filename)
return filename
26 changes: 26 additions & 0 deletions tests/numpy_batch/test_collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ocf_data_sampler.numpy_batch import GSPBatchKey, SatelliteBatchKey
from ocf_data_sampler.numpy_batch.collate import stack_np_examples_into_batch
from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset


def test_pvnet(pvnet_config_filename):

# Create dataset object
dataset = PVNetUKRegionalDataset(pvnet_config_filename)

assert len(dataset.locations) == 317
assert len(dataset.valid_t0_times) == 39
assert len(dataset) == 317 * 39

# Generate 2 samples
sample1 = dataset[0]
sample2 = dataset[1]

batch = stack_np_examples_into_batch([sample1, sample2])

assert isinstance(batch, dict)
assert "nwp" in batch
assert isinstance(batch["nwp"], dict)
assert "ukv" in batch["nwp"]
assert GSPBatchKey.gsp in batch
assert SatelliteBatchKey.satellite_actual in batch
13 changes: 0 additions & 13 deletions tests/torch_datasets/test_pvnet_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,6 @@
from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey


@pytest.fixture()
def pvnet_config_filename(tmp_path, config_filename, nwp_ukv_zarr_path, uk_gsp_zarr_path, sat_zarr_path):

# adjust config to point to the zarr file
config = load_yaml_configuration(config_filename)
config.input_data.nwp['ukv'].zarr_path = nwp_ukv_zarr_path
config.input_data.satellite.zarr_path = sat_zarr_path
config.input_data.gsp.zarr_path = uk_gsp_zarr_path

filename = f"{tmp_path}/configuration.yaml"
save_yaml_configuration(config, filename)
return filename


def test_pvnet(pvnet_config_filename):

Expand Down

0 comments on commit 7d67873

Please sign in to comment.