Skip to content

Commit 764494e

Browse files
Adds option for parallel weight generation with xESMF (#145)
Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com>
1 parent 4e87518 commit 764494e

3 files changed

Lines changed: 10 additions & 3 deletions

File tree

ndpyramid/regrid.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def pyramid_regrid(
196196
projection: typing.Literal["web-mercator", "equidistant-cylindrical"] = "web-mercator",
197197
target_pyramid: xr.DataTree = None,
198198
levels: int = None,
199+
parallel_weights: bool = True,
199200
weights_pyramid: xr.DataTree = None,
200201
method: str = "bilinear",
201202
regridder_kws: dict = None,
@@ -217,6 +218,8 @@ def pyramid_regrid(
217218
Number of levels in pyramid, by default None
218219
weights_pyramid : xr.DataTree, optional
219220
pyramid containing pregenerated weights
221+
parallel_weights : Bool
222+
Use dask to generate parallel weights
220223
method : str, optional
221224
Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
222225
regridder_kws : dict
@@ -285,7 +288,7 @@ def pyramid_regrid(
285288
grid = target_pyramid[str(level)].ds.load()
286289
# get the regridder object
287290
if weights_pyramid is None:
288-
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
291+
regridder = xe.Regridder(ds, grid, method, parallel=parallel_weights, **regridder_kws)
289292
else:
290293
# Reconstruct weights into format that xESMF understands
291294
# this is a hack that assumes the weights were generated by

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ known-first-party = ["ndpyramid"]
151151

152152

153153
# Notebook ruff config
154-
[tool.ruff.per-file-ignores]
154+
[tool.ruff.lint.per-file-ignores]
155155
"*.ipynb" = [
156156
"D100",
157157
"E402",

tests/test_pyramid_regrid.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ def test_regridded_pyramid(temperature, regridder_apply_kws, benchmark):
1616
temperature = temperature.isel(time=slice(0, 5))
1717
pyramid = benchmark(
1818
lambda: pyramid_regrid(
19-
temperature, levels=2, regridder_apply_kws=regridder_apply_kws, other_chunks={"time": 2}
19+
temperature,
20+
levels=2,
21+
parallel_weights=False,
22+
regridder_apply_kws=regridder_apply_kws,
23+
other_chunks={"time": 2},
2024
)
2125
)
2226
verify_bounds(pyramid)

0 commit comments

Comments
 (0)