11import itertools
2- import warnings
32from collections .abc import Callable , Hashable , Iterator , Mapping
3+ from typing import Any
44
55import numpy as np
66import pandas as pd
77import pyarrow as pa
88import xarray as xr
9- from datafusion .context import ArrowStreamExportable
109
1110Block = dict [Hashable , slice ]
1211Chunks = dict [str , int ] | None
@@ -155,6 +154,61 @@ def pivot(ds: xr.Dataset) -> pd.DataFrame:
155154 return ds .to_dataframe ().reset_index () # type: ignore[no-any-return]
156155
157156
157+ def dataset_to_record_batch (
158+ ds : xr .Dataset , schema : pa .Schema
159+ ) -> pa .RecordBatch :
160+ """Convert an xarray Dataset partition to an Arrow RecordBatch.
161+
162+ Builds the RecordBatch directly from numpy arrays, bypassing the pandas
163+ round-trip (to_dataframe → reset_index → from_pandas) used by pivot().
164+ For large partitions this reduces peak memory from ~5× to ~2× the
165+ partition size.
166+
167+ Dimension coordinates are broadcast to the full partition shape and
168+ ravelled. np.broadcast_to() is zero-copy; the ravel() forces one copy
169+ per coordinate (unavoidable, since broadcast arrays are non-contiguous).
170+ Data variable arrays are ravelled in-place — a zero-copy view when the
171+ underlying array is already C-contiguous (the common case for numpy-backed
172+ xarray datasets).
173+
174+ Args:
175+ ds: A partition-sized xarray Dataset (already sliced via isel).
176+ schema: The Arrow schema for the output, as produced by _parse_schema.
177+ Column order in the output matches schema field order.
178+
179+ Returns:
180+ A RecordBatch with one column per dimension coordinate and data
181+ variable, in schema order.
182+ """
183+ # Use the data variable's dimension order as canonical so coordinate
184+ # broadcasts and data variable ravels use the same layout. All data
185+ # variables are validated to share the same dims tuple.
186+ if ds .data_vars :
187+ first_var = next (iter (ds .data_vars .values ()))
188+ dim_names = list (first_var .dims )
189+ shape = first_var .shape
190+ else :
191+ dim_names = list (ds .sizes .keys ())
192+ shape = tuple (ds .sizes [d ] for d in dim_names )
193+
194+ arrays = []
195+ for field in schema :
196+ name = field .name
197+ if name in ds .coords and name in ds .dims :
198+ # Broadcast 1-D coordinate to the full N-D partition shape, then ravel.
199+ axis = dim_names .index (name )
200+ coord = ds .coords [name ].values
201+ reshape = [1 ] * len (shape )
202+ reshape [axis ] = coord .shape [0 ]
203+ arr = np .broadcast_to (coord .reshape (reshape ), shape ).ravel ()
204+ arrays .append (pa .array (arr , type = field .type ))
205+ else :
206+ # Data variable: ravel to 1-D (zero-copy for C-contiguous arrays).
207+ arrays .append (pa .array (ds [name ].values .ravel (), type = field .type ))
208+
209+ return pa .RecordBatch .from_arrays (arrays , schema = schema )
210+
211+
158212def _parse_schema (ds ) -> pa .Schema :
159213 """Extracts a `pa.Schema` from the Dataset, treating dims and data_vars as columns."""
160214 columns = []
@@ -173,12 +227,12 @@ def _parse_schema(ds) -> pa.Schema:
173227
174228
175229# Type alias for partition metadata: maps dimension name to (min, max, dtype_str) values
176- PartitionBounds = t . Dict [str , t . Tuple [ t . Any , t . Any , str ]]
230+ PartitionBounds = dict [str , tuple [ Any , Any , str ]]
177231
178232
179233def partition_metadata (
180- ds : xr .Dataset , blocks : t . List [Block ]
181- ) -> t . List [PartitionBounds ]:
234+ ds : xr .Dataset , blocks : list [Block ]
235+ ) -> list [PartitionBounds ]:
182236 """Compute min/max coordinate values for each partition.
183237
184238 This metadata enables filter pushdown: SQL queries with WHERE clauses
0 commit comments