Skip to content

Commit 3572f4e

Browse files
max-sixtyclaude
andauthored
Fix Dataset.map to handle non-DataArray outputs (#10839)
After PR #10602, Dataset.map started failing when functions returned non-DataArray values (e.g., scalars), raising AttributeError when trying to access .coords on the returned values. This restores backward compatibility by converting non-DataArray outputs to DataArrays, which was the behavior before the regression. Fixes #10835 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <[email protected]>
1 parent 3a05647 commit 3572f4e

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

xarray/core/dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6929,12 +6929,19 @@ def map(
69296929
foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773
69306930
bar (x) float64 16B 1.0 2.0
69316931
"""
6932+
from xarray.core.dataarray import DataArray
6933+
69326934
if keep_attrs is None:
69336935
keep_attrs = _get_keep_attrs(default=True)
69346936
variables = {
69356937
k: maybe_wrap_array(v, func(v, *args, **kwargs))
69366938
for k, v in self.data_vars.items()
69376939
}
6940+
# Convert non-DataArray values to DataArrays
6941+
variables = {
6942+
k: v if isinstance(v, DataArray) else DataArray(v)
6943+
for k, v in variables.items()
6944+
}
69386945
coord_vars, indexes = merge_coordinates_without_align(
69396946
[v.coords for v in variables.values()]
69406947
)

xarray/tests/test_dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6309,6 +6309,31 @@ def func(arr):
63096309
ds["x"].attrs["y"] = "x"
63106310
assert ds["x"].attrs != actual["x"].attrs
63116311

6312+
def test_map_non_dataarray_outputs(self) -> None:
6313+
# Test that map handles non-DataArray outputs by converting them
6314+
# Regression test for GH10835
6315+
ds = xr.Dataset({"foo": ("x", [1, 2, 3]), "bar": ("y", [4, 5])})
6316+
6317+
# Scalar output
6318+
result = ds.map(lambda x: 1)
6319+
expected = xr.Dataset({"foo": 1, "bar": 1})
6320+
assert_identical(result, expected)
6321+
6322+
# Numpy array output with same shape
6323+
result = ds.map(lambda x: x.values)
6324+
expected = ds.copy()
6325+
assert_identical(result, expected)
6326+
6327+
# Mixed: some return scalars, some return arrays
6328+
def mixed_func(x):
6329+
if "x" in x.dims:
6330+
return 42
6331+
return x
6332+
6333+
result = ds.map(mixed_func)
6334+
expected = xr.Dataset({"foo": 42, "bar": ("y", [4, 5])})
6335+
assert_identical(result, expected)
6336+
63126337
def test_apply_pending_deprecated_map(self) -> None:
63136338
data = create_test_data()
63146339
data.attrs["foo"] = "bar"

0 commit comments

Comments
 (0)