Skip to content

[Regridder] Expand dims move to smmregrid#2690

Open
oloapinivad wants to merge 2 commits intomainfrom
remove-expand-dims
Open

[Regridder] Expand dims move to smmregrid#2690
oloapinivad wants to merge 2 commits intomainfrom
remove-expand-dims

Conversation

@oloapinivad
Copy link
Collaborator

PR description:

Cleaning up the regridder functionalities to aling with upcoming update in smmregrid jhardenberg/smmregrid#59
Specifically, remove the expand_dims feature


  • Docstrings are updated if needed.
  • Changelog is updated.

@oloapinivad oloapinivad added the run tests Set this up to let test run label Feb 16, 2026
@codecov
Copy link

codecov bot commented Feb 16, 2026

❌ 5 Tests Failed:

Tests completed Failed Passed Skipped
454 5 449 1
View the top 2 failed test(s) by shortest run time
tests/test_regrid.py::TestRegridder::test_basic_interpolation[reader_arguments6]
Stack Traces | 1.21s run time
self = <xarray.Dataset> Size: 15MB
Dimensions:              (src_grid_rank: 1, dst_grid_rank: 2,
                          sr...con,r180x90 /tmp/tmpn8u27fz1 /tmp/t...
    CDO:            Climate Data Operators version 2.4.4 (https://mpimet.mpg....
name = None

    def _construct_dataarray(self, name: Hashable) -> DataArray:
        """Construct a DataArray by indexing this dataset"""
        from xarray.core.dataarray import DataArray
    
        try:
>           variable = self._variables[name]
                       ^^^^^^^^^^^^^^^^^^^^^
E           KeyError: None

.../aqua/lib/python3.12.../xarray/core/dataset.py:1237: KeyError

During handling of the above exception, another exception occurred:

self = <xarray.Dataset> Size: 15MB
Dimensions:              (src_grid_rank: 1, dst_grid_rank: 2,
                          sr...con,r180x90 /tmp/tmpn8u27fz1 /tmp/t...
    CDO:            Climate Data Operators version 2.4.4 (https://mpimet.mpg....
key = None

    def __getitem__(
        self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable]
    ) -> Self | DataArray:
        """Access variables or coordinates of this dataset as a
        :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset.
    
        Indexing with a list of names will return a new ``Dataset`` object.
        """
        from xarray.core.formatting import shorten_list_repr
    
        if utils.is_dict_like(key):
            return self.isel(**key)
        if utils.hashable(key):
            try:
>               return self._construct_dataarray(key)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.../aqua/lib/python3.12.../xarray/core/dataset.py:1344: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../aqua/lib/python3.12.../xarray/core/dataset.py:1239: in _construct_dataarray
    _, name, variable = _get_virtual_variable(self._variables, name, self.sizes)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

variables = {'dst_address': <xarray.Variable (num_links: 107445)> Size: 430kB
array([  983,   984,   984, ..., 16200, 16200, 16200....Variable (dst_grid_size: 16200)> Size: 130kB
[16200 values with dtype=float64]
Attributes:
    units:    radians, ...}
key = None
dim_sizes = Frozen({'src_grid_rank': 1, 'dst_grid_rank': 2, 'src_grid_size': 120184, 'dst_grid_size': 16200, 'src_grid_corners': 4, 'dst_grid_corners': 4, 'num_links': 107445, 'num_wgts': 1})

    def _get_virtual_variable(
        variables, key: Hashable, dim_sizes: Mapping | None = None
    ) -> tuple[Hashable, Hashable, Variable]:
        """Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
        objects (if possible)
    
        """
        from xarray.core.dataarray import DataArray
    
        if dim_sizes is None:
            dim_sizes = {}
    
        if key in dim_sizes:
            data = pd.Index(range(dim_sizes[key]), name=key)
            variable = IndexVariable((key,), data)
            return key, key, variable
    
        if not isinstance(key, str):
>           raise KeyError(key)
E           KeyError: None

.../aqua/lib/python3.12.../xarray/core/dataset_utils.py:75: KeyError

The above exception was the direct cause of the following exception:

self = <test_regrid.TestRegridder object at 0x7f5b6a7440b0>
reader_arguments = ('NEMO', 'test-eORCA1', 'long-2d', 'tos', 0.3379)

    def test_basic_interpolation(self, reader_arguments):
        """
        Test basic interpolation,
        checking output grid dimension and
        fraction of land (i.e. any missing points)
        """
        model, exp, source, variable, ratio = reader_arguments
    
        reader = Reader(model=model, exp=exp, source=source, regrid="r200",
                        fix=True, loglevel=LOGLEVEL)
        data = reader.retrieve()
>       rgd = reader.regrid(data[variable])
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

tests/test_regrid.py:130: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../core/reader/reader.py:583: in regrid
    out = self.regridder.regrid(data)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../core/regridder/regridder.py:586: in regrid
    data = self._apply_regrid(data, shared_vars)
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../core/regridder/regridder.py:621: in _apply_regrid
    data = self.smmregridder[vertical].regrid(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12........./site-packages/smmregrid/regrid.py:269: in regrid
    return self.regrid_array(source_data)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12........./site-packages/smmregrid/regrid.py:310: in regrid_array
    return self.regrid3d(source_data, datagrids[0]).squeeze()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12........./site-packages/smmregrid/regrid.py:391: in regrid3d
    mask_index = weights.coords[mask_dim].to_index()
                 ^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12.../xarray/core/coordinates.py:927: in __getitem__
    return self._data[key]
           ^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <xarray.Dataset> Size: 15MB
Dimensions:              (src_grid_rank: 1, dst_grid_rank: 2,
                          sr...con,r180x90 /tmp/tmpn8u27fz1 /tmp/t...
    CDO:            Climate Data Operators version 2.4.4 (https://mpimet.mpg....
key = None

    def __getitem__(
        self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable]
    ) -> Self | DataArray:
        """Access variables or coordinates of this dataset as a
        :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset.
    
        Indexing with a list of names will return a new ``Dataset`` object.
        """
        from xarray.core.formatting import shorten_list_repr
    
        if utils.is_dict_like(key):
            return self.isel(**key)
        if utils.hashable(key):
            try:
                return self._construct_dataarray(key)
            except KeyError as e:
                message = f"No variable named {key!r}."
    
                best_guess = utils.did_you_mean(key, self.variables.keys())
                if best_guess:
                    message += f" {best_guess}"
                else:
                    message += f" Variables on the dataset include {shorten_list_repr(list(self.variables.keys()), max_items=10)}"
    
                # If someone attempts `ds['foo' , 'bar']` instead of `ds[['foo', 'bar']]`
                if isinstance(key, tuple):
                    message += f"\nHint: use a list to select multiple variables, for example `ds[{list(key)}]`"
>               raise KeyError(message) from e
E               KeyError: "No variable named None. Variables on the dataset include ['src_grid_dims', 'dst_grid_dims', 'src_grid_center_lat', 'dst_grid_center_lat', 'src_grid_center_lon', ..., 'dst_grid_frac', 'src_address', 'dst_address', 'remap_matrix', 'dst_grid_masked']"

.../aqua/lib/python3.12.../xarray/core/dataset.py:1357: KeyError
tests/test_regrid.py::TestRegridder::test_levels_and_regrid
Stack Traces | 9.55s run time
self = <test_regrid.TestRegridder object at 0x7f89603b5ac0>

    def test_levels_and_regrid(self):
        """
        Test regridding selected levels.
        """
        reader = Reader(model='FESOM', exp='test-pi', source='original_3d', datamodel=False,
                        regrid='r100', loglevel=LOGLEVEL, rebuild=True)
        data = reader.retrieve()
    
        layers = [0, 2]
        val = data.aqua.regrid().isel(time=1, nz=2, nz1=layers).wo.aqua.fldmean().values
        #assert val == pytest.approx(8.6758228e-08) #smmregrid <= v0.1.3
        assert val == pytest.approx(7.00622013e-08, rel=APPROX_REL)
>       val = data.isel(time=1, nz=2, nz1=layers).aqua.regrid().wo.aqua.fldmean().values
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

tests/test_regrid.py:231: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
aqua/core/accessor.py:40: in regrid
    return self.instance.regrid(self._obj, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../core/reader/reader.py:583: in regrid
    out = self.regridder.regrid(data)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../core/regridder/regridder.py:586: in regrid
    data = self._apply_regrid(data, shared_vars)
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../core/regridder/regridder.py:611: in _apply_regrid
    datar.append(self.smmregridder[vertical].regrid(data[existing_vars]))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12....../site-packages/smmregrid/regrid.py:262: in regrid
    out = source_data.map(self.regrid_array, keep_attrs=False)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12.../xarray/core/dataset.py:6953: in map
    k: maybe_wrap_array(v, func(v, *args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12....../site-packages/smmregrid/regrid.py:295: in regrid_array
    datagrids = grid_inspect.get_gridtype()
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12........./site-packages/smmregrid/gridinspector.py:106: in get_gridtype
    self._inspect_grids()
.../aqua/lib/python3.12........./site-packages/smmregrid/gridinspector.py:63: in _inspect_grids
    self._inspect_dataarray_grid(self.data)
.../aqua/lib/python3.12........./site-packages/smmregrid/gridinspector.py:78: in _inspect_dataarray_grid
    gridtype = GridType(dims=grid_key, extra_dims=self.extra_dims)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../aqua/lib/python3.12....../site-packages/smmregrid/gridtype.py:52: in __init__
    self.mask_dim = self._identify_dims('mask', dims, default_dims)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = GridType(horizontal_dims=['x']), axis = 'mask'
dims = ('nz', 'time', 'nz1', 'x')
default_dims = {'horizontal': ['j', 'longitude', 'rgrid', 'values', 'lat', 'y', ...], 'mask': ['nz', 'lev', 'depth', 'depth_full', 'nz1', 'depth_half'], 'time': ['time', 'time_counter']}

    def _identify_dims(self, axis, dims, default_dims):
        """
        Identify dimensions along a specified axis.
    
        Args:
            axis (str): The axis to check ('horizontal', 'mask', or 'time').
            default_dims (dict): The dictionary of default dimensions to check against.
    
        Returns:
            list or str: A list of identified dimensions or a single identified masked dimension.
                          Returns None if no dimensions are identified.
    
        Raises:
            ValueError: If more than one masked dimension is identified.
        """
    
        # Check if the axis is valid
        if axis not in ['horizontal', 'mask', 'time']:
            raise ValueError(f"Invalid axis '{axis}'. Must be one of 'horizontal', 'mask', or 'time'.")
    
        # Check if the axis is in the default dimensions
        if axis not in default_dims:
            return None
    
        # Identify dimensions based on the provided axis
        identified_dims = list(set(dims).intersection(default_dims[axis]))
        if axis == 'mask':
            if len(identified_dims) > 1:
>               raise ValueError(f'Only one masked dimension can be processed at the time: check {identified_dims}')
E               ValueError: Only one masked dimension can be processed at the time: check ['nz1', 'nz']

.../aqua/lib/python3.12....../site-packages/smmregrid/gridtype.py:158: ValueError
View the full list of 3 ❄️ flaky test(s)
tests/test_drop.py::TestDROP::test_definitive_true[drop_arguments0-1]@dask_operations

Flake rate in main: 5.00% (Passed 19 times, Failed 1 times)

Stack Traces | 1.99s run time
self = <test_drop.TestDROP object at 0x7ff36bf097c0>
drop_arguments = {'exp': 'test-tco79', 'model': 'IFS', 'outdir': 'drop_test', 'source': 'long', ...}
tmp_path = PosixPath('.../pytest-0/popen-gw1/test_definitive_true_drop_argu0')
nworkers = 1

    @pytest.mark.parametrize("nworkers", [1, 2])
    def test_definitive_true(self, drop_arguments, tmp_path, nworkers):
        test = Drop(
            catalog='ci', **drop_arguments, tmpdir=str(tmp_path),
            nproc=nworkers, resolution='r100', frequency='monthly',
            definitive=True, loglevel=LOGLEVEL
        )
    
        test.retrieve()
        test.data = test.data.sel(time="2020-01")
>       test.drop_generator()

tests/test_drop.py:126: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../core/drop/drop.py:349: in drop_generator
    self._write_var(self.var)
.../core/drop/drop.py:601: in _write_var
    self._write_var_catalog(var)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <aqua.core.drop.drop.Drop object at 0x7ff36c136f30>, var = '2t'

    def _write_var_catalog(self, var):
        """
        Write variable to file
    
        Args:
            var (str): variable name
        """
    
        self.logger.info('Processing variable %s...', var)
        temp_data = self.data[var]
    
        if self.frequency:
            temp_data = self.reader.timstat(temp_data, self.stat, freq=self.frequency,
                                            exclude_incomplete=self.exclude_incomplete)
    
        # temp_data could be empty after time statistics if everything was excluded
        if 'time' in temp_data.coords and len(temp_data.time) == 0:
            self.logger.warning('No data available for variable %s after time statistics, skipping...', var)
            return
    
        # regrid
        if self.resolution and self.resolution != 'native':
            temp_data = self.reader.regrid(temp_data)
            temp_data = self._remove_regridded(temp_data)
    
        if self.region:
            temp_data = self.reader.select_area(temp_data, lon=self.region['lon'], lat=self.region['lat'], drop=self.drop)
    
        # Splitting data into yearly files
>       years = sorted(set(temp_data.time.dt.year.values))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       TypeError: iteration over a 0-d array

.../core/drop/drop.py:644: TypeError
tests/test_drop.py::TestDROP::test_definitive_true[drop_arguments0-2]@dask_operations

Flake rate in main: 5.00% (Passed 19 times, Failed 1 times)

Stack Traces | 3.01s run time
self = <test_drop.TestDROP object at 0x7ff36bf0bf80>
drop_arguments = {'exp': 'test-tco79', 'model': 'IFS', 'outdir': 'drop_test', 'source': 'long', ...}
tmp_path = PosixPath('.../pytest-0/popen-gw1/test_definitive_true_drop_argu1')
nworkers = 2

    @pytest.mark.parametrize("nworkers", [1, 2])
    def test_definitive_true(self, drop_arguments, tmp_path, nworkers):
        test = Drop(
            catalog='ci', **drop_arguments, tmpdir=str(tmp_path),
            nproc=nworkers, resolution='r100', frequency='monthly',
            definitive=True, loglevel=LOGLEVEL
        )
    
        test.retrieve()
        test.data = test.data.sel(time="2020-01")
>       test.drop_generator()

tests/test_drop.py:126: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../core/drop/drop.py:349: in drop_generator
    self._write_var(self.var)
.../core/drop/drop.py:601: in _write_var
    self._write_var_catalog(var)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <aqua.core.drop.drop.Drop object at 0x7ff372ad5010>, var = '2t'

    def _write_var_catalog(self, var):
        """
        Write variable to file
    
        Args:
            var (str): variable name
        """
    
        self.logger.info('Processing variable %s...', var)
        temp_data = self.data[var]
    
        if self.frequency:
            temp_data = self.reader.timstat(temp_data, self.stat, freq=self.frequency,
                                            exclude_incomplete=self.exclude_incomplete)
    
        # temp_data could be empty after time statistics if everything was excluded
        if 'time' in temp_data.coords and len(temp_data.time) == 0:
            self.logger.warning('No data available for variable %s after time statistics, skipping...', var)
            return
    
        # regrid
        if self.resolution and self.resolution != 'native':
            temp_data = self.reader.regrid(temp_data)
            temp_data = self._remove_regridded(temp_data)
    
        if self.region:
            temp_data = self.reader.select_area(temp_data, lon=self.region['lon'], lat=self.region['lat'], drop=self.drop)
    
        # Splitting data into yearly files
>       years = sorted(set(temp_data.time.dt.year.values))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       TypeError: iteration over a 0-d array

.../core/drop/drop.py:644: TypeError
tests/test_drop.py::TestDROP::test_regional_subset[drop_arguments0]@dask_operations

Flake rate in main: 5.00% (Passed 19 times, Failed 1 times)

Stack Traces | 1.89s run time
self = <test_drop.TestDROP object at 0x7ff36bf0bad0>
drop_arguments = {'exp': 'test-tco79', 'model': 'IFS', 'outdir': 'drop_test', 'source': 'long', ...}
tmp_path = PosixPath('.../pytest-0/popen-gw1/test_regional_subset_drop_argu0')

    def test_regional_subset(self, drop_arguments, tmp_path):
        """Test DROP with regional subset."""
        region = {'name': 'europe', 'lon': [-10, 30], 'lat': [35, 70]}
    
        test = Drop(
            catalog='ci', **drop_arguments, tmpdir=str(tmp_path),
            resolution='r100', frequency='daily', definitive=True,
            loglevel=LOGLEVEL, region=region
        )
    
        test.retrieve()
        test.data = test.data.sel(time="2020-01-20")
>       test.drop_generator()

tests/test_drop.py:149: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../core/drop/drop.py:349: in drop_generator
    self._write_var(self.var)
.../core/drop/drop.py:601: in _write_var
    self._write_var_catalog(var)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <aqua.core.drop.drop.Drop object at 0x7ff3689a6b10>, var = '2t'

    def _write_var_catalog(self, var):
        """
        Write variable to file
    
        Args:
            var (str): variable name
        """
    
        self.logger.info('Processing variable %s...', var)
        temp_data = self.data[var]
    
        if self.frequency:
            temp_data = self.reader.timstat(temp_data, self.stat, freq=self.frequency,
                                            exclude_incomplete=self.exclude_incomplete)
    
        # temp_data could be empty after time statistics if everything was excluded
        if 'time' in temp_data.coords and len(temp_data.time) == 0:
            self.logger.warning('No data available for variable %s after time statistics, skipping...', var)
            return
    
        # regrid
        if self.resolution and self.resolution != 'native':
            temp_data = self.reader.regrid(temp_data)
            temp_data = self._remove_regridded(temp_data)
    
        if self.region:
            temp_data = self.reader.select_area(temp_data, lon=self.region['lon'], lat=self.region['lat'], drop=self.drop)
    
        # Splitting data into yearly files
>       years = sorted(set(temp_data.time.dt.year.values))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       TypeError: iteration over a 0-d array

.../core/drop/drop.py:644: TypeError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@oloapinivad oloapinivad added the on-hold Tasks that are on-hold at the moment label Feb 16, 2026
@oloapinivad
Copy link
Collaborator Author

This hits a fundamental issue in smmregrid which we might not be able to solve, we might decide to keep the current implementation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

on-hold Tasks that are on-hold at the moment run tests Set this up to let test run

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant