Skip to content

Commit a67cd19

Browse files
authored
Implement a Vec<RecordBatch> wrapper for pyarrow.Table convenience (#8790)
# Rationale for this change When dealing with Parquet files that have an exceedingly large amount of Binary or UTF8 data in one row group, there can be issues when returning a single RecordBatch because of index overflows (#7973). In `pyarrow` this is usually solved by representing data as a `pyarrow.Table` object whose columns are `ChunkedArray`s, which basically are just lists of Arrow Arrays, or alternatively, the `pyarrow.Table` is just a representation of a list of `RecordBatch`es. I'd like to build a function in PyO3 that returns a `pyarrow.Table`, very similar to [pyarrow's read_row_group method](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html#pyarrow.parquet.ParquetFile.read_row_group). With that, we could have feature parity with `pyarrow` in circumstances of potential index overflows without resorting to type changes (such as reading the data as `LargeString` or `StringView` columns). Currently, AFAIS, there is no way in `arrow-pyarrow` to export a `pyarrow.Table` directly. Especially convenience methods from `Vec<RecordBatch>` seem to be missing. This PR tries to implement a convenience wrapper that allows directly exporting `pyarrow.Table`. # What changes are included in this PR? A new struct `Table` in the crate `arrow-pyarrow` is added which can be constructed from `Vec<RecordBatch>` or from `ArrowArrayStreamReader`. It implements `FromPyArrow` and `IntoPyArrow`. `FromPyArrow` will support anything that either implements the ArrowStreamReader protocol or is a RecordBatchReader, or has a `to_reader()` method which does that. `pyarrow.Table` does both of these things. `IntoPyArrow` will result int a `pyarrow.Table` on the Python side, constructed through `pyarrow.Table.from_batches(...)`. # Are these changes tested? Yes, in `arrow-pyarrow-integration-tests`. # Are there any user-facing changes? A new `Table` convience wrapper is added!
1 parent ce4edd5 commit a67cd19

File tree

3 files changed

+223
-20
lines changed

3 files changed

+223
-20
lines changed

arrow-pyarrow-integration-testing/src/lib.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use arrow::compute::kernels;
3232
use arrow::datatypes::{DataType, Field, Schema};
3333
use arrow::error::ArrowError;
3434
use arrow::ffi_stream::ArrowArrayStreamReader;
35-
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
35+
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
3636
use arrow::record_batch::RecordBatch;
3737

3838
fn to_py_err(err: ArrowError) -> PyErr {
@@ -140,6 +140,26 @@ fn round_trip_record_batch_reader(
140140
Ok(obj)
141141
}
142142

143+
#[pyfunction]
144+
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
145+
Ok(obj)
146+
}
147+
148+
/// Builds a Table from a list of RecordBatches and a Schema.
149+
#[pyfunction]
150+
pub fn build_table(
151+
record_batches: Vec<PyArrowType<RecordBatch>>,
152+
schema: PyArrowType<Schema>,
153+
) -> PyResult<PyArrowType<Table>> {
154+
Ok(PyArrowType(
155+
Table::try_new(
156+
record_batches.into_iter().map(|rb| rb.0).collect(),
157+
Arc::new(schema.0),
158+
)
159+
.map_err(to_py_err)?,
160+
))
161+
}
162+
143163
#[pyfunction]
144164
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
145165
// This makes sure we can correctly consume a RBR and return the error,
@@ -178,6 +198,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResu
178198
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
179199
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
180200
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
201+
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
202+
m.add_wrapped(wrap_pyfunction!(build_table))?;
181203
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
182204
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
183205
Ok(())

arrow-pyarrow-integration-testing/tests/test_sql.py

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import datetime
2121
import decimal
2222
import string
23+
from typing import Union, Tuple, Protocol
2324

2425
import pytest
2526
import pyarrow as pa
@@ -130,28 +131,50 @@ def assert_pyarrow_leak():
130131
# This defines that Arrow consumers should allow any object that has specific "dunder"
131132
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
132133
# _any_ class, without pyarrow-specific handling.
133-
class SchemaWrapper:
134-
def __init__(self, schema):
134+
135+
136+
class ArrowSchemaExportable(Protocol):
137+
def __arrow_c_schema__(self) -> object: ...
138+
139+
140+
class ArrowArrayExportable(Protocol):
141+
def __arrow_c_array__(
142+
self,
143+
requested_schema: Union[object, None] = None
144+
) -> Tuple[object, object]:
145+
...
146+
147+
148+
class ArrowStreamExportable(Protocol):
149+
def __arrow_c_stream__(
150+
self,
151+
requested_schema: Union[object, None] = None
152+
) -> object:
153+
...
154+
155+
156+
class SchemaWrapper(ArrowSchemaExportable):
157+
def __init__(self, schema: ArrowSchemaExportable) -> None:
135158
self.schema = schema
136159

137-
def __arrow_c_schema__(self):
160+
def __arrow_c_schema__(self) -> object:
138161
return self.schema.__arrow_c_schema__()
139162

140163

141-
class ArrayWrapper:
142-
def __init__(self, array):
164+
class ArrayWrapper(ArrowArrayExportable):
165+
def __init__(self, array: ArrowArrayExportable) -> None:
143166
self.array = array
144167

145-
def __arrow_c_array__(self):
146-
return self.array.__arrow_c_array__()
168+
def __arrow_c_array__(self, requested_schema: Union[object, None] = None) -> Tuple[object, object]:
169+
return self.array.__arrow_c_array__(requested_schema=requested_schema)
147170

148171

149-
class StreamWrapper:
150-
def __init__(self, stream):
172+
class StreamWrapper(ArrowStreamExportable):
173+
def __init__(self, stream: ArrowStreamExportable) -> None:
151174
self.stream = stream
152175

153-
def __arrow_c_stream__(self):
154-
return self.stream.__arrow_c_stream__()
176+
def __arrow_c_stream__(self, requested_schema: Union[object, None] = None) -> object:
177+
return self.stream.__arrow_c_stream__(requested_schema=requested_schema)
155178

156179

157180
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
@@ -632,6 +655,67 @@ def test_table_pycapsule():
632655
assert len(table.to_batches()) == len(new_table.to_batches())
633656

634657

658+
def test_table_empty():
659+
"""
660+
Python -> Rust -> Python
661+
"""
662+
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
663+
table = pa.Table.from_batches([], schema=schema)
664+
new_table = rust.build_table([], schema=schema)
665+
666+
assert table.schema == new_table.schema
667+
assert table == new_table
668+
assert len(table.to_batches()) == len(new_table.to_batches())
669+
670+
671+
def test_table_roundtrip():
672+
"""
673+
Python -> Rust -> Python
674+
"""
675+
schema = pa.schema([('ints', pa.list_(pa.int32()))])
676+
batches = [
677+
pa.record_batch([[[1], [2, 42]]], schema),
678+
pa.record_batch([[None, [], [5, 6]]], schema),
679+
]
680+
table = pa.Table.from_batches(batches, schema=schema)
681+
new_table = rust.round_trip_table(table)
682+
683+
assert table.schema == new_table.schema
684+
assert table == new_table
685+
assert len(table.to_batches()) == len(new_table.to_batches())
686+
687+
688+
def test_table_from_batches():
689+
"""
690+
Python -> Rust -> Python
691+
"""
692+
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
693+
batches = [
694+
pa.record_batch([[[1], [2, 42]]], schema),
695+
pa.record_batch([[None, [], [5, 6]]], schema),
696+
]
697+
table = pa.Table.from_batches(batches)
698+
new_table = rust.build_table(batches, schema)
699+
700+
assert table.schema == new_table.schema
701+
assert table == new_table
702+
assert len(table.to_batches()) == len(new_table.to_batches())
703+
704+
705+
def test_table_error_inconsistent_schema():
706+
"""
707+
Python -> Rust -> Python
708+
"""
709+
schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
710+
schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
711+
batches = [
712+
pa.record_batch([[[1], [2, 42]]], schema_1),
713+
pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
714+
]
715+
with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
716+
rust.build_table(batches, schema_1)
717+
718+
635719
def test_reject_other_classes():
636720
# Arbitrary type that is not a PyArrow type
637721
not_pyarrow = ["hello"]

arrow-pyarrow/src/lib.rs

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,20 @@
4444
//! | `pyarrow.Array` | [ArrayData] |
4545
//! | `pyarrow.RecordBatch` | [RecordBatch] |
4646
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
47+
//! | `pyarrow.Table` | [Table] (2) |
4748
//!
4849
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
4950
//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
5051
//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
5152
//! easier to create.)
5253
//!
53-
//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
54-
//! have these same concepts. A chunked table is instead represented with
55-
//! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
56-
//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
57-
//! and then importing the reader as a [ArrowArrayStreamReader].
54+
//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
55+
//! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already
56+
//! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
57+
//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
58+
//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
59+
//! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>` instead of
60+
//! forcing eager reading into `Vec<RecordBatch>`.
5861
5962
use std::convert::{From, TryFrom};
6063
use std::ptr::{addr_of, addr_of_mut};
@@ -68,13 +71,13 @@ use arrow_array::{
6871
make_array,
6972
};
7073
use arrow_data::ArrayData;
71-
use arrow_schema::{ArrowError, DataType, Field, Schema};
74+
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
7275
use pyo3::exceptions::{PyTypeError, PyValueError};
7376
use pyo3::ffi::Py_uintptr_t;
74-
use pyo3::import_exception;
7577
use pyo3::prelude::*;
7678
use pyo3::pybacked::PyBackedStr;
77-
use pyo3::types::{PyCapsule, PyList, PyTuple};
79+
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
80+
use pyo3::{import_exception, intern};
7881

7982
import_exception!(pyarrow, ArrowException);
8083
/// Represents an exception raised by PyArrow.
@@ -484,6 +487,100 @@ impl IntoPyArrow for ArrowArrayStreamReader {
484487
}
485488
}
486489

490+
/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
491+
/// and to `pyarrow.Table`.
492+
///
493+
/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
494+
/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
495+
/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
496+
/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
497+
/// side.
498+
///
499+
/// ```ignore
500+
/// #[pyfunction]
501+
/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
502+
/// let batches: Vec<RecordBatch>;
503+
/// let schema: SchemaRef;
504+
/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
505+
/// }
506+
/// ```
507+
#[derive(Clone)]
508+
pub struct Table {
509+
record_batches: Vec<RecordBatch>,
510+
schema: SchemaRef,
511+
}
512+
513+
impl Table {
514+
pub fn try_new(
515+
record_batches: Vec<RecordBatch>,
516+
schema: SchemaRef,
517+
) -> Result<Self, ArrowError> {
518+
for record_batch in &record_batches {
519+
if schema != record_batch.schema() {
520+
return Err(ArrowError::SchemaError(format!(
521+
"All record batches must have the same schema. \
522+
Expected schema: {:?}, got schema: {:?}",
523+
schema,
524+
record_batch.schema()
525+
)));
526+
}
527+
}
528+
Ok(Self {
529+
record_batches,
530+
schema,
531+
})
532+
}
533+
534+
pub fn record_batches(&self) -> &[RecordBatch] {
535+
&self.record_batches
536+
}
537+
538+
pub fn schema(&self) -> SchemaRef {
539+
self.schema.clone()
540+
}
541+
542+
pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
543+
(self.record_batches, self.schema)
544+
}
545+
}
546+
547+
impl TryFrom<Box<dyn RecordBatchReader>> for Table {
548+
type Error = ArrowError;
549+
550+
fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
551+
let schema = value.schema();
552+
let batches = value.collect::<Result<Vec<_>, _>>()?;
553+
Self::try_new(batches, schema)
554+
}
555+
}
556+
557+
/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
558+
impl FromPyArrow for Table {
559+
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
560+
let reader: Box<dyn RecordBatchReader> =
561+
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
562+
Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
563+
}
564+
}
565+
566+
/// Convert a [`Table`] into `pyarrow.Table`.
567+
impl IntoPyArrow for Table {
568+
fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
569+
let module = py.import(intern!(py, "pyarrow"))?;
570+
let class = module.getattr(intern!(py, "Table"))?;
571+
572+
let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
573+
let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
574+
575+
let kwargs = PyDict::new(py);
576+
kwargs.set_item("schema", py_schema)?;
577+
578+
let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;
579+
580+
Ok(reader)
581+
}
582+
}
583+
487584
/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
488585
///
489586
/// When wrapped around a type `T: FromPyArrow`, it

0 commit comments

Comments
 (0)