Skip to content

Commit 38f2796

Browse files
committed
Fixed signficant correctness bugs in the tabel provider impl related to corrent column names as well as correct scanning (type handing).
1 parent 27baf67 commit 38f2796

File tree

1 file changed

+143
-25
lines changed

1 file changed

+143
-25
lines changed

src/table_provider.rs

Lines changed: 143 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ use pyo3::types::PyCapsule;
1414
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
1515
use std::path::Path;
1616
use std::sync::Arc;
17+
use std::collections::HashSet;
1718
use async_trait::async_trait;
1819
use zarrs::filesystem::FilesystemStore;
1920
use zarrs::group::Group;
2021
use zarrs::array::Array;
2122
use zarrs::array::data_type::DataType as ZarrDataType;
2223
use zarrs::array_subset::ArraySubset;
2324
use zarrs::array::chunk_grid::ChunkGrid;
24-
use arrow_array::{Int64Array, Float64Array, Int32Array, Float32Array};
25+
use arrow_array::{Int64Array, Float64Array, Int32Array, Float32Array, Int16Array};
2526

2627
/// Represents a coordinate range constraint for chunk filtering
2728
#[derive(Debug, Clone)]
@@ -211,6 +212,28 @@ impl ToArrowArray for i32 {
211212
}
212213
}
213214

215+
impl ToArrowArray for i16 {
216+
type ArrowArray = arrow_array::Int16Array;
217+
218+
fn to_arrow_array(
219+
flat_data: &ndarray::ArrayBase<ndarray::CowRepr<'_, Self>, ndarray::Dim<[usize; 1]>>,
220+
) -> Vec<Self> {
221+
if flat_data.is_standard_layout() {
222+
flat_data.as_slice().unwrap().to_vec()
223+
} else {
224+
flat_data.iter().cloned().collect()
225+
}
226+
}
227+
228+
fn arrow_data_type() -> DataType {
229+
DataType::Int16
230+
}
231+
232+
fn from_vec(data: Vec<Self>) -> Arc<Self::ArrowArray> {
233+
Arc::new(arrow_array::Int16Array::from(data))
234+
}
235+
}
236+
214237

215238
/// A DataFusion TableProvider that reads from Zarr stores
216239
#[pyclass(name = "ZarrTableProvider", module = "zarrquet", subclass)]
@@ -324,20 +347,106 @@ impl ZarrTableProvider {
324347
// Build unified schema: dimensions first, then data variables
325348
let mut fields = Vec::new();
326349

327-
// Add coordinate/dimension fields
350+
// Add coordinate/dimension fields using actual coordinate names
328351
if let Some(ref shape) = reference_shape {
329-
for (dim_idx, &_dim_size) in shape.iter().enumerate() {
330-
// TODO: Extract actual coordinate names from metadata
352+
// For typical xarray-generated Zarr, we expect coordinates in a specific order
353+
// The air dataset has dimensions (time, lat, lon)
354+
// We need to match coordinate arrays to dimensions
355+
356+
// First, collect all 1D coordinate arrays that match dimension sizes
357+
let mut valid_coords: Vec<(String, usize)> = Vec::new();
358+
359+
for (name, coord_shape, _) in &coordinate_arrays {
360+
if coord_shape.len() == 1 {
361+
let size = coord_shape[0];
362+
// Find all dimension indices that match this size
363+
for (dim_idx, &dim_size) in shape.iter().enumerate() {
364+
if dim_size == size {
365+
let clean_name = if name.starts_with('/') {
366+
name.chars().skip(1).collect()
367+
} else {
368+
name.clone()
369+
};
370+
valid_coords.push((clean_name, dim_idx));
371+
}
372+
}
373+
}
374+
}
375+
376+
// Sort by the typical xarray dimension order: time, lat, lon, etc.
377+
// We'll use a simple heuristic based on common coordinate names
378+
valid_coords.sort_by(|a, b| {
379+
let name_a = &a.0;
380+
let name_b = &b.0;
381+
382+
// Define priority order for common coordinate names
383+
let priority = |name: &str| -> i32 {
384+
match name {
385+
"time" => 0,
386+
"lat" | "latitude" => 1,
387+
"lon" | "longitude" => 2,
388+
"level" | "lev" => 3,
389+
_ => 100
390+
}
391+
};
392+
393+
let priority_a = priority(name_a);
394+
let priority_b = priority(name_b);
395+
396+
if priority_a != priority_b {
397+
priority_a.cmp(&priority_b)
398+
} else {
399+
// If same priority, sort by dimension index
400+
a.1.cmp(&b.1)
401+
}
402+
});
403+
404+
// Add coordinate fields, avoiding duplicates
405+
let mut added_coords: HashSet<String> = HashSet::new();
406+
407+
for (coord_name, _) in valid_coords {
408+
if !added_coords.contains(&coord_name) {
409+
fields.push(Field::new(coord_name.clone(), DataType::Int64, false));
410+
added_coords.insert(coord_name);
411+
}
412+
}
413+
414+
// If we didn't find enough coordinates, add generic dimension names
415+
// Only add generic names if we have fewer coordinates than dimensions
416+
for dim_idx in added_coords.len()..shape.len() {
331417
let dim_name = format!("dim_{}", dim_idx);
332-
fields.push(Field::new(dim_name, DataType::Int64, false));
418+
fields.push(Field::new(dim_name.clone(), DataType::Int64, false));
419+
added_coords.insert(dim_name);
333420
}
334421
}
335422

336-
// Add data variable fields
423+
// Add data variable fields (remove leading slash if present)
424+
// But exclude coordinate arrays that are already added as dimensions
425+
let coord_names: HashSet<String> = coordinate_arrays.iter()
426+
.map(|(name, _, _)| {
427+
if name.starts_with('/') {
428+
name.chars().skip(1).collect()
429+
} else {
430+
name.clone()
431+
}
432+
})
433+
.collect();
434+
337435
for (var_name, _shape, data_type) in &data_variables {
338436
let arrow_type = self.zarr_type_to_arrow(data_type)
339437
.unwrap_or(DataType::Float64); // Default fallback
340-
fields.push(Field::new(var_name.clone(), arrow_type, true));
438+
439+
// Remove leading slash from variable name if present
440+
let clean_name = if var_name.starts_with('/') {
441+
var_name.chars().skip(1).collect()
442+
} else {
443+
var_name.clone()
444+
};
445+
446+
// Only add if this is not a coordinate array
447+
if !coord_names.contains(&clean_name) {
448+
fields.push(Field::new(clean_name, arrow_type, true));
449+
}
341450
}
342451

343452
Ok(Arc::new(Schema::new(fields)))
@@ -592,6 +701,11 @@ impl ZarrTableProvider {
592701
.map_err(|e| DataFusionError::External(Box::new(e)))?;
593702
self.create_data_array_i32(chunk_data)?
594703
}
704+
ZarrDataType::Int16 => {
705+
let chunk_data = array.retrieve_chunk_ndarray::<i16>(chunk_indices)
706+
.map_err(|e| DataFusionError::External(Box::new(e)))?;
707+
self.create_data_array_i16(chunk_data)?
708+
}
595709
other => {
596710
return Err(DataFusionError::External(
597711
format!("Unsupported zarr data type for variable '{}': {:?}", var_name, other).into()
@@ -615,22 +729,8 @@ impl ZarrTableProvider {
615729
arrows.push(data_array);
616730
}
617731

618-
// Create the schema
619-
let mut fields = Vec::new();
620-
621-
// Add dimension fields
622-
for dim_idx in 0..ndim {
623-
fields.push(Field::new(format!("dim_{}", dim_idx), DataType::Int64, false));
624-
}
625-
626-
// Add data variable fields
627-
for (var_name, array, _) in &arrays_data {
628-
let arrow_type = self.zarr_type_to_arrow(array.data_type())
629-
.unwrap_or(DataType::Float64);
630-
fields.push(Field::new(var_name.clone(), arrow_type, true));
631-
}
632-
633-
let schema = Arc::new(Schema::new(fields));
732+
// Create the schema using the same logic as infer_schema
733+
let schema = self.infer_schema()?;
634734

635735
// Create the RecordBatch
636736
RecordBatch::try_new(schema, arrows)
@@ -688,6 +788,16 @@ impl ZarrTableProvider {
688788
Ok(i32::from_vec(data_vec) as Arc<dyn arrow_array::Array>)
689789
}
690790

791+
/// Create Arrow array from i16 ndarray
792+
fn create_data_array_i16(&self, data: ndarray::ArrayD<i16>) -> Result<Arc<dyn arrow_array::Array>, DataFusionError> {
793+
let total_elements = data.len();
794+
let flat_data = data.to_shape(total_elements)
795+
.map_err(|e| DataFusionError::External(format!("Failed to reshape i16 array: {}", e).into()))?;
796+
797+
let data_vec = i16::to_arrow_array(&flat_data);
798+
Ok(i16::from_vec(data_vec) as Arc<dyn arrow_array::Array>)
799+
}
800+
691801
/// Convert a specific array chunk to RecordBatch
692802
fn array_chunk_to_record_batch(
693803
&self,
@@ -798,12 +908,20 @@ impl ZarrTableProvider {
798908
let data_array = T::from_vec(data_vec);
799909
arrays.push(data_array as Arc<dyn arrow_array::Array>);
800910

801-
// Create the schema
911+
// Create the schema using the same logic as infer_schema
912+
// For single array case, we need to create a minimal schema
802913
let mut fields = Vec::new();
803914
for dim_idx in 0..ndim {
804915
fields.push(Field::new(format!("dim_{}", dim_idx), DataType::Int64, false));
805916
}
806-
fields.push(Field::new(array_name, T::arrow_data_type(), true));
917+
918+
// Remove leading slash from array name if present
919+
let clean_name = if array_name.starts_with('/') {
920+
array_name.chars().skip(1).collect()
921+
} else {
922+
array_name.to_string()
923+
};
924+
fields.push(Field::new(clean_name, T::arrow_data_type(), true));
807925

808926
let schema = Arc::new(Schema::new(fields));
809927

0 commit comments

Comments
 (0)