|
8 | 8 | import xarray as xr |
9 | 9 |
|
10 | 10 | from .reader import read_xarray |
11 | | -from .df import explode, block_slices, dataset_to_record_batch, from_map, pivot, from_map_batched, _parse_schema |
| 11 | +from .df import ( |
| 12 | + DEFAULT_BATCH_SIZE, |
| 13 | + _parse_schema, |
| 14 | + block_slices, |
| 15 | + dataset_to_record_batch, |
| 16 | + explode, |
| 17 | + from_map, |
| 18 | + from_map_batched, |
| 19 | + iter_record_batches, |
| 20 | + pivot, |
| 21 | +) |
12 | 22 |
|
13 | 23 |
|
14 | 24 | def rand_wx(start: str, end: str) -> xr.Dataset: |
@@ -177,6 +187,55 @@ def make_arrow_table(x): |
177 | 187 | assert len(result) == 3 |
178 | 188 |
|
179 | 189 |
|
| 190 | +def test_iter_record_batches_splits_into_multiple_batches(air_small): |
| 191 | + """iter_record_batches should emit >1 batch when partition exceeds batch_size.""" |
| 192 | + schema = _parse_schema(air_small) |
| 193 | + block = next(block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4})) |
| 194 | + ds_block = air_small.isel(block) |
| 195 | + total_rows = int(np.prod([ds_block.sizes[d] for d in ds_block.sizes])) |
| 196 | + |
| 197 | + small_batch = 16 # force many small batches |
| 198 | + batches = list(iter_record_batches(ds_block, schema, batch_size=small_batch)) |
| 199 | + |
| 200 | + assert len(batches) == -(-total_rows // small_batch) # ceiling division |
| 201 | + assert all(b.num_rows <= small_batch for b in batches) |
| 202 | + assert sum(b.num_rows for b in batches) == total_rows |
| 203 | + |
| 204 | + |
| 205 | +def test_iter_record_batches_matches_dataset_to_record_batch(air_small): |
| 206 | + """Concatenating all iter_record_batches output must equal dataset_to_record_batch.""" |
| 207 | + schema = _parse_schema(air_small) |
| 208 | + dim_cols = [f.name for f in schema if f.name in air_small.dims] |
| 209 | + block = next(block_slices(air_small, chunks={"time": 4, "lat": 3, "lon": 4})) |
| 210 | + ds_block = air_small.isel(block) |
| 211 | + |
| 212 | + batches = list(iter_record_batches(ds_block, schema, batch_size=16)) |
| 213 | + actual_df = ( |
| 214 | + pa.Table.from_batches(batches) |
| 215 | + .to_pandas() |
| 216 | + .sort_values(dim_cols) |
| 217 | + .reset_index(drop=True) |
| 218 | + ) |
| 219 | + expected_df = ( |
| 220 | + dataset_to_record_batch(ds_block, schema) |
| 221 | + .to_pandas() |
| 222 | + .sort_values(dim_cols) |
| 223 | + .reset_index(drop=True) |
| 224 | + ) |
| 225 | + pd.testing.assert_frame_equal(actual_df, expected_df) |
| 226 | + |
| 227 | + |
| 228 | +def test_iter_record_batches_default_batch_size(): |
| 229 | + """A single-batch partition (rows <= DEFAULT_BATCH_SIZE) yields exactly one batch.""" |
| 230 | + ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 2)) |
| 231 | + schema = _parse_schema(ds) |
| 232 | + total_rows = int(np.prod([ds.sizes[d] for d in ds.sizes])) |
| 233 | + assert total_rows <= DEFAULT_BATCH_SIZE, "fixture too large — adjust isel" |
| 234 | + batches = list(iter_record_batches(ds, schema)) |
| 235 | + assert len(batches) == 1 |
| 236 | + assert batches[0].num_rows == total_rows |
| 237 | + |
| 238 | + |
180 | 239 | def test_dataset_to_record_batch_matches_pivot(air_small): |
181 | 240 | """dataset_to_record_batch should contain the same rows as pivot. |
182 | 241 |
|
@@ -371,20 +430,19 @@ def test_read_xarray_loads_one_chunk_at_a_time(large_ds): |
371 | 430 | sizes.append(cur_size) |
372 | 431 | peaks.append(cur_peak) |
373 | 432 |
|
374 | | - mean_size = np.mean(sizes) |
375 | | - mean_peak = np.mean(peaks) |
376 | | - |
377 | 433 | for size in sizes: |
378 | | - assert mean_size * 1.1 > size |
379 | | - assert chunk_size * 3 > size |
380 | | - assert chunk_size * 2 < size |
| 434 | + # Observed range: 1.59–1.83× chunk_size. |
| 435 | + # iter_record_batches holds data-variable arrays (≈1× chunk) while |
| 436 | + # yielding sub-batches, plus the current Arrow batch (≈0.65× chunk). |
| 437 | + assert chunk_size * 1.3 < size, f"size {size} unexpectedly low" |
| 438 | + assert chunk_size * 2.2 > size, f"size {size} unexpectedly high" |
381 | 439 |
|
382 | 440 | for peak in peaks: |
383 | | - assert mean_peak * 1.1 > peak |
384 | | - assert chunk_size * 7 > peak |
385 | | - # Lower bound: at least chunk + Arrow output must be allocated. |
386 | | - # The numpy-direct path peaks at ~2.5x (vs ~5x for the old pandas path). |
387 | | - assert chunk_size * 1.5 < peak |
| 441 | + # Observed range: 1.84–3.28× chunk_size. |
| 442 | + # Peak includes data arrays + Arrow batch + temporary coordinate index |
| 443 | + # arrays; the first batch of each chunk is highest (Dask compute overhead). |
| 444 | + assert chunk_size * 1.5 < peak, f"peak {peak} unexpectedly low" |
| 445 | + assert chunk_size * 4.0 > peak, f"peak {peak} unexpectedly high" |
388 | 446 |
|
389 | 447 | assert max(peaks) < large_ds.nbytes |
390 | 448 |
|
|
0 commit comments