Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion_ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
except ImportError:
import importlib_metadata

from .core import RayContext, prettify, runtime_env, RayStagePool
from .core import RayContext, exec_sql_on_tables, prettify, runtime_env, RayStagePool

__version__ = importlib_metadata.version(__name__)
10 changes: 8 additions & 2 deletions datafusion_ray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from datafusion_ray._datafusion_ray_internal import (
RayContext as RayContextInternal,
RayDataFrame as RayDataFrameInternal,
exec_sql_on_tables,
prettify,
)

Expand Down Expand Up @@ -465,6 +466,9 @@ def stages(self):

return self._stages

def schema(self):
return self.df.schema()

def execution_plan(self):
return self.df.execution_plan()

Expand All @@ -479,7 +483,7 @@ def collect(self) -> list[pa.RecordBatch]:
t1 = time.time()
self.stages()
t2 = time.time()
log.debug(f"creating stages took {t2 -t1}s")
log.debug(f"creating stages took {t2 - t1}s")

last_stage_id = max([stage.stage_id for stage in self._stages])
log.debug(f"last stage is {last_stage_id}")
Expand Down Expand Up @@ -553,7 +557,9 @@ def __init__(
s = time.time()
call_sync(wait_for([start_ref], "RayContextSupervisor start"))
e = time.time()
log.info(f"RayContext::__init__ waiting for supervisor to be ready took {e-s}s")
log.info(
f"RayContext::__init__ waiting for supervisor to be ready took {e - s}s"
)

def register_parquet(self, name: str, path: str):
self.ctx.register_parquet(name, path)
Expand Down
4 changes: 4 additions & 0 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ impl RayDataFrame {
Ok(PyLogicalPlan::new(self.df.logical_plan().clone()))
}

fn schema(&self, py: Python) -> PyResult<PyObject> {
self.df.schema().as_arrow().to_pyarrow(py)
}

fn optimized_logical_plan(&self) -> PyResult<PyLogicalPlan> {
Ok(PyLogicalPlan::new(self.df.clone().into_optimized_plan()?))
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ fn _datafusion_ray_internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<dataframe::PyDataFrameStage>()?;
m.add_class::<stage_service::StageService>()?;
m.add_function(wrap_pyfunction!(util::prettify, m)?)?;
m.add_function(wrap_pyfunction!(util::exec_sql_on_tables, m)?)?;
Ok(())
}

Expand Down
51 changes: 51 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::task::{Context, Poll};
use std::time::Duration;

use arrow::array::RecordBatch;
use arrow::compute::concat_batches;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::ipc::convert::fb_to_schema;
Expand All @@ -20,13 +21,17 @@ use arrow_flight::{FlightClient, FlightData, Ticket};
use async_stream::stream;
use datafusion::common::internal_datafusion_err;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::datasource::file_format::options::ParquetReadOptions;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::ListingOptions;
use datafusion::datasource::physical_plan::ParquetExec;
use datafusion::error::DataFusionError;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, SessionStateBuilder};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_python::utils::wait_for_future;
use futures::{Stream, StreamExt};
use parking_lot::Mutex;
use pyo3::prelude::*;
Expand Down Expand Up @@ -397,6 +402,52 @@ fn print_node(plan: &Arc<dyn ExecutionPlan>, indent: usize, output: &mut String)
}
}

async fn exec_sql(query: String, tables: Vec<(String, String)>) -> PyResult<RecordBatch> {
let ctx = SessionContext::new();
for (name, path) in tables {
if path.ends_with(".parquet") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to check if path is a file or directory rather than basing the logic on file extension.

For example, in my local setup, I have a directory named customer.parquet that contains multiple Parquet files.

Copy link
Contributor Author

@vmingchen vmingchen Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! I looked into it, and it turned out register_parquet internally use register_listing_table as well. The later is capable of register both a single file and a directory of files. So I have changed the function to use register_listing_table only in b8e0c6b; the new commit also adds a unit test to check that it is working for both file and directory and a doc pointing to the format of the URI.

Please take another look; thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating that. LGTM.

let opt = ParquetReadOptions::default();
ctx.register_parquet(&name, &path, opt).await?;
} else {
let opt =
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(".parquet");
ctx.register_listing_table(&name, &path, opt, None, None)
.await?;
}
}
let df = ctx.sql(&query).await?;
let schema = df.schema().inner().clone();
let batches = df.collect().await?;
concat_batches(&schema, batches.iter()).to_py_err()
}

/// Executes a query on the specified tables using DataFusion without Ray.
///
/// Returns the query results as a RecordBatch that can be used to verify the
/// correctness of DataFusion-Ray execution of the same query.
///
/// # Arguments
///
/// * `py`: the Python token
/// * `query`: the SQL query string to execute
/// * `tables`: a list of `(name, path)` tuples specifing the tables to query
#[pyfunction]
pub fn exec_sql_on_tables(
py: Python,
query: String,
tables: Bound<'_, PyList>,
) -> PyResult<PyObject> {
let table_vec = {
let mut v = Vec::with_capacity(tables.len());
for entry in tables.iter() {
v.push(entry.extract::<(String, String)>()?);
}
v
};
let batch = wait_for_future(py, exec_sql(query, table_vec))?;
batch.to_pyarrow(py)
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand Down
14 changes: 7 additions & 7 deletions tpch/tpcbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import argparse
import ray
from datafusion import SessionContext, SessionConfig
from datafusion_ray import RayContext, prettify, runtime_env
from datafusion_ray import RayContext, exec_sql_on_tables, prettify, runtime_env
from datetime import datetime
import json
import os
Expand Down Expand Up @@ -49,7 +49,6 @@ def main(
validate: bool,
prefetch_buffer_size: int,
):

# Register the tables
table_names = [
"customer",
Expand Down Expand Up @@ -79,17 +78,13 @@ def main(

local_config = SessionConfig()

local_ctx = SessionContext(local_config)

for table in table_names:
path = os.path.join(data_path, f"{table}.parquet")
print(f"Registering table {table} using path {path}")
if listing_tables:
ctx.register_listing_table(table, f"{path}/")
local_ctx.register_listing_table(table, f"{path}/")
else:
ctx.register_parquet(table, path)
local_ctx.register_parquet(table, path)

current_time_millis = int(datetime.now().timestamp() * 1000)
results_path = f"datafusion-ray-tpch-{current_time_millis}.json"
Expand Down Expand Up @@ -125,6 +120,7 @@ def main(
start_time = time.time()
df = ctx.sql(sql)
end_time = time.time()
print(f"Ray output schema {df.schema()}")
print("Logical plan \n", df.logical_plan().display_indent())
print("Optimized Logical plan \n", df.optimized_logical_plan().display_indent())
part1 = end_time - start_time
Expand All @@ -143,7 +139,11 @@ def main(
print(calculated)
if validate:
start_time = time.time()
answer_batches = local_ctx.sql(sql).collect()
tables = [
(name, os.path.join(data_path, f"{name}.parquet"))
for name in table_names
]
answer_batches = [b for b in [exec_sql_on_tables(sql, tables)] if b]
end_time = time.time()
results["local_queries"][qnum] = end_time - start_time

Expand Down