Skip to content

Commit b5fad14

Browse files
committed
support run mutiple queries in TPC-H benchmark
1 parent 42681a1 commit b5fad14

5 files changed

Lines changed: 81 additions & 46 deletions

File tree

datafusion_ray/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from datafusion_ray._datafusion_ray_internal import (
2-
exec_sql_on_tables,
2+
exec_sqls_on_tables,
33
prettify,
44
)

docs/contributing.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ RAY_COLOR_PREFIX=1 RAY_DEDUP_LOGS=0 python tips.py --data-dir=$(pwd)/../testdata
8080
- In the `tpch` directory, use `make_data.py` to create a TPCH dataset at a provided scale factor, then
8181

8282
```bash
83-
RAY_COLOR_PREFIX=1 RAY_DEDUP_LOGS=0 python tpc.py --data=file:///path/to/your/tpch/directory/ --concurrency=2 --batch-size=8182 --worker-pool-min=10 --qnum 2
83+
RAY_COLOR_PREFIX=1 RAY_DEDUP_LOGS=0 python tpcbench.py --data=file:///path/to/your/tpch/directory/ --concurrency=2 --batch-size=8182 --worker-pool-min=10 --qnum 2
8484
```
8585

8686
To execute the TPCH query #2. To execute an arbitrary query against the TPCH dataset, provide it with `--query` instead of `--qnum`. This is useful for validating plans that DataFusion Ray will create.
8787

8888
For example, to execute the following query:
8989

9090
```bash
91-
RAY_COLOR_PREFIX=1 RAY_DEDUP_LOGS=0 python tpc.py --data=file:///path/to/your/tpch/directory/ --concurrency=2 --batch-size=8182 --worker-pool-min=10 --query 'select c.c_name, sum(o.o_totalprice) as total from orders o inner join customer c on o.o_custkey = c.c_custkey group by c_name limit 1'
91+
RAY_COLOR_PREFIX=1 RAY_DEDUP_LOGS=0 python tpcbench.py --data=file:///path/to/your/tpch/directory/ --concurrency=2 --batch-size=8182 --worker-pool-min=10 --query 'select c.c_name, sum(o.o_totalprice) as total from orders o inner join customer c on o.o_custkey = c.c_custkey group by c_name limit 1'
9292
```
9393

9494
To further parallelize execution, you can choose how many partitions will be served by each Stage with `--partitions-per-processor`. If this number is less than `--concurrency` Then multiple Actors will host portions of the stage. For example, if there are 10 stages calculated for a query, `concurrency=16` and `partitions-per-processor=4`, then `40` `RayStage` Actors will be created. If `partitions-per-processor=16` or is absent, then `10` `RayStage` Actors will be created.

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ fn _datafusion_ray_internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
4444
m.add_class::<dataframe::PyDFRayStage>()?;
4545
m.add_class::<processor_service::DFRayProcessorService>()?;
4646
m.add_function(wrap_pyfunction!(util::prettify, m)?)?;
47-
m.add_function(wrap_pyfunction!(util::exec_sql_on_tables, m)?)?;
47+
m.add_function(wrap_pyfunction!(util::exec_sqls_on_tables, m)?)?;
4848
Ok(())
4949
}
5050

src/util.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use arrow::error::ArrowError;
1414
use arrow::ipc::convert::fb_to_schema;
1515
use arrow::ipc::reader::StreamReader;
1616
use arrow::ipc::writer::{IpcWriteOptions, StreamWriter};
17-
use arrow::ipc::{root_as_message, MetadataVersion};
17+
use arrow::ipc::{MetadataVersion, root_as_message};
1818
use arrow::pyarrow::*;
1919
use arrow::util::pretty;
2020
use arrow_flight::{FlightClient, FlightData, Ticket};
@@ -30,16 +30,16 @@ use datafusion::error::DataFusionError;
3030
use datafusion::execution::object_store::ObjectStoreUrl;
3131
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, SessionStateBuilder};
3232
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
33-
use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties};
33+
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable};
3434
use datafusion::prelude::{SessionConfig, SessionContext};
3535
use datafusion_proto::physical_plan::AsExecutionPlan;
3636
use datafusion_python::utils::wait_for_future;
3737
use futures::{Stream, StreamExt};
3838
use log::debug;
39+
use object_store::ObjectStore;
3940
use object_store::aws::AmazonS3Builder;
4041
use object_store::gcp::GoogleCloudStorageBuilder;
4142
use object_store::http::HttpBuilder;
42-
use object_store::ObjectStore;
4343
use parking_lot::Mutex;
4444
use pyo3::prelude::*;
4545
use pyo3::types::{PyBytes, PyList};
@@ -412,9 +412,9 @@ fn print_node(plan: &Arc<dyn ExecutionPlan>, indent: usize, output: &mut String)
412412
}
413413

