@@ -166,48 +166,29 @@ def nan_ds(self):
166166 },
167167 ).chunk ({"time" : 1 })
168168
169- def test_max_ignores_nan (self , nan_ds ):
170- ctx = XarrayContext ()
171- ctx .from_dataset ("data" , nan_ds )
172- result = ctx .sql ("SELECT MAX(temp) AS mx FROM data" ).to_pandas ()
173- assert result ["mx" ].iloc [0 ] == 8.0
174-
175- def test_min_ignores_nan (self , nan_ds ):
176- ctx = XarrayContext ()
177- ctx .from_dataset ("data" , nan_ds )
178- result = ctx .sql ("SELECT MIN(temp) AS mn FROM data" ).to_pandas ()
179- assert result ["mn" ].iloc [0 ] == 1.0
180-
181- def test_avg_ignores_nan (self , nan_ds ):
169+ def test_nan_aggregates (self , nan_ds ):
182170 import numpy as np
183171
184172 ctx = XarrayContext ()
185173 ctx .from_dataset ("data" , nan_ds )
186- result = ctx .sql ("SELECT AVG(temp) AS avg FROM data" ).to_pandas ()
187- expected = np .nanmean ([1.0 , 2.0 , 4.0 , 5.0 , 7.0 , 8.0 ])
188- assert abs (result ["avg" ].iloc [0 ] - expected ) < 1e-6
189174
190- def test_count_excludes_nan (self , nan_ds ):
191- ctx = XarrayContext ()
192- ctx .from_dataset ("data" , nan_ds )
193- result = ctx .sql ("SELECT COUNT(temp) AS cnt FROM data" ).to_pandas ()
194- # 8 total cells, 2 are NaN → 6 non-null
195- assert result ["cnt" ].iloc [0 ] == 6
196-
197- def test_is_null_matches_nan (self , nan_ds ):
198- ctx = XarrayContext ()
199- ctx .from_dataset ("data" , nan_ds )
200- result = ctx .sql (
201- "SELECT COUNT(*) AS cnt FROM data WHERE temp IS NULL"
202- ).to_pandas ()
203- assert result ["cnt" ].iloc [0 ] == 2
204-
205- def test_from_dataset_returns_self (self ):
206- ds = (
207- xr .tutorial .open_dataset ("air_temperature" )
208- .isel (time = slice (0 , 2 ), lat = slice (0 , 2 ), lon = slice (0 , 2 ))
209- .chunk ({"time" : 1 })
210- )
211- ctx = XarrayContext ()
212- ret = ctx .from_dataset ("air" , ds )
213- assert ret is ctx
175+ # Test multiple aggregates at once:
176+ # MAX/MIN/AVG should ignore NaN, COUNT(col) should exclude NaN,
177+ # and WHERE col IS NULL should match NaN.
178+ query = """
179+ SELECT
180+ MAX(temp) AS mx,
181+ MIN(temp) AS mn,
182+ AVG(temp) AS avg,
183+ COUNT(temp) AS cnt,
184+ COUNT(*) FILTER (WHERE temp IS NULL) AS null_cnt
185+ FROM data
186+ """
187+ result = ctx .sql (query ).to_pandas ().iloc [0 ]
188+
189+ assert result ["mx" ] == 8.0
190+ assert result ["mn" ] == 1.0
191+ expected_avg = np .nanmean ([1.0 , 2.0 , 4.0 , 5.0 , 7.0 , 8.0 ])
192+ assert abs (result ["avg" ] - expected_avg ) < 1e-6
193+ assert result ["cnt" ] == 6
194+ assert result ["null_cnt" ] == 2
0 commit comments