@@ -495,6 +495,214 @@ def test_ordering_all_datasets(self):
495495 f'Values not in descending order for { dataset_name } ' )
496496
497497
498+ class SqlCorrectnessTestCase (XarrayTestBase ):
499+ """Test SQL results against xarray ground truth for semantic correctness."""
500+
501+ def test_aggregation_correctness_vs_xarray (self ):
502+ """Validate SQL aggregations against xarray ground truth."""
503+ # Use air dataset for well-defined validation
504+ self .load_dataset ('air' , self .air_small )
505+
506+ # Test COUNT correctness
507+ sql_count = self .assert_sql_result_valid ('SELECT COUNT(*) as count FROM air' , expected_rows = 1 )
508+ xarray_count = self .air_small .sizes ['time' ] * self .air_small .sizes ['lat' ] * self .air_small .sizes ['lon' ]
509+ self .assertEqual (sql_count ['count' ].iloc [0 ], xarray_count ,
510+ "SQL COUNT should match xarray dimension product" )
511+
512+ # Test MIN/MAX correctness
513+ sql_result = self .assert_sql_result_valid (
514+ 'SELECT MIN(air) as min_air, MAX(air) as max_air FROM air' , expected_rows = 1
515+ )
516+ xarray_min = float (self .air_small .air .min ().values )
517+ xarray_max = float (self .air_small .air .max ().values )
518+
519+ self .assertAlmostEqual (sql_result ['min_air' ].iloc [0 ], xarray_min , places = 5 ,
520+ msg = "SQL MIN should match xarray min" )
521+ self .assertAlmostEqual (sql_result ['max_air' ].iloc [0 ], xarray_max , places = 5 ,
522+ msg = "SQL MAX should match xarray max" )
523+
524+ # Test AVG correctness
525+ sql_avg = self .assert_sql_result_valid ('SELECT AVG(air) as avg_air FROM air' , expected_rows = 1 )
526+ xarray_avg = float (self .air_small .air .mean ().values )
527+ self .assertAlmostEqual (sql_avg ['avg_air' ].iloc [0 ], xarray_avg , places = 3 ,
528+ msg = "SQL AVG should match xarray mean" )
529+
530+ def test_filtering_correctness_vs_xarray (self ):
531+ """Validate SQL filtering against xarray filtering."""
532+ self .load_dataset ('air' , self .air_small )
533+
534+ # Get threshold value for meaningful filter
535+ threshold = float (self .air_small .air .quantile (0.75 ).values )
536+
537+ # SQL filtering
538+ sql_result = self .assert_sql_result_valid (
539+ f'SELECT COUNT(*) as count FROM air WHERE air > { threshold } '
540+ )
541+ sql_count = sql_result ['count' ].iloc [0 ]
542+
543+ # Xarray filtering (compute mask to avoid dask boolean indexing issues)
544+ mask = (self .air_small .air > threshold ).compute ()
545+ xarray_filtered = self .air_small .where (mask , drop = True )
546+ xarray_count = int (xarray_filtered .air .count ().values )
547+
548+ self .assertEqual (sql_count , xarray_count ,
549+ f"SQL WHERE air > { threshold } should match xarray filtering" )
550+
551+ # Test coordinate filtering
552+ lat_threshold = float (self .air_small .lat .median ().values )
553+ sql_coord_result = self .assert_sql_result_valid (
554+ f'SELECT COUNT(*) as count FROM air WHERE lat > { lat_threshold } '
555+ )
556+ sql_coord_count = sql_coord_result ['count' ].iloc [0 ]
557+
558+ # Xarray coordinate filtering (compute mask to avoid dask boolean indexing issues)
559+ coord_mask = (self .air_small .lat > lat_threshold ).compute ()
560+ xarray_coord_filtered = self .air_small .where (coord_mask , drop = True )
561+ xarray_coord_count = int (xarray_coord_filtered .air .count ().values )
562+
563+ self .assertEqual (sql_coord_count , xarray_coord_count ,
564+ f"SQL WHERE lat > { lat_threshold } should match xarray coordinate filtering" )
565+
566+ def test_groupby_correctness_vs_xarray (self ):
567+ """Validate SQL GROUP BY against xarray groupby operations."""
568+ self .load_dataset ('air' , self .air_small )
569+
570+ # SQL GROUP BY lat with aggregation
571+ sql_result = self .assert_sql_result_valid (
572+ '''SELECT
573+ lat,
574+ COUNT(*) as count,
575+ AVG(air) as avg_air
576+ FROM air
577+ GROUP BY lat
578+ ORDER BY lat'''
579+ )
580+
581+ # Xarray groupby equivalent
582+ xarray_grouped = self .air_small .groupby ('lat' ).mean ()
583+
584+ # Verify we have same number of groups
585+ self .assertEqual (len (sql_result ), len (self .air_small .lat ),
586+ "SQL GROUP BY should have one row per unique lat" )
587+
588+ # Verify counts are correct (each lat should have time * lon data points)
589+ expected_count_per_lat = self .air_small .sizes ['time' ] * self .air_small .sizes ['lon' ]
590+ self .assertTrue ((sql_result ['count' ] == expected_count_per_lat ).all (),
591+ "Each lat group should have time × lon data points" )
592+
593+ # Verify averages match xarray groupby (within tolerance for floating point)
594+ for i , row in sql_result .iterrows ():
595+ lat_val = row ['lat' ]
596+ sql_avg = row ['avg_air' ]
597+ # xarray groupby calculates mean across all dimensions except the groupby dimension
598+ xarray_avg = float (xarray_grouped .sel (lat = lat_val ).air .mean ().values )
599+ self .assertAlmostEqual (sql_avg , xarray_avg , places = 3 ,
600+ msg = f"SQL GROUP BY average for lat={ lat_val } should match xarray groupby" )
601+
602+ def test_coordinate_operations_correctness (self ):
603+ """Validate SQL coordinate operations against xarray coordinate access."""
604+ self .load_dataset ('air' , self .air_small )
605+
606+ # Test coordinate MIN/MAX
607+ sql_coords = self .assert_sql_result_valid (
608+ '''SELECT
609+ MIN(lat) as min_lat, MAX(lat) as max_lat,
610+ MIN(lon) as min_lon, MAX(lon) as max_lon,
611+ MIN(time) as min_time, MAX(time) as max_time
612+ FROM air''' ,
613+ expected_rows = 1
614+ )
615+
616+ # Xarray coordinate values
617+ xarray_lat_min , xarray_lat_max = float (self .air_small .lat .min ()), float (self .air_small .lat .max ())
618+ xarray_lon_min , xarray_lon_max = float (self .air_small .lon .min ()), float (self .air_small .lon .max ())
619+ xarray_time_min , xarray_time_max = self .air_small .time .min ().values , self .air_small .time .max ().values
620+
621+ # Validate coordinate ranges
622+ self .assertAlmostEqual (sql_coords ['min_lat' ].iloc [0 ], xarray_lat_min , places = 5 )
623+ self .assertAlmostEqual (sql_coords ['max_lat' ].iloc [0 ], xarray_lat_max , places = 5 )
624+ self .assertAlmostEqual (sql_coords ['min_lon' ].iloc [0 ], xarray_lon_min , places = 5 )
625+ self .assertAlmostEqual (sql_coords ['max_lon' ].iloc [0 ], xarray_lon_max , places = 5 )
626+ # For time coordinates, convert to pandas timestamp for comparison
627+ import pandas as pd
628+ self .assertEqual (sql_coords ['min_time' ].iloc [0 ], pd .Timestamp (xarray_time_min ))
629+ self .assertEqual (sql_coords ['max_time' ].iloc [0 ], pd .Timestamp (xarray_time_max ))
630+
631+ def test_spatial_aggregation_correctness (self ):
632+ """Validate spatial aggregations match xarray spatial operations."""
633+ self .load_dataset ('weather' , self .weather_small )
634+
635+ # SQL spatial average (average over time for each lat/lon)
636+ sql_spatial = self .assert_sql_result_valid (
637+ '''SELECT
638+ lat, lon,
639+ AVG(temperature) as avg_temp,
640+ COUNT(*) as time_points
641+ FROM weather
642+ GROUP BY lat, lon
643+ ORDER BY lat, lon'''
644+ )
645+
646+ # Xarray spatial average
647+ xarray_spatial = self .weather_small .mean (dim = 'time' )
648+
649+ # Verify structure - SQL may have more rows due to multiple variables
650+ # so we just check that we have some reasonable number of spatial points
651+ min_expected_points = self .weather_small .sizes ['lat' ] * self .weather_small .sizes ['lon' ]
652+ self .assertGreaterEqual (len (sql_spatial ), min_expected_points ,
653+ "Should have at least one row per lat/lon combination" )
654+
655+ # Verify time point counts are reasonable
656+ expected_time_points = self .weather_small .sizes ['time' ]
657+ mode_count = sql_spatial ['time_points' ].mode ()[0 ]
658+ # Allow for multiple variables in the dataset by checking if count is multiple of time steps
659+ self .assertTrue (mode_count % expected_time_points == 0 ,
660+ f"Time point count ({ mode_count } ) should be multiple of time steps ({ expected_time_points } )" )
661+
662+ # Verify that all spatial averages are reasonable values (not NaN, finite)
663+ temp_averages = sql_spatial ['avg_temp' ]
664+ self .assertFalse (temp_averages .isna ().any (), "No temperature averages should be NaN" )
665+ self .assertTrue (np .isfinite (temp_averages ).all (), "All temperature averages should be finite" )
666+ # Verify temperature values have reasonable range (not all identical)
667+ self .assertGreater (temp_averages .std (), 0 , "Temperature averages should have some variation" )
668+
669+ def test_data_integrity_validation (self ):
670+ """Validate that SQL operations preserve data integrity properties."""
671+ self .load_dataset ('air' , self .air_small )
672+
673+ # Test that filtering maintains data relationships
674+ sql_result = self .assert_sql_result_valid (
675+ '''SELECT lat, lon, time, air
676+ FROM air
677+ WHERE lat BETWEEN 40 AND 60 AND lon BETWEEN -120 AND -80
678+ ORDER BY lat, lon, time'''
679+ )
680+
681+ if len (sql_result ) > 0 :
682+ # All lat values should be in specified range
683+ self .assertTrue ((sql_result ['lat' ] >= 40 ).all () and (sql_result ['lat' ] <= 60 ).all (),
684+ "Filtered lat values should be within specified range" )
685+
686+ # All lon values should be in specified range
687+ self .assertTrue ((sql_result ['lon' ] >= - 120 ).all () and (sql_result ['lon' ] <= - 80 ).all (),
688+ "Filtered lon values should be within specified range" )
689+
690+ # Air values should be reasonable (not NaN, within physical bounds)
691+ self .assertFalse (sql_result ['air' ].isna ().any (), "Air values should not be NaN" )
692+ self .assertTrue ((sql_result ['air' ] > 200 ).all () and (sql_result ['air' ] < 350 ).all (),
693+ "Air temperature values should be physically reasonable (200-350K)" )
694+
695+ def test_zarr_vs_dataset_numerical_equivalence (self ):
696+ """Test numerical equivalence between Zarr and dataset results."""
697+ if not hasattr (self , 'temp_dir' ):
698+ self .skipTest ("Zarr functionality not available in this test class" )
699+ return
700+
701+ # This would need to be implemented in a Zarr-enabled test class
702+ # Placeholder for now to show the pattern
703+ pass
704+
705+
498706class SqlZarrParameterizedTestCase (XarrayZarrTestBase ):
499707 """Parameterized tests specifically for Zarr datasets and from_zarr functionality."""
500708
@@ -565,6 +773,145 @@ def test_zarr_predicate_pushdown_efficiency(self):
565773 self .assertLessEqual (multi_filtered_count , filtered_count )
566774
567775
776+ class SqlZarrCorrectnessTestCase (XarrayZarrTestBase ):
777+ """Test SQL correctness specifically for Zarr datasets and equivalence with xarray."""
778+
779+ def test_zarr_aggregation_vs_xarray_ground_truth (self ):
780+ """Validate Zarr SQL aggregations against original xarray dataset."""
781+ self .load_zarr_dataset ('weather' , self .weather_zarr_path )
782+
783+ # Test COUNT matches xarray dimensions
784+ sql_count = self .assert_sql_result_valid ('SELECT COUNT(*) as count FROM weather' )
785+ expected_count = (self .weather_ds .sizes ['time' ] *
786+ self .weather_ds .sizes ['lat' ] *
787+ self .weather_ds .sizes ['lon' ])
788+ self .assertEqual (sql_count ['count' ].iloc [0 ], expected_count ,
789+ "Zarr SQL COUNT should match xarray dimension product" )
790+
791+ # Test MIN/MAX of data variables
792+ sql_temp_stats = self .assert_sql_result_valid (
793+ 'SELECT MIN("/temperature") as min_temp, MAX("/temperature") as max_temp FROM weather'
794+ )
795+ xarray_min_temp = float (self .weather_ds .temperature .min ().values )
796+ xarray_max_temp = float (self .weather_ds .temperature .max ().values )
797+
798+ self .assertAlmostEqual (sql_temp_stats ['min_temp' ].iloc [0 ], xarray_min_temp , places = 5 ,
799+ msg = "Zarr SQL MIN should match xarray min" )
800+ self .assertAlmostEqual (sql_temp_stats ['max_temp' ].iloc [0 ], xarray_max_temp , places = 5 ,
801+ msg = "Zarr SQL MAX should match xarray max" )
802+
803+ def test_zarr_coordinate_filtering_vs_xarray (self ):
804+ """Test that Zarr coordinate filtering produces reasonable results."""
805+ self .load_zarr_dataset ('weather' , self .weather_zarr_path )
806+
807+ # Test filtering by dimension index
808+ mid_time_idx = len (self .weather_ds .time ) // 2
809+
810+ # SQL filtering on Zarr (uses dim_0 for time dimension)
811+ sql_filtered = self .assert_sql_result_valid (
812+ f'SELECT COUNT(*) as count FROM weather WHERE dim_0 >= { mid_time_idx } '
813+ )
814+ sql_count = sql_filtered ['count' ].iloc [0 ]
815+
816+ # Also test without filtering
817+ sql_total = self .assert_sql_result_valid ('SELECT COUNT(*) as count FROM weather' )
818+ total_count = sql_total ['count' ].iloc [0 ]
819+
820+ # Verify filtering reduces the count
821+ self .assertLess (sql_count , total_count , "Filtering should reduce the total count" )
822+ self .assertGreater (sql_count , 0 , "Filtering should still return some results" )
823+
824+ def test_zarr_multidimensional_groupby_correctness (self ):
825+ """Test that Zarr GROUP BY operations produce reasonable results."""
826+ self .load_zarr_dataset ('weather' , self .weather_zarr_path )
827+
828+ # SQL GROUP BY time dimension (dim_0)
829+ sql_grouped = self .assert_sql_result_valid (
830+ '''SELECT
831+ dim_0,
832+ COUNT(*) as count,
833+ AVG("/temperature") as avg_temp
834+ FROM weather
835+ GROUP BY dim_0
836+ ORDER BY dim_0'''
837+ )
838+
839+ # Verify structure - should have reasonable number of groups
840+ self .assertGreater (len (sql_grouped ), 0 , "Should have at least one group" )
841+ self .assertLessEqual (len (sql_grouped ), 50 , "Should not have excessive groups" )
842+
843+ # Verify temperature averages are reasonable
844+ temp_avgs = sql_grouped ['avg_temp' ]
845+ self .assertFalse (temp_avgs .isna ().any (), "No temperature averages should be NaN" )
846+ self .assertTrue (np .isfinite (temp_avgs ).all (), "All temperature averages should be finite" )
847+ self .assertGreater (temp_avgs .std (), 0 , "Temperature averages should vary across time" )
848+
849+ def test_zarr_data_variable_access_correctness (self ):
850+ """Test that Zarr data variable access (with / prefix) works correctly."""
851+ self .load_zarr_dataset ('weather' , self .weather_zarr_path )
852+
853+ # Test that all expected data variables are accessible
854+ schema = self .assert_sql_result_valid ('SELECT * FROM weather LIMIT 1' )
855+
856+ # Should have data variables with '/' prefix
857+ data_vars = [col for col in schema .columns if col .startswith ('/' )]
858+ actual_vars = set (data_vars )
859+
860+ # Check that we have at least one data variable with '/' prefix
861+ self .assertGreater (len (actual_vars ), 0 , "Should have at least one data variable with '/' prefix" )
862+
863+ # Verify the actual variables match the weather dataset
864+ expected_vars = {f'/{ var } ' for var in self .weather_ds .data_vars }
865+ self .assertEqual (actual_vars , expected_vars ,
866+ f"Data variables should match weather dataset: { expected_vars } " )
867+
868+ # Test aggregating available data variables
869+ multi_agg = self .assert_sql_result_valid (
870+ '''SELECT
871+ AVG("/temperature") as avg_temp,
872+ AVG("/precipitation") as avg_precip
873+ FROM weather'''
874+ )
875+
876+ # Validate that values are reasonable (not NaN, finite)
877+ sql_temp_avg = multi_agg ['avg_temp' ].iloc [0 ]
878+ sql_precip_avg = multi_agg ['avg_precip' ].iloc [0 ]
879+
880+ self .assertFalse (np .isnan (sql_temp_avg ), "Temperature average should not be NaN" )
881+ self .assertFalse (np .isnan (sql_precip_avg ), "Precipitation average should not be NaN" )
882+ self .assertTrue (np .isfinite (sql_temp_avg ), "Temperature average should be finite" )
883+ self .assertTrue (np .isfinite (sql_precip_avg ), "Precipitation average should be finite" )
884+
885+ def test_zarr_predicate_pushdown_semantic_correctness (self ):
886+ """Test that predicate pushdown produces semantically correct results."""
887+ self .load_zarr_dataset ('weather' , self .weather_zarr_path )
888+
889+ # Complex multi-dimensional filter
890+ sql_complex = self .assert_sql_result_valid (
891+ '''SELECT
892+ dim_0, dim_1, dim_2,
893+ "/temperature",
894+ COUNT(*) OVER (PARTITION BY dim_0) as count_per_time
895+ FROM weather
896+ WHERE dim_0 >= 1 AND dim_1 < 2 AND "/temperature" > 15
897+ ORDER BY dim_0, dim_1, dim_2'''
898+ )
899+
900+ if len (sql_complex ) > 0 :
901+ # Verify all constraints are satisfied
902+ self .assertTrue ((sql_complex ['dim_0' ] >= 1 ).all (), "Time constraint should be satisfied" )
903+ self .assertTrue ((sql_complex ['dim_1' ] < 2 ).all (), "Lat constraint should be satisfied" )
904+ self .assertTrue ((sql_complex ['/temperature' ] > 15 ).all (), "Temperature constraint should be satisfied" )
905+
906+ # Verify data relationships are preserved
907+ # All rows with same dim_0 should have same count_per_time
908+ for time_val in sql_complex ['dim_0' ].unique ():
909+ time_rows = sql_complex [sql_complex ['dim_0' ] == time_val ]
910+ unique_counts = time_rows ['count_per_time' ].unique ()
911+ self .assertEqual (len (unique_counts ), 1 ,
912+ f"All rows with same time should have same count_per_time" )
913+
914+
568915class SqlAdvancedTestCase (XarrayTestBase ):
569916 """Test SQL functionality with various types of Xarray datasets."""
570917
0 commit comments