Skip to content

Commit b5242e3

Browse files
authored
Iterating through block into multiple record batches per partition. (#135)
This is better for smaller peak memory and more parallelizable by the partition. Fixes #128.
1 parent 77d29f1 commit b5242e3

File tree

3 files changed

+165
-21
lines changed

3 files changed

+165
-21
lines changed

xarray_sql/df.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,82 @@ def dataset_to_record_batch(
209209
return pa.RecordBatch.from_arrays(arrays, schema=schema)
210210

211211

212+
#: Default number of rows per emitted Arrow RecordBatch.
213+
#: 64 K rows balances DataFusion pipeline depth against per-batch overhead.
214+
DEFAULT_BATCH_SIZE: int = 65_536
215+
216+
217+
def iter_record_batches(
218+
ds: xr.Dataset,
219+
schema: pa.Schema,
220+
batch_size: int = DEFAULT_BATCH_SIZE,
221+
) -> Iterator[pa.RecordBatch]:
222+
"""Yield RecordBatches of at most *batch_size* rows from a partition Dataset.
223+
224+
Unlike :func:`dataset_to_record_batch`, which materialises the entire
225+
partition as one batch, this generator emits smaller batches so that
226+
DataFusion can begin filtering and aggregating before the full partition
227+
is loaded. Peak memory per batch is O(batch_size) for coordinate columns
228+
and O(partition_size) for data-variable columns (which must be loaded in
229+
full from storage).
230+
231+
Coordinate values are computed per batch via strided index arithmetic —
232+
no broadcast array spanning the whole partition is ever allocated. Data
233+
variable flat arrays are loaded once (triggering any remote I/O) and then
234+
sliced as zero-copy views for each batch.
235+
236+
Args:
237+
ds: A partition-sized xarray Dataset (already sliced via isel).
238+
schema: The Arrow schema for the output, as produced by _parse_schema.
239+
batch_size: Maximum number of rows per yielded RecordBatch.
240+
241+
Yields:
242+
RecordBatches in schema column order, covering all rows of the
243+
partition exactly once.
244+
"""
245+
if ds.data_vars:
246+
first_var = next(iter(ds.data_vars.values()))
247+
dim_names = list(first_var.dims)
248+
shape = first_var.shape
249+
else:
250+
dim_names = list(ds.sizes.keys())
251+
shape = tuple(ds.sizes[d] for d in dim_names)
252+
253+
total_rows = int(np.prod(shape))
254+
255+
# Preload small 1-D coordinate arrays (negligible memory).
256+
coord_values = {name: ds.coords[name].values for name in dim_names}
257+
258+
# C-order stride for each dimension: stride[k] = prod(shape[k+1:]).
259+
# Flat row index i → coordinate index for dim k: (i // stride[k]) % shape[k].
260+
strides = [int(np.prod(shape[k + 1 :])) for k in range(len(shape))]
261+
262+
# Load data-variable arrays fully (triggers Dask/Zarr compute once).
263+
# ravel() is a zero-copy view for C-contiguous arrays.
264+
data_arrays = {}
265+
for field in schema:
266+
if field.name not in ds.dims:
267+
data_arrays[field.name] = ds[field.name].values.ravel()
268+
269+
for row_start in range(0, total_rows, batch_size):
270+
row_end = min(row_start + batch_size, total_rows)
271+
row_idx = np.arange(row_start, row_end)
272+
273+
arrays = []
274+
for field in schema:
275+
name = field.name
276+
if name in ds.coords and name in ds.dims:
277+
k = dim_names.index(name)
278+
coord_idx = (row_idx // strides[k]) % shape[k]
279+
arrays.append(pa.array(coord_values[name][coord_idx], type=field.type))
280+
else:
281+
arrays.append(
282+
pa.array(data_arrays[name][row_start:row_end], type=field.type)
283+
)
284+
285+
yield pa.RecordBatch.from_arrays(arrays, schema=schema)
286+
287+
212288
def _parse_schema(ds) -> pa.Schema:
213289
"""Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns."""
214290
columns = []

xarray_sql/df_test.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,17 @@
88
import xarray as xr
99

1010
from .reader import read_xarray
11-
from .df import explode, block_slices, dataset_to_record_batch, from_map, pivot, from_map_batched, _parse_schema
11+
from .df import (
12+
DEFAULT_BATCH_SIZE,
13+
_parse_schema,
14+
block_slices,
15+
dataset_to_record_batch,
16+
explode,
17+
from_map,
18+
from_map_batched,
19+
iter_record_batches,
20+
pivot,
21+
)
1222

1323

1424
def rand_wx(start: str, end: str) -> xr.Dataset:
@@ -177,6 +187,55 @@ def make_arrow_table(x):
177187
assert len(result) == 3
178188

179189

190+
def test_iter_record_batches_splits_into_multiple_batches(air_small):
191+
"""iter_record_batches should emit >1 batch when partition exceeds batch_size."""
192+
schema = _parse_schema(air_small)
193+
block = next(block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4}))
194+
ds_block = air_small.isel(block)
195+
total_rows = int(np.prod([ds_block.sizes[d] for d in ds_block.sizes]))
196+
197+
small_batch = 16 # force many small batches
198+
batches = list(iter_record_batches(ds_block, schema, batch_size=small_batch))
199+
200+
assert len(batches) == -(-total_rows // small_batch) # ceiling division
201+
assert all(b.num_rows <= small_batch for b in batches)
202+
assert sum(b.num_rows for b in batches) == total_rows
203+
204+
205+
def test_iter_record_batches_matches_dataset_to_record_batch(air_small):
206+
"""Concatenating all iter_record_batches output must equal dataset_to_record_batch."""
207+
schema = _parse_schema(air_small)
208+
dim_cols = [f.name for f in schema if f.name in air_small.dims]
209+
block = next(block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4}))
210+
ds_block = air_small.isel(block)
211+
212+
batches = list(iter_record_batches(ds_block, schema, batch_size=16))
213+
actual_df = (
214+
pa.Table.from_batches(batches)
215+
.to_pandas()
216+
.sort_values(dim_cols)
217+
.reset_index(drop=True)
218+
)
219+
expected_df = (
220+
dataset_to_record_batch(ds_block, schema)
221+
.to_pandas()
222+
.sort_values(dim_cols)
223+
.reset_index(drop=True)
224+
)
225+
pd.testing.assert_frame_equal(actual_df, expected_df)
226+
227+
228+
def test_iter_record_batches_default_batch_size():
229+
"""A single-batch partition (rows <= DEFAULT_BATCH_SIZE) yields exactly one batch."""
230+
ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 2))
231+
schema = _parse_schema(ds)
232+
total_rows = int(np.prod([ds.sizes[d] for d in ds.sizes]))
233+
assert total_rows <= DEFAULT_BATCH_SIZE, "fixture too large — adjust isel"
234+
batches = list(iter_record_batches(ds, schema))
235+
assert len(batches) == 1
236+
assert batches[0].num_rows == total_rows
237+
238+
180239
def test_dataset_to_record_batch_matches_pivot(air_small):
181240
"""dataset_to_record_batch should contain the same rows as pivot.
182241
@@ -371,20 +430,19 @@ def test_read_xarray_loads_one_chunk_at_a_time(large_ds):
371430
sizes.append(cur_size)
372431
peaks.append(cur_peak)
373432

374-
mean_size = np.mean(sizes)
375-
mean_peak = np.mean(peaks)
376-
377433
for size in sizes:
378-
assert mean_size * 1.1 > size
379-
assert chunk_size * 3 > size
380-
assert chunk_size * 2 < size
434+
# Observed range: 1.59–1.83× chunk_size.
435+
# iter_record_batches holds data-variable arrays (≈1× chunk) while
436+
# yielding sub-batches, plus the current Arrow batch (≈0.65× chunk).
437+
assert chunk_size * 1.3 < size, f"size {size} unexpectedly low"
438+
assert chunk_size * 2.2 > size, f"size {size} unexpectedly high"
381439

382440
for peak in peaks:
383-
assert mean_peak * 1.1 > peak
384-
assert chunk_size * 7 > peak
385-
# Lower bound: at least chunk + Arrow output must be allocated.
386-
# The numpy-direct path peaks at ~2.5x (vs ~5x for the old pandas path).
387-
assert chunk_size * 1.5 < peak
441+
# Observed range: 1.84–3.28× chunk_size.
442+
# Peak includes data arrays + Arrow batch + temporary coordinate index
443+
# arrays; the first batch of each chunk is highest (Dask compute overhead).
444+
assert chunk_size * 1.5 < peak, f"peak {peak} unexpectedly low"
445+
assert chunk_size * 4.0 > peak, f"peak {peak} unexpectedly high"
388446

389447
assert max(peaks) < large_ds.nbytes
390448

xarray_sql/reader.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from .df import (
2020
Block,
2121
Chunks,
22+
DEFAULT_BATCH_SIZE,
2223
_parse_schema,
2324
block_slices,
24-
dataset_to_record_batch,
25+
iter_record_batches,
2526
partition_metadata,
2627
)
2728

@@ -61,6 +62,7 @@ def __init__(
6162
ds: xr.Dataset,
6263
chunks: Chunks = None,
6364
*,
65+
batch_size: int = DEFAULT_BATCH_SIZE,
6466
_iteration_callback: Callable[[Block], None] | None = None,
6567
):
6668
"""Initialize the lazy reader.
@@ -69,12 +71,16 @@ def __init__(
6971
ds: An xarray Dataset. All data_vars must share the same dimensions.
7072
chunks: Xarray-like chunks specification. If not provided, uses
7173
the Dataset's existing chunks.
74+
batch_size: Maximum rows per emitted Arrow RecordBatch. Smaller
75+
values let DataFusion start processing earlier at the cost of
76+
more Python→Arrow conversion calls.
7277
_iteration_callback: Internal callback for testing. Called with
7378
each block dict just before it's converted to Arrow. This
7479
allows tests to track when iteration actually occurs.
7580
"""
7681
self._ds = ds
7782
self._chunks = chunks
83+
self._batch_size = batch_size
7884
self._schema = _parse_schema(ds)
7985
self._iteration_callback = _iteration_callback
8086
self._consumed = False
@@ -95,16 +101,17 @@ def _generate_batches(self) -> Iterator[pa.RecordBatch]:
95101
"""Generate RecordBatches lazily from xarray blocks.
96102
97103
This generator is only consumed when the Arrow stream's get_next
98-
is called, ensuring true lazy evaluation.
104+
is called, ensuring true lazy evaluation. Each xarray block is
105+
emitted as one or more RecordBatches of at most self._batch_size rows.
99106
"""
100107
for block in block_slices(self._ds, self._chunks):
101108
# Call the iteration callback if provided (for testing)
102109
if self._iteration_callback is not None:
103110
self._iteration_callback(block)
104111

105-
# Convert this block to a RecordBatch directly from numpy arrays,
106-
# bypassing the pandas round-trip for lower peak memory usage.
107-
yield dataset_to_record_batch(self._ds.isel(block), self._schema)
112+
yield from iter_record_batches(
113+
self._ds.isel(block), self._schema, self._batch_size
114+
)
108115

109116
def __arrow_c_stream__(
110117
self, requested_schema: object | None = None
@@ -179,6 +186,7 @@ def read_xarray_table(
179186
ds: xr.Dataset,
180187
chunks: Chunks = None,
181188
*,
189+
batch_size: int = DEFAULT_BATCH_SIZE,
182190
_iteration_callback: Callable[[Block], None] | None = None,
183191
) -> "LazyArrowStreamTable":
184192
"""Create a lazy DataFusion table from an xarray Dataset.
@@ -208,6 +216,9 @@ def read_xarray_table(
208216
ds: An xarray Dataset. All data_vars must share the same dimensions.
209217
chunks: Xarray-like chunks specification. If not provided, uses
210218
the Dataset's existing chunks.
219+
batch_size: Maximum rows per Arrow RecordBatch emitted per partition.
220+
Smaller values let DataFusion start processing earlier; the default
221+
(65 536) works well for most datasets.
211222
_iteration_callback: Internal callback for testing. Called with
212223
each block dict just before it's converted to Arrow.
213224
@@ -253,10 +264,9 @@ def make_stream() -> pa.RecordBatchReader:
253264
if _iteration_callback is not None:
254265
_iteration_callback(block)
255266

256-
# Convert this block to Arrow directly from numpy arrays,
257-
# bypassing the pandas round-trip for lower peak memory usage.
258-
batch = dataset_to_record_batch(ds.isel(block), schema)
259-
return pa.RecordBatchReader.from_batches(schema, [batch])
267+
return pa.RecordBatchReader.from_batches(
268+
schema, iter_record_batches(ds.isel(block), schema, batch_size)
269+
)
260270

261271
return make_stream
262272

0 commit comments

Comments
 (0)