|
1 | 1 | """SQL functionality tests for xarray-sql using pytest.""" |
2 | 2 |
|
3 | 3 | import numpy as np |
| 4 | +import pandas as pd |
4 | 5 | import pytest |
5 | 6 | import xarray as xr |
6 | 7 |
|
@@ -146,3 +147,44 @@ def test_string_coordinates(): |
146 | 147 | assert "student" in result.columns |
147 | 148 | assert "subject" in result.columns |
148 | 149 | assert "score" in result.columns |
| 150 | + |
| 151 | + |
| 152 | +class TestNanAsNull: |
| 153 | + """NaN in float columns should become Arrow nulls so SQL aggregates work.""" |
| 154 | + |
| 155 | + @pytest.fixture |
| 156 | + def nan_ds(self): |
| 157 | + data = np.array([[[1.0, 2.0], [np.nan, 4.0]], [[5.0, np.nan], [7.0, 8.0]]]) |
| 158 | + return xr.Dataset( |
| 159 | + {"temp": (["time", "x", "y"], data)}, |
| 160 | + coords={ |
| 161 | + "time": pd.date_range("2020-01-01", periods=2), |
| 162 | + "x": [0, 1], |
| 163 | + "y": [0, 1], |
| 164 | + }, |
| 165 | + ).chunk({"time": 1}) |
| 166 | + |
| 167 | + def test_nan_aggregates(self, nan_ds): |
| 168 | + ctx = XarrayContext() |
| 169 | + ctx.from_dataset("data", nan_ds) |
| 170 | + |
| 171 | + # Test multiple aggregates at once: |
| 172 | + # MAX/MIN/AVG should ignore NaN, COUNT(col) should exclude NaN, |
| 173 | + # and WHERE col IS NULL should match NaN. |
| 174 | + query = """ |
| 175 | + SELECT |
| 176 | + MAX(temp) AS mx, |
| 177 | + MIN(temp) AS mn, |
| 178 | + AVG(temp) AS avg, |
| 179 | + COUNT(temp) AS cnt, |
| 180 | + COUNT(*) FILTER (WHERE temp IS NULL) AS null_cnt |
| 181 | + FROM data |
| 182 | + """ |
| 183 | + result = ctx.sql(query).to_pandas().iloc[0] |
| 184 | + |
| 185 | + assert result["mx"] == 8.0 |
| 186 | + assert result["mn"] == 1.0 |
| 187 | + expected_avg = np.nanmean([1.0, 2.0, 4.0, 5.0, 7.0, 8.0]) |
| 188 | + assert abs(result["avg"] - expected_avg) < 1e-6 |
| 189 | + assert result["cnt"] == 6 |
| 190 | + assert result["null_cnt"] == 2 |
0 commit comments