Skip to content

Commit 7c4baad

Browse files
alxmrsclaude
andcommitted
Address self-review issues: decouple abstraction, fix filters, Arc projection.
- (#21) Define `ProjectableStream` trait so `PrunableStreamingTable` no longer stores `Arc<PyArrowStreamPartition>` directly; uses `Arc<dyn ProjectableStream>` instead, decoupling the pruning logic from the Python-specific stream type. - (#22) Forward `filters` to `StreamingTable::scan()` in both the push-projection and fallback branches (was silently passing `&[]`). - (#23) Use `Arc<[String]>` for projected column names so N partition clones share one allocation via atomic refcount increment rather than N Vec copies. - (#24) When projection contains only dimension columns (empty `data_vars_needed`), use `ds[[]]` to avoid loading data variables unnecessarily. - (#25) Add `test_count_star_passes_none_projection` to verify COUNT(*) does not push a projection to the factory (factory must receive None). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 55681fc commit 7c4baad

File tree

3 files changed

+92
-32
lines changed

3 files changed

+92
-32
lines changed

src/lib.rs

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,17 @@ impl PartitionMetadata {
158158
struct PrunableStreamingTable {
159159
schema: SchemaRef,
160160
/// Partition streams paired with their coordinate range metadata.
161-
/// Stored as the concrete type so scan() can clone them with a projection.
162-
partitions: Vec<(Arc<PyArrowStreamPartition>, PartitionMetadata)>,
161+
/// Stored behind the `ProjectableStream` trait so `PrunableStreamingTable`
162+
/// is not coupled to `PyArrowStreamPartition`.
163+
partitions: Vec<(Arc<dyn ProjectableStream>, PartitionMetadata)>,
163164
/// Set of column names that are dimension columns (eligible for pruning)
164165
dimension_columns: HashSet<String>,
165166
}
166167

167168
impl PrunableStreamingTable {
168169
fn new(
169170
schema: SchemaRef,
170-
partitions: Vec<(Arc<PyArrowStreamPartition>, PartitionMetadata)>,
171+
partitions: Vec<(Arc<dyn ProjectableStream>, PartitionMetadata)>,
171172
) -> Self {
172173
// Collect dimension column names from the first partition that has
173174
// non-empty metadata. All partitions share the same dimension names,
@@ -436,6 +437,24 @@ impl PrunableStreamingTable {
436437
}
437438
}
438439

440+
/// Extension trait for partition streams that support column projection.
441+
///
442+
/// Implemented by `PyArrowStreamPartition` so that `PrunableStreamingTable`
443+
/// can push projections to Python factories without coupling to the concrete type.
444+
/// Any new stream implementation (e.g. for non-Python backends) can implement this
445+
/// trait and be used with `PrunableStreamingTable` directly.
446+
trait ProjectableStream: PartitionStream + Debug {
447+
/// Return a new stream that emits only the specified columns.
448+
fn clone_with_projection(
449+
&self,
450+
projection: Arc<[String]>,
451+
projected_schema: SchemaRef,
452+
) -> Arc<dyn PartitionStream>;
453+
454+
/// Clone this stream as a generic `PartitionStream` Arc.
455+
fn clone_as_stream(&self) -> Arc<dyn PartitionStream>;
456+
}
457+
439458
/// Flip a comparison operator (for when literal is on left side).
440459
fn flip_operator(op: &Operator) -> Operator {
441460
match op {
@@ -581,37 +600,40 @@ impl TableProvider for PrunableStreamingTable {
581600
.collect();
582601
let projected_schema = Arc::new(Schema::new(proj_fields));
583602

584-
// Collect the requested column names to send to the factory
585-
let proj_col_names: Vec<String> = indices
603+
// Collect the requested column names to send to the factory.
604+
// Stored in an Arc so each clone_with_projection call shares the
605+
// same allocation via an atomic refcount increment (no N Vec copies).
606+
let proj_col_names: Arc<[String]> = indices
586607
.iter()
587608
.map(|&i| self.schema.field(i).name().to_string())
588-
.collect();
609+
.collect::<Vec<_>>()
610+
.into();
589611

590612
// Clone each pruned partition with the projection baked in.
591613
// The factory will receive proj_col_names and load only those vars.
592614
let projected_partitions: Vec<Arc<dyn PartitionStream>> = included_indices
593615
.iter()
594616
.map(|&idx| {
595-
Arc::new(self.partitions[idx].0.clone_with_projection(
596-
proj_col_names.clone(),
617+
self.partitions[idx].0.clone_with_projection(
618+
Arc::clone(&proj_col_names),
597619
Arc::clone(&projected_schema),
598-
)) as Arc<dyn PartitionStream>
620+
)
599621
})
600622
.collect();
601623

602624
// StreamingTable already has the projected schema — pass None for
603625
// projection so it doesn't wrap the stream in a redundant ProjectionExec.
604626
let streaming = StreamingTable::try_new(projected_schema, projected_partitions)?;
605-
streaming.scan(state, None, &[], limit).await
627+
streaming.scan(state, None, filters, limit).await
606628
} else {
607629
// No projection pushdown — factory is called with None (loads all
608630
// columns). StreamingTable applies projection via ProjectionExec.
609631
let included_partitions: Vec<Arc<dyn PartitionStream>> = included_indices
610632
.iter()
611-
.map(|&idx| Arc::clone(&self.partitions[idx].0) as Arc<dyn PartitionStream>)
633+
.map(|&idx| self.partitions[idx].0.clone_as_stream())
612634
.collect();
613635
let streaming = StreamingTable::try_new(Arc::clone(&self.schema), included_partitions)?;
614-
streaming.scan(state, projection, &[], limit).await
636+
streaming.scan(state, projection, filters, limit).await
615637
}
616638
}
617639
}
@@ -629,12 +651,13 @@ struct PyArrowStreamPartition {
629651
/// A Python callable (factory) that returns a fresh stream.
630652
/// Signature: `make_stream(projection_names: Optional[List[str]]) -> RecordBatchReader`
631653
///
632-
/// Wrapped in `Arc` so `clone_with_projection` can share the same Python
633-
/// object across projected partitions without acquiring the GIL — only an
634-
/// atomic reference-count increment is needed.
654+
/// Wrapped in `Arc` so `ProjectableStream::clone_with_projection` can share
655+
/// the same Python object across projected partitions without acquiring the
656+
/// GIL — only an atomic reference-count increment is needed.
635657
stream_factory: Arc<Py<PyAny>>,
636658
/// Column names to pass to the factory. `None` means load all columns.
637-
projection: Option<Vec<String>>,
659+
/// Stored as `Arc<[String]>` so multiple projected clones share one allocation.
660+
projection: Option<Arc<[String]>>,
638661
}
639662

640663
impl PyArrowStreamPartition {
@@ -645,18 +668,31 @@ impl PyArrowStreamPartition {
645668
projection: None,
646669
}
647670
}
671+
}
648672

649-
/// Create a new partition with a baked-in column projection.
673+
impl ProjectableStream for PyArrowStreamPartition {
674+
/// Return a new partition that emits only the given columns.
650675
///
651-
/// Clones the factory `Arc` (atomic refcount increment, no GIL) and
652-
/// uses `projected_schema` so the stream it produces has only the
653-
/// requested columns.
654-
fn clone_with_projection(&self, projection: Vec<String>, projected_schema: SchemaRef) -> Self {
655-
Self {
676+
/// Clones the factory `Arc` (atomic refcount increment, no GIL) so the
677+
/// same Python callable is shared across all projected partitions.
678+
fn clone_with_projection(
679+
&self,
680+
projection: Arc<[String]>,
681+
projected_schema: SchemaRef,
682+
) -> Arc<dyn PartitionStream> {
683+
Arc::new(Self {
656684
schema: projected_schema,
657685
stream_factory: Arc::clone(&self.stream_factory),
658686
projection: Some(projection),
659-
}
687+
})
688+
}
689+
690+
fn clone_as_stream(&self) -> Arc<dyn PartitionStream> {
691+
Arc::new(Self {
692+
schema: Arc::clone(&self.schema),
693+
stream_factory: Arc::clone(&self.stream_factory),
694+
projection: self.projection.clone(),
695+
})
660696
}
661697
}
662698

@@ -834,9 +870,9 @@ impl LazyArrowStreamTable {
834870
// eliminating the per-partition Python::attach() calls of the old
835871
// three-list approach. Python can release each block dict, factory
836872
// closure, and metadata dict as soon as Rust has ingested them.
837-
// Stored as Arc<PyArrowStreamPartition> (not erased to dyn PartitionStream)
838-
// so that scan() can clone them with a projection at query time.
839-
let mut partition_list: Vec<(Arc<PyArrowStreamPartition>, PartitionMetadata)> = Vec::new();
873+
// Stored as Arc<dyn ProjectableStream> so PrunableStreamingTable
874+
// is decoupled from PyArrowStreamPartition.
875+
let mut partition_list: Vec<(Arc<dyn ProjectableStream>, PartitionMetadata)> = Vec::new();
840876
for item_result in partitions.try_iter()? {
841877
let item = item_result?;
842878
let (factory_obj, meta_obj): (Py<PyAny>, Py<PyAny>) = item.extract().map_err(|e| {
@@ -845,7 +881,8 @@ impl LazyArrowStreamTable {
845881
))
846882
})?;
847883
let meta = convert_python_metadata_from_bound(meta_obj.bind(partitions.py()))?;
848-
let partition = Arc::new(PyArrowStreamPartition::new(factory_obj, schema_ref.clone()));
884+
let partition: Arc<dyn ProjectableStream> =
885+
Arc::new(PyArrowStreamPartition::new(factory_obj, schema_ref.clone()));
849886
partition_list.push((partition, meta));
850887
}
851888

xarray_sql/reader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,12 @@ def make_stream(
270270
# Restrict to the data variables mentioned in the projection.
271271
# Dimension coordinates come along automatically via coords.
272272
data_vars_needed = [c for c in projection_names if c in data_var_names]
273-
ds_block = (
274-
ds[data_vars_needed].isel(block)
275-
if data_vars_needed
276-
else ds.isel(block)
277-
)
273+
if data_vars_needed:
274+
ds_block = ds[data_vars_needed].isel(block)
275+
else:
276+
# Only dimension coords requested — drop all data vars to avoid
277+
# loading them unnecessarily (e.g. for queries like SELECT lat, lon).
278+
ds_block = ds[[]].isel(block)
278279
batch_schema = pa.schema(
279280
[schema.field(name) for name in projection_names]
280281
)

xarray_sql/reader_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,3 +1358,25 @@ def test_projection_result_correctness(self, two_var_ds):
13581358
assert (
13591359
abs(projected - expected) < 1e-4
13601360
), f"Projected AVG {projected} differs from expected {expected}"
1361+
1362+
def test_count_star_passes_none_projection(self, two_var_ds):
1363+
"""COUNT(*) should not push a projection — factory receives None."""
1364+
projections_seen = []
1365+
1366+
def callback(block, projection_names):
1367+
projections_seen.append(projection_names)
1368+
1369+
table = read_xarray_table(
1370+
two_var_ds,
1371+
chunks={"time": 5},
1372+
_iteration_callback=callback,
1373+
)
1374+
ctx = SessionContext()
1375+
ctx.register_table("data", table)
1376+
result = ctx.sql("SELECT COUNT(*) FROM data").to_arrow_table()
1377+
1378+
total_rows = 10 * 5 # time=10, lat=5
1379+
assert result[0][0].as_py() == total_rows
1380+
assert all(
1381+
p is None for p in projections_seen
1382+
), f"COUNT(*) should not push a projection, but got: {projections_seen}"

0 commit comments

Comments
 (0)