Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;

fn to_py_err(err: ArrowError) -> PyErr {
Expand Down Expand Up @@ -140,6 +140,26 @@ fn round_trip_record_batch_reader(
Ok(obj)
}

#[pyfunction]
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
Ok(obj)
}

/// Builds a Table from a list of RecordBatches and a Schema.
#[pyfunction]
pub fn build_table(
record_batches: Vec<PyArrowType<RecordBatch>>,
schema: PyArrowType<Schema>,
) -> PyResult<PyArrowType<Table>> {
Ok(PyArrowType(
Table::try_new(
record_batches.into_iter().map(|rb| rb.0).collect(),
Arc::new(schema.0),
)
.map_err(to_py_err)?,
))
}

#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
Expand Down Expand Up @@ -178,6 +198,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResu
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
Expand Down
106 changes: 95 additions & 11 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import datetime
import decimal
import string
from typing import Union, Tuple, Protocol

import pytest
import pyarrow as pa
Expand Down Expand Up @@ -120,28 +121,50 @@ def assert_pyarrow_leak():
# This defines that Arrow consumers should allow any object that has specific "dunder"
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
# _any_ class, without pyarrow-specific handling.
class SchemaWrapper:
def __init__(self, schema):


class ArrowSchemaExportable(Protocol):
def __arrow_c_schema__(self) -> object: ...


class ArrowArrayExportable(Protocol):
def __arrow_c_array__(
self,
requested_schema: Union[object, None] = None
) -> Tuple[object, object]:
...


class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(
self,
requested_schema: Union[object, None] = None
) -> object:
...


class SchemaWrapper(ArrowSchemaExportable):
def __init__(self, schema: ArrowSchemaExportable) -> None:
self.schema = schema

def __arrow_c_schema__(self):
def __arrow_c_schema__(self) -> object:
return self.schema.__arrow_c_schema__()


class ArrayWrapper:
def __init__(self, array):
class ArrayWrapper(ArrowArrayExportable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you actually have to subclass from the prototype; the type checker will automatically check for structural type equality

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One doesn't have to subtype a typing.Protocol, yes (this is the whole idea behind it, i.e. not doing static typing but structural typing, which can allow consuming external objects that don't have to inherit directly from your classes as long as they conform to a certain pattern).

But in cases where you anyways have strong control over your own classes, I find it highly beneficial to always inherit directly from the Protocol if possible. This has the advantage that it will move the detection of type mismatches to the place where you defined your class, instead of requiring you to make sure that you used all classes in business logic where objects conforming to a certain protocol are expected. Also, I saw a bit over the years that with very intricate Protocols, subtle type errors sometimes can be caught a bit more reliably with existing Python type checkers when directly inheriting from a Protocol, but this shouldn't be really relevant here I think.

Besides that, I sadly don't think that there anyways are type checks actually running in the CI or so 😅 I think the only thing done here is to compile the Python package and run the tests. There probably should be another PR introducing some type checking with mypy --strict or so.

def __init__(self, array: ArrowArrayExportable) -> None:
self.array = array

def __arrow_c_array__(self):
return self.array.__arrow_c_array__()
def __arrow_c_array__(self, requested_schema: Union[object, None] = None) -> Tuple[object, object]:
return self.array.__arrow_c_array__(requested_schema=requested_schema)


class StreamWrapper:
def __init__(self, stream):
class StreamWrapper(ArrowStreamExportable):
def __init__(self, stream: ArrowStreamExportable) -> None:
self.stream = stream

def __arrow_c_stream__(self):
return self.stream.__arrow_c_stream__()
def __arrow_c_stream__(self, requested_schema: Union[object, None] = None) -> object:
return self.stream.__arrow_c_stream__(requested_schema=requested_schema)


@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
Expand Down Expand Up @@ -613,6 +636,67 @@ def test_table_pycapsule():
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_empty():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
table = pa.Table.from_batches([], schema=schema)
new_table = rust.build_table([], schema=schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_roundtrip():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))])
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches, schema=schema)
new_table = rust.round_trip_table(table)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_from_batches():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
new_table = rust.build_table(batches, schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_error_inconsistent_schema():
"""
Python -> Rust -> Python
"""
schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
batches = [
pa.record_batch([[[1], [2, 42]]], schema_1),
pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
]
with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
rust.build_table(batches, schema_1)


def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
Expand Down
113 changes: 105 additions & 8 deletions arrow-pyarrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,20 @@
//! | `pyarrow.Array` | [ArrayData] |
//! | `pyarrow.RecordBatch` | [RecordBatch] |
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
//! | `pyarrow.Table` | [Table] (2) |
//!
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
//! easier to create.)
//!
//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
//! have these same concepts. A chunked table is instead represented with
//! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
//! and then importing the reader as a [ArrowArrayStreamReader].
//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
//! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already
//! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
//! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>` instead of
//! forcing eager reading into `Vec<RecordBatch>`.

use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
Expand All @@ -68,13 +71,13 @@ use arrow_array::{
make_array,
};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, Schema};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
use pyo3::{import_exception, intern};

import_exception!(pyarrow, ArrowException);
/// Represents an exception raised by PyArrow.
Expand Down Expand Up @@ -484,6 +487,100 @@ impl IntoPyArrow for ArrowArrayStreamReader {
}
}

/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar enough with how the python interface works to know if this is reasonable or not. Perhaps @kylebarron can help review this part

/// and to `pyarrow.Table`.
///
/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
/// side.
///
/// ```ignore
/// #[pyfunction]
/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
/// let batches: Vec<RecordBatch>;
/// let schema: SchemaRef;
/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
/// }
/// ```
#[derive(Clone)]
pub struct Table {
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
}

impl Table {
pub fn try_new(
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
) -> Result<Self, ArrowError> {
for record_batch in &record_batches {
if schema != record_batch.schema() {
return Err(ArrowError::SchemaError(format!(
"All record batches must have the same schema. \
Expected schema: {:?}, got schema: {:?}",
schema,
record_batch.schema()
)));
}
}
Ok(Self {
record_batches,
schema,
})
}

pub fn record_batches(&self) -> &[RecordBatch] {
&self.record_batches
}

pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}

pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
(self.record_batches, self.schema)
}
}

impl TryFrom<Box<dyn RecordBatchReader>> for Table {
type Error = ArrowError;

fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
let schema = value.schema();
let batches = value.collect::<Result<Vec<_>, _>>()?;
Self::try_new(batches, schema)
}
}

/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
impl FromPyArrow for Table {
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
let reader: Box<dyn RecordBatchReader> =
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
}
}

/// Convert a [`Table`] into `pyarrow.Table`.
impl IntoPyArrow for Table {
fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
let module = py.import(intern!(py, "pyarrow"))?;
let class = module.getattr(intern!(py, "Table"))?;

let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));

let kwargs = PyDict::new(py);
kwargs.set_item("schema", py_schema)?;

let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;

Ok(reader)
}
}

/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
///
/// When wrapped around a type `T: FromPyArrow`, it
Expand Down
Loading