Skip to content

Commit 5176113

Browse files
authored
Fix null conversion (#151)
1 parent 1382e91 commit 5176113

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

tests/test_sql.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""SQL functionality tests for xarray-sql using pytest."""
22

33
import numpy as np
4+
import pandas as pd
45
import pytest
56
import xarray as xr
67

@@ -146,3 +147,44 @@ def test_string_coordinates():
146147
assert "student" in result.columns
147148
assert "subject" in result.columns
148149
assert "score" in result.columns
150+
151+
152+
class TestNanAsNull:
153+
"""NaN in float columns should become Arrow nulls so SQL aggregates work."""
154+
155+
@pytest.fixture
156+
def nan_ds(self):
157+
data = np.array([[[1.0, 2.0], [np.nan, 4.0]], [[5.0, np.nan], [7.0, 8.0]]])
158+
return xr.Dataset(
159+
{"temp": (["time", "x", "y"], data)},
160+
coords={
161+
"time": pd.date_range("2020-01-01", periods=2),
162+
"x": [0, 1],
163+
"y": [0, 1],
164+
},
165+
).chunk({"time": 1})
166+
167+
def test_nan_aggregates(self, nan_ds):
168+
ctx = XarrayContext()
169+
ctx.from_dataset("data", nan_ds)
170+
171+
# Test multiple aggregates at once:
172+
# MAX/MIN/AVG should ignore NaN, COUNT(col) should exclude NaN,
173+
# and WHERE col IS NULL should match NaN.
174+
query = """
175+
SELECT
176+
MAX(temp) AS mx,
177+
MIN(temp) AS mn,
178+
AVG(temp) AS avg,
179+
COUNT(temp) AS cnt,
180+
COUNT(*) FILTER (WHERE temp IS NULL) AS null_cnt
181+
FROM data
182+
"""
183+
result = ctx.sql(query).to_pandas().iloc[0]
184+
185+
assert result["mx"] == 8.0
186+
assert result["mn"] == 1.0
187+
expected_avg = np.nanmean([1.0, 2.0, 4.0, 5.0, 7.0, 8.0])
188+
assert abs(result["avg"] - expected_avg) < 1e-6
189+
assert result["cnt"] == 6
190+
assert result["null_cnt"] == 2

xarray_sql/df.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,11 @@ def dataset_to_record_batch(
207207
arrays.append(pa.array(arr, type=field.type))
208208
else:
209209
# Data variable: ravel to 1-D (zero-copy for C-contiguous arrays).
210-
arrays.append(pa.array(ds[name].values.ravel(), type=field.type))
210+
# from_pandas=True maps NaN → Arrow null inside the C++ copy kernel,
211+
# so SQL aggregates (MAX, MIN, AVG) skip missing values correctly.
212+
arrays.append(
213+
pa.array(ds[name].values.ravel(), type=field.type, from_pandas=True)
214+
)
211215

212216
return pa.RecordBatch.from_arrays(arrays, schema=schema)
213217

@@ -282,7 +286,11 @@ def iter_record_batches(
282286
arrays.append(pa.array(coord_values[name][coord_idx], type=field.type))
283287
else:
284288
arrays.append(
285-
pa.array(data_arrays[name][row_start:row_end], type=field.type)
289+
pa.array(
290+
data_arrays[name][row_start:row_end],
291+
type=field.type,
292+
from_pandas=True,
293+
)
286294
)
287295

288296
yield pa.RecordBatch.from_arrays(arrays, schema=schema)

0 commit comments

Comments
 (0)