11from __future__ import annotations # noqa: F401
22
33import itertools
4- import pathlib
54
65import datatree as dt
76import numpy as np
109from .utils import add_metadata_and_zarr_encoding , get_version , multiscales_template
1110
1211
12+ def xesmf_weights_to_xarray (regridder ) -> xr .Dataset :
13+ w = regridder .weights .data
14+ dim = 'n_s'
15+ ds = xr .Dataset (
16+ {
17+ 'S' : (dim , w .data ),
18+ 'col' : (dim , w .coords [1 , :] + 1 ),
19+ 'row' : (dim , w .coords [0 , :] + 1 ),
20+ }
21+ )
22+ ds .attrs = {'n_in' : regridder .n_in , 'n_out' : regridder .n_out }
23+ return ds
24+
25+
26+ def _reconstruct_xesmf_weights (ds_w ):
27+ """Reconstruct weights into format that xESMF understands"""
28+ import sparse
29+ import xarray as xr
30+
31+ col = ds_w ['col' ].values - 1
32+ row = ds_w ['row' ].values - 1
33+ s = ds_w ['S' ].values
34+ n_out , n_in = ds_w .attrs ['n_out' ], ds_w .attrs ['n_in' ]
35+ crds = np .stack ([row , col ])
36+ return xr .DataArray (
37+ sparse .COO (crds , s , (n_out , n_in )), dims = ('out_dim' , 'in_dim' ), name = 'weights'
38+ )
39+
40+
1341def make_grid_ds (level : int , pixels_per_tile : int = 128 ) -> xr .Dataset :
1442 """Make a dataset representing a target grid
1543
@@ -97,11 +125,52 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree:
97125 return data
98126
99127
128+ def generate_weights_pyramid (
129+ ds_in : xr .Dataset , levels : int , method : str = 'bilinear' , regridder_kws : dict = None
130+ ) -> dt .DataTree :
131+ """helper function to generate weights for a multiscale regridder
132+
133+ Parameters
134+ ----------
135+ ds_in : xr.Dataset
136+ Input dataset to regrid
137+ levels : int
138+ Number of levels in the pyramid
139+ method : str, optional
140+ Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
141+ regridder_kws : dict
142+ Keyword arguments to pass to :py:class:`~xesmf.Regridder`. Default is `{'periodic': True}`
143+
144+ Returns
145+ -------
146+ weights : dt.DataTree
147+ Multiscale weights
148+ """
149+ import datatree
150+ import xesmf as xe
151+
152+ regridder_kws = {} if regridder_kws is None else regridder_kws
153+ regridder_kws = {'periodic' : True , ** regridder_kws }
154+
155+ weights_pyramid = datatree .DataTree ()
156+ for level in range (levels ):
157+ ds_out = make_grid_ds (level = level )
158+ regridder = xe .Regridder (ds_in , ds_out , method , ** regridder_kws )
159+ ds = xesmf_weights_to_xarray (regridder )
160+
161+ weights_pyramid [str (level )] = ds
162+
163+ weights_pyramid .ds .attrs ['levels' ] = levels
164+ weights_pyramid .ds .attrs ['regrid_method' ] = method
165+
166+ return weights_pyramid
167+
168+
100169def pyramid_regrid (
101170 ds : xr .Dataset ,
102171 target_pyramid : dt .DataTree = None ,
103172 levels : int = None ,
104- weights_template : str = None ,
173+ weights_pyramid : dt . DataTree = None ,
105174 method : str = 'bilinear' ,
106175 regridder_kws : dict = None ,
107176 regridder_apply_kws : dict = None ,
@@ -118,8 +187,8 @@ def pyramid_regrid(
118187 Target grids, if not provided, they will be generated, by default None
119188 levels : int, optional
120189 Number of levels in pyramid, by default None
121- weights_template : str , optional
122- Filepath to write generated weights to, e.g. `'weights_{level}'`, by default None
190+ weights_pyramid : dt.DataTree , optional
191+ pyramid containing pregenerated weights
123192 method : str, optional
124193 Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
125194 regridder_kws : dict
@@ -147,14 +216,15 @@ def pyramid_regrid(
147216 if levels is None :
148217 levels = len (target_pyramid .keys ()) # TODO: get levels from the pyramid metadata
149218
150- if regridder_kws is None :
151- regridder_kws = {'periodic' : True }
219+ regridder_kws = {} if regridder_kws is None else regridder_kws
220+ regridder_kws = {'periodic' : True , ** regridder_kws }
152221
153222 # multiscales spec
154223 save_kwargs = locals ()
155224 del save_kwargs ['ds' ]
156225 del save_kwargs ['target_pyramid' ]
157226 del save_kwargs ['xe' ]
227+ del save_kwargs ['weights_pyramid' ]
158228
159229 attrs = {
160230 'multiscales' : multiscales_template (
@@ -173,21 +243,23 @@ def pyramid_regrid(
173243 # pyramid data
174244 for level in range (levels ):
175245 grid = target_pyramid [str (level )].ds .load ()
176-
177246 # get the regridder object
178- if not weights_template :
247+ if weights_pyramid is None :
179248 regridder = xe .Regridder (ds , grid , method , ** regridder_kws )
180249 else :
181- fn = pathlib .PosixPath (weights_template .format (level = level ))
182- if not fn .exists ():
183- regridder = xe .Regridder (ds , grid , method , ** regridder_kws )
184- regridder .to_netcdf (filename = fn )
185- else :
186- regridder = xe .Regridder (ds , grid , method , weights = fn , ** regridder_kws )
187-
250+ # Reconstruct weights into format that xESMF understands
251+ # this is a hack that assumes the weights were generated by
252+ # the `generate_weights_pyramid` function
253+
254+ ds_w = weights_pyramid [str (level )].ds
255+ weights = _reconstruct_xesmf_weights (ds_w )
256+ regridder = xe .Regridder (
257+ ds , grid , method , reuse_weights = True , weights = weights , ** regridder_kws
258+ )
188259 # regrid
189260 if regridder_apply_kws is None :
190261 regridder_apply_kws = {}
262+ regridder_apply_kws = {** {'keep_attrs' : True }, ** regridder_apply_kws }
191263 pyramid [str (level )] = regridder (ds , ** regridder_apply_kws )
192264
193265 pyramid = add_metadata_and_zarr_encoding (
0 commit comments