Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ object_store = { version = "0.11.0", features = [
] }
parking_lot = { version = "0.12", features = ["deadlock_detection"] }
prost = "0.13"
protobuf-src = "2.1"
pyo3 = { version = "0.23", features = [
"extension-module",
"abi3",
Expand Down Expand Up @@ -85,7 +86,6 @@ tonic-build = { version = "0.8", default-features = false, features = [
"prost",
] }
url = "2"
protobuf-src = "2.1"

[dev-dependencies]
tempfile = "3.17"
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,20 @@ Once installed, you can run queries using DataFusion's familiar API while levera
capabilities of Ray.

```python
# from example in ./examples/http_csv.py
import ray
from datafusion_ray import DFRayContext, df_ray_runtime_env

ray.init(runtime_env=df_ray_runtime_env)
session = DFRayContext()
df = session.sql("SELECT * FROM my_table WHERE value > 100")

ctx = DFRayContext()
ctx.register_csv(
"aggregate_test_100",
"https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv",
)

df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5")

df.show()
```

Expand Down
93 changes: 59 additions & 34 deletions datafusion_ray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ async def wait_for(coros, name=""):
# wrap the coro in a task to work with python 3.10 and 3.11+ where asyncio.wait semantics
# changed to not accept any awaitable
start = time.time()
done, _ = await asyncio.wait(
[asyncio.create_task(_ensure_coro(c)) for c in coros]
)
done, _ = await asyncio.wait([asyncio.create_task(_ensure_coro(c)) for c in coros])
end = time.time()
log.info(f"waiting for {name} took {end - start}s")
for d in done:
Expand Down Expand Up @@ -166,9 +164,7 @@ async def acquire(self, need=1):
need_to_make = need - have

if need_to_make > can_make:
raise Exception(
f"Cannot allocate workers above {self.max_workers}"
)
raise Exception(f"Cannot allocate workers above {self.max_workers}")

if need_to_make > 0:
log.debug(f"creating {need_to_make} additional processors")
Expand Down Expand Up @@ -197,9 +193,9 @@ def _new_processor(self):
self.processors_ready.clear()
processor_key = new_friendly_name()
log.debug(f"starting processor: {processor_key}")
processor = DFRayProcessor.options(
name=f"Processor : {processor_key}"
).remote(processor_key)
processor = DFRayProcessor.options(name=f"Processor : {processor_key}").remote(
processor_key
)
self.pool[processor_key] = processor
self.processors_started.add(processor.start_up.remote())
self.available.add(processor_key)
Expand Down Expand Up @@ -248,9 +244,7 @@ async def _wait_for_serve(self):

async def all_done(self):
log.info("calling processor all done")
refs = [
processor.all_done.remote() for processor in self.pool.values()
]
refs = [processor.all_done.remote() for processor in self.pool.values()]
await wait_for(refs, "processors to be all done")
log.info("all processors shutdown")

Expand Down Expand Up @@ -293,9 +287,7 @@ async def update_plan(
)

async def serve(self):
log.info(
f"[{self.processor_key}] serving on {self.processor_service.addr()}"
)
log.info(f"[{self.processor_key}] serving on {self.processor_service.addr()}")
await self.processor_service.serve()
log.info(f"[{self.processor_key}] done serving")

Expand Down Expand Up @@ -332,9 +324,7 @@ def __init__(
worker_pool_min: int,
worker_pool_max: int,
) -> None:
log.info(
f"Creating DFRayContextSupervisor worker_pool_min: {worker_pool_min}"
)
log.info(f"Creating DFRayContextSupervisor worker_pool_min: {worker_pool_min}")
self.pool = DFRayProcessorPool(worker_pool_min, worker_pool_max)
self.stages: dict[str, InternalStageData] = {}
log.info("Created DFRayContextSupervisor")
Expand All @@ -347,9 +337,7 @@ async def wait_for_ready(self):

async def get_stage_addrs(self, stage_id: int):
addrs = [
sd.remote_addr
for sd in self.stages.values()
if sd.stage_id == stage_id
sd.remote_addr for sd in self.stages.values() if sd.stage_id == stage_id
]
return addrs

Expand Down Expand Up @@ -399,10 +387,7 @@ async def new_query(
refs.append(
isd.remote_processor.update_plan.remote(
isd.stage_id,
{
stage_id: val["child_addrs"]
for (stage_id, val) in kid.items()
},
{stage_id: val["child_addrs"] for (stage_id, val) in kid.items()},
isd.partition_group,
isd.plan_bytes,
)
Expand Down Expand Up @@ -434,9 +419,7 @@ async def sort_out_addresses(self):
]

# sanity check
assert all(
[op == output_partitions[0] for op in output_partitions]
)
assert all([op == output_partitions[0] for op in output_partitions])
output_partitions = output_partitions[0]

for child_stage_isd in child_stage_datas:
Expand Down Expand Up @@ -520,9 +503,7 @@ def collect(self) -> list[pa.RecordBatch]:
)
log.debug(f"last stage addrs {last_stage_addrs}")

reader = self.df.read_final_stage(
last_stage_id, last_stage_addrs[0]
)
reader = self.df.read_final_stage(last_stage_id, last_stage_addrs[0])
log.debug("got reader")
self._batches = list(reader)
return self._batches
Expand Down Expand Up @@ -589,11 +570,55 @@ def __init__(
)

def register_parquet(self, name: str, path: str):
"""
Register a Parquet file with the given name and path.
The path can be a local filesystem path, absolute filesystem path, or a url.

If the path is a object store url, the appropriate object store will be registered.
Configuration of the object store will be gathered from the environment.

For example for s3:// urls, credentials will be looked for by the AWS SDK,
which will check environment variables, credential files, etc

Parameters:
path (str): The file path to the Parquet file.
name (str): The name to register the Parquet file under.
"""
self.ctx.register_parquet(name, path)

def register_listing_table(
self, name: str, path: str, file_extention="parquet"
):
def register_csv(self, name: str, path: str):
"""
Register a csvfile with the given name and path.
The path can be a local filesystem path, absolute filesystem path, or a url.

If the path is a object store url, the appropriate object store will be registered.
Configuration of the object store will be gathered from the environment.

For example for s3:// urls, credentials will be looked for by the AWS SDK,
which will check environment variables, credential files, etc

Parameters:
path (str): The file path to the csv file.
name (str): The name to register the Parquet file under.
"""
self.ctx.register_csv(name, path)

def register_listing_table(self, name: str, path: str, file_extention="parquet"):
"""
Register a directory of parquet files with the given name.
The path can be a local filesystem path, absolute filesystem path, or a url.

If the path is a object store url, the appropriate object store will be registered.
Configuration of the object store will be gathered from the environment.

For example for s3:// urls, credentials will be looked for by the AWS SDK,
which will check environment variables, credential files, etc

Parameters:
path (str): The file path to the Parquet file directory
name (str): The name to register the Parquet file under.
"""

self.ctx.register_listing_table(name, path, file_extention)

def sql(self, query: str) -> DFRayDataFrame:
Expand Down
40 changes: 40 additions & 0 deletions examples/http_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# this is a port of the example at
# https://github.com/apache/datafusion/blob/45.0.0/datafusion-examples/examples/query-http-csv.rs

import ray

from datafusion_ray import DFRayContext, df_ray_runtime_env


def main():
ctx = DFRayContext()
ctx.register_csv(
"aggregate_test_100",
"https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv",
)

df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5")

df.show()


if __name__ == __name__:
ray.init(namespace="http_csv", runtime_env=df_ray_runtime_env)
main()
23 changes: 5 additions & 18 deletions examples/tips.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,27 @@
# under the License.

import argparse
import datafusion
import os
import ray

from datafusion_ray import DFRayContext
from datafusion_ray import DFRayContext, df_ray_runtime_env


def go(data_dir: str):
ctx = DFRayContext()
# we could set this value to however many CPUs we plan to give each
# ray task
ctx.set("datafusion.execution.target_partitions", "1")
ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")

ctx.register_parquet("tips", f"{data_dir}/tips*.parquet")
ctx.register_parquet("tips", os.path.join(data_dir, "tips.parquet"))

df = ctx.sql(
"select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker order by sex, smoker"
)
df.show()

print("no ray result:")

# compare to non ray version
ctx = datafusion.SessionContext()
ctx.register_parquet("tips", f"{data_dir}/tips*.parquet")
ctx.sql(
"select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker order by sex, smoker"
).show()


if __name__ == "__main__":
ray.init(namespace="tips")
ray.init(namespace="tips", runtime_env=df_ray_runtime_env)
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", required=True, help="path to tips*.parquet files")
parser.add_argument("--data-dir", required=True, help="path to tips.parquet files")
args = parser.parse_args()

go(args.data_dir)
47 changes: 30 additions & 17 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
// under the License.

use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::ListingOptions;
use datafusion::{execution::SessionStateBuilder, prelude::*};
use datafusion::datasource::listing::{ListingOptions, ListingTableUrl};
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions, SessionConfig, SessionContext};
use datafusion_python::utils::wait_for_future;
use object_store::aws::AmazonS3Builder;
use log::debug;
use pyo3::prelude::*;
use std::sync::Arc;

use crate::dataframe::DFRayDataFrame;
use crate::physical::RayStageOptimizerRule;
use crate::util::ResultExt;
use url::Url;
use crate::util::{maybe_register_object_store, ResultExt};

/// Internal Session Context object for the python class DFRayContext
#[pyclass]
Expand Down Expand Up @@ -54,23 +54,27 @@ impl DFRayContext {
Ok(Self { ctx })
}

pub fn register_s3(&self, bucket_name: String) -> PyResult<()> {
let s3 = AmazonS3Builder::from_env()
.with_bucket_name(&bucket_name)
.build()
.to_py_err()?;
pub fn register_parquet(&self, py: Python, name: String, path: String) -> PyResult<()> {
let options = ParquetReadOptions::default();

let url = ListingTableUrl::parse(&path).to_py_err()?;

let path = format!("s3://{bucket_name}");
let s3_url = Url::parse(&path).to_py_err()?;
let arc_s3 = Arc::new(s3);
self.ctx.register_object_store(&s3_url, arc_s3.clone());
maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
debug!("register_parquet: registering table {} at {}", name, path);

wait_for_future(py, self.ctx.register_parquet(&name, &path, options.clone()))?;
Ok(())
}

pub fn register_parquet(&self, py: Python, name: String, path: String) -> PyResult<()> {
let options = ParquetReadOptions::default();
pub fn register_csv(&self, py: Python, name: String, path: String) -> PyResult<()> {
let options = CsvReadOptions::default();

wait_for_future(py, self.ctx.register_parquet(&name, &path, options.clone()))?;
let url = ListingTableUrl::parse(&path).to_py_err()?;

maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
debug!("register_csv: registering table {} at {}", name, path);

wait_for_future(py, self.ctx.register_csv(&name, &path, options.clone()))?;
Ok(())
}

Expand All @@ -85,6 +89,15 @@ impl DFRayContext {
let options =
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(file_extension);

let path = format!("{path}/");
let url = ListingTableUrl::parse(&path).to_py_err()?;

maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;

debug!(
"register_listing_table: registering table {} at {}",
name, path
);
wait_for_future(
py,
self.ctx
Expand Down
Loading
Loading