diff --git a/Cargo.toml b/Cargo.toml index fa4dc89..9d9b659 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ object_store = { version = "0.11.0", features = [ ] } parking_lot = { version = "0.12", features = ["deadlock_detection"] } prost = "0.13" +protobuf-src = "2.1" pyo3 = { version = "0.23", features = [ "extension-module", "abi3", @@ -85,7 +86,6 @@ tonic-build = { version = "0.8", default-features = false, features = [ "prost", ] } url = "2" -protobuf-src = "2.1" [dev-dependencies] tempfile = "3.17" diff --git a/README.md b/README.md index 1fc484f..d47dcda 100644 --- a/README.md +++ b/README.md @@ -51,12 +51,20 @@ Once installed, you can run queries using DataFusion's familiar API while levera capabilities of Ray. ```python +# from example in ./examples/http_csv.py import ray from datafusion_ray import DFRayContext, df_ray_runtime_env ray.init(runtime_env=df_ray_runtime_env) -session = DFRayContext() -df = session.sql("SELECT * FROM my_table WHERE value > 100") + +ctx = DFRayContext() +ctx.register_csv( + "aggregate_test_100", + "https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv", +) + +df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5") + df.show() ``` diff --git a/datafusion_ray/core.py b/datafusion_ray/core.py index 70955bc..0d1736a 100644 --- a/datafusion_ray/core.py +++ b/datafusion_ray/core.py @@ -86,9 +86,7 @@ async def wait_for(coros, name=""): # wrap the coro in a task to work with python 3.10 and 3.11+ where asyncio.wait semantics # changed to not accept any awaitable start = time.time() - done, _ = await asyncio.wait( - [asyncio.create_task(_ensure_coro(c)) for c in coros] - ) + done, _ = await asyncio.wait([asyncio.create_task(_ensure_coro(c)) for c in coros]) end = time.time() log.info(f"waiting for {name} took {end - start}s") for d in done: @@ -166,9 +164,7 @@ async def acquire(self, need=1): need_to_make = need - have if need_to_make > can_make: - raise Exception( - f"Cannot allocate workers above {self.max_workers}" - ) + raise Exception(f"Cannot allocate workers above {self.max_workers}") if need_to_make > 0: log.debug(f"creating {need_to_make} additional processors") @@ -197,9 +193,9 @@ def _new_processor(self): self.processors_ready.clear() processor_key = new_friendly_name() log.debug(f"starting processor: {processor_key}") - processor = DFRayProcessor.options( - name=f"Processor : {processor_key}" - ).remote(processor_key) + processor = DFRayProcessor.options(name=f"Processor : {processor_key}").remote( + processor_key + ) self.pool[processor_key] = processor self.processors_started.add(processor.start_up.remote()) self.available.add(processor_key) @@ -248,9 +244,7 @@ async def _wait_for_serve(self): async def all_done(self): log.info("calling processor all done") - refs = [ - processor.all_done.remote() for processor in self.pool.values() - ] + refs = [processor.all_done.remote() for processor in self.pool.values()] await wait_for(refs, "processors to be all done") log.info("all processors shutdown") @@ -293,9 +287,7 @@ async def update_plan( ) async def serve(self): - log.info( - f"[{self.processor_key}] serving on {self.processor_service.addr()}" - ) + log.info(f"[{self.processor_key}] serving on {self.processor_service.addr()}") await self.processor_service.serve() log.info(f"[{self.processor_key}] done serving") @@ -332,9 +324,7 @@ def __init__( worker_pool_min: int, worker_pool_max: int, ) -> None: - log.info( - f"Creating DFRayContextSupervisor worker_pool_min: {worker_pool_min}" - ) + log.info(f"Creating DFRayContextSupervisor worker_pool_min: {worker_pool_min}") self.pool = DFRayProcessorPool(worker_pool_min, worker_pool_max) self.stages: dict[str, InternalStageData] = {} log.info("Created DFRayContextSupervisor") @@ -347,9 +337,7 @@ async def wait_for_ready(self): async def get_stage_addrs(self, stage_id: int): addrs = [ - sd.remote_addr - for sd in self.stages.values() - if sd.stage_id == stage_id + sd.remote_addr for sd in self.stages.values() if sd.stage_id == stage_id ] return addrs @@ -399,10 +387,7 @@ async def new_query( refs.append( isd.remote_processor.update_plan.remote( isd.stage_id, - { - stage_id: val["child_addrs"] - for (stage_id, val) in kid.items() - }, + {stage_id: val["child_addrs"] for (stage_id, val) in kid.items()}, isd.partition_group, isd.plan_bytes, ) @@ -434,9 +419,7 @@ async def sort_out_addresses(self): ] # sanity check - assert all( - [op == output_partitions[0] for op in output_partitions] - ) + assert all([op == output_partitions[0] for op in output_partitions]) output_partitions = output_partitions[0] for child_stage_isd in child_stage_datas: @@ -520,9 +503,7 @@ def collect(self) -> list[pa.RecordBatch]: ) log.debug(f"last stage addrs {last_stage_addrs}") - reader = self.df.read_final_stage( - last_stage_id, last_stage_addrs[0] - ) + reader = self.df.read_final_stage(last_stage_id, last_stage_addrs[0]) log.debug("got reader") self._batches = list(reader) return self._batches @@ -589,11 +570,55 @@ def __init__( ) def register_parquet(self, name: str, path: str): + """ + Register a Parquet file with the given name and path. + The path can be a local filesystem path, absolute filesystem path, or a url. + + If the path is a object store url, the appropriate object store will be registered. + Configuration of the object store will be gathered from the environment. + + For example for s3:// urls, credentials will be looked for by the AWS SDK, + which will check environment variables, credential files, etc + + Parameters: + path (str): The file path to the Parquet file. + name (str): The name to register the Parquet file under. + """ self.ctx.register_parquet(name, path) - def register_listing_table( - self, name: str, path: str, file_extention="parquet" - ): + def register_csv(self, name: str, path: str): + """ + Register a csvfile with the given name and path. + The path can be a local filesystem path, absolute filesystem path, or a url. + + If the path is a object store url, the appropriate object store will be registered. + Configuration of the object store will be gathered from the environment. + + For example for s3:// urls, credentials will be looked for by the AWS SDK, + which will check environment variables, credential files, etc + + Parameters: + path (str): The file path to the csv file. + name (str): The name to register the Parquet file under. + """ + self.ctx.register_csv(name, path) + + def register_listing_table(self, name: str, path: str, file_extention="parquet"): + """ + Register a directory of parquet files with the given name. + The path can be a local filesystem path, absolute filesystem path, or a url. + + If the path is a object store url, the appropriate object store will be registered. + Configuration of the object store will be gathered from the environment. + + For example for s3:// urls, credentials will be looked for by the AWS SDK, + which will check environment variables, credential files, etc + + Parameters: + path (str): The file path to the Parquet file directory + name (str): The name to register the Parquet file under. + """ + self.ctx.register_listing_table(name, path, file_extention) def sql(self, query: str) -> DFRayDataFrame: diff --git a/examples/http_csv.py b/examples/http_csv.py new file mode 100644 index 0000000..9fc7de4 --- /dev/null +++ b/examples/http_csv.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# this is a port of the example at +# https://github.com/apache/datafusion/blob/45.0.0/datafusion-examples/examples/query-http-csv.rs + +import ray + +from datafusion_ray import DFRayContext, df_ray_runtime_env + + +def main(): + ctx = DFRayContext() + ctx.register_csv( + "aggregate_test_100", + "https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv", + ) + + df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5") + + df.show() + + +if __name__ == __name__: + ray.init(namespace="http_csv", runtime_env=df_ray_runtime_env) + main() diff --git a/examples/tips.py b/examples/tips.py index 7d72ba5..7537f5f 100644 --- a/examples/tips.py +++ b/examples/tips.py @@ -16,40 +16,27 @@ # under the License. import argparse -import datafusion +import os import ray -from datafusion_ray import DFRayContext +from datafusion_ray import DFRayContext, df_ray_runtime_env def go(data_dir: str): ctx = DFRayContext() - # we could set this value to however many CPUs we plan to give each - # ray task - ctx.set("datafusion.execution.target_partitions", "1") - ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false") - ctx.register_parquet("tips", f"{data_dir}/tips*.parquet") + ctx.register_parquet("tips", os.path.join(data_dir, "tips.parquet")) df = ctx.sql( "select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker order by sex, smoker" ) df.show() - print("no ray result:") - - # compare to non ray version - ctx = datafusion.SessionContext() - ctx.register_parquet("tips", f"{data_dir}/tips*.parquet") - ctx.sql( - "select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker order by sex, smoker" - ).show() - if __name__ == "__main__": - ray.init(namespace="tips") + ray.init(namespace="tips", runtime_env=df_ray_runtime_env) parser = argparse.ArgumentParser() - parser.add_argument("--data-dir", required=True, help="path to tips*.parquet files") + parser.add_argument("--data-dir", required=True, help="path to tips.parquet files") args = parser.parse_args() go(args.data_dir) diff --git a/src/context.rs b/src/context.rs index 191d632..ee76080 100644 --- a/src/context.rs +++ b/src/context.rs @@ -16,17 +16,17 @@ // under the License. use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::listing::ListingOptions; -use datafusion::{execution::SessionStateBuilder, prelude::*}; +use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; +use datafusion::execution::SessionStateBuilder; +use datafusion::prelude::{CsvReadOptions, ParquetReadOptions, SessionConfig, SessionContext}; use datafusion_python::utils::wait_for_future; -use object_store::aws::AmazonS3Builder; +use log::debug; use pyo3::prelude::*; use std::sync::Arc; use crate::dataframe::DFRayDataFrame; use crate::physical::RayStageOptimizerRule; -use crate::util::ResultExt; -use url::Url; +use crate::util::{maybe_register_object_store, ResultExt}; /// Internal Session Context object for the python class DFRayContext #[pyclass] @@ -54,23 +54,27 @@ impl DFRayContext { Ok(Self { ctx }) } - pub fn register_s3(&self, bucket_name: String) -> PyResult<()> { - let s3 = AmazonS3Builder::from_env() - .with_bucket_name(&bucket_name) - .build() - .to_py_err()?; + pub fn register_parquet(&self, py: Python, name: String, path: String) -> PyResult<()> { + let options = ParquetReadOptions::default(); + + let url = ListingTableUrl::parse(&path).to_py_err()?; - let path = format!("s3://{bucket_name}"); - let s3_url = Url::parse(&path).to_py_err()?; - let arc_s3 = Arc::new(s3); - self.ctx.register_object_store(&s3_url, arc_s3.clone()); + maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?; + debug!("register_parquet: registering table {} at {}", name, path); + + wait_for_future(py, self.ctx.register_parquet(&name, &path, options.clone()))?; Ok(()) } - pub fn register_parquet(&self, py: Python, name: String, path: String) -> PyResult<()> { - let options = ParquetReadOptions::default(); + pub fn register_csv(&self, py: Python, name: String, path: String) -> PyResult<()> { + let options = CsvReadOptions::default(); - wait_for_future(py, self.ctx.register_parquet(&name, &path, options.clone()))?; + let url = ListingTableUrl::parse(&path).to_py_err()?; + + maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?; + debug!("register_csv: registering table {} at {}", name, path); + + wait_for_future(py, self.ctx.register_csv(&name, &path, options.clone()))?; Ok(()) } @@ -85,6 +89,15 @@ impl DFRayContext { let options = ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(file_extension); + let path = format!("{path}/"); + let url = ListingTableUrl::parse(&path).to_py_err()?; + + maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?; + + debug!( + "register_listing_table: registering table {} at {}", + name, path + ); wait_for_future( py, self.ctx diff --git a/src/processor_service.rs b/src/processor_service.rs index 5164577..120ba21 100644 --- a/src/processor_service.rs +++ b/src/processor_service.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::collections::hash_map::Entry; use std::error::Error; use std::sync::Arc; use arrow::array::RecordBatch; +use arrow_flight::FlightClient; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; -use arrow_flight::FlightClient; use datafusion::common::internal_datafusion_err; use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::ExecutionPlan; @@ -35,23 +35,23 @@ use log::{debug, error, info, trace}; use tokio::net::TcpListener; use tonic::transport::Server; -use tonic::{async_trait, Request, Response, Status}; +use tonic::{Request, Response, Status, async_trait}; use datafusion::error::Result as DFResult; -use arrow_flight::{flight_service_server::FlightServiceServer, Ticket}; +use arrow_flight::{Ticket, flight_service_server::FlightServiceServer}; use pyo3::prelude::*; use parking_lot::{Mutex, RwLock}; -use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::mpsc::{Receiver, Sender, channel}; use crate::flight::{FlightHandler, FlightServ}; use crate::isolator::PartitionGroup; use crate::util::{ - bytes_to_physical_plan, display_plan_with_partition_counts, extract_ticket, input_stage_ids, - make_client, ResultExt, + ResultExt, bytes_to_physical_plan, display_plan_with_partition_counts, extract_ticket, + input_stage_ids, make_client, register_object_store_for_paths_in_plan, }; /// a map of stage_id, partition to a list FlightClients that can serve @@ -102,7 +102,7 @@ impl DFRayProcessorHandlerInner { plan: Arc, partition_group: Vec, ) -> DFResult { - let ctx = Self::configure_ctx(stage_id, stage_addrs, &plan, partition_group).await?; + let ctx = Self::configure_ctx(stage_id, stage_addrs, plan.clone(), partition_group).await?; Ok(Self { plan, ctx }) } @@ -110,10 +110,10 @@ impl DFRayProcessorHandlerInner { async fn configure_ctx( stage_id: usize, stage_addrs: HashMap>>, - plan: &Arc, + plan: Arc, partition_group: Vec, ) -> DFResult { - let stage_ids_i_need = input_stage_ids(plan)?; + let stage_ids_i_need = input_stage_ids(&plan)?; // map of stage_id, partition -> Vec let mut client_map = HashMap::new(); @@ -163,6 +163,8 @@ impl DFRayProcessorHandlerInner { .build(); let ctx = SessionContext::new_with_state(state); + register_object_store_for_paths_in_plan(&ctx, plan.clone())?; + trace!("ctx configured for stage {}", stage_id); Ok(ctx) @@ -212,9 +214,7 @@ impl FlightHandler for DFRayProcessorHandler { trace!( "{}, request for partition {} from {}", - self.name, - partition, - remote_addr + self.name, partition, remote_addr ); let name = self.name.clone(); diff --git a/src/util.rs b/src/util.rs index 0c07c5c..1fa36e8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -22,9 +22,12 @@ use async_stream::stream; use datafusion::common::internal_datafusion_err; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::listing::ListingOptions; -use datafusion::datasource::physical_plan::ParquetExec; +use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; +use datafusion::datasource::physical_plan::{ + ArrowExec, AvroExec, CsvExec, NdJsonExec, ParquetExec, +}; use datafusion::error::DataFusionError; +use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, SessionStateBuilder}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; @@ -32,10 +35,16 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_python::utils::wait_for_future; use futures::{Stream, StreamExt}; +use log::debug; +use object_store::aws::AmazonS3Builder; +use object_store::gcp::GoogleCloudStorageBuilder; +use object_store::http::HttpBuilder; +use object_store::ObjectStore; use parking_lot::Mutex; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyList}; use tonic::transport::Channel; +use url::Url; use crate::codec::RayCodec; use crate::processor_service::ServiceClients; @@ -410,6 +419,12 @@ async fn exec_sql( for (name, path) in tables { let opt = ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(".parquet"); + debug!("exec_sql: registering table {} at {}", name, path); + + let url = ListingTableUrl::parse(&path)?; + + maybe_register_object_store(&ctx, url.as_ref())?; + ctx.register_listing_table(&name, &path, opt, None, None) .await?; } @@ -432,16 +447,21 @@ async fn exec_sql( /// the `url` identifies the parquet files for each listing table and see /// [`datafusion::datasource::listing::ListingTableUrl::parse`] for details /// of supported URL formats +/// * `listing`: boolean indicating whether this is a listing table path or not #[pyfunction] +#[pyo3(signature = (query, tables, listing=false))] pub fn exec_sql_on_tables( py: Python, query: String, tables: Bound<'_, PyList>, + listing: bool, ) -> PyResult { let table_vec = { let mut v = Vec::with_capacity(tables.len()); for entry in tables.iter() { - v.push(entry.extract::<(String, String)>()?); + let (name, path) = entry.extract::<(String, String)>()?; + let path = if listing { format!("{path}/") } else { path }; + v.push((name, path)); } v }; @@ -449,6 +469,102 @@ pub fn exec_sql_on_tables( batch.to_pyarrow(py) } +pub(crate) fn register_object_store_for_paths_in_plan( + ctx: &SessionContext, + plan: Arc, +) -> Result<(), DataFusionError> { + let check_plan = |plan: Arc| -> Result<_, DataFusionError> { + for input in plan.children().into_iter() { + if let Some(node) = input.as_any().downcast_ref::() { + let url = &node.base_config().object_store_url; + maybe_register_object_store(ctx, url.as_ref())? + } else if let Some(node) = input.as_any().downcast_ref::() { + let url = &node.base_config().object_store_url; + maybe_register_object_store(ctx, url.as_ref())? + } else if let Some(node) = input.as_any().downcast_ref::() { + let url = &node.base_config().object_store_url; + maybe_register_object_store(ctx, url.as_ref())? + } else if let Some(node) = input.as_any().downcast_ref::() { + let url = &node.base_config().object_store_url; + maybe_register_object_store(ctx, url.as_ref())? + } else if let Some(node) = input.as_any().downcast_ref::() { + let url = &node.base_config().object_store_url; + maybe_register_object_store(ctx, url.as_ref())? + } + } + Ok(Transformed::no(plan)) + }; + + plan.transform_down(check_plan)?; + + Ok(()) +} + +/// Registers an object store with the given session context based on the provided path. +/// +/// # Arguments +/// +/// * `ctx` - A reference to the `SessionContext` where the object store will be registered. +/// * `path` - A string slice that holds the path or URL of the object store. +pub(crate) fn maybe_register_object_store( + ctx: &SessionContext, + url: &Url, +) -> Result<(), DataFusionError> { + let (ob_url, object_store) = if url.as_str().starts_with("s3://") { + let bucket = url + .host_str() + .ok_or(internal_datafusion_err!("missing bucket name in s3:// url"))?; + + let s3 = AmazonS3Builder::from_env() + .with_bucket_name(bucket) + .build()?; + ( + ObjectStoreUrl::parse(format!("s3://{bucket}"))?, + Arc::new(s3) as Arc, + ) + } else if url.as_str().starts_with("gs://") || url.as_str().starts_with("gcs://") { + let bucket = url + .host_str() + .ok_or(internal_datafusion_err!("missing bucket name in gs:// url"))?; + + let gs = GoogleCloudStorageBuilder::new() + .with_bucket_name(bucket) + .build()?; + + ( + ObjectStoreUrl::parse(format!("gs://{bucket}"))?, + Arc::new(gs) as Arc, + ) + } else if url.as_str().starts_with("http://") || url.as_str().starts_with("https://") { + let scheme = url.scheme(); + + let host = url.host_str().ok_or(internal_datafusion_err!( + "missing host name in {}:// url", + scheme + ))?; + + let http = HttpBuilder::new() + .with_url(format!("{scheme}://{host}")) + .build()?; + + ( + ObjectStoreUrl::parse(format!("{scheme}://{host}"))?, + Arc::new(http) as Arc, + ) + } else { + let local = object_store::local::LocalFileSystem::new(); + ( + ObjectStoreUrl::parse("file://")?, + Arc::new(local) as Arc, + ) + }; + + debug!("Registering object store for {}", ob_url); + + ctx.register_object_store(ob_url.as_ref(), object_store); + Ok(()) +} + #[cfg(test)] mod test { use std::{sync::Arc, vec}; diff --git a/tpch/tpcbench.py b/tpch/tpcbench.py index 13960bf..dd6df1e 100644 --- a/tpch/tpcbench.py +++ b/tpch/tpcbench.py @@ -72,7 +72,7 @@ def main( path = os.path.join(data_path, f"{table}.parquet") print(f"Registering table {table} using path {path}") if listing_tables: - ctx.register_listing_table(table, f"{path}/") + ctx.register_listing_table(table, path) else: ctx.register_parquet(table, path) @@ -93,7 +93,6 @@ def main( "queries": {}, } if validate: - results["local_queries"] = {} results["validated"] = {} queries = range(1, 23) if qnum == -1 else [qnum] @@ -114,15 +113,13 @@ def main( calculated = prettify(batches) print(calculated) if validate: - start_time = time.time() tables = [ (name, os.path.join(data_path, f"{name}.parquet")) for name in table_names ] - answer_batches = [b for b in [exec_sql_on_tables(sql, tables)] if b] - end_time = time.time() - results["local_queries"][qnum] = end_time - start_time - + answer_batches = [ + b for b in [exec_sql_on_tables(sql, tables, listing_tables)] if b + ] expected = prettify(answer_batches) results["validated"][qnum] = calculated == expected @@ -137,7 +134,7 @@ def main( print(results_dump) # give ray a moment to clean up - print("sleeping for 3 seconds for ray to clean up") + print("benchmark complete. sleeping for 3 seconds for ray to clean up") time.sleep(3) if validate and False in results["validated"].values():