diff --git a/modulus/datapipes/healpix/couplers.py b/modulus/datapipes/healpix/couplers.py index f73d09776a..42afb162ee 100644 --- a/modulus/datapipes/healpix/couplers.py +++ b/modulus/datapipes/healpix/couplers.py @@ -175,12 +175,12 @@ def set_coupled_fields(self, coupled_fields): :, :, :, self.coupled_channel_indices, :, : ].permute(0, 2, 3, 1, 4, 5) self.preset_coupled_fields = th.empty( - [self.coupled_integration_dim, self.batch_size, self.timevar_dim] + [self.coupled_integration_dim, coupled_fields.shape[0], self.timevar_dim] + list(self.spatial_dims) ) for i in range(len(self.preset_coupled_fields)): self.preset_coupled_fields[i, :, :, :, :, :] = coupled_fields[ - 0, 0, -1, :, :, : + :, -1, :, :, :, : ] # flag for construct integrated coupling method to use this array self.coupled_mode = True diff --git a/test/datapipes/test_healpix_couple.py b/test/datapipes/test_healpix_couple.py index 7ee24aacc8..55582c6ebe 100644 --- a/test/datapipes/test_healpix_couple.py +++ b/test/datapipes/test_healpix_couple.py @@ -22,6 +22,7 @@ import numpy as np import pandas as pd import pytest +import torch as th import xarray as xr from omegaconf import DictConfig, OmegaConf from pytest_utils import nfsdata_or_fail @@ -142,6 +143,25 @@ def test_ConstantCoupler(data_dir, dataset_name, scaling_dict, pytestconfig): expected = np.expand_dims(coupled_scaling["std"].to_numpy(), (0, 2, 3, 4)) assert np.array_equal(expected, coupler.coupled_scaling["std"]) + coupler.coupled_channel_indices = [0, 1] + coupled_fields_batch_size = 4 + coupled_fields_timedim = 2 + coupled_fields = th.rand( + coupled_fields_batch_size, + coupler.spatial_dims[0], + coupled_fields_timedim, + len(coupler.coupled_channel_indices), + coupler.spatial_dims[1], + coupler.spatial_dims[2], + ) + expected_shape = [ + coupler.coupled_integration_dim, + coupled_fields_batch_size, + coupler.timevar_dim, + ] + list(coupler.spatial_dims) + coupler.set_coupled_fields(coupled_fields) + assert list(coupler.preset_coupled_fields.shape) == expected_shape + DistributedManager.cleanup() @@ -200,6 +220,40 @@ def test_TrailingAverageCoupler(data_dir, dataset_name, scaling_dict, pytestconf expected = np.expand_dims(coupled_scaling["std"].to_numpy(), (0, 2, 3, 4)) assert np.array_equal(expected, coupler.coupled_scaling["std"]) + averaging_window_max_indices = [ + i // pd.Timedelta(data_time_step) for i in coupler.input_times + ] + di = averaging_window_max_indices[0] + averaging_slices = [] + for j in range(coupler.coupled_integration_dim): + averaging_slices.append([]) + for i, r in enumerate(averaging_window_max_indices): + averaging_slices[j].append( + slice( + coupler.input_time_dim * j * di + i * di, + coupler.input_time_dim * j * di + r, + ) + ) + coupler.averaging_slices = averaging_slices + coupler.coupled_channel_indices = [0, 1] + coupled_fields_batch_size = 4 + coupled_fields_timedim = 4 + coupled_fields = th.rand( + coupled_fields_batch_size, + coupler.spatial_dims[0], + coupled_fields_timedim, + len(coupler.coupled_channel_indices), + coupler.spatial_dims[1], + coupler.spatial_dims[2], + ) + expected_shape = [ + coupler.coupled_integration_dim, + coupled_fields_batch_size, + coupler.timevar_dim, + ] + list(coupler.spatial_dims) + coupler.set_coupled_fields(coupled_fields) + assert list(coupler.preset_coupled_fields.shape) == expected_shape + DistributedManager.cleanup()