414414
async fn exec_sql(
415-
query: String,
415+
queries: Vec<String>,
416416
tables: Vec<(String, String)>,
417-
) -> Result<RecordBatch, DataFusionError> {
417+
) -> Result<Vec<RecordBatch>, DataFusionError> {
418418
let ctx = SessionContext::new();
419419
for (name, path) in tables {
420420
let opt =
@@ -428,34 +428,39 @@ async fn exec_sql(
428428
ctx.register_listing_table(&name, &path, opt, None, None)
429429
.await?;
430430
}
431-
let df = ctx.sql(&query).await?;
432-
let schema = df.schema().inner().clone();
433-
let batches = df.collect().await?;
434-
concat_batches(&schema, batches.iter()).map_err(|e| DataFusionError::ArrowError(e, None))
431+
let mut results = vec![];
432+
for query in queries {
433+
let df = ctx.sql(&query).await?;
434+
let schema = df.schema().inner().clone();
435+
let batches = df.collect().await?;
436+
let result = concat_batches(&schema, &batches)?;
437+
results.push(result);
438+
}
439+
Ok(results)
435440
}
436441

437-
/// Executes a query on the specified tables using DataFusion without Ray.
442+
/// Executes queries on the specified tables using DataFusion without Ray.
438443
///
439-
/// Returns the query results as a RecordBatch that can be used to verify the
440-
/// correctness of DataFusion-Ray execution of the same query.
444+
/// Returns the query results as a Vec of RecordBatch that can be used to verify the
445+
/// correctness of DataFusion-Ray execution of the same queries.
441446
///
442447
/// # Arguments
443448
///
444449
/// * `py`: the Python token
445-
/// * `query`: the SQL query string to execute
450+
/// * `queries`: the SQL query strings to execute
446451
/// * `tables`: a list of `(name, url)` tuples specifying the tables to query;
447452
/// the `url` identifies the parquet files for each listing table and see
448453
/// [`datafusion::datasource::listing::ListingTableUrl::parse`] for details
449454
/// of supported URL formats
450455
/// * `listing`: boolean indicating whether this is a listing table path or not
451456
#[pyfunction]
452-
#[pyo3(signature = (query, tables, listing=false))]
453-
pub fn exec_sql_on_tables(
457+
#[pyo3(signature = (queries, tables, listing=false))]
458+
pub fn exec_sqls_on_tables(
454459
py: Python,
455-
query: String,
460+
queries: Vec<String>,
456461
tables: Bound<'_, PyList>,
457462
listing: bool,
458-
) -> PyResult<PyObject> {
463+
) -> PyResult<Vec<PyObject>> {
459464
let table_vec = {
460465
let mut v = Vec::with_capacity(tables.len());
461466
for entry in tables.iter() {
@@ -465,8 +470,8 @@ pub fn exec_sql_on_tables(
465470
}
466471
v
467472
};
468-
let batch = wait_for_future(py, exec_sql(query, table_vec))?;
469-
batch.to_pyarrow(py)
473+
let batches = wait_for_future(py, exec_sql(queries, table_vec))?;
474+
batches.iter().map(|b| b.to_pyarrow(py)).collect()
470475
}
471476

472477
pub(crate) fn register_object_store_for_paths_in_plan(

tpch/tpcbench.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import argparse
1919
import ray
2020
from datafusion_ray import DFRayContext, df_ray_runtime_env
21-
from datafusion_ray.util import exec_sql_on_tables, prettify
21+
from datafusion_ray.util import exec_sqls_on_tables, prettify
2222
from datetime import datetime
2323
import json
2424
import os
@@ -31,7 +31,7 @@ def tpch_query(qnum: int) -> str:
3131

3232

3333
def main(
34-
qnum: int,
34+
queries: list[(str, str)],
3535
data_path: str,
3636
concurrency: int,
3737
batch_size: int,
@@ -95,35 +95,43 @@ def main(
9595
if validate:
9696
results["validated"] = {}
9797

98-
queries = range(1, 23) if qnum == -1 else [qnum]
99-
for qnum in queries:
100-
sql = tpch_query(qnum)
101-
102-
statements = sql.split(";")
103-
sql = statements[0]
104-
98+
for (qid, sql) in queries:
10599
print("executing ", sql)
106100

101+
statements = [s for s in sql.split(";") if s.strip() != ""]
107102
start_time = time.time()
108-
df = ctx.sql(sql)
109-
batches = df.collect()
103+
batches = [ctx.sql(s).collect() for s in statements]
110104
end_time = time.time()
111-
results["queries"][qnum] = end_time - start_time
105+
results["queries"][qid] = end_time - start_time
112106

113-
calculated = prettify(batches)
114-
print(calculated)
107+
calculated = [prettify(batch) for batch in batches if batch]
108+
for pretty_batch in calculated:
109+
print(pretty_batch)
115110
if validate:
116111
tables = [
117112
(name, os.path.join(data_path, f"{name}.parquet"))
118113
for name in table_names
119114
]
120-
answer_batches = [
121-
b for b in [exec_sql_on_tables(sql, tables, listing_tables)] if b
122-
]
123-
expected = prettify(answer_batches)
124-
125-
results["validated"][qnum] = calculated == expected
126-
print(f"done with query {qnum}")
115+
answer_batches = [b for b in exec_sqls_on_tables(
116+
statements, tables, listing_tables) if b]
117+
118+
validated = True
119+
if len(answer_batches) == len(calculated):
120+
expected = [prettify([answer_batch])
121+
for answer_batch in answer_batches]
122+
validated = all(x[0] == x[1]
123+
for x in zip(calculated, expected))
124+
for x in zip(calculated, expected):
125+
if x[0] != x[1]:
126+
print(f"Expected:\n{x[1]}")
127+
print(f"Got:\n{x[0]}")
128+
else:
129+
print(
130+
f"Expected {len(answer_batches)} batches, got {len(calculated)}")
131+
validated = False
132+
133+
results["validated"][qid] = validated
134+
print(f"done with query {qid}")
127135

128136
# write the results as we go, so you can peek at them
129137
results_dump = json.dumps(results, indent=4)
@@ -151,7 +159,10 @@ def main(
151159
parser.add_argument(
152160
"--concurrency", required=True, help="Number of concurrent tasks"
153161
)
154-
parser.add_argument("--qnum", type=int, default=-1, help="TPCH query number, 1-22")
162+
parser.add_argument("--qnum", type=int, default=-1,
163+
help="TPCH query number, 1-22")
164+
parser.add_argument("--query", required=False, type=str,
165+
help="Custom query to run with tpch tables")
155166
parser.add_argument("--listing-tables", action="store_true")
156167
parser.add_argument("--validate", action="store_true")
157168
parser.add_argument(
@@ -183,8 +194,27 @@ def main(
183194

184195
args = parser.parse_args()
185196

197+
if (args.qnum != -1 and args.query is not None):
198+
print("Please specify either --qnum or --query, but not both")
199+
200+
queries = []
201+
if (args.qnum != -1):
202+
if args.qnum < 1 or args.qnum > 22:
203+
print("Invalid query number. Please specify a number between 1 and 22.")
204+
exit(1)
205+
else:
206+
queries.append((str(args.qnum), tpch_query(args.qnum)))
207+
print("Executing tpch query ", args.qnum)
208+
209+
elif (args.query is not None):
210+
queries.append(("custom query", args.query))
211+
print("Executing custom query ", args.query)
212+
else:
213+
print("Executing all tpch queries")
214+
queries = [(str(i), tpch_query(i)) for i in range(1, 23)]
215+
186216
main(
187-
args.qnum,
217+
queries,
188218
args.data,
189219
int(args.concurrency),
190220
int(args.batch_size),

0 commit comments

Comments
 (0)