Skip to content

Commit 3a05647

Browse files
authored
Fix equivalent() for NumPy scalar NaN comparison (#10838)
1 parent 20d3773 commit 3a05647

File tree

3 files changed

+66
-3
lines changed

3 files changed

+66
-3
lines changed

xarray/core/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ def equivalent(first: T, second: T) -> bool:
256256
if isinstance(first, list) or isinstance(second, list):
257257
return list_equiv(first, second) # type: ignore[arg-type]
258258

259+
# Check for NaN equivalence early (before equality comparison)
260+
# This handles both Python float NaN and NumPy scalar NaN (issue #10833)
261+
if pd.isnull(first) and pd.isnull(second): # type: ignore[call-overload]
262+
return True
263+
259264
# For non-array/list types, use == but require boolean result
260265
result = first == second
261266
if not isinstance(result, bool):
@@ -265,8 +270,7 @@ def equivalent(first: T, second: T) -> bool:
265270
# Reject any other non-boolean type (Dataset, Series, custom objects, etc.)
266271
return False
267272

268-
# Check for NaN equivalence
269-
return result or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload]
273+
return result
270274

271275

272276
def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool:

xarray/tests/test_concat.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99
import pandas as pd
1010
import pytest
1111

12-
from xarray import AlignmentError, DataArray, Dataset, Variable, concat, set_options
12+
from xarray import (
13+
AlignmentError,
14+
DataArray,
15+
Dataset,
16+
Variable,
17+
concat,
18+
open_dataset,
19+
set_options,
20+
)
1321
from xarray.core import dtypes, types
1422
from xarray.core.coordinates import Coordinates
1523
from xarray.core.indexes import PandasIndex
@@ -23,6 +31,7 @@
2331
assert_identical,
2432
requires_dask,
2533
requires_pyarrow,
34+
requires_scipy_or_netCDF4,
2635
)
2736
from xarray.tests.indexes import XYIndex
2837
from xarray.tests.test_dataset import create_test_data
@@ -1119,6 +1128,47 @@ def test_concat_promote_shape_without_creating_new_index(self) -> None:
11191128
assert_identical(actual, expected, check_default_indexes=False)
11201129
assert actual.indexes == {}
11211130

1131+
@requires_scipy_or_netCDF4
1132+
def test_concat_combine_attrs_nan_after_netcdf_roundtrip(self, tmp_path) -> None:
1133+
# Test for issue #10833: NaN attributes should be preserved
1134+
# with combine_attrs="drop_conflicts" after NetCDF roundtrip
1135+
import numpy as np
1136+
1137+
# Create arrays with matching NaN fill_value attribute
1138+
ds1 = Dataset(
1139+
{"a": ("x", [0, 1])},
1140+
attrs={"fill_value": np.nan, "sensor": "G18", "field": "CTH"},
1141+
)
1142+
ds2 = Dataset(
1143+
{"a": ("x", [2, 3])},
1144+
attrs={"fill_value": np.nan, "sensor": "G16", "field": "CTH"},
1145+
)
1146+
1147+
# Save to NetCDF and reload (converts Python float NaN to NumPy scalar NaN)
1148+
path1 = tmp_path / "ds1.nc"
1149+
path2 = tmp_path / "ds2.nc"
1150+
ds1.to_netcdf(path1)
1151+
ds2.to_netcdf(path2)
1152+
1153+
ds1_loaded = open_dataset(path1)
1154+
ds2_loaded = open_dataset(path2)
1155+
1156+
# Verify that NaN attributes are preserved after concat
1157+
actual = concat(
1158+
[ds1_loaded, ds2_loaded], dim="y", combine_attrs="drop_conflicts"
1159+
)
1160+
1161+
# fill_value should be preserved (not dropped) since both have NaN
1162+
assert "fill_value" in actual.attrs
1163+
assert np.isnan(actual.attrs["fill_value"])
1164+
# field should be preserved (identical in both)
1165+
assert actual.attrs["field"] == "CTH"
1166+
# sensor should be dropped (conflicts)
1167+
assert "sensor" not in actual.attrs
1168+
1169+
ds1_loaded.close()
1170+
ds2_loaded.close()
1171+
11221172

11231173
class TestConcatDataArray:
11241174
def test_concat(self) -> None:

xarray/tests/test_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,15 @@ def test_equivalent(self):
8080
assert utils.equivalent(np.array([0]), [0])
8181
assert utils.equivalent(np.arange(3), 1.0 * np.arange(3))
8282
assert not utils.equivalent(0, np.zeros(3))
83+
# Test NaN comparisons (issue #10833)
84+
# Python float NaN
85+
assert utils.equivalent(float("nan"), float("nan"))
86+
# NumPy scalar NaN (various dtypes)
87+
assert utils.equivalent(np.float64(np.nan), np.float64(np.nan))
88+
assert utils.equivalent(np.float32(np.nan), np.float32(np.nan))
89+
# Mixed: Python float NaN vs NumPy scalar NaN
90+
assert utils.equivalent(float("nan"), np.float64(np.nan))
91+
assert utils.equivalent(np.float64(np.nan), float("nan"))
8392

8493
def test_safe(self):
8594
# should not raise exception:

0 commit comments

Comments
 (0)