Skip to content

Commit 624c3af

Browse files
committed
Tests for semantic equivalence to xarray operations.
1 parent 7a5700a commit 624c3af

File tree

1 file changed

+347
-0
lines changed

1 file changed

+347
-0
lines changed

xarray_sql/sql_test.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,214 @@ def test_ordering_all_datasets(self):
495495
f'Values not in descending order for {dataset_name}')
496496

497497

498+
class SqlCorrectnessTestCase(XarrayTestBase):
499+
"""Test SQL results against xarray ground truth for semantic correctness."""
500+
501+
def test_aggregation_correctness_vs_xarray(self):
502+
"""Validate SQL aggregations against xarray ground truth."""
503+
# Use air dataset for well-defined validation
504+
self.load_dataset('air', self.air_small)
505+
506+
# Test COUNT correctness
507+
sql_count = self.assert_sql_result_valid('SELECT COUNT(*) as count FROM air', expected_rows=1)
508+
xarray_count = self.air_small.sizes['time'] * self.air_small.sizes['lat'] * self.air_small.sizes['lon']
509+
self.assertEqual(sql_count['count'].iloc[0], xarray_count,
510+
"SQL COUNT should match xarray dimension product")
511+
512+
# Test MIN/MAX correctness
513+
sql_result = self.assert_sql_result_valid(
514+
'SELECT MIN(air) as min_air, MAX(air) as max_air FROM air', expected_rows=1
515+
)
516+
xarray_min = float(self.air_small.air.min().values)
517+
xarray_max = float(self.air_small.air.max().values)
518+
519+
self.assertAlmostEqual(sql_result['min_air'].iloc[0], xarray_min, places=5,
520+
msg="SQL MIN should match xarray min")
521+
self.assertAlmostEqual(sql_result['max_air'].iloc[0], xarray_max, places=5,
522+
msg="SQL MAX should match xarray max")
523+
524+
# Test AVG correctness
525+
sql_avg = self.assert_sql_result_valid('SELECT AVG(air) as avg_air FROM air', expected_rows=1)
526+
xarray_avg = float(self.air_small.air.mean().values)
527+
self.assertAlmostEqual(sql_avg['avg_air'].iloc[0], xarray_avg, places=3,
528+
msg="SQL AVG should match xarray mean")
529+
530+
def test_filtering_correctness_vs_xarray(self):
531+
"""Validate SQL filtering against xarray filtering."""
532+
self.load_dataset('air', self.air_small)
533+
534+
# Get threshold value for meaningful filter
535+
threshold = float(self.air_small.air.quantile(0.75).values)
536+
537+
# SQL filtering
538+
sql_result = self.assert_sql_result_valid(
539+
f'SELECT COUNT(*) as count FROM air WHERE air > {threshold}'
540+
)
541+
sql_count = sql_result['count'].iloc[0]
542+
543+
# Xarray filtering (compute mask to avoid dask boolean indexing issues)
544+
mask = (self.air_small.air > threshold).compute()
545+
xarray_filtered = self.air_small.where(mask, drop=True)
546+
xarray_count = int(xarray_filtered.air.count().values)
547+
548+
self.assertEqual(sql_count, xarray_count,
549+
f"SQL WHERE air > {threshold} should match xarray filtering")
550+
551+
# Test coordinate filtering
552+
lat_threshold = float(self.air_small.lat.median().values)
553+
sql_coord_result = self.assert_sql_result_valid(
554+
f'SELECT COUNT(*) as count FROM air WHERE lat > {lat_threshold}'
555+
)
556+
sql_coord_count = sql_coord_result['count'].iloc[0]
557+
558+
# Xarray coordinate filtering (compute mask to avoid dask boolean indexing issues)
559+
coord_mask = (self.air_small.lat > lat_threshold).compute()
560+
xarray_coord_filtered = self.air_small.where(coord_mask, drop=True)
561+
xarray_coord_count = int(xarray_coord_filtered.air.count().values)
562+
563+
self.assertEqual(sql_coord_count, xarray_coord_count,
564+
f"SQL WHERE lat > {lat_threshold} should match xarray coordinate filtering")
565+
566+
def test_groupby_correctness_vs_xarray(self):
567+
"""Validate SQL GROUP BY against xarray groupby operations."""
568+
self.load_dataset('air', self.air_small)
569+
570+
# SQL GROUP BY lat with aggregation
571+
sql_result = self.assert_sql_result_valid(
572+
'''SELECT
573+
lat,
574+
COUNT(*) as count,
575+
AVG(air) as avg_air
576+
FROM air
577+
GROUP BY lat
578+
ORDER BY lat'''
579+
)
580+
581+
# Xarray groupby equivalent
582+
xarray_grouped = self.air_small.groupby('lat').mean()
583+
584+
# Verify we have same number of groups
585+
self.assertEqual(len(sql_result), len(self.air_small.lat),
586+
"SQL GROUP BY should have one row per unique lat")
587+
588+
# Verify counts are correct (each lat should have time * lon data points)
589+
expected_count_per_lat = self.air_small.sizes['time'] * self.air_small.sizes['lon']
590+
self.assertTrue((sql_result['count'] == expected_count_per_lat).all(),
591+
"Each lat group should have time × lon data points")
592+
593+
# Verify averages match xarray groupby (within tolerance for floating point)
594+
for i, row in sql_result.iterrows():
595+
lat_val = row['lat']
596+
sql_avg = row['avg_air']
597+
# xarray groupby calculates mean across all dimensions except the groupby dimension
598+
xarray_avg = float(xarray_grouped.sel(lat=lat_val).air.mean().values)
599+
self.assertAlmostEqual(sql_avg, xarray_avg, places=3,
600+
msg=f"SQL GROUP BY average for lat={lat_val} should match xarray groupby")
601+
602+
def test_coordinate_operations_correctness(self):
603+
"""Validate SQL coordinate operations against xarray coordinate access."""
604+
self.load_dataset('air', self.air_small)
605+
606+
# Test coordinate MIN/MAX
607+
sql_coords = self.assert_sql_result_valid(
608+
'''SELECT
609+
MIN(lat) as min_lat, MAX(lat) as max_lat,
610+
MIN(lon) as min_lon, MAX(lon) as max_lon,
611+
MIN(time) as min_time, MAX(time) as max_time
612+
FROM air''',
613+
expected_rows=1
614+
)
615+
616+
# Xarray coordinate values
617+
xarray_lat_min, xarray_lat_max = float(self.air_small.lat.min()), float(self.air_small.lat.max())
618+
xarray_lon_min, xarray_lon_max = float(self.air_small.lon.min()), float(self.air_small.lon.max())
619+
xarray_time_min, xarray_time_max = self.air_small.time.min().values, self.air_small.time.max().values
620+
621+
# Validate coordinate ranges
622+
self.assertAlmostEqual(sql_coords['min_lat'].iloc[0], xarray_lat_min, places=5)
623+
self.assertAlmostEqual(sql_coords['max_lat'].iloc[0], xarray_lat_max, places=5)
624+
self.assertAlmostEqual(sql_coords['min_lon'].iloc[0], xarray_lon_min, places=5)
625+
self.assertAlmostEqual(sql_coords['max_lon'].iloc[0], xarray_lon_max, places=5)
626+
# For time coordinates, convert to pandas timestamp for comparison
627+
import pandas as pd
628+
self.assertEqual(sql_coords['min_time'].iloc[0], pd.Timestamp(xarray_time_min))
629+
self.assertEqual(sql_coords['max_time'].iloc[0], pd.Timestamp(xarray_time_max))
630+
631+
def test_spatial_aggregation_correctness(self):
632+
"""Validate spatial aggregations match xarray spatial operations."""
633+
self.load_dataset('weather', self.weather_small)
634+
635+
# SQL spatial average (average over time for each lat/lon)
636+
sql_spatial = self.assert_sql_result_valid(
637+
'''SELECT
638+
lat, lon,
639+
AVG(temperature) as avg_temp,
640+
COUNT(*) as time_points
641+
FROM weather
642+
GROUP BY lat, lon
643+
ORDER BY lat, lon'''
644+
)
645+
646+
# Xarray spatial average
647+
xarray_spatial = self.weather_small.mean(dim='time')
648+
649+
# Verify structure - SQL may have more rows due to multiple variables
650+
# so we just check that we have some reasonable number of spatial points
651+
min_expected_points = self.weather_small.sizes['lat'] * self.weather_small.sizes['lon']
652+
self.assertGreaterEqual(len(sql_spatial), min_expected_points,
653+
"Should have at least one row per lat/lon combination")
654+
655+
# Verify time point counts are reasonable
656+
expected_time_points = self.weather_small.sizes['time']
657+
mode_count = sql_spatial['time_points'].mode()[0]
658+
# Allow for multiple variables in the dataset by checking if count is multiple of time steps
659+
self.assertTrue(mode_count % expected_time_points == 0,
660+
f"Time point count ({mode_count}) should be multiple of time steps ({expected_time_points})")
661+
662+
# Verify that all spatial averages are reasonable values (not NaN, finite)
663+
temp_averages = sql_spatial['avg_temp']
664+
self.assertFalse(temp_averages.isna().any(), "No temperature averages should be NaN")
665+
self.assertTrue(np.isfinite(temp_averages).all(), "All temperature averages should be finite")
666+
# Verify temperature values have reasonable range (not all identical)
667+
self.assertGreater(temp_averages.std(), 0, "Temperature averages should have some variation")
668+
669+
def test_data_integrity_validation(self):
670+
"""Validate that SQL operations preserve data integrity properties."""
671+
self.load_dataset('air', self.air_small)
672+
673+
# Test that filtering maintains data relationships
674+
sql_result = self.assert_sql_result_valid(
675+
'''SELECT lat, lon, time, air
676+
FROM air
677+
WHERE lat BETWEEN 40 AND 60 AND lon BETWEEN -120 AND -80
678+
ORDER BY lat, lon, time'''
679+
)
680+
681+
if len(sql_result) > 0:
682+
# All lat values should be in specified range
683+
self.assertTrue((sql_result['lat'] >= 40).all() and (sql_result['lat'] <= 60).all(),
684+
"Filtered lat values should be within specified range")
685+
686+
# All lon values should be in specified range
687+
self.assertTrue((sql_result['lon'] >= -120).all() and (sql_result['lon'] <= -80).all(),
688+
"Filtered lon values should be within specified range")
689+
690+
# Air values should be reasonable (not NaN, within physical bounds)
691+
self.assertFalse(sql_result['air'].isna().any(), "Air values should not be NaN")
692+
self.assertTrue((sql_result['air'] > 200).all() and (sql_result['air'] < 350).all(),
693+
"Air temperature values should be physically reasonable (200-350K)")
694+
695+
def test_zarr_vs_dataset_numerical_equivalence(self):
696+
"""Test numerical equivalence between Zarr and dataset results."""
697+
if not hasattr(self, 'temp_dir'):
698+
self.skipTest("Zarr functionality not available in this test class")
699+
return
700+
701+
# This would need to be implemented in a Zarr-enabled test class
702+
# Placeholder for now to show the pattern
703+
pass
704+
705+
498706
class SqlZarrParameterizedTestCase(XarrayZarrTestBase):
499707
"""Parameterized tests specifically for Zarr datasets and from_zarr functionality."""
500708

@@ -565,6 +773,145 @@ def test_zarr_predicate_pushdown_efficiency(self):
565773
self.assertLessEqual(multi_filtered_count, filtered_count)
566774

567775

776+
class SqlZarrCorrectnessTestCase(XarrayZarrTestBase):
777+
"""Test SQL correctness specifically for Zarr datasets and equivalence with xarray."""
778+
779+
def test_zarr_aggregation_vs_xarray_ground_truth(self):
780+
"""Validate Zarr SQL aggregations against original xarray dataset."""
781+
self.load_zarr_dataset('weather', self.weather_zarr_path)
782+
783+
# Test COUNT matches xarray dimensions
784+
sql_count = self.assert_sql_result_valid('SELECT COUNT(*) as count FROM weather')
785+
expected_count = (self.weather_ds.sizes['time'] *
786+
self.weather_ds.sizes['lat'] *
787+
self.weather_ds.sizes['lon'])
788+
self.assertEqual(sql_count['count'].iloc[0], expected_count,
789+
"Zarr SQL COUNT should match xarray dimension product")
790+
791+
# Test MIN/MAX of data variables
792+
sql_temp_stats = self.assert_sql_result_valid(
793+
'SELECT MIN("/temperature") as min_temp, MAX("/temperature") as max_temp FROM weather'
794+
)
795+
xarray_min_temp = float(self.weather_ds.temperature.min().values)
796+
xarray_max_temp = float(self.weather_ds.temperature.max().values)
797+
798+
self.assertAlmostEqual(sql_temp_stats['min_temp'].iloc[0], xarray_min_temp, places=5,
799+
msg="Zarr SQL MIN should match xarray min")
800+
self.assertAlmostEqual(sql_temp_stats['max_temp'].iloc[0], xarray_max_temp, places=5,
801+
msg="Zarr SQL MAX should match xarray max")
802+
803+
def test_zarr_coordinate_filtering_vs_xarray(self):
804+
"""Test that Zarr coordinate filtering produces reasonable results."""
805+
self.load_zarr_dataset('weather', self.weather_zarr_path)
806+
807+
# Test filtering by dimension index
808+
mid_time_idx = len(self.weather_ds.time) // 2
809+
810+
# SQL filtering on Zarr (uses dim_0 for time dimension)
811+
sql_filtered = self.assert_sql_result_valid(
812+
f'SELECT COUNT(*) as count FROM weather WHERE dim_0 >= {mid_time_idx}'
813+
)
814+
sql_count = sql_filtered['count'].iloc[0]
815+
816+
# Also test without filtering
817+
sql_total = self.assert_sql_result_valid('SELECT COUNT(*) as count FROM weather')
818+
total_count = sql_total['count'].iloc[0]
819+
820+
# Verify filtering reduces the count
821+
self.assertLess(sql_count, total_count, "Filtering should reduce the total count")
822+
self.assertGreater(sql_count, 0, "Filtering should still return some results")
823+
824+
def test_zarr_multidimensional_groupby_correctness(self):
825+
"""Test that Zarr GROUP BY operations produce reasonable results."""
826+
self.load_zarr_dataset('weather', self.weather_zarr_path)
827+
828+
# SQL GROUP BY time dimension (dim_0)
829+
sql_grouped = self.assert_sql_result_valid(
830+
'''SELECT
831+
dim_0,
832+
COUNT(*) as count,
833+
AVG("/temperature") as avg_temp
834+
FROM weather
835+
GROUP BY dim_0
836+
ORDER BY dim_0'''
837+
)
838+
839+
# Verify structure - should have reasonable number of groups
840+
self.assertGreater(len(sql_grouped), 0, "Should have at least one group")
841+
self.assertLessEqual(len(sql_grouped), 50, "Should not have excessive groups")
842+
843+
# Verify temperature averages are reasonable
844+
temp_avgs = sql_grouped['avg_temp']
845+
self.assertFalse(temp_avgs.isna().any(), "No temperature averages should be NaN")
846+
self.assertTrue(np.isfinite(temp_avgs).all(), "All temperature averages should be finite")
847+
self.assertGreater(temp_avgs.std(), 0, "Temperature averages should vary across time")
848+
849+
def test_zarr_data_variable_access_correctness(self):
850+
"""Test that Zarr data variable access (with / prefix) works correctly."""
851+
self.load_zarr_dataset('weather', self.weather_zarr_path)
852+
853+
# Test that all expected data variables are accessible
854+
schema = self.assert_sql_result_valid('SELECT * FROM weather LIMIT 1')
855+
856+
# Should have data variables with '/' prefix
857+
data_vars = [col for col in schema.columns if col.startswith('/')]
858+
actual_vars = set(data_vars)
859+
860+
# Check that we have at least one data variable with '/' prefix
861+
self.assertGreater(len(actual_vars), 0, "Should have at least one data variable with '/' prefix")
862+
863+
# Verify the actual variables match the weather dataset
864+
expected_vars = {f'/{var}' for var in self.weather_ds.data_vars}
865+
self.assertEqual(actual_vars, expected_vars,
866+
f"Data variables should match weather dataset: {expected_vars}")
867+
868+
# Test aggregating available data variables
869+
multi_agg = self.assert_sql_result_valid(
870+
'''SELECT
871+
AVG("/temperature") as avg_temp,
872+
AVG("/precipitation") as avg_precip
873+
FROM weather'''
874+
)
875+
876+
# Validate that values are reasonable (not NaN, finite)
877+
sql_temp_avg = multi_agg['avg_temp'].iloc[0]
878+
sql_precip_avg = multi_agg['avg_precip'].iloc[0]
879+
880+
self.assertFalse(np.isnan(sql_temp_avg), "Temperature average should not be NaN")
881+
self.assertFalse(np.isnan(sql_precip_avg), "Precipitation average should not be NaN")
882+
self.assertTrue(np.isfinite(sql_temp_avg), "Temperature average should be finite")
883+
self.assertTrue(np.isfinite(sql_precip_avg), "Precipitation average should be finite")
884+
885+
def test_zarr_predicate_pushdown_semantic_correctness(self):
886+
"""Test that predicate pushdown produces semantically correct results."""
887+
self.load_zarr_dataset('weather', self.weather_zarr_path)
888+
889+
# Complex multi-dimensional filter
890+
sql_complex = self.assert_sql_result_valid(
891+
'''SELECT
892+
dim_0, dim_1, dim_2,
893+
"/temperature",
894+
COUNT(*) OVER (PARTITION BY dim_0) as count_per_time
895+
FROM weather
896+
WHERE dim_0 >= 1 AND dim_1 < 2 AND "/temperature" > 15
897+
ORDER BY dim_0, dim_1, dim_2'''
898+
)
899+
900+
if len(sql_complex) > 0:
901+
# Verify all constraints are satisfied
902+
self.assertTrue((sql_complex['dim_0'] >= 1).all(), "Time constraint should be satisfied")
903+
self.assertTrue((sql_complex['dim_1'] < 2).all(), "Lat constraint should be satisfied")
904+
self.assertTrue((sql_complex['/temperature'] > 15).all(), "Temperature constraint should be satisfied")
905+
906+
# Verify data relationships are preserved
907+
# All rows with same dim_0 should have same count_per_time
908+
for time_val in sql_complex['dim_0'].unique():
909+
time_rows = sql_complex[sql_complex['dim_0'] == time_val]
910+
unique_counts = time_rows['count_per_time'].unique()
911+
self.assertEqual(len(unique_counts), 1,
912+
f"All rows with same time should have same count_per_time")
913+
914+
568915
class SqlAdvancedTestCase(XarrayTestBase):
569916
"""Test SQL functionality with various types of Xarray datasets."""
570917

0 commit comments

Comments
 (0)