@@ -14,14 +14,15 @@ use pyo3::types::PyCapsule;
1414use pyo3:: { pyclass, pymethods, Bound , PyResult , Python } ;
1515use std:: path:: Path ;
1616use std:: sync:: Arc ;
17+ use std:: collections:: HashSet ;
1718use async_trait:: async_trait;
1819use zarrs:: filesystem:: FilesystemStore ;
1920use zarrs:: group:: Group ;
2021use zarrs:: array:: Array ;
2122use zarrs:: array:: data_type:: DataType as ZarrDataType ;
2223use zarrs:: array_subset:: ArraySubset ;
2324use zarrs:: array:: chunk_grid:: ChunkGrid ;
24- use arrow_array:: { Int64Array , Float64Array , Int32Array , Float32Array } ;
25+ use arrow_array:: { Int64Array , Float64Array , Int32Array , Float32Array , Int16Array } ;
2526
2627/// Represents a coordinate range constraint for chunk filtering
2728#[ derive( Debug , Clone ) ]
@@ -211,6 +212,28 @@ impl ToArrowArray for i32 {
211212 }
212213}
213214
215+ impl ToArrowArray for i16 {
216+ type ArrowArray = arrow_array:: Int16Array ;
217+
218+ fn to_arrow_array (
219+ flat_data : & ndarray:: ArrayBase < ndarray:: CowRepr < ' _ , Self > , ndarray:: Dim < [ usize ; 1 ] > > ,
220+ ) -> Vec < Self > {
221+ if flat_data. is_standard_layout ( ) {
222+ flat_data. as_slice ( ) . unwrap ( ) . to_vec ( )
223+ } else {
224+ flat_data. iter ( ) . cloned ( ) . collect ( )
225+ }
226+ }
227+
228+ fn arrow_data_type ( ) -> DataType {
229+ DataType :: Int16
230+ }
231+
232+ fn from_vec ( data : Vec < Self > ) -> Arc < Self :: ArrowArray > {
233+ Arc :: new ( arrow_array:: Int16Array :: from ( data) )
234+ }
235+ }
236+
214237
215238/// A DataFusion TableProvider that reads from Zarr stores
216239#[ pyclass( name = "ZarrTableProvider" , module = "zarrquet" , subclass) ]
@@ -324,20 +347,106 @@ impl ZarrTableProvider {
324347 // Build unified schema: dimensions first, then data variables
325348 let mut fields = Vec :: new ( ) ;
326349
327- // Add coordinate/dimension fields
350+ // Add coordinate/dimension fields using actual coordinate names
328351 if let Some ( ref shape) = reference_shape {
329- for ( dim_idx, & _dim_size) in shape. iter ( ) . enumerate ( ) {
330- // TODO: Extract actual coordinate names from metadata
352+ // For typical xarray-generated Zarr, we expect coordinates in a specific order
353+ // The air dataset has dimensions (time, lat, lon)
354+ // We need to match coordinate arrays to dimensions
355+
356+ // First, collect all 1D coordinate arrays that match dimension sizes
357+ let mut valid_coords: Vec < ( String , usize ) > = Vec :: new ( ) ;
358+
359+ for ( name, coord_shape, _) in & coordinate_arrays {
360+ if coord_shape. len ( ) == 1 {
361+ let size = coord_shape[ 0 ] ;
362+ // Find all dimension indices that match this size
363+ for ( dim_idx, & dim_size) in shape. iter ( ) . enumerate ( ) {
364+ if dim_size == size {
365+ let clean_name = if name. starts_with ( '/' ) {
366+ name. chars ( ) . skip ( 1 ) . collect ( )
367+ } else {
368+ name. clone ( )
369+ } ;
370+ valid_coords. push ( ( clean_name, dim_idx) ) ;
371+ }
372+ }
373+ }
374+ }
375+
376+ // Sort by the typical xarray dimension order: time, lat, lon, etc.
377+ // We'll use a simple heuristic based on common coordinate names
378+ valid_coords. sort_by ( |a, b| {
379+ let name_a = & a. 0 ;
380+ let name_b = & b. 0 ;
381+
382+ // Define priority order for common coordinate names
383+ let priority = |name : & str | -> i32 {
384+ match name {
385+ "time" => 0 ,
386+ "lat" | "latitude" => 1 ,
387+ "lon" | "longitude" => 2 ,
388+ "level" | "lev" => 3 ,
389+ _ => 100
390+ }
391+ } ;
392+
393+ let priority_a = priority ( name_a) ;
394+ let priority_b = priority ( name_b) ;
395+
396+ if priority_a != priority_b {
397+ priority_a. cmp ( & priority_b)
398+ } else {
399+ // If same priority, sort by dimension index
400+ a. 1 . cmp ( & b. 1 )
401+ }
402+ } ) ;
403+
404+ // Add coordinate fields, avoiding duplicates
405+ let mut added_coords: HashSet < String > = HashSet :: new ( ) ;
406+
407+ for ( coord_name, _) in valid_coords {
408+ if !added_coords. contains ( & coord_name) {
409+ fields. push ( Field :: new ( coord_name. clone ( ) , DataType :: Int64 , false ) ) ;
410+ added_coords. insert ( coord_name) ;
411+ }
412+ }
413+
414+ // If we didn't find enough coordinates, add generic dimension names
415+ // Only add generic names if we have fewer coordinates than dimensions
416+ for dim_idx in added_coords. len ( ) ..shape. len ( ) {
331417 let dim_name = format ! ( "dim_{}" , dim_idx) ;
332- fields. push ( Field :: new ( dim_name, DataType :: Int64 , false ) ) ;
418+ fields. push ( Field :: new ( dim_name. clone ( ) , DataType :: Int64 , false ) ) ;
419+ added_coords. insert ( dim_name) ;
333420 }
334421 }
335422
336- // Add data variable fields
423+ // Add data variable fields (remove leading slash if present)
424+ // But exclude coordinate arrays that are already added as dimensions
425+ let coord_names: HashSet < String > = coordinate_arrays. iter ( )
426+ . map ( |( name, _, _) | {
427+ if name. starts_with ( '/' ) {
428+ name. chars ( ) . skip ( 1 ) . collect ( )
429+ } else {
430+ name. clone ( )
431+ }
432+ } )
433+ . collect ( ) ;
434+
337435 for ( var_name, _shape, data_type) in & data_variables {
338436 let arrow_type = self . zarr_type_to_arrow ( data_type)
339437 . unwrap_or ( DataType :: Float64 ) ; // Default fallback
340- fields. push ( Field :: new ( var_name. clone ( ) , arrow_type, true ) ) ;
438+
439+ // Remove leading slash from variable name if present
440+ let clean_name = if var_name. starts_with ( '/' ) {
441+ var_name. chars ( ) . skip ( 1 ) . collect ( )
442+ } else {
443+ var_name. clone ( )
444+ } ;
445+
446+ // Only add if this is not a coordinate array
447+ if !coord_names. contains ( & clean_name) {
448+ fields. push ( Field :: new ( clean_name, arrow_type, true ) ) ;
449+ }
341450 }
342451
343452 Ok ( Arc :: new ( Schema :: new ( fields) ) )
@@ -592,6 +701,11 @@ impl ZarrTableProvider {
592701 . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) ) ?;
593702 self . create_data_array_i32 ( chunk_data) ?
594703 }
704+ ZarrDataType :: Int16 => {
705+ let chunk_data = array. retrieve_chunk_ndarray :: < i16 > ( chunk_indices)
706+ . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) ) ?;
707+ self . create_data_array_i16 ( chunk_data) ?
708+ }
595709 other => {
596710 return Err ( DataFusionError :: External (
597711 format ! ( "Unsupported zarr data type for variable '{}': {:?}" , var_name, other) . into ( )
@@ -615,22 +729,8 @@ impl ZarrTableProvider {
615729 arrows. push ( data_array) ;
616730 }
617731
618- // Create the schema
619- let mut fields = Vec :: new ( ) ;
620-
621- // Add dimension fields
622- for dim_idx in 0 ..ndim {
623- fields. push ( Field :: new ( format ! ( "dim_{}" , dim_idx) , DataType :: Int64 , false ) ) ;
624- }
625-
626- // Add data variable fields
627- for ( var_name, array, _) in & arrays_data {
628- let arrow_type = self . zarr_type_to_arrow ( array. data_type ( ) )
629- . unwrap_or ( DataType :: Float64 ) ;
630- fields. push ( Field :: new ( var_name. clone ( ) , arrow_type, true ) ) ;
631- }
632-
633- let schema = Arc :: new ( Schema :: new ( fields) ) ;
732+ // Create the schema using the same logic as infer_schema
733+ let schema = self . infer_schema ( ) ?;
634734
635735 // Create the RecordBatch
636736 RecordBatch :: try_new ( schema, arrows)
@@ -688,6 +788,16 @@ impl ZarrTableProvider {
688788 Ok ( i32:: from_vec ( data_vec) as Arc < dyn arrow_array:: Array > )
689789 }
690790
791+ /// Create Arrow array from i16 ndarray
792+ fn create_data_array_i16 ( & self , data : ndarray:: ArrayD < i16 > ) -> Result < Arc < dyn arrow_array:: Array > , DataFusionError > {
793+ let total_elements = data. len ( ) ;
794+ let flat_data = data. to_shape ( total_elements)
795+ . map_err ( |e| DataFusionError :: External ( format ! ( "Failed to reshape i16 array: {}" , e) . into ( ) ) ) ?;
796+
797+ let data_vec = i16:: to_arrow_array ( & flat_data) ;
798+ Ok ( i16:: from_vec ( data_vec) as Arc < dyn arrow_array:: Array > )
799+ }
800+
691801 /// Convert a specific array chunk to RecordBatch
692802 fn array_chunk_to_record_batch (
693803 & self ,
@@ -798,12 +908,20 @@ impl ZarrTableProvider {
798908 let data_array = T :: from_vec ( data_vec) ;
799909 arrays. push ( data_array as Arc < dyn arrow_array:: Array > ) ;
800910
801- // Create the schema
911+ // Create the schema using the same logic as infer_schema
912+ // For single array case, we need to create a minimal schema
802913 let mut fields = Vec :: new ( ) ;
803914 for dim_idx in 0 ..ndim {
804915 fields. push ( Field :: new ( format ! ( "dim_{}" , dim_idx) , DataType :: Int64 , false ) ) ;
805916 }
806- fields. push ( Field :: new ( array_name, T :: arrow_data_type ( ) , true ) ) ;
917+
918+ // Remove leading slash from array name if present
919+ let clean_name = if array_name. starts_with ( '/' ) {
920+ array_name. chars ( ) . skip ( 1 ) . collect ( )
921+ } else {
922+ array_name. to_string ( )
923+ } ;
924+ fields. push ( Field :: new ( clean_name, T :: arrow_data_type ( ) , true ) ) ;
807925
808926 let schema = Arc :: new ( Schema :: new ( fields) ) ;
809927
0 commit comments