Skip to content

Commit 935878f

Browse files
authored
Allow reusing weights saved in a pyramid during xESFM regridding (#34)
1 parent 9d21440 commit 935878f

3 files changed

Lines changed: 113 additions & 18 deletions

File tree

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies:
1616
- rasterio
1717
- rioxarray
1818
- scipy
19+
- sparse>=0.13.0
1920
- xarray
2021
- xarray-datatree>=0.0.4
2122
- xesmf

ndpyramid/regrid.py

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations # noqa: F401
22

33
import itertools
4-
import pathlib
54

65
import datatree as dt
76
import numpy as np
@@ -10,6 +9,35 @@
109
from .utils import add_metadata_and_zarr_encoding, get_version, multiscales_template
1110

1211

12+
def xesmf_weights_to_xarray(regridder) -> xr.Dataset:
13+
w = regridder.weights.data
14+
dim = 'n_s'
15+
ds = xr.Dataset(
16+
{
17+
'S': (dim, w.data),
18+
'col': (dim, w.coords[1, :] + 1),
19+
'row': (dim, w.coords[0, :] + 1),
20+
}
21+
)
22+
ds.attrs = {'n_in': regridder.n_in, 'n_out': regridder.n_out}
23+
return ds
24+
25+
26+
def _reconstruct_xesmf_weights(ds_w):
27+
"""Reconstruct weights into format that xESMF understands"""
28+
import sparse
29+
import xarray as xr
30+
31+
col = ds_w['col'].values - 1
32+
row = ds_w['row'].values - 1
33+
s = ds_w['S'].values
34+
n_out, n_in = ds_w.attrs['n_out'], ds_w.attrs['n_in']
35+
crds = np.stack([row, col])
36+
return xr.DataArray(
37+
sparse.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights'
38+
)
39+
40+
1341
def make_grid_ds(level: int, pixels_per_tile: int = 128) -> xr.Dataset:
1442
"""Make a dataset representing a target grid
1543
@@ -97,11 +125,52 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree:
97125
return data
98126

99127

128+
def generate_weights_pyramid(
129+
ds_in: xr.Dataset, levels: int, method: str = 'bilinear', regridder_kws: dict = None
130+
) -> dt.DataTree:
131+
"""helper function to generate weights for a multiscale regridder
132+
133+
Parameters
134+
----------
135+
ds_in : xr.Dataset
136+
Input dataset to regrid
137+
levels : int
138+
Number of levels in the pyramid
139+
method : str, optional
140+
Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
141+
regridder_kws : dict
142+
Keyword arguments to pass to :py:class:`~xesmf.Regridder`. Default is `{'periodic': True}`
143+
144+
Returns
145+
-------
146+
weights : dt.DataTree
147+
Multiscale weights
148+
"""
149+
import datatree
150+
import xesmf as xe
151+
152+
regridder_kws = {} if regridder_kws is None else regridder_kws
153+
regridder_kws = {'periodic': True, **regridder_kws}
154+
155+
weights_pyramid = datatree.DataTree()
156+
for level in range(levels):
157+
ds_out = make_grid_ds(level=level)
158+
regridder = xe.Regridder(ds_in, ds_out, method, **regridder_kws)
159+
ds = xesmf_weights_to_xarray(regridder)
160+
161+
weights_pyramid[str(level)] = ds
162+
163+
weights_pyramid.ds.attrs['levels'] = levels
164+
weights_pyramid.ds.attrs['regrid_method'] = method
165+
166+
return weights_pyramid
167+
168+
100169
def pyramid_regrid(
101170
ds: xr.Dataset,
102171
target_pyramid: dt.DataTree = None,
103172
levels: int = None,
104-
weights_template: str = None,
173+
weights_pyramid: dt.DataTree = None,
105174
method: str = 'bilinear',
106175
regridder_kws: dict = None,
107176
regridder_apply_kws: dict = None,
@@ -118,8 +187,8 @@ def pyramid_regrid(
118187
Target grids, if not provided, they will be generated, by default None
119188
levels : int, optional
120189
Number of levels in pyramid, by default None
121-
weights_template : str, optional
122-
Filepath to write generated weights to, e.g. `'weights_{level}'`, by default None
190+
weights_pyramid : dt.DataTree, optional
191+
pyramid containing pregenerated weights
123192
method : str, optional
124193
Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
125194
regridder_kws : dict
@@ -147,14 +216,15 @@ def pyramid_regrid(
147216
if levels is None:
148217
levels = len(target_pyramid.keys()) # TODO: get levels from the pyramid metadata
149218

150-
if regridder_kws is None:
151-
regridder_kws = {'periodic': True}
219+
regridder_kws = {} if regridder_kws is None else regridder_kws
220+
regridder_kws = {'periodic': True, **regridder_kws}
152221

153222
# multiscales spec
154223
save_kwargs = locals()
155224
del save_kwargs['ds']
156225
del save_kwargs['target_pyramid']
157226
del save_kwargs['xe']
227+
del save_kwargs['weights_pyramid']
158228

159229
attrs = {
160230
'multiscales': multiscales_template(
@@ -173,21 +243,23 @@ def pyramid_regrid(
173243
# pyramid data
174244
for level in range(levels):
175245
grid = target_pyramid[str(level)].ds.load()
176-
177246
# get the regridder object
178-
if not weights_template:
247+
if weights_pyramid is None:
179248
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
180249
else:
181-
fn = pathlib.PosixPath(weights_template.format(level=level))
182-
if not fn.exists():
183-
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
184-
regridder.to_netcdf(filename=fn)
185-
else:
186-
regridder = xe.Regridder(ds, grid, method, weights=fn, **regridder_kws)
187-
250+
# Reconstruct weights into format that xESMF understands
251+
# this is a hack that assumes the weights were generated by
252+
# the `generate_weights_pyramid` function
253+
254+
ds_w = weights_pyramid[str(level)].ds
255+
weights = _reconstruct_xesmf_weights(ds_w)
256+
regridder = xe.Regridder(
257+
ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws
258+
)
188259
# regrid
189260
if regridder_apply_kws is None:
190261
regridder_apply_kws = {}
262+
regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws}
191263
pyramid[str(level)] = regridder(ds, **regridder_apply_kws)
192264

193265
pyramid = add_metadata_and_zarr_encoding(

tests/test_pyramids.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from zarr.storage import MemoryStore
55

66
from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject
7-
from ndpyramid.regrid import make_grid_ds
7+
from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds
88

99

1010
@pytest.fixture
@@ -32,7 +32,7 @@ def test_reprojected_pyramid(temperature):
3232
pyramid.to_zarr(MemoryStore())
3333

3434

35-
@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': True}])
35+
@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}])
3636
def test_regridded_pyramid(temperature, regridder_apply_kws):
3737
pytest.importorskip('xesmf')
3838
pyramid = pyramid_regrid(
@@ -41,16 +41,38 @@ def test_regridded_pyramid(temperature, regridder_apply_kws):
4141
assert pyramid.ds.attrs['multiscales']
4242
expected_attrs = (
4343
temperature['air'].attrs
44-
if regridder_apply_kws is not None and regridder_apply_kws['keep_attrs']
44+
if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs')
4545
else {}
4646
)
4747
assert pyramid['0'].ds.air.attrs == expected_attrs
4848
assert pyramid['1'].ds.air.attrs == expected_attrs
4949
pyramid.to_zarr(MemoryStore())
5050

5151

52+
def test_regridded_pyramid_with_weights(temperature):
53+
pytest.importorskip('xesmf')
54+
levels = 2
55+
weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels)
56+
pyramid = pyramid_regrid(
57+
temperature, levels=levels, weights_pyramid=weights_pyramid, other_chunks={'time': 2}
58+
)
59+
assert pyramid.ds.attrs['multiscales']
60+
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels
61+
pyramid.to_zarr(MemoryStore())
62+
63+
5264
def test_make_grid_ds():
5365

5466
grid = make_grid_ds(0, pixels_per_tile=8)
5567
lon_vals = grid.lon_b.values
5668
assert np.all((lon_vals[-1, :] - lon_vals[0, :]) < 0.001)
69+
70+
71+
@pytest.mark.parametrize('levels', [1, 2])
72+
@pytest.mark.parametrize('method', ['bilinear', 'conservative'])
73+
def test_generate_weights_pyramid(temperature, levels, method):
74+
weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels, method=method)
75+
assert weights_pyramid.ds.attrs['levels'] == levels
76+
assert weights_pyramid.ds.attrs['regrid_method'] == method
77+
assert set(weights_pyramid['0'].ds.data_vars) == {'S', 'col', 'row'}
78+
assert 'n_in' in weights_pyramid['0'].ds.attrs and 'n_out' in weights_pyramid['0'].ds.attrs

0 commit comments

Comments
 (0)