Skip to content

Commit 32d6975

Browse files
Pyo3 refactorings (#740)
* let pyo3 convert the StorageContexts argument in PySessionContext::register_object_store * clean PySessionContext methods from_pylist and from_pydict * clean PySessionContext metehods from_polars, from_pandas, from_arrow_table * prefer bound Python token over Python::with_gil When available, using an already bound python token is zero-cost. Python::with_gil carries a runtime check. Ref: PyO3/pyo3#4274
1 parent faa26b2 commit 32d6975

File tree

3 files changed

+89
-116
lines changed

3 files changed

+89
-116
lines changed

src/context.rs

Lines changed: 65 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ use datafusion::prelude::{
6060
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
6161
};
6262
use datafusion_common::ScalarValue;
63-
use pyo3::types::PyTuple;
63+
use pyo3::types::{PyDict, PyList, PyTuple};
6464
use tokio::task::JoinHandle;
6565

6666
/// Configuration options for a SessionContext
@@ -291,24 +291,17 @@ impl PySessionContext {
291291
pub fn register_object_store(
292292
&mut self,
293293
scheme: &str,
294-
store: &Bound<'_, PyAny>,
294+
store: StorageContexts,
295295
host: Option<&str>,
296296
) -> PyResult<()> {
297-
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
298-
match StorageContexts::extract_bound(store) {
299-
Ok(store) => match store {
300-
StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)),
301-
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)),
302-
StorageContexts::MicrosoftAzure(azure) => {
303-
Ok((azure.inner, azure.container_name))
304-
}
305-
StorageContexts::LocalFileSystem(local) => Ok((local.inner, "".to_string())),
306-
},
307-
Err(_e) => Err(PyValueError::new_err("Invalid object store")),
308-
};
309-
310297
// for most stores the "host" is the bucket name and can be inferred from the store
311-
let (store, upstream_host) = res?;
298+
let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
299+
StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
300+
StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
301+
StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
302+
StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
303+
};
304+
312305
// let users override the host to match the api signature from upstream
313306
let derived_host = if let Some(host) = host {
314307
host
@@ -434,105 +427,96 @@ impl PySessionContext {
434427
}
435428

436429
/// Construct datafusion dataframe from Python list
437-
#[allow(clippy::wrong_self_convention)]
438430
pub fn from_pylist(
439431
&mut self,
440-
data: PyObject,
432+
data: Bound<'_, PyList>,
441433
name: Option<&str>,
442-
_py: Python,
443434
) -> PyResult<PyDataFrame> {
444-
Python::with_gil(|py| {
445-
// Instantiate pyarrow Table object & convert to Arrow Table
446-
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
447-
let args = PyTuple::new_bound(py, &[data]);
448-
let table = table_class.call_method1("from_pylist", args)?.into();
449-
450-
// Convert Arrow Table to datafusion DataFrame
451-
let df = self.from_arrow_table(table, name, py)?;
452-
Ok(df)
453-
})
435+
// Acquire GIL Token
436+
let py = data.py();
437+
438+
// Instantiate pyarrow Table object & convert to Arrow Table
439+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
440+
let args = PyTuple::new_bound(py, &[data]);
441+
let table = table_class.call_method1("from_pylist", args)?;
442+
443+
// Convert Arrow Table to datafusion DataFrame
444+
let df = self.from_arrow_table(table, name, py)?;
445+
Ok(df)
454446
}
455447

456448
/// Construct datafusion dataframe from Python dictionary
457-
#[allow(clippy::wrong_self_convention)]
458449
pub fn from_pydict(
459450
&mut self,
460-
data: PyObject,
451+
data: Bound<'_, PyDict>,
461452
name: Option<&str>,
462-
_py: Python,
463453
) -> PyResult<PyDataFrame> {
464-
Python::with_gil(|py| {
465-
// Instantiate pyarrow Table object & convert to Arrow Table
466-
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
467-
let args = PyTuple::new_bound(py, &[data]);
468-
let table = table_class.call_method1("from_pydict", args)?.into();
469-
470-
// Convert Arrow Table to datafusion DataFrame
471-
let df = self.from_arrow_table(table, name, py)?;
472-
Ok(df)
473-
})
454+
// Acquire GIL Token
455+
let py = data.py();
456+
457+
// Instantiate pyarrow Table object & convert to Arrow Table
458+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
459+
let args = PyTuple::new_bound(py, &[data]);
460+
let table = table_class.call_method1("from_pydict", args)?;
461+
462+
// Convert Arrow Table to datafusion DataFrame
463+
let df = self.from_arrow_table(table, name, py)?;
464+
Ok(df)
474465
}
475466

476467
/// Construct datafusion dataframe from Arrow Table
477-
#[allow(clippy::wrong_self_convention)]
478468
pub fn from_arrow_table(
479469
&mut self,
480-
data: PyObject,
470+
data: Bound<'_, PyAny>,
481471
name: Option<&str>,
482-
_py: Python,
472+
py: Python,
483473
) -> PyResult<PyDataFrame> {
484-
Python::with_gil(|py| {
485-
// Instantiate pyarrow Table object & convert to batches
486-
let table = data.call_method0(py, "to_batches")?;
487-
488-
let schema = data.getattr(py, "schema")?;
489-
let schema = schema.extract::<PyArrowType<Schema>>(py)?;
490-
491-
// Cast PyObject to RecordBatch type
492-
// Because create_dataframe() expects a vector of vectors of record batches
493-
// here we need to wrap the vector of record batches in an additional vector
494-
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
495-
let list_of_batches = PyArrowType::from(vec![batches.0]);
496-
self.create_dataframe(list_of_batches, name, Some(schema), py)
497-
})
474+
// Instantiate pyarrow Table object & convert to batches
475+
let table = data.call_method0("to_batches")?;
476+
477+
let schema = data.getattr("schema")?;
478+
let schema = schema.extract::<PyArrowType<Schema>>()?;
479+
480+
// Cast PyAny to RecordBatch type
481+
// Because create_dataframe() expects a vector of vectors of record batches
482+
// here we need to wrap the vector of record batches in an additional vector
483+
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>()?;
484+
let list_of_batches = PyArrowType::from(vec![batches.0]);
485+
self.create_dataframe(list_of_batches, name, Some(schema), py)
498486
}
499487

500488
/// Construct datafusion dataframe from pandas
501489
#[allow(clippy::wrong_self_convention)]
502490
pub fn from_pandas(
503491
&mut self,
504-
data: PyObject,
492+
data: Bound<'_, PyAny>,
505493
name: Option<&str>,
506-
_py: Python,
507494
) -> PyResult<PyDataFrame> {
508-
Python::with_gil(|py| {
509-
// Instantiate pyarrow Table object & convert to Arrow Table
510-
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
511-
let args = PyTuple::new_bound(py, &[data]);
512-
let table = table_class.call_method1("from_pandas", args)?.into();
513-
514-
// Convert Arrow Table to datafusion DataFrame
515-
let df = self.from_arrow_table(table, name, py)?;
516-
Ok(df)
517-
})
495+
// Obtain GIL token
496+
let py = data.py();
497+
498+
// Instantiate pyarrow Table object & convert to Arrow Table
499+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
500+
let args = PyTuple::new_bound(py, &[data]);
501+
let table = table_class.call_method1("from_pandas", args)?;
502+
503+
// Convert Arrow Table to datafusion DataFrame
504+
let df = self.from_arrow_table(table, name, py)?;
505+
Ok(df)
518506
}
519507

520508
/// Construct datafusion dataframe from polars
521-
#[allow(clippy::wrong_self_convention)]
522509
pub fn from_polars(
523510
&mut self,
524-
data: PyObject,
511+
data: Bound<'_, PyAny>,
525512
name: Option<&str>,
526-
_py: Python,
527513
) -> PyResult<PyDataFrame> {
528-
Python::with_gil(|py| {
529-
// Convert Polars dataframe to Arrow Table
530-
let table = data.call_method0(py, "to_arrow")?;
514+
// Convert Polars dataframe to Arrow Table
515+
let table = data.call_method0("to_arrow")?;
531516

532-
// Convert Arrow Table to datafusion DataFrame
533-
let df = self.from_arrow_table(table, name, py)?;
534-
Ok(df)
535-
})
517+
// Convert Arrow Table to datafusion DataFrame
518+
let df = self.from_arrow_table(table, name, data.py())?;
519+
Ok(df)
536520
}
537521

538522
pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {

src/dataframe.rs

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,15 @@ impl PyDataFrame {
423423

424424
/// Convert to Arrow Table
425425
/// Collect the batches and pass to Arrow Table
426-
fn to_arrow_table(&self, py: Python) -> PyResult<PyObject> {
426+
fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
427427
let batches = self.collect(py)?.to_object(py);
428428
let schema: PyObject = self.schema().into_py(py);
429429

430-
Python::with_gil(|py| {
431-
// Instantiate pyarrow Table object and use its from_batches method
432-
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
433-
let args = PyTuple::new_bound(py, &[batches, schema]);
434-
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
435-
Ok(table)
436-
})
430+
// Instantiate pyarrow Table object and use its from_batches method
431+
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
432+
let args = PyTuple::new_bound(py, &[batches, schema]);
433+
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
434+
Ok(table)
437435
}
438436

439437
fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
@@ -464,51 +462,42 @@ impl PyDataFrame {
464462

465463
/// Convert to pandas dataframe with pyarrow
466464
/// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
467-
fn to_pandas(&self, py: Python) -> PyResult<PyObject> {
465+
fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
468466
let table = self.to_arrow_table(py)?;
469467

470-
Python::with_gil(|py| {
471-
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
472-
let result = table.call_method0(py, "to_pandas")?;
473-
Ok(result)
474-
})
468+
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
469+
let result = table.call_method0(py, "to_pandas")?;
470+
Ok(result)
475471
}
476472

477473
/// Convert to Python list using pyarrow
478474
/// Each list item represents one row encoded as dictionary
479-
fn to_pylist(&self, py: Python) -> PyResult<PyObject> {
475+
fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
480476
let table = self.to_arrow_table(py)?;
481477

482-
Python::with_gil(|py| {
483-
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
484-
let result = table.call_method0(py, "to_pylist")?;
485-
Ok(result)
486-
})
478+
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
479+
let result = table.call_method0(py, "to_pylist")?;
480+
Ok(result)
487481
}
488482

489483
/// Convert to Python dictionary using pyarrow
490484
/// Each dictionary key is a column and the dictionary value represents the column values
491485
fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
492486
let table = self.to_arrow_table(py)?;
493487

494-
Python::with_gil(|py| {
495-
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
496-
let result = table.call_method0(py, "to_pydict")?;
497-
Ok(result)
498-
})
488+
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
489+
let result = table.call_method0(py, "to_pydict")?;
490+
Ok(result)
499491
}
500492

501493
/// Convert to polars dataframe with pyarrow
502494
/// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
503-
fn to_polars(&self, py: Python) -> PyResult<PyObject> {
495+
fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
504496
let table = self.to_arrow_table(py)?;
505-
506-
Python::with_gil(|py| {
507-
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
508-
let args = PyTuple::new_bound(py, &[table]);
509-
let result: PyObject = dataframe.call1(args)?.into();
510-
Ok(result)
511-
})
497+
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
498+
let args = PyTuple::new_bound(py, &[table]);
499+
let result: PyObject = dataframe.call1(args)?.into();
500+
Ok(result)
512501
}
513502

514503
// Executes this DataFrame to get the total number of rows.

src/sql/logical.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl PyLogicalPlan {
6363
impl PyLogicalPlan {
6464
/// Return the specific logical operator
6565
pub fn to_variant(&self, py: Python) -> PyResult<PyObject> {
66-
Python::with_gil(|_| match self.plan.as_ref() {
66+
match self.plan.as_ref() {
6767
LogicalPlan::Aggregate(plan) => PyAggregate::from(plan.clone()).to_variant(py),
6868
LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py),
6969
LogicalPlan::CrossJoin(plan) => PyCrossJoin::from(plan.clone()).to_variant(py),
@@ -85,7 +85,7 @@ impl PyLogicalPlan {
8585
"Cannot convert this plan to a LogicalNode: {:?}",
8686
other
8787
))),
88-
})
88+
}
8989
}
9090

9191
/// Get the inputs to this plan

0 commit comments

Comments
 (0)