Skip to content

Commit d055c12

Browse files
committed
Convert a Xarray Dataset partition to a pyarrow record batch directly (without going through pandas).
1 parent 43a8aac commit d055c12

File tree

3 files changed

+125
-14
lines changed

3 files changed

+125
-14
lines changed

xarray_sql/df.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import itertools
2-
import warnings
32
from collections.abc import Callable, Hashable, Iterator, Mapping
3+
from typing import Any
44

55
import numpy as np
66
import pandas as pd
77
import pyarrow as pa
88
import xarray as xr
9-
from datafusion.context import ArrowStreamExportable
109

1110
Block = dict[Hashable, slice]
1211
Chunks = dict[str, int] | None
@@ -155,6 +154,61 @@ def pivot(ds: xr.Dataset) -> pd.DataFrame:
155154
return ds.to_dataframe().reset_index() # type: ignore[no-any-return]
156155

157156

157+
def dataset_to_record_batch(
158+
ds: xr.Dataset, schema: pa.Schema
159+
) -> pa.RecordBatch:
160+
"""Convert an xarray Dataset partition to an Arrow RecordBatch.
161+
162+
Builds the RecordBatch directly from numpy arrays, bypassing the pandas
163+
round-trip (to_dataframe → reset_index → from_pandas) used by pivot().
164+
For large partitions this reduces peak memory from ~5× to ~2× the
165+
partition size.
166+
167+
Dimension coordinates are broadcast to the full partition shape and
168+
ravelled. np.broadcast_to() is zero-copy; the ravel() forces one copy
169+
per coordinate (unavoidable, since broadcast arrays are non-contiguous).
170+
Data variable arrays are ravelled in-place — a zero-copy view when the
171+
underlying array is already C-contiguous (the common case for numpy-backed
172+
xarray datasets).
173+
174+
Args:
175+
ds: A partition-sized xarray Dataset (already sliced via isel).
176+
schema: The Arrow schema for the output, as produced by _parse_schema.
177+
Column order in the output matches schema field order.
178+
179+
Returns:
180+
A RecordBatch with one column per dimension coordinate and data
181+
variable, in schema order.
182+
"""
183+
# Use the data variable's dimension order as canonical so coordinate
184+
# broadcasts and data variable ravels use the same layout. All data
185+
# variables are validated to share the same dims tuple.
186+
if ds.data_vars:
187+
first_var = next(iter(ds.data_vars.values()))
188+
dim_names = list(first_var.dims)
189+
shape = first_var.shape
190+
else:
191+
dim_names = list(ds.sizes.keys())
192+
shape = tuple(ds.sizes[d] for d in dim_names)
193+
194+
arrays = []
195+
for field in schema:
196+
name = field.name
197+
if name in ds.coords and name in ds.dims:
198+
# Broadcast 1-D coordinate to the full N-D partition shape, then ravel.
199+
axis = dim_names.index(name)
200+
coord = ds.coords[name].values
201+
reshape = [1] * len(shape)
202+
reshape[axis] = coord.shape[0]
203+
arr = np.broadcast_to(coord.reshape(reshape), shape).ravel()
204+
arrays.append(pa.array(arr, type=field.type))
205+
else:
206+
# Data variable: ravel to 1-D (zero-copy for C-contiguous arrays).
207+
arrays.append(pa.array(ds[name].values.ravel(), type=field.type))
208+
209+
return pa.RecordBatch.from_arrays(arrays, schema=schema)
210+
211+
158212
def _parse_schema(ds) -> pa.Schema:
159213
"""Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns."""
160214
columns = []
@@ -173,12 +227,12 @@ def _parse_schema(ds) -> pa.Schema:
173227

174228

175229
# Type alias for partition metadata: maps dimension name to (min, max, dtype_str) values
176-
PartitionBounds = t.Dict[str, t.Tuple[t.Any, t.Any, str]]
230+
PartitionBounds = dict[str, tuple[Any, Any, str]]
177231

178232

179233
def partition_metadata(
180-
ds: xr.Dataset, blocks: t.List[Block]
181-
) -> t.List[PartitionBounds]:
234+
ds: xr.Dataset, blocks: list[Block]
235+
) -> list[PartitionBounds]:
182236
"""Compute min/max coordinate values for each partition.
183237
184238
This metadata enables filter pushdown: SQL queries with WHERE clauses

xarray_sql/df_test.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import xarray as xr
99

1010
from .reader import read_xarray
11-
from .df import explode, block_slices, from_map, pivot, from_map_batched
11+
from .df import explode, block_slices, dataset_to_record_batch, from_map, pivot, from_map_batched, _parse_schema
1212

1313

1414
def rand_wx(start: str, end: str) -> xr.Dataset:
@@ -177,6 +177,54 @@ def make_arrow_table(x):
177177
assert len(result) == 3
178178

179179

180+
def test_dataset_to_record_batch_matches_pivot(air_small):
181+
"""dataset_to_record_batch should contain the same rows as pivot.
182+
183+
Row ordering may differ (pivot uses ds.dims key order; dataset_to_record_batch
184+
uses the data variable's own dim order). Both orderings are valid for SQL, so
185+
we sort by the coordinate columns before comparing.
186+
"""
187+
schema = _parse_schema(air_small)
188+
dim_cols = [f.name for f in schema if f.name in air_small.dims]
189+
blocks = list(block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4}))
190+
191+
for block in blocks:
192+
ds_block = air_small.isel(block)
193+
actual_df = (
194+
dataset_to_record_batch(ds_block, schema)
195+
.to_pandas()
196+
.sort_values(dim_cols)
197+
.reset_index(drop=True)
198+
)
199+
expected_df = (
200+
pa.RecordBatch.from_pandas(pivot(ds_block), schema=schema)
201+
.to_pandas()
202+
.sort_values(dim_cols)
203+
.reset_index(drop=True)
204+
)
205+
206+
pd.testing.assert_frame_equal(actual_df, expected_df, check_like=False)
207+
208+
209+
def test_dataset_to_record_batch_column_order(air_small):
210+
"""Output column order must match schema (dims first, then data vars)."""
211+
schema = _parse_schema(air_small)
212+
block = next(block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4}))
213+
batch = dataset_to_record_batch(air_small.isel(block), schema)
214+
assert batch.schema.names == schema.names
215+
216+
217+
def test_dataset_to_record_batch_row_count(air_small):
218+
"""Row count must equal the product of the block dimension sizes."""
219+
schema = _parse_schema(air_small)
220+
chunks = {"time": 4, "lat": 3, "lon": 4}
221+
for block in block_slices(air_small, chunks=chunks):
222+
ds_block = air_small.isel(block)
223+
expected_rows = int(np.prod([ds_block.sizes[d] for d in ds_block.sizes]))
224+
batch = dataset_to_record_batch(ds_block, schema)
225+
assert batch.num_rows == expected_rows
226+
227+
180228
def test_from_map_batched_basic_functionality(air_small):
181229
blocks = list(block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4}))
182230

@@ -334,7 +382,9 @@ def test_read_xarray_loads_one_chunk_at_a_time(large_ds):
334382
for peak in peaks:
335383
assert mean_peak * 1.1 > peak
336384
assert chunk_size * 7 > peak
337-
assert chunk_size * 4 < peak
385+
# Lower bound: at least chunk + Arrow output must be allocated.
386+
# The numpy-direct path peaks at ~2.5x (vs ~5x for the old pandas path).
387+
assert chunk_size * 1.5 < peak
338388

339389
assert max(peaks) < large_ds.nbytes
340390

xarray_sql/reader.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
import pyarrow as pa
1717
import xarray as xr
1818

19-
from .df import Block, Chunks, block_slices, partition_metadata, pivot, _parse_schema
19+
from .df import (
20+
Block,
21+
Chunks,
22+
_parse_schema,
23+
block_slices,
24+
dataset_to_record_batch,
25+
partition_metadata,
26+
)
2027

2128
if TYPE_CHECKING:
2229
from ._native import LazyArrowStreamTable
@@ -95,9 +102,9 @@ def _generate_batches(self) -> Iterator[pa.RecordBatch]:
95102
if self._iteration_callback is not None:
96103
self._iteration_callback(block)
97104

98-
# Convert this block to a RecordBatch
99-
df = pivot(self._ds.isel(block))
100-
yield pa.RecordBatch.from_pandas(df, schema=self._schema)
105+
# Convert this block to a RecordBatch directly from numpy arrays,
106+
# bypassing the pandas round-trip for lower peak memory usage.
107+
yield dataset_to_record_batch(self._ds.isel(block), self._schema)
101108

102109
def __arrow_c_stream__(
103110
self, requested_schema: object | None = None
@@ -246,9 +253,9 @@ def make_stream() -> pa.RecordBatchReader:
246253
if _iteration_callback is not None:
247254
_iteration_callback(block)
248255

249-
# Extract just this block from the dataset and convert to Arrow
250-
df = pivot(ds.isel(block))
251-
batch = pa.RecordBatch.from_pandas(df, schema=schema)
256+
# Convert this block to Arrow directly from numpy arrays,
257+
# bypassing the pandas round-trip for lower peak memory usage.
258+
batch = dataset_to_record_batch(ds.isel(block), schema)
252259
return pa.RecordBatchReader.from_batches(schema, [batch])
253260

254261
return make_stream

0 commit comments

Comments
 (0)