Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8d025da
Start of implementing C-Grid interpolation
erikvansebille Aug 15, 2025
50d9d52
Moving C-grid velocity code from v3 to v4
erikvansebille Aug 15, 2025
3226306
use time_interval type to set default time
erikvansebille Aug 15, 2025
1bd7e0c
Adding nemo curvilinear test for C-grid
erikvansebille Aug 15, 2025
2da436a
Speeding up curvilinear search by dask loading lon and lat
erikvansebille Aug 15, 2025
0e1a221
Updating c-grid velocity test and algorithm
erikvansebille Aug 15, 2025
1ecab22
Fixing vector interpolation
erikvansebille Aug 18, 2025
8b1b830
Fixing CGrid_Velocity interpolation for multiple particles
erikvansebille Aug 18, 2025
cad34e0
Adding better error message handling for CGrid_velocity interpolation
erikvansebille Aug 18, 2025
26cffcc
Adding warning suppression for index_search
erikvansebille Aug 18, 2025
e2376e2
Fixing error when Grid does not have lon or lat
erikvansebille Aug 18, 2025
6529e13
Fixing to keep the maximum Error code in field
erikvansebille Aug 18, 2025
bbc2c30
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Aug 19, 2025
dda469d
Adding NEMO3D test
erikvansebille Aug 21, 2025
76d015b
Updating CGrid interpolation to not interpolate over depth for U and V
erikvansebille Aug 22, 2025
196c6e7
Fixing W interpolation for CGrid
erikvansebille Aug 22, 2025
1c948dd
Fixing stommel gyre CGrid interpolation test
erikvansebille Aug 22, 2025
df78b62
Updating failing unit test
erikvansebille Aug 22, 2025
cd69147
Further fixing unit test by dropping unused dimensions
erikvansebille Aug 22, 2025
1a64033
Temporary fix to spatialhash
erikvansebille Aug 22, 2025
7c1a87a
Adding TODO statement about spherical meshes
erikvansebille Aug 22, 2025
07a8238
Updating spherical mash hashmap creation
erikvansebille Aug 22, 2025
f77f5f4
merge
erikvansebille Aug 25, 2025
2924ee0
Fixing grid._mesh in interpolator
erikvansebille Aug 25, 2025
9264d8a
Merge branch 'feature/morton-hashing' into c-grid-interpolation
erikvansebille Aug 25, 2025
1a17911
Using is_dask_collection to check for dask in interpolation
erikvansebille Aug 25, 2025
6753e81
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Aug 27, 2025
5bdc874
Fixing vector_interp_method
erikvansebille Aug 28, 2025
a0962dc
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 1, 2025
6659a1d
fixing merging bugs
erikvansebille Sep 1, 2025
f14e8cb
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 1, 2025
2568feb
Using is_dask_collection for c-grid interpolator
erikvansebille Sep 1, 2025
525ec71
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 3, 2025
cebda92
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 8, 2025
0ac7ae0
Removing mesh types check for VectorField
erikvansebille Sep 9, 2025
872b75c
Setting lon and lat as coordinates
erikvansebille Sep 9, 2025
0be3404
Merge branch 'v4-dev' into c-grid-interpolation
erikvansebille Sep 9, 2025
38622fc
Merge branch 'v4-dev' into c-grid-interpolation
VeckoTheGecko Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ def _search_indices_curvilinear_2d(
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]

det2 = bb * bb - 4 * aa * cc
det = np.where(det2 > 0, np.sqrt(det2), eta)
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))

xsi = np.where(
abs(a[1] + a[3] * eta) < 1e-12,
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
)
with np.errstate(divide="ignore", invalid="ignore"):
det = np.where(det2 > 0, np.sqrt(det2), eta)

eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))

Comment on lines +86 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which scenarios would produce the errors in particular are we ignoring here? And is the rest of the code robust to the nans/inf values that ignoring the errors would produce?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way that np.where works is that it always computes both the True and False results (as far as I understand).
That means that it will for example take the np.sqrt(det2) even if det2 < 0, leading to a lot of warnings (have you not seen them in your test results?).
The with np.errstate() filters these warnings out
So yes, the code is robust to NaNs and Infs, because these would not be warnings but errors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one option would be to do:

mask = det2 > 0
det[mask] = np.sqrt(det2[mask])

lets just merge for now and leave this 'unresolved' in the PR - minor thing

xsi = np.where(
abs(a[1] + a[3] * eta) < 1e-12,
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
)

xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))
Expand Down
296 changes: 288 additions & 8 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@

from typing import TYPE_CHECKING

import dask.array as dask
import numpy as np
import xarray as xr
from dask import is_dask_collection

import parcels.tools.interpolation_utils as i_u

if TYPE_CHECKING:
from parcels.field import Field
from parcels.field import Field, VectorField
from parcels.uxgrid import _UXGRID_AXES
from parcels.xgrid import _XGRID_AXES

__all__ = [
"CGrid_Tracer",
"CGrid_Velocity",
"UXPiecewiseConstantFace",
"UXPiecewiseLinearNode",
"XLinear",
"XNearest",
"ZeroInterpolator",
"ZeroInterpolator_Vector",
]


Expand All @@ -36,6 +41,21 @@ def ZeroInterpolator(
return 0.0


def ZeroInterpolator_Vector(
vectorfield: VectorField,
ti: int,
position: dict[str, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
applyConversion: bool,
) -> np.float32 | np.float64:
"""Template function used for the signature check of the interpolation methods for velocity fields."""
return 0.0


def XLinear(
field: Field,
ti: int,
Expand All @@ -53,6 +73,7 @@ def XLinear(

axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
data = field.data
tdim, zdim, ydim, xdim = data.shape[0], data.shape[1], data.shape[2], data.shape[3]

lenT = 2 if np.any(tau > 0) else 1
lenZ = 2 if np.any(zeta > 0) else 1
Expand All @@ -61,22 +82,22 @@ def XLinear(
if lenT == 1:
ti = np.repeat(ti, lenZ * 4)
else:
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
ti_1 = np.clip(ti + 1, 0, tdim - 1)
ti = np.concatenate([np.repeat(ti, lenZ * 4), np.repeat(ti_1, lenZ * 4)])

# Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
if lenZ == 1:
zi = np.repeat(zi, lenT * 4)
else:
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
zi_1 = np.clip(zi + 1, 0, zdim - 1)
zi = np.tile(np.array([zi, zi, zi, zi, zi_1, zi_1, zi_1, zi_1]).flatten(), lenT)

# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
yi_1 = np.clip(yi + 1, 0, ydim - 1)
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT) * (lenZ))

# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
xi_1 = np.clip(xi + 1, 0, xdim - 1)
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))

# Create DataArrays for indexing
Expand Down Expand Up @@ -109,7 +130,266 @@ def XLinear(
+ (1 - xsi) * eta * corner_data[:, 2]
+ xsi * eta * corner_data[:, 3]
)
return value.compute() if isinstance(value, dask.Array) else value
return value.compute() if is_dask_collection(value) else value


def CGrid_Velocity(
vectorfield: VectorField,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
applyConversion: bool,
):
"""
Interpolation kernel for velocity fields on a C-Grid.
Following Delandmeter and Van Sebille (2019), velocity fields should be interpolated
only in the direction of the grid cell faces.
"""
xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]

U = vectorfield.U.data
V = vectorfield.V.data
grid = vectorfield.grid
tdim, zdim, ydim, xdim = U.shape[0], U.shape[1], U.shape[2], U.shape[3]

if grid.lon.ndim == 1:
px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]])
py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]])
else:
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])

if grid._mesh == "spherical":
px[0] = np.where(px[0] < x - 225, px[0] + 360, px[0])
px[0] = np.where(px[0] > x + 225, px[0] - 360, px[0])
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
c1 = i_u._geodetic_distance(
py[0], py[1], px[0], px[1], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(0.0, xsi), py)
)
c2 = i_u._geodetic_distance(
py[1], py[2], px[1], px[2], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 1.0), py)
)
c3 = i_u._geodetic_distance(
py[2], py[3], px[2], px[3], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(1.0, xsi), py)
)
c4 = i_u._geodetic_distance(
py[3], py[0], px[3], px[0], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 0.0), py)
)

lenT = 2 if np.any(tau > 0) else 1

# Create arrays of corner points for xarray.isel
# TODO C grid may not need all xi and yi cornerpoints, so could speed up here?

# Time coordinates: 4 points at ti, then 4 points at ti+1
if lenT == 1:
ti_full = np.repeat(ti, 4)
else:
ti_1 = np.clip(ti + 1, 0, tdim - 1)
ti_full = np.concatenate([np.repeat(ti, 4), np.repeat(ti_1, 4)])

# Depth coordinates: 4 points at zi, repeated for both time levels
zi_full = np.repeat(zi, lenT * 4)

# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
yi_1 = np.clip(yi + 1, 0, ydim - 1)
yi_full = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT))
# # TODO check why in some cases minus needed here!!!
# yi_minus_1 = np.clip(yi - 1, 0, ydim - 1)
# yi = np.tile(np.repeat(np.column_stack([yi_minus_1, yi]), 2), (lenT))

# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
xi_1 = np.clip(xi + 1, 0, xdim - 1)
xi_full = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT))

for data in [U, V]:
axis_dim = grid.get_axis_dim_mapping(data.dims)

# Create DataArrays for indexing
selection_dict = {
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
}
if "Z" in axis_dim:
selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points"))
if "time" in data.dims:
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))

corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi), 4)

if lenT == 2:
tau_full = tau[:, np.newaxis]
corner_data = corner_data[0, :, :] * (1 - tau_full) + corner_data[1, :, :] * tau_full
else:
corner_data = corner_data[0, :, :]
# # See code below for v3 version
# # if self.gridindexingtype == "nemo":
# # U0 = self.U.data[ti, zi, yi + 1, xi] * c4
# # U1 = self.U.data[ti, zi, yi + 1, xi + 1] * c2
# # V0 = self.V.data[ti, zi, yi, xi + 1] * c1
# # V1 = self.V.data[ti, zi, yi + 1, xi + 1] * c3
# # elif self.gridindexingtype in ["mitgcm", "croco"]:
# # U0 = self.U.data[ti, zi, yi, xi] * c4
# # U1 = self.U.data[ti, zi, yi, xi + 1] * c2
# # V0 = self.V.data[ti, zi, yi, xi] * c1
# # V1 = self.V.data[ti, zi, yi + 1, xi] * c3
# # TODO Nick can you help use xgcm to fix this implementation?

# # CROCO and MITgcm grid indexing,
# if data is U:
# U0 = corner_data[:, 0] * c4
# U1 = corner_data[:, 1] * c2
# elif data is V:
# V0 = corner_data[:, 0] * c1
# V1 = corner_data[:, 2] * c3
# # NEMO grid indexing
if data is U:
U0 = corner_data[:, 2] * c4
U1 = corner_data[:, 3] * c2
elif data is V:
V0 = corner_data[:, 1] * c1
V1 = corner_data[:, 3] * c3

U = (1 - xsi) * U0 + xsi * U1
V = (1 - eta) * V0 + eta * V1

deg2m = 1852 * 60.0
if applyConversion:
meshJac = (deg2m * deg2m * np.cos(np.deg2rad(y))) if grid._mesh == "spherical" else 1
else:
meshJac = deg2m if grid._mesh == "spherical" else 1

jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * meshJac

u = (
(-(1 - eta) * U - (1 - xsi) * V) * px[0]
+ ((1 - eta) * U - xsi * V) * px[1]
+ (eta * U + xsi * V) * px[2]
+ (-eta * U + (1 - xsi) * V) * px[3]
) / jac
v = (
(-(1 - eta) * U - (1 - xsi) * V) * py[0]
+ ((1 - eta) * U - xsi * V) * py[1]
+ (eta * U + xsi * V) * py[2]
+ (-eta * U + (1 - xsi) * V) * py[3]
) / jac
if is_dask_collection(u):
u = u.compute()
v = v.compute()

# check whether the grid conversion has been applied correctly
xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3]
u = np.where(np.abs((xx - x) / x) > 1e-4, np.nan, u)

if vectorfield.W:
data = vectorfield.W.data
# Time coordinates: 2 points at ti, then 2 points at ti+1
if lenT == 1:
ti_full = np.repeat(ti, 2)
else:
ti_1 = np.clip(ti + 1, 0, tdim - 1)
ti_full = np.concatenate([np.repeat(ti, 2), np.repeat(ti_1, 2)])

# Depth coordinates: 1 points at zi, repeated for both time levels
zi_1 = np.clip(zi + 1, 0, zdim - 1)
zi_full = np.tile(np.array([zi, zi_1]).flatten(), lenT)

# Y coordinates: yi+1 for each spatial point, repeated for time/depth
yi_1 = np.clip(yi + 1, 0, ydim - 1)
yi_full = np.tile(yi_1, (lenT) * 2)

# X coordinates: xi+1 for each spatial point, repeated for time/depth
xi_1 = np.clip(xi + 1, 0, xdim - 1)
xi_full = np.tile(xi_1, (lenT) * 2)

axis_dim = grid.get_axis_dim_mapping(data.dims)

# Create DataArrays for indexing
selection_dict = {
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
axis_dim["Z"]: xr.DataArray(zi_full, dims=("points")),
}
if "time" in data.dims:
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))

corner_data = data.isel(selection_dict).data.reshape(lenT, 2, len(xsi))

if lenT == 2:
tau_full = tau[np.newaxis, :]
corner_data = corner_data[0, :, :] * (1 - tau_full) + corner_data[1, :, :] * tau_full
else:
corner_data = corner_data[0, :, :]

w = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta
if is_dask_collection(w):
w = w.compute()
else:
w = np.zeros_like(u)

return (u, v, w)


def CGrid_Tracer(
field: Field,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
):
"""Interpolation kernel for tracer fields on a C-Grid.

Following Delandmeter and Van Sebille (2019), tracer fields should be interpolated
constant over the grid cell
"""
xi, _ = position["X"]
yi, _ = position["Y"]
zi, _ = position["Z"]

axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
data = field.data

lenT = 2 if np.any(tau > 0) else 1

if lenT == 2:
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
ti = np.concatenate([np.repeat(ti), np.repeat(ti_1)])
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
zi = np.concatenate([np.repeat(zi), np.repeat(zi_1)])
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
yi = np.concatenate([np.repeat(yi), np.repeat(yi_1)])
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
xi = np.concatenate([np.repeat(xi), np.repeat(xi_1)])

# Create DataArrays for indexing
selection_dict = {
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
}
if "Z" in axis_dim:
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
if "time" in field.data.dims:
selection_dict["time"] = xr.DataArray(ti, dims=("points"))

value = data.isel(selection_dict).data.reshape(lenT, len(xi))

if lenT == 2:
tau = tau[:, np.newaxis]
value = value[0, :] * (1 - tau) + value[1, :] * tau
else:
value = value[0, :]

return value.compute() if is_dask_collection(value) else value


def XNearest(
Expand Down Expand Up @@ -172,7 +452,7 @@ def XNearest(
else:
value = corner_data[0, :]

return value.compute() if isinstance(value, dask.Array) else value
return value.compute() if is_dask_collection(value) else value


def UXPiecewiseConstantFace(
Expand Down
Loading
Loading