diff --git a/.env b/.env index ae2739187db9..1668ae84695d 100644 --- a/.env +++ b/.env @@ -47,7 +47,7 @@ FEDORA=33 PYTHON=3.6 LLVM=11 CLANG_TOOLS=8 -RUST=nightly-2021-10-23 +RUST=nightly-2022-01-17 GO=1.15 NODE=14 MAVEN=3.5.4 diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a4557c17fe9f..9bd42dbaa0d6 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,9 +3,7 @@ updates: - package-ecosystem: cargo directory: "/" schedule: - interval: weekly - day: sunday - time: "7:00" + interval: daily open-pull-requests-limit: 10 target-branch: master - labels: [auto-dependencies] \ No newline at end of file + labels: [auto-dependencies] diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 096ed7817aa6..8a7f6737ded5 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -293,7 +293,7 @@ jobs: strategy: matrix: arch: [amd64] - rust: [nightly-2021-10-23] + rust: [nightly-2022-01-17] steps: - uses: actions/checkout@v2 with: diff --git a/Cargo.toml b/Cargo.toml index 757d671fbe0a..d030266955d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,13 +28,11 @@ members = [ "ballista-examples", ] -exclude = ["python"] - [profile.release] lto = true codegen-units = 1 [patch.crates-io] -arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", rev = "ef7937dfe56033c2cc491482c67587b52cd91554" } +arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", branch = "main" } #arrow2 = { git = "https://github.com/blaze-init/arrow2.git", branch = "shuffle_ipc" } #parquet2 = { git = "https://github.com/blaze-init/parquet2.git", branch = "meta_new" } diff --git a/README.md b/README.md index 6bef96637712..25fc16c8956c 100644 --- a/README.md +++ b/README.md @@ -49,19 +49,25 @@ the convenience of an SQL interface or a DataFrame API. ## Known Uses +Projects that adapt to or serve as plugins to DataFusion: + +- [datafusion-python](https://github.com/datafusion-contrib/datafusion-python) +- [datafusion-java](https://github.com/datafusion-contrib/datafusion-java) +- [datafusion-ruby](https://github.com/j-a-m-l/datafusion-ruby) +- [datafusion-objectstore-s3](https://github.com/datafusion-contrib/datafusion-objectstore-s3) +- [datafusion-hdfs-native](https://github.com/datafusion-contrib/datafusion-hdfs-native) + Here are some of the projects known to use DataFusion: - [Ballista](ballista) Distributed Compute Platform - [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) -- [datafusion-python](https://pypi.org/project/datafusion) -- [datafusion-java](https://github.com/datafusion-contrib/datafusion-java) -- [datafusion-ruby](https://github.com/j-a-m-l/datafusion-ruby) - [delta-rs](https://github.com/delta-io/delta-rs) - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [ROAPI](https://github.com/roapi/roapi) - [Tensorbase](https://github.com/tensorbase/tensorbase) - [Squirtle](https://github.com/DSLAM-UMD/Squirtle) +- [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar (if you know of another project, please submit a PR to add a link!) @@ -134,6 +140,60 @@ datafusion = "6.0.0" DataFusion also includes a simple command-line interactive SQL utility. See the [CLI reference](https://arrow.apache.org/datafusion/cli/index.html) for more information. +# Roadmap + +A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. + +## 2022 Q1 + +### DataFusion Core + +- Publish official Arrow2 branch +- Implementation of memory manager (i.e. to enable spilling to disk as needed) + +### Benchmarking + +- Inclusion in Db-Benchmark with all quries covered +- All TPCH queries covered + +### Performance Improvements + +- Predicate evaluation +- Improve multi-column comparisons (that can't be vectorized at the moment) +- Null constant support + +### New Features + +- Read JSON as table +- Simplify DDL with Datafusion-Cli +- Add Decimal128 data type and the attendant features such as Arrow Kernel and UDF support +- Add new experimental e-graph based optimizer + +### Ballista + +- Begin work on design documents and plan / priorities for development + +### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib])) + +- Stable S3 support +- Begin design discussions and prototyping of a stream provider + +## Beyond 2022 Q1 + +There is no clear timeline for the below, but community members have expressed interest in working on these topics. + +### DataFusion Core + +- Custom SQL support +- Split DataFusion into multiple crates +- Push based query execution and code generation + +### Ballista + +- Evolve architecture so that it can be deployed in a multi-tenant cloud native environment +- Ensure Ballista is scalable, elastic, and stable for production usage +- Develop distributed ML capabilities + # Status ## General @@ -266,7 +326,7 @@ This library currently supports many SQL constructs, including - `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` - Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. - `WHERE` to filter -- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `VAR`, `STDDEV` (sample and population) +- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) - `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` ## Supported Functions diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index 338f69994bfd..d5f7d65d83ef 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" keywords = [ "arrow", "distributed", "query", "sql" ] edition = "2021" publish = false -rust-version = "1.57" +rust-version = "1.58" [dependencies] datafusion = { path = "../datafusion" } diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index 7736e949d29f..aa8297f8d06d 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -24,7 +24,7 @@ homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" authors = ["Apache Arrow "] edition = "2021" -rust-version = "1.57" +rust-version = "1.58" [dependencies] ballista-core = { path = "../core", version = "0.6.0" } @@ -33,6 +33,8 @@ ballista-scheduler = { path = "../scheduler", version = "0.6.0", optional = true futures = "0.3" log = "0.4" tokio = "1.0" +tempfile = "3" +sqlparser = "0.13" datafusion = { path = "../../../datafusion", version = "6.0.0" } diff --git a/ballista/rust/client/src/columnar_batch.rs b/ballista/rust/client/src/columnar_batch.rs index 9460bed1a8d3..5177261a2bd2 100644 --- a/ballista/rust/client/src/columnar_batch.rs +++ b/ballista/rust/client/src/columnar_batch.rs @@ -23,8 +23,9 @@ use datafusion::arrow::{ array::ArrayRef, compute::aggregate::estimated_bytes_size, datatypes::{DataType, Schema}, - record_batch::RecordBatch, }; +use datafusion::field_util::{FieldExt, SchemaExt}; +use datafusion::record_batch::RecordBatch; use datafusion::scalar::ScalarValue; pub type MaybeColumnarBatch = Result>; @@ -44,7 +45,7 @@ impl ColumnarBatch { .enumerate() .map(|(i, array)| { ( - batch.schema().field(i).name().clone(), + batch.schema().field(i).name().to_string(), ColumnarValue::Columnar(array.clone()), ) }) @@ -61,7 +62,7 @@ impl ColumnarBatch { .fields() .iter() .enumerate() - .map(|(i, f)| (f.name().clone(), values[i].clone())) + .map(|(i, f)| (f.name().to_string(), values[i].clone())) .collect(); Self { diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index fff6f26305fa..3fb347bddbce 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -17,6 +17,7 @@ //! Distributed execution context. +use sqlparser::ast::Statement; use std::collections::HashMap; use std::fs; use std::path::PathBuf; @@ -31,8 +32,10 @@ use datafusion::datasource::TableProvider; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::dataframe_impl::DataFrameImpl; use datafusion::logical_plan::{CreateExternalTable, LogicalPlan, TableScan}; -use datafusion::prelude::{AvroReadOptions, CsvReadOptions}; -use datafusion::sql::parser::FileType; +use datafusion::prelude::{ + AvroReadOptions, CsvReadOptions, ExecutionConfig, ExecutionContext, +}; +use datafusion::sql::parser::{DFParser, FileType, Statement as DFStatement}; struct BallistaContextState { /// Ballista configuration @@ -242,6 +245,35 @@ impl BallistaContext { } } + /// is a 'show *' sql + pub async fn is_show_statement(&self, sql: &str) -> Result { + let mut is_show_variable: bool = false; + let statements = DFParser::parse_sql(sql)?; + + if statements.len() != 1 { + return Err(DataFusionError::NotImplemented( + "The context currently only supports a single SQL statement".to_string(), + )); + } + + if let DFStatement::Statement(s) = &statements[0] { + let st: &Statement = s; + match st { + Statement::ShowVariable { .. } => { + is_show_variable = true; + } + Statement::ShowColumns { .. } => { + is_show_variable = true; + } + _ => { + is_show_variable = false; + } + } + }; + + Ok(is_show_variable) + } + /// Create a DataFrame from a SQL statement. /// /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` @@ -256,6 +288,17 @@ impl BallistaContext { ) }; + let is_show = self.is_show_statement(sql).await?; + // the show tables、 show columns sql can not run at scheduler because the tables is store at client + if is_show { + let state = self.state.lock().unwrap(); + ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema( + state.config.default_with_information_schema(), + ), + ); + } + // register tables with DataFusion context { let state = self.state.lock().unwrap(); @@ -268,6 +311,7 @@ impl BallistaContext { } let plan = ctx.create_logical_plan(sql)?; + match plan { LogicalPlan::CreateExternalTable(CreateExternalTable { ref schema, @@ -309,6 +353,7 @@ impl BallistaContext { #[cfg(test)] mod tests { + #[tokio::test] #[cfg(feature = "standalone")] async fn test_standalone_mode() { @@ -319,4 +364,161 @@ mod tests { let df = context.sql("SELECT 1;").await.unwrap(); df.collect().await.unwrap(); } + + #[tokio::test] + #[cfg(feature = "standalone")] + async fn test_ballista_show_tables() { + use super::*; + use std::fs::File; + use std::io::Write; + use tempfile::TempDir; + let context = BallistaContext::standalone(&BallistaConfig::new().unwrap(), 1) + .await + .unwrap(); + + let data = "Jorge,2018-12-13T12:12:10.011Z\n\ + Andrew,2018-11-13T17:11:10.011Z"; + + let tmp_dir = TempDir::new().unwrap(); + let file_path = tmp_dir.path().join("timestamps.csv"); + + // scope to ensure the file is closed and written + { + File::create(&file_path) + .expect("creating temp file") + .write_all(data.as_bytes()) + .expect("writing data"); + } + + let sql = format!( + "CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP + ) + STORED AS CSV + LOCATION '{}' + ", + file_path.to_str().expect("path is utf8") + ); + + context.sql(sql.as_str()).await.unwrap(); + + let df = context.sql("show columns from csv_with_timestamps;").await; + + assert!(df.is_err()); + } + + #[tokio::test] + #[cfg(feature = "standalone")] + async fn test_show_tables_not_with_information_schema() { + use super::*; + use ballista_core::config::{ + BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA, + }; + use std::fs::File; + use std::io::Write; + use tempfile::TempDir; + let config = BallistaConfigBuilder::default() + .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") + .build() + .unwrap(); + let context = BallistaContext::standalone(&config, 1).await.unwrap(); + + let data = "Jorge,2018-12-13T12:12:10.011Z\n\ + Andrew,2018-11-13T17:11:10.011Z"; + + let tmp_dir = TempDir::new().unwrap(); + let file_path = tmp_dir.path().join("timestamps.csv"); + + // scope to ensure the file is closed and written + { + File::create(&file_path) + .expect("creating temp file") + .write_all(data.as_bytes()) + .expect("writing data"); + } + + let sql = format!( + "CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP + ) + STORED AS CSV + LOCATION '{}' + ", + file_path.to_str().expect("path is utf8") + ); + + context.sql(sql.as_str()).await.unwrap(); + let df = context.sql("show tables;").await; + assert!(df.is_ok()); + } + + #[tokio::test] + #[cfg(feature = "standalone")] + async fn test_task_stuck_when_referenced_task_failed() { + use super::*; + use datafusion::arrow::datatypes::Schema; + use datafusion::arrow::util::pretty; + use datafusion::datasource::file_format::csv::CsvFormat; + use datafusion::datasource::file_format::parquet::ParquetFormat; + use datafusion::datasource::listing::{ListingOptions, ListingTable}; + + use ballista_core::config::{ + BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA, + }; + use std::fs::File; + use std::io::Write; + use tempfile::TempDir; + let config = BallistaConfigBuilder::default() + .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") + .build() + .unwrap(); + let context = BallistaContext::standalone(&config, 1).await.unwrap(); + + let testdata = datafusion::test_util::parquet_test_data(); + context + .register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata)) + .await + .unwrap(); + + { + let mut guard = context.state.lock().unwrap(); + let csv_table = guard.tables.get("single_nan"); + + if let Some(table_provide) = csv_table { + if let Some(listing_table) = table_provide + .clone() + .as_any() + .downcast_ref::() + { + let x = listing_table.options(); + let error_options = ListingOptions { + file_extension: x.file_extension.clone(), + format: Arc::new(CsvFormat::default()), + table_partition_cols: x.table_partition_cols.clone(), + collect_stat: x.collect_stat, + target_partitions: x.target_partitions, + }; + let error_table = ListingTable::new( + listing_table.object_store().clone(), + listing_table.table_path().to_string(), + Arc::new(Schema::new(vec![])), + error_options, + ); + // change the table to an error table + guard + .tables + .insert("single_nan".to_string(), Arc::new(error_table)); + } + } + } + + let df = context + .sql("select count(1) from single_nan;") + .await + .unwrap(); + let results = df.collect().await.unwrap(); + pretty::print_batches(&results); + } } diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 3415d13a3487..caa9ca84f12d 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -33,7 +33,7 @@ simd = ["datafusion/simd"] ahash = { version = "0.7", default-features = false } async-trait = "0.1.36" futures = "0.3" -hashbrown = "0.11" +hashbrown = "0.12" log = "0.4" prost = "0.9" serde = {version = "1", features = ["derive"]} @@ -42,9 +42,11 @@ tokio = "1.0" tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } chrono = { version = "0.4", default-features = false } +clap = { version = "3", features = ["derive", "cargo"] } +parse_arg = "0.1.3" -arrow-format = { version = "0.3", features = ["flight-data", "flight-service"] } -arrow = { package = "arrow2", version="0.8", features = ["io_ipc", "io_flight"] } +arrow-format = { version = "0.4", features = ["flight-data", "flight-service"] } +arrow = { package = "arrow2", version="0.9", features = ["io_ipc", "io_flight"] } datafusion = { path = "../../../datafusion", version = "6.0.0" } diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 5a755cc9a2ac..ea0d15f9e8ef 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -171,8 +171,11 @@ enum AggregateFunction { ARRAY_AGG = 6; VARIANCE=7; VARIANCE_POP=8; - STDDEV=9; - STDDEV_POP=10; + COVARIANCE=9; + COVARIANCE_POP=10; + STDDEV=11; + STDDEV_POP=12; + CORRELATION=13; } message AggregateExprNode { @@ -623,7 +626,6 @@ message ScanLimit { message FileScanExecConf { repeated FileGroup file_groups = 1; Schema schema = 2; - uint32 batch_size = 3; repeated uint32 projection = 4; ScanLimit limit = 5; Statistics statistics = 6; @@ -836,6 +838,7 @@ message ExecutorMetadata { string id = 1; string host = 2; uint32 port = 3; + uint32 grpc_port = 4; } message ExecutorRegistration { @@ -846,12 +849,46 @@ message ExecutorRegistration { string host = 2; } uint32 port = 3; + uint32 grpc_port = 4; } message ExecutorHeartbeat { ExecutorMetadata meta = 1; // Unix epoch-based timestamp in seconds uint64 timestamp = 2; + ExecutorState state = 3; +} + +message ExecutorState { + repeated ExecutorMetric metrics = 1; +} + +message ExecutorMetric { + // TODO add more metrics + oneof metric { + uint64 available_memory = 1; + } +} + +message ExecutorSpecification { + repeated ExecutorResource resources = 1; +} + +message ExecutorResource { + // TODO add more resources + oneof resource { + uint32 task_slots = 1; + } +} + +message ExecutorData { + string executor_id = 1; + repeated ExecutorResourcePair resources = 2; +} + +message ExecutorResourcePair { + ExecutorResource total = 1; + ExecutorResource available = 2; } message RunningTask { @@ -904,6 +941,41 @@ message PollWorkResult { TaskDefinition task = 1; } +message RegisterExecutorParams { + ExecutorRegistration metadata = 1; + ExecutorSpecification specification = 2; +} + +message RegisterExecutorResult { + bool success = 1; +} + +message SendHeartBeatParams { + ExecutorRegistration metadata = 1; + ExecutorState state = 2; +} + +message SendHeartBeatResult { + // TODO it's from Spark for BlockManager + bool reregister = 1; +} + +message StopExecutorParams { +} + +message StopExecutorResult { +} + +message UpdateTaskStatusParams { + ExecutorRegistration metadata = 1; + // All tasks must be reported until they reach the failed or completed state + repeated TaskStatus task_status = 2; +} + +message UpdateTaskStatusResult { + bool success = 1; +} + message ExecuteQueryParams { oneof query { LogicalPlanNode logical_plan = 1; @@ -963,10 +1035,28 @@ message FilePartitionMetadata { repeated string filename = 1; } +message LaunchTaskParams { + // Allow to launch a task set to an executor at once + repeated TaskDefinition task = 1; +} + +message LaunchTaskResult { + bool success = 1; + // TODO when part of the task set are scheduled successfully +} + service SchedulerGrpc { // Executors must poll the scheduler for heartbeat and to receive tasks rpc PollWork (PollWorkParams) returns (PollWorkResult) {} + rpc RegisterExecutor(RegisterExecutorParams) returns (RegisterExecutorResult) {} + + // Push-based task scheduler will only leverage this interface + // rather than the PollWork interface to report executor states + rpc SendHeartBeat (SendHeartBeatParams) returns (SendHeartBeatResult) {} + + rpc UpdateTaskStatus (UpdateTaskStatusParams) returns (UpdateTaskStatusResult) {} + rpc GetFileMetadata (GetFileMetadataParams) returns (GetFileMetadataResult) {} rpc ExecuteQuery (ExecuteQueryParams) returns (ExecuteQueryResult) {} @@ -974,6 +1064,12 @@ service SchedulerGrpc { rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {} } +service ExecutorGrpc { + rpc LaunchTask (LaunchTaskParams) returns (LaunchTaskResult) {} + + rpc StopExecutor (StopExecutorParams) returns (StopExecutorResult) {} +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1041,11 +1137,18 @@ message Struct{ repeated Field sub_field_types = 1; } +enum UnionMode{ + sparse = 0; + dense = 1; +} + message Union{ repeated Field union_types = 1; + UnionMode union_mode = 2; } + message ScalarListValue{ ScalarType datatype = 1; repeated ScalarValue values = 2; @@ -1076,9 +1179,21 @@ message ScalarValue{ ScalarType null_list_value = 18; PrimitiveScalarType null_value = 19; + Decimal128 decimal128_value = 20; + int64 date_64_value = 21; + int64 time_second_value = 22; + int64 time_millisecond_value = 23; + int32 interval_yearmonth_value = 24; + int64 interval_daytime_value = 25; } } +message Decimal128{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + // Contains all valid datafusion scalar type except for // List enum PrimitiveScalarType{ @@ -1100,6 +1215,13 @@ enum PrimitiveScalarType{ TIME_MICROSECOND = 14; TIME_NANOSECOND = 15; NULL = 16; + + DECIMAL128 = 17; + DATE64 = 20; + TIME_SECOND = 21; + TIME_MILLISECOND = 22; + INTERVAL_YEARMONTH = 23; + INTERVAL_DAYTIME = 24; } message ScalarType{ diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs index eaacda8badf2..6adaa8c0ac92 100644 --- a/ballista/rust/core/src/client.rs +++ b/ballista/rust/core/src/client.rs @@ -27,7 +27,6 @@ use std::{ }; use crate::error::{ballista_error, BallistaError, Result}; -use crate::memory_stream::MemoryStream; use crate::serde::protobuf::{self}; use crate::serde::scheduler::{ Action, ExecutePartition, ExecutePartitionResult, PartitionId, PartitionStats, @@ -39,10 +38,11 @@ use datafusion::arrow::{ array::{StructArray, Utf8Array}, datatypes::{Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, - record_batch::RecordBatch, }; +use datafusion::field_util::SchemaExt; use datafusion::physical_plan::common::collect; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use datafusion::record_batch::RecordBatch; use datafusion::{logical_plan::LogicalPlan, physical_plan::RecordBatchStream}; use futures::{Stream, StreamExt}; use log::debug; @@ -164,7 +164,7 @@ impl Stream for FlightDataStream { self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let mut stream = self.stream.lock().unwrap(); + let mut stream = self.stream.lock().expect("mutex is bad"); stream.poll_next_unpin(cx).map(|x| match x { Some(flight_data_chunk_result) => { let converted_chunk = flight_data_chunk_result @@ -174,11 +174,12 @@ impl Stream for FlightDataStream { arrow::io::flight::deserialize_batch( &flight_data_chunk, - self.schema.clone(), + self.schema.fields(), &self.ipc_schema, &hm, ) - }); + }) + .map(|c| RecordBatch::new_with_chunk(&self.schema, c)); Some(converted_chunk) } None => None, diff --git a/ballista/rust/core/src/config.rs b/ballista/rust/core/src/config.rs index 5d7b3c5cacb0..601d63756623 100644 --- a/ballista/rust/core/src/config.rs +++ b/ballista/rust/core/src/config.rs @@ -18,7 +18,11 @@ //! Ballista configuration +use clap::ArgEnum; +use core::fmt; use std::collections::HashMap; +use std::result; +use std::string::ParseError; use crate::error::{BallistaError, Result}; @@ -26,6 +30,9 @@ use datafusion::arrow::datatypes::DataType; use log::warn; pub const BALLISTA_DEFAULT_SHUFFLE_PARTITIONS: &str = "ballista.shuffle.partitions"; +pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = "ballista.with_information_schema"; + +pub type ParseResult = result::Result; /// Configuration option meta-data #[derive(Debug, Clone)] @@ -103,9 +110,9 @@ impl BallistaConfig { for (name, entry) in &supported_entries { if let Some(v) = settings.get(name) { // validate that we can parse the user-supplied value - let _ = v.parse::().map_err(|e| BallistaError::General(format!("Failed to parse user-supplied value '{}' for configuration setting '{}': {:?}", name, v, e)))?; + let _ = Self::parse_value(v.as_str(), entry._data_type.clone()).map_err(|e| BallistaError::General(format!("Failed to parse user-supplied value '{}' for configuration setting '{}': {}", name, v, e)))?; } else if let Some(v) = entry.default_value.clone() { - let _ = v.parse::().map_err(|e| BallistaError::General(format!("Failed to parse default value '{}' for configuration setting '{}': {:?}", name, v, e)))?; + let _ = Self::parse_value(v.as_str(), entry._data_type.clone()).map_err(|e| BallistaError::General(format!("Failed to parse default value '{}' for configuration setting '{}': {}", name, v, e)))?; } else { return Err(BallistaError::General(format!( "No value specified for mandatory configuration setting '{}'", @@ -117,12 +124,35 @@ impl BallistaConfig { Ok(Self { settings }) } + pub fn parse_value(val: &str, data_type: DataType) -> ParseResult<()> { + match data_type { + DataType::UInt16 => { + val.to_string() + .parse::() + .map_err(|e| format!("{:?}", e))?; + } + DataType::Boolean => { + val.to_string() + .parse::() + .map_err(|e| format!("{:?}", e))?; + } + _ => { + return Err(format!("not support data type: {:?}", data_type)); + } + } + + Ok(()) + } + /// All available configuration options pub fn valid_entries() -> HashMap { let entries = vec![ ConfigEntry::new(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS.to_string(), "Sets the default number of partitions to create when repartitioning query stages".to_string(), DataType::UInt16, Some("2".to_string())), + ConfigEntry::new(BALLISTA_WITH_INFORMATION_SCHEMA.to_string(), + "Sets whether enable information_schema".to_string(), + DataType::Boolean,Some("false".to_string())), ]; entries .iter() @@ -138,6 +168,10 @@ impl BallistaConfig { self.get_usize_setting(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS) } + pub fn default_with_information_schema(&self) -> bool { + self.get_bool_setting(BALLISTA_WITH_INFORMATION_SCHEMA) + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor @@ -149,6 +183,40 @@ impl BallistaConfig { v.parse().unwrap() } } + + fn get_bool_setting(&self, key: &str) -> bool { + if let Some(v) = self.settings.get(key) { + // infallible because we validate all configs in the constructor + v.parse::().unwrap() + } else { + let entries = Self::valid_entries(); + // infallible because we validate all configs in the constructor + let v = entries.get(key).unwrap().default_value.as_ref().unwrap(); + v.parse::().unwrap() + } + } +} + +// an enum used to configure the scheduler policy +// needs to be visible to code generated by configure_me +#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)] +pub enum TaskSchedulingPolicy { + PullStaged, + PushStaged, +} + +impl std::str::FromStr for TaskSchedulingPolicy { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + ArgEnum::from_str(s, true) + } +} + +impl parse_arg::ParseArgFromStr for TaskSchedulingPolicy { + fn describe_type(mut writer: W) -> fmt::Result { + write!(writer, "The scheduler policy for the scheduler") + } } #[cfg(test)] @@ -159,6 +227,7 @@ mod tests { fn default_config() -> Result<()> { let config = BallistaConfig::new()?; assert_eq!(2, config.default_shuffle_partitions()); + assert!(!config.default_with_information_schema()); Ok(()) } @@ -166,8 +235,10 @@ mod tests { fn custom_config() -> Result<()> { let config = BallistaConfig::builder() .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "123") + .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") .build()?; assert_eq!(123, config.default_shuffle_partitions()); + assert!(config.default_with_information_schema()); Ok(()) } @@ -178,6 +249,13 @@ mod tests { .build(); assert!(config.is_err()); assert_eq!("General(\"Failed to parse user-supplied value 'ballista.shuffle.partitions' for configuration setting 'true': ParseIntError { kind: InvalidDigit }\")", format!("{:?}", config.unwrap_err())); + + let config = BallistaConfig::builder() + .set(BALLISTA_WITH_INFORMATION_SCHEMA, "123") + .build(); + assert!(config.is_err()); + assert_eq!("General(\"Failed to parse user-supplied value 'ballista.with_information_schema' for configuration setting '123': ParseBoolError\")", format!("{:?}", config.unwrap_err())); + Ok(()) } } diff --git a/ballista/rust/core/src/error.rs b/ballista/rust/core/src/error.rs index b2c8d99ae9f9..e9ffcd8180eb 100644 --- a/ballista/rust/core/src/error.rs +++ b/ballista/rust/core/src/error.rs @@ -139,7 +139,7 @@ impl From for BallistaError { } impl Display for BallistaError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match self { BallistaError::NotImplemented(ref desc) => { write!(f, "Not implemented: {}", desc) diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index bebc98f08cc4..619cc9bc925d 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -39,6 +39,7 @@ use datafusion::physical_plan::{ }; use async_trait::async_trait; +use datafusion::execution::runtime_env::RuntimeEnv; use futures::future; use futures::StreamExt; use log::{error, info}; @@ -99,7 +100,8 @@ impl ExecutionPlan for DistributedQueryExec { async fn execute( &self, partition: usize, - ) -> datafusion::error::Result { + _runtime: Arc, + ) -> Result { assert_eq!(0, partition); info!("Connecting to Ballista scheduler at {}", self.scheduler_url); diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index 6cdd8cc7665a..496aa5d09065 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -20,20 +20,21 @@ use std::sync::Arc; use std::{any::Any, pin::Pin}; use crate::client::BallistaClient; -use crate::memory_stream::MemoryStream; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; use crate::utils::WrappedStream; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::error::Result as ArrowResult; -use datafusion::arrow::record_batch::RecordBatch; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::metrics::{ ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Metric, Partitioning, Statistics, + DisplayFormatType, ExecutionPlan, Metric, Partitioning, SendableRecordBatchStream, + Statistics, }; +use datafusion::record_batch::RecordBatch; use datafusion::{ error::{DataFusionError, Result}, physical_plan::RecordBatchStream, @@ -100,7 +101,8 @@ impl ExecutionPlan for ShuffleReaderExec { async fn execute( &self, partition: usize, - ) -> Result>> { + _runtime: Arc, + ) -> Result { info!("ShuffleReaderExec::execute({})", partition); let fetch_time = diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 52386049b13b..2c4b2401b4f3 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -28,11 +28,11 @@ use std::time::Instant; use std::{any::Any, pin::Pin}; use crate::error::BallistaError; -use crate::memory_stream::MemoryStream; use crate::utils; use crate::serde::protobuf::ShuffleWritePartition; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; +use arrow::chunk::Chunk; use arrow::io::ipc::write::WriteOptions; use async_trait::async_trait; use datafusion::arrow::array::*; @@ -41,17 +41,22 @@ use datafusion::arrow::compute::take; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::io::ipc::read::FileReader; use datafusion::arrow::io::ipc::write::FileWriter; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::field_util::SchemaExt; +use datafusion::physical_plan::common::IPCWriter; use datafusion::physical_plan::hash_utils::create_hashes; +use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::metrics::{ self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::Partitioning::RoundRobinBatch; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Metric, Partitioning, RecordBatchStream, Statistics, + DisplayFormatType, ExecutionPlan, Metric, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; +use datafusion::record_batch::RecordBatch; use futures::StreamExt; use hashbrown::HashMap; use log::{debug, info}; @@ -140,10 +145,11 @@ impl ShuffleWriterExec { pub async fn execute_shuffle_write( &self, input_partition: usize, + runtime: Arc, ) -> Result> { let now = Instant::now(); - let mut stream = self.plan.execute(input_partition).await?; + let mut stream = self.plan.execute(input_partition, runtime).await?; let mut path = PathBuf::from(&self.work_dir); path.push(&self.job_id); @@ -198,7 +204,7 @@ impl ShuffleWriterExec { // we won't necessary produce output for every possible partition, so we // create writers on demand - let mut writers: Vec> = vec![]; + let mut writers: Vec> = vec![]; for _ in 0..num_output_partitions { writers.push(None); } @@ -265,11 +271,10 @@ impl ShuffleWriterExec { std::fs::create_dir_all(&path)?; path.push(format!("data-{}.arrow", input_partition)); - let path = path.to_str().unwrap(); - info!("Writing results to {}", path); + info!("Writing results to {:?}", path); let mut writer = - ShuffleWriter::new(path, stream.schema().as_ref())?; + IPCWriter::new(&path, stream.schema().as_ref())?; writer.write(&output_batch)?; writers[output_partition] = Some(writer); @@ -287,7 +292,7 @@ impl ShuffleWriterExec { Some(w) => { w.finish()?; info!( - "Finished writing shuffle partition {} at {}. Batches: {}. Rows: {}. Bytes: {}.", + "Finished writing shuffle partition {} at {:?}. Batches: {}. Rows: {}. Bytes: {}.", i, w.path(), w.num_batches, @@ -297,7 +302,7 @@ impl ShuffleWriterExec { part_locs.push(ShuffleWritePartition { partition_id: i as u64, - path: w.path().to_owned(), + path: w.path().to_string_lossy().to_string(), num_batches: w.num_batches, num_rows: w.num_rows, num_bytes: w.num_bytes, @@ -354,9 +359,10 @@ impl ExecutionPlan for ShuffleWriterExec { async fn execute( &self, - input_partition: usize, - ) -> Result>> { - let part_loc = self.execute_shuffle_write(input_partition).await?; + partition: usize, + runtime: Arc, + ) -> Result { + let part_loc = self.execute_shuffle_write(partition, runtime).await?; // build metadata result batch let num_writers = part_loc.len(); @@ -434,61 +440,6 @@ fn result_schema() -> SchemaRef { ])) } -struct ShuffleWriter { - path: String, - writer: FileWriter>, - num_batches: u64, - num_rows: u64, - num_bytes: u64, -} - -impl ShuffleWriter { - fn new(path: &str, schema: &Schema) -> Result { - let file = File::create(path) - .map_err(|e| { - BallistaError::General(format!( - "Failed to create partition file at {}: {:?}", - path, e - )) - }) - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - let buffer_writer = std::io::BufWriter::new(file); - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.to_owned(), - writer: FileWriter::try_new( - buffer_writer, - schema, - None, - WriteOptions::default(), - )?, - }) - } - - fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch, None)?; - self.num_batches += 1; - self.num_rows += batch.num_rows() as u64; - let num_bytes: usize = batch - .columns() - .iter() - .map(|array| estimated_bytes_size(array.as_ref())) - .sum(); - self.num_bytes += num_bytes as u64; - Ok(()) - } - - fn finish(&mut self) -> Result<()> { - self.writer.finish().map_err(DataFusionError::ArrowError) - } - - fn path(&self) -> &str { - &self.path - } -} - #[cfg(test)] mod tests { use super::*; @@ -502,6 +453,8 @@ mod tests { #[tokio::test] async fn test() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let input_plan = Arc::new(CoalescePartitionsExec::new(create_input_plan()?)); let work_dir = TempDir::new()?; let query_stage = ShuffleWriterExec::try_new( @@ -511,7 +464,7 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0).await?; + let mut stream = query_stage.execute(0, runtime).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -554,6 +507,8 @@ mod tests { #[tokio::test] async fn test_partitioned() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let input_plan = create_input_plan()?; let work_dir = TempDir::new()?; let query_stage = ShuffleWriterExec::try_new( @@ -563,7 +518,7 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0).await?; + let mut stream = query_stage.execute(0, runtime).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; diff --git a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs index 6290add4e2b4..e14c1ebf0e65 100644 --- a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs +++ b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs @@ -18,13 +18,13 @@ use std::sync::Arc; use std::{any::Any, pin::Pin}; -use crate::memory_stream::MemoryStream; use crate::serde::scheduler::PartitionLocation; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, Statistics, + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use datafusion::{ error::{DataFusionError, Result}, @@ -102,7 +102,8 @@ impl ExecutionPlan for UnresolvedShuffleExec { async fn execute( &self, _partition: usize, - ) -> Result>> { + _runtime: Arc, + ) -> Result { Err(DataFusionError::Plan( "Ballista UnresolvedShuffleExec does not support execution".to_owned(), )) diff --git a/ballista/rust/core/src/lib.rs b/ballista/rust/core/src/lib.rs index 4e51067ec976..bc7be4f88651 100644 --- a/ballista/rust/core/src/lib.rs +++ b/ballista/rust/core/src/lib.rs @@ -27,7 +27,6 @@ pub mod client; pub mod config; pub mod error; pub mod execution_plans; -pub mod memory_stream; pub mod utils; #[macro_use] diff --git a/ballista/rust/core/src/memory_stream.rs b/ballista/rust/core/src/memory_stream.rs index ab72bdc82aee..8b137891791f 100644 --- a/ballista/rust/core/src/memory_stream.rs +++ b/ballista/rust/core/src/memory_stream.rs @@ -1,93 +1 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -//! This is copied from DataFusion because it is declared as `pub(crate)`. See -//! https://issues.apache.org/jira/browse/ARROW-11276. - -use std::task::{Context, Poll}; - -use datafusion::arrow::{datatypes::SchemaRef, error::Result, record_batch::RecordBatch}; -use datafusion::physical_plan::RecordBatchStream; -use futures::Stream; - -/// Iterator over batches - -pub struct MemoryStream { - /// Vector of record batches - data: Vec, - /// Schema representing the data - schema: SchemaRef, - /// Optional projection for which columns to load - projection: Option>, - /// Index into the data - index: usize, -} - -impl MemoryStream { - /// Create an iterator for a vector of record batches - - pub fn try_new( - data: Vec, - schema: SchemaRef, - projection: Option>, - ) -> Result { - Ok(Self { - data, - schema, - projection, - index: 0, - }) - } -} - -impl Stream for MemoryStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(if self.index < self.data.len() { - self.index += 1; - - let batch = &self.data[self.index - 1]; - - // apply projection - match &self.projection { - Some(columns) => Some(RecordBatch::try_new( - self.schema.clone(), - columns.iter().map(|i| batch.column(*i).clone()).collect(), - )), - None => Some(Ok(batch.clone())), - } - } else { - None - }) - } - - fn size_hint(&self) -> (usize, Option) { - (self.data.len(), Some(self.data.len())) - } -} - -impl RecordBatchStream for MemoryStream { - /// Get the schema - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index f429e175664f..32ed6f1c1a4f 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -18,7 +18,9 @@ //! Serde code to convert from protocol buffers to Rust data structures. use crate::error::BallistaError; -use crate::serde::{from_proto_binary_op, proto_error, protobuf, str_to_byte}; +use crate::serde::{ + from_proto_binary_op, proto_error, protobuf, str_to_byte, vec_to_array, +}; use crate::{convert_box_required, convert_required}; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::datasource::file_format::avro::AvroFormat; @@ -246,8 +248,8 @@ impl TryInto for &protobuf::LogicalPlanNode { .collect::, _>>()?, partition_count as usize, ), - PartitionMethod::RoundRobin(batch_size) => { - Partitioning::RoundRobinBatch(batch_size as usize) + PartitionMethod::RoundRobin(partition_count) => { + Partitioning::RoundRobinBatch(partition_count as usize) } }; @@ -540,12 +542,51 @@ fn typechecked_scalar_value_conversion( "Untyped scalar null is not a valid scalar value", )) } + PrimitiveScalarType::Decimal128 => { + ScalarValue::Decimal128(None, 0, 0) + } + PrimitiveScalarType::Date64 => ScalarValue::Date64(None), + PrimitiveScalarType::TimeSecond => { + ScalarValue::TimestampSecond(None, None) + } + PrimitiveScalarType::TimeMillisecond => { + ScalarValue::TimestampMillisecond(None, None) + } + PrimitiveScalarType::IntervalYearmonth => { + ScalarValue::IntervalYearMonth(None) + } + PrimitiveScalarType::IntervalDaytime => { + ScalarValue::IntervalDayTime(None) + } }; scalar_value } else { return Err(proto_error("Could not convert to the proper type")); } } + (Value::Decimal128Value(val), PrimitiveScalarType::Decimal128) => { + let array = vec_to_array(val.value.clone()); + ScalarValue::Decimal128( + Some(i128::from_be_bytes(array)), + val.p as usize, + val.s as usize, + ) + } + (Value::Date64Value(v), PrimitiveScalarType::Date64) => { + ScalarValue::Date64(Some(*v)) + } + (Value::TimeSecondValue(v), PrimitiveScalarType::TimeSecond) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + (Value::TimeMillisecondValue(v), PrimitiveScalarType::TimeMillisecond) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + (Value::IntervalYearmonthValue(v), PrimitiveScalarType::IntervalYearmonth) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + (Value::IntervalDaytimeValue(v), PrimitiveScalarType::IntervalDaytime) => { + ScalarValue::IntervalDayTime(Some(days_ms::new(1i32, *v as i32))) + } _ => return Err(proto_error("Could not convert to the proper type")), }) } @@ -607,6 +648,29 @@ impl TryInto for &protobuf::scalar_value::Value .ok_or_else(|| proto_error("Invalid scalar type"))? .try_into()? } + protobuf::scalar_value::Value::Decimal128Value(val) => { + let array = vec_to_array(val.value.clone()); + ScalarValue::Decimal128( + Some(i128::from_be_bytes(array)), + val.p as usize, + val.s as usize, + ) + } + protobuf::scalar_value::Value::Date64Value(v) => { + ScalarValue::Date64(Some(*v)) + } + protobuf::scalar_value::Value::TimeSecondValue(v) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + protobuf::scalar_value::Value::TimeMillisecondValue(v) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + protobuf::scalar_value::Value::IntervalYearmonthValue(v) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + protobuf::scalar_value::Value::IntervalDaytimeValue(v) => { + ScalarValue::IntervalDayTime(Some(days_ms::new(1, *v as i32))) + } }; Ok(scalar) } @@ -763,6 +827,22 @@ impl TryInto for protobuf::PrimitiveScalarType protobuf::PrimitiveScalarType::TimeNanosecond => { ScalarValue::TimestampNanosecond(None, None) } + protobuf::PrimitiveScalarType::Decimal128 => { + ScalarValue::Decimal128(None, 0, 0) + } + protobuf::PrimitiveScalarType::Date64 => ScalarValue::Date64(None), + protobuf::PrimitiveScalarType::TimeSecond => { + ScalarValue::TimestampSecond(None, None) + } + protobuf::PrimitiveScalarType::TimeMillisecond => { + ScalarValue::TimestampMillisecond(None, None) + } + protobuf::PrimitiveScalarType::IntervalYearmonth => { + ScalarValue::IntervalYearMonth(None) + } + protobuf::PrimitiveScalarType::IntervalDaytime => { + ScalarValue::IntervalDayTime(None) + } }) } } @@ -846,6 +926,29 @@ impl TryInto for &protobuf::ScalarValue { .ok_or_else(|| proto_error("Protobuf deserialization error found invalid enum variant for DatafusionScalar"))?; null_type_enum.try_into()? } + protobuf::scalar_value::Value::Decimal128Value(val) => { + let array = vec_to_array(val.value.clone()); + ScalarValue::Decimal128( + Some(i128::from_be_bytes(array)), + val.p as usize, + val.s as usize, + ) + } + protobuf::scalar_value::Value::Date64Value(v) => { + ScalarValue::Date64(Some(*v)) + } + protobuf::scalar_value::Value::TimeSecondValue(v) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + protobuf::scalar_value::Value::TimeMillisecondValue(v) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + protobuf::scalar_value::Value::IntervalYearmonthValue(v) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + protobuf::scalar_value::Value::IntervalDaytimeValue(v) => { + ScalarValue::IntervalDayTime(Some(days_ms::new(1, *v as i32))) + } }) } } @@ -1169,6 +1272,8 @@ impl TryInto for &protobuf::Field { } use crate::serde::protobuf::ColumnStats; +use arrow::types::days_ms; +use datafusion::field_util::SchemaExt; use datafusion::physical_plan::{aggregates, windows}; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 50ab4c7b7c91..74cf7091faf9 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -23,8 +23,10 @@ mod roundtrip_tests { use super::super::{super::error::Result, protobuf}; use crate::error::BallistaError; - use arrow::datatypes::UnionMode; + use arrow::datatypes::IntegerType; use core::panic; + use datafusion::arrow::datatypes::UnionMode; + use datafusion::field_util::SchemaExt; use datafusion::logical_plan::Repartition; use datafusion::{ arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}, @@ -65,7 +67,7 @@ mod roundtrip_tests { async fn roundtrip_repartition() -> Result<()> { use datafusion::logical_plan::Partitioning; - let test_batch_sizes = [usize::MIN, usize::MAX, 43256]; + let test_partition_counts = [usize::MIN, usize::MAX, 43256]; let test_expr: Vec = vec![col("c1") + col("c2"), Expr::Literal((4.0).into())]; @@ -92,8 +94,8 @@ mod roundtrip_tests { .map_err(BallistaError::DataFusionError)?, ); - for batch_size in test_batch_sizes.iter() { - let rr_repartition = Partitioning::RoundRobinBatch(*batch_size); + for partition_count in test_partition_counts.iter() { + let rr_repartition = Partitioning::RoundRobinBatch(*partition_count); let roundtrip_plan = LogicalPlan::Repartition(Repartition { input: plan.clone(), @@ -102,7 +104,7 @@ mod roundtrip_tests { roundtrip_test!(roundtrip_plan); - let h_repartition = Partitioning::Hash(test_expr.clone(), *batch_size); + let h_repartition = Partitioning::Hash(test_expr.clone(), *partition_count); let roundtrip_plan = LogicalPlan::Repartition(Repartition { input: plan.clone(), @@ -111,7 +113,7 @@ mod roundtrip_tests { roundtrip_test!(roundtrip_plan); - let no_expr_hrepartition = Partitioning::Hash(Vec::new(), *batch_size); + let no_expr_hrepartition = Partitioning::Hash(Vec::new(), *partition_count); let roundtrip_plan = LogicalPlan::Repartition(Repartition { input: plan.clone(), @@ -438,7 +440,24 @@ mod roundtrip_tests { ), ], None, - UnionMode::Dense, + UnionMode::Sparse, + ), + DataType::Dictionary( + IntegerType::UInt8, + Box::new(DataType::Struct(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ])), + false, + ), + DataType::Dictionary( + IntegerType::UInt64, + Box::new(DataType::FixedSizeList( + new_box_field("Level1", DataType::Binary, true), + 4, + )), + false, ), ]; @@ -557,7 +576,7 @@ mod roundtrip_tests { Field::new("datatype", DataType::Binary, false), ], None, - UnionMode::Dense, + UnionMode::Sparse, ), DataType::Union( vec![ @@ -577,6 +596,23 @@ mod roundtrip_tests { None, UnionMode::Dense, ), + DataType::Dictionary( + IntegerType::UInt8, + Box::new(DataType::Struct(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ])), + false, + ), + DataType::Dictionary( + IntegerType::UInt64, + Box::new(DataType::FixedSizeList( + new_box_field("Level1", DataType::Binary, true), + 4, + )), + false, + ), ]; for test_case in test_cases.into_iter() { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 573cf86e607d..304d2db1cd83 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -22,9 +22,9 @@ use super::super::proto_error; use crate::serde::protobuf::integer_type::IntegerTypeEnum; use crate::serde::{byte_to_string, protobuf, BallistaError}; -use arrow::datatypes::{IntegerType, UnionMode}; +use arrow::datatypes::IntegerType; use datafusion::arrow::datatypes::{ - DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, + DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, }; use datafusion::datasource::file_format::avro::AvroFormat; use datafusion::datasource::file_format::csv::CsvFormat; @@ -32,6 +32,7 @@ use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingTable; +use datafusion::field_util::{FieldExt, SchemaExt}; use datafusion::logical_plan::plan::{ Aggregate, EmptyRelation, Filter, Join, Projection, Sort, Window, }; @@ -290,12 +291,19 @@ impl From<&DataType> for protobuf::arrow_type::ArrowTypeEnum { .map(|field| field.into()) .collect::>(), }), - DataType::Union(union_types, _, _) => ArrowTypeEnum::Union(protobuf::Union { - union_types: union_types - .iter() - .map(|field| field.into()) - .collect::>(), - }), + DataType::Union(union_types, _, union_mode) => { + let union_mode = match union_mode { + UnionMode::Sparse => protobuf::UnionMode::Sparse, + UnionMode::Dense => protobuf::UnionMode::Dense, + }; + ArrowTypeEnum::Union(protobuf::Union { + union_types: union_types + .iter() + .map(|field| field.into()) + .collect::>(), + union_mode: union_mode.into(), + }) + } DataType::Dictionary(key_type, value_type, _) => { ArrowTypeEnum::Dictionary(Box::new(protobuf::Dictionary { key: Some(key_type.into()), @@ -444,6 +452,7 @@ impl TryFrom<&DataType> for protobuf::scalar_type::Datatype { | DataType::Struct(_) | DataType::Union(_, _, _) | DataType::Dictionary(_, _, _) + | DataType::Map(_, _) | DataType::Decimal(_, _) => { return Err(proto_error(format!( "Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.", @@ -452,8 +461,6 @@ impl TryFrom<&DataType> for protobuf::scalar_type::Datatype { } DataType::Extension(_, _, _) => panic!("DataType::Extension is not supported"), - DataType::Map(_, _) => - panic!("DataType::Map is not supported"), }; Ok(scalar_value) } @@ -612,6 +619,51 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { Value::TimeNanosecondValue(*s) }) } + datafusion::scalar::ScalarValue::Decimal128(val, p, s) => { + match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + protobuf::ScalarValue { + value: Some(Value::Decimal128Value(protobuf::Decimal128 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + } + } + None => { + protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue(PrimitiveScalarType::Decimal128 as i32)) + } + } + } + } + datafusion::scalar::ScalarValue::Date64(val) => { + create_proto_scalar(val, PrimitiveScalarType::Date64, |s| { + Value::Date64Value(*s) + }) + } + datafusion::scalar::ScalarValue::TimestampSecond(val, _) => { + create_proto_scalar(val, PrimitiveScalarType::TimeSecond, |s| { + Value::TimeSecondValue(*s) + }) + } + datafusion::scalar::ScalarValue::TimestampMillisecond(val, _) => { + create_proto_scalar(val, PrimitiveScalarType::TimeMillisecond, |s| { + Value::TimeMillisecondValue(*s) + }) + } + datafusion::scalar::ScalarValue::IntervalYearMonth(val) => { + create_proto_scalar(val, PrimitiveScalarType::IntervalYearmonth, |s| { + Value::IntervalYearmonthValue(*s) + }) + } + datafusion::scalar::ScalarValue::IntervalDayTime(val) => { + create_proto_scalar(val, PrimitiveScalarType::IntervalDaytime, |s| { + Value::IntervalDaytimeValue(s.milliseconds() as i64) + }) + } _ => { return Err(proto_error(format!( "Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.", @@ -875,8 +927,8 @@ impl TryInto for &LogicalPlan { partition_count: *partition_count as u64, }) } - Partitioning::RoundRobinBatch(batch_size) => { - PartitionMethod::RoundRobin(*batch_size as u64) + Partitioning::RoundRobinBatch(partition_count) => { + PartitionMethod::RoundRobin(*partition_count as u64) } }; @@ -1088,10 +1140,19 @@ impl TryInto for &Expr { AggregateFunction::VariancePop => { protobuf::AggregateFunction::VariancePop } + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, AggregateFunction::StddevPop => { protobuf::AggregateFunction::StddevPop } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } }; let arg = &args[0]; @@ -1324,8 +1385,11 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Variance => Self::Variance, AggregateFunction::VariancePop => Self::VariancePop, + AggregateFunction::Covariance => Self::Covariance, + AggregateFunction::CovariancePop => Self::CovariancePop, AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, + AggregateFunction::Correlation => Self::Correlation, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 9ff2a6cedb17..b2f3db2a6d52 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,13 +20,14 @@ use std::{convert::TryInto, io::Cursor}; +use datafusion::arrow::datatypes::{IntervalUnit, UnionMode}; use datafusion::logical_plan::{JoinConstraint, JoinType, Operator}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; -use arrow::datatypes::{IntegerType, UnionMode}; +use arrow::datatypes::IntegerType; use prost::Message; // include the generated protobuf source as a submodule @@ -122,8 +123,13 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, protobuf::AggregateFunction::Variance => AggregateFunction::Variance, protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop, + protobuf::AggregateFunction::Covariance => AggregateFunction::Covariance, + protobuf::AggregateFunction::CovariancePop => { + AggregateFunction::CovariancePop + } protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, + protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation, } } } @@ -250,15 +256,24 @@ impl TryInto .map(|field| field.try_into()) .collect::, _>>()?, ), - arrow_type::ArrowTypeEnum::Union(union) => DataType::Union( - union + arrow_type::ArrowTypeEnum::Union(union) => { + let union_mode = protobuf::UnionMode::from_i32(union.union_mode) + .ok_or_else(|| { + proto_error( + "Protobuf deserialization error: Unknown union mode type", + ) + })?; + let union_mode = match union_mode { + protobuf::UnionMode::Dense => UnionMode::Dense, + protobuf::UnionMode::Sparse => UnionMode::Sparse, + }; + let union_types = union .union_types .iter() .map(|field| field.try_into()) - .collect::, _>>()?, - None, - UnionMode::Dense, - ), + .collect::, _>>()?; + DataType::Union(union_types, None, union_mode) + } arrow_type::ArrowTypeEnum::Dictionary(dict) => { let pb_key_datatype = dict .as_ref() @@ -304,6 +319,20 @@ impl Into for protobuf::PrimitiveScalarT DataType::Time64(TimeUnit::Nanosecond) } protobuf::PrimitiveScalarType::Null => DataType::Null, + protobuf::PrimitiveScalarType::Decimal128 => DataType::Decimal(0, 0), + protobuf::PrimitiveScalarType::Date64 => DataType::Date64, + protobuf::PrimitiveScalarType::TimeSecond => { + DataType::Timestamp(TimeUnit::Second, None) + } + protobuf::PrimitiveScalarType::TimeMillisecond => { + DataType::Timestamp(TimeUnit::Millisecond, None) + } + protobuf::PrimitiveScalarType::IntervalYearmonth => { + DataType::Interval(IntervalUnit::YearMonth) + } + protobuf::PrimitiveScalarType::IntervalDaytime => { + DataType::Interval(IntervalUnit::DayTime) + } } } } @@ -365,3 +394,9 @@ fn str_to_byte(s: &str) -> Result { } Ok(s.as_bytes()[0]) } + +fn vec_to_array(v: Vec) -> [T; N] { + v.try_into().unwrap_or_else(|v: Vec| { + panic!("Expected a Vec of length {} but it was {}", N, v.len()) + }) +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 4f4f72eca74b..1986d8114a87 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -17,6 +17,7 @@ //! Serde code to convert from protocol buffers to Rust data structures. +use arrow::compute::cast::CastOptions; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; @@ -41,18 +42,20 @@ use datafusion::datasource::PartitionedFile; use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::{ window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType, }; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::file_format::{ - AvroExec, CsvExec, ParquetExec, PhysicalPlanConfig, + AvroExec, CsvExec, FileScanConfig, ParquetExec, }; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; +use datafusion::physical_plan::sorts::sort::{SortExec, SortOptions}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -71,7 +74,6 @@ use datafusion::physical_plan::{ limit::{GlobalLimitExec, LocalLimitExec}, projection::ProjectionExec, repartition::RepartitionExec, - sort::{SortExec, SortOptions}, Partitioning, }; use datafusion::physical_plan::{ @@ -593,6 +595,11 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { ExprType::Cast(e) => Arc::new(CastExpr::new( convert_box_required!(e.expr)?, convert_required!(e.arrow_type)?, + // TODO: shouldn't this be added to proto ? + CastOptions { + wrapped: false, + partial: false, + }, )), ExprType::TryCast(e) => Arc::new(TryCastExpr::new( convert_box_required!(e.expr)?, @@ -624,6 +631,7 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), + runtime_env: Arc::new(RuntimeEnv::default()), }; let fun_expr = functions::create_physical_fun( @@ -766,10 +774,10 @@ impl TryInto for &protobuf::Statistics { } } -impl TryInto for &protobuf::FileScanExecConf { +impl TryInto for &protobuf::FileScanExecConf { type Error = BallistaError; - fn try_into(self) -> Result { + fn try_into(self) -> Result { let schema = Arc::new(convert_required!(self.schema)?); let projection = self .projection @@ -783,7 +791,7 @@ impl TryInto for &protobuf::FileScanExecConf { }; let statistics = convert_required!(self.statistics)?; - Ok(PhysicalPlanConfig { + Ok(FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: schema, file_groups: self @@ -793,7 +801,6 @@ impl TryInto for &protobuf::FileScanExecConf { .collect::, _>>()?, statistics, projection, - batch_size: self.batch_size as usize, limit: self.limit.as_ref().map(|sl| sl.limit as usize), table_partition_cols: vec![], }) diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index 23826605b797..70354a15c3e3 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -22,6 +22,8 @@ pub mod to_proto; mod roundtrip_tests { use std::{convert::TryInto, sync::Arc}; + use datafusion::field_util::SchemaExt; + use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::{ arrow::{ compute::sort::SortOptions, @@ -36,7 +38,6 @@ mod roundtrip_tests { hash_aggregate::{AggregateMode, HashAggregateExec}, hash_join::{HashJoinExec, PartitionMode}, limit::{GlobalLimitExec, LocalLimitExec}, - sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, }, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 41484db57a7b..9c6d6d6a4bce 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -26,10 +26,11 @@ use std::{ sync::Arc, }; +use datafusion::field_util::FieldExt; use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; -use datafusion::physical_plan::sort::SortExec; +use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{cross_join::CrossJoinExec, ColumnStatistics}; use datafusion::physical_plan::{ expressions::{ @@ -43,7 +44,7 @@ use datafusion::physical_plan::{ }; use datafusion::physical_plan::{file_format::AvroExec, filter::FilterExec}; use datafusion::physical_plan::{ - file_format::PhysicalPlanConfig, hash_aggregate::AggregateMode, + file_format::FileScanConfig, hash_aggregate::AggregateMode, }; use datafusion::{ datasource::PartitionedFile, physical_plan::coalesce_batches::CoalesceBatchesExec, @@ -197,7 +198,7 @@ impl TryInto for Arc { .aggr_expr() .iter() .map(|expr| match expr.field() { - Ok(field) => Ok(field.name().clone()), + Ok(field) => Ok(field.name().to_string()), Err(e) => Err(BallistaError::DataFusionError(e)), }) .collect::>()?; @@ -677,10 +678,10 @@ impl From<&Statistics> for protobuf::Statistics { } } -impl TryFrom<&PhysicalPlanConfig> for protobuf::FileScanExecConf { +impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { type Error = BallistaError; fn try_from( - conf: &PhysicalPlanConfig, + conf: &FileScanConfig, ) -> Result { let file_groups = conf .file_groups @@ -700,7 +701,6 @@ impl TryFrom<&PhysicalPlanConfig> for protobuf::FileScanExecConf { .map(|n| *n as u32) .collect(), schema: Some(conf.file_schema.as_ref().into()), - batch_size: conf.batch_size as u32, table_partition_cols: conf.table_partition_cols.to_vec(), }) } diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs b/ballista/rust/core/src/serde/scheduler/mod.rs index d76f432aaad1..5a9d6dd8af40 100644 --- a/ballista/rust/core/src/serde/scheduler/mod.rs +++ b/ballista/rust/core/src/serde/scheduler/mod.rs @@ -75,6 +75,7 @@ pub struct ExecutorMeta { pub id: String, pub host: String, pub port: u16, + pub grpc_port: u16, } #[allow(clippy::from_over_into)] @@ -84,6 +85,7 @@ impl Into for ExecutorMeta { id: self.id, host: self.host, port: self.port as u32, + grpc_port: self.grpc_port as u32, } } } @@ -94,10 +96,149 @@ impl From for ExecutorMeta { id: meta.id, host: meta.host, port: meta.port as u16, + grpc_port: meta.grpc_port as u16, } } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +pub struct ExecutorSpecification { + pub task_slots: u32, +} + +#[allow(clippy::from_over_into)] +impl Into for ExecutorSpecification { + fn into(self) -> protobuf::ExecutorSpecification { + protobuf::ExecutorSpecification { + resources: vec![protobuf::executor_resource::Resource::TaskSlots( + self.task_slots, + )] + .into_iter() + .map(|r| protobuf::ExecutorResource { resource: Some(r) }) + .collect(), + } + } +} + +impl From for ExecutorSpecification { + fn from(input: protobuf::ExecutorSpecification) -> Self { + let mut ret = Self { task_slots: 0 }; + for resource in input.resources { + if let Some(protobuf::executor_resource::Resource::TaskSlots(task_slots)) = + resource.resource + { + ret.task_slots = task_slots + } + } + ret + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct ExecutorData { + pub executor_id: String, + pub total_task_slots: u32, + pub available_task_slots: u32, +} + +struct ExecutorResourcePair { + total: protobuf::executor_resource::Resource, + available: protobuf::executor_resource::Resource, +} + +#[allow(clippy::from_over_into)] +impl Into for ExecutorData { + fn into(self) -> protobuf::ExecutorData { + protobuf::ExecutorData { + executor_id: self.executor_id, + resources: vec![ExecutorResourcePair { + total: protobuf::executor_resource::Resource::TaskSlots( + self.total_task_slots, + ), + available: protobuf::executor_resource::Resource::TaskSlots( + self.available_task_slots, + ), + }] + .into_iter() + .map(|r| protobuf::ExecutorResourcePair { + total: Some(protobuf::ExecutorResource { + resource: Some(r.total), + }), + available: Some(protobuf::ExecutorResource { + resource: Some(r.available), + }), + }) + .collect(), + } + } +} + +impl From for ExecutorData { + fn from(input: protobuf::ExecutorData) -> Self { + let mut ret = Self { + executor_id: input.executor_id, + total_task_slots: 0, + available_task_slots: 0, + }; + for resource in input.resources { + if let Some(task_slots) = resource.total { + if let Some(protobuf::executor_resource::Resource::TaskSlots( + task_slots, + )) = task_slots.resource + { + ret.total_task_slots = task_slots + } + }; + if let Some(task_slots) = resource.available { + if let Some(protobuf::executor_resource::Resource::TaskSlots( + task_slots, + )) = task_slots.resource + { + ret.available_task_slots = task_slots + } + }; + } + ret + } +} + +#[derive(Debug, Clone, Copy, Serialize)] +pub struct ExecutorState { + // in bytes + pub available_memory_size: u64, +} + +#[allow(clippy::from_over_into)] +impl Into for ExecutorState { + fn into(self) -> protobuf::ExecutorState { + protobuf::ExecutorState { + metrics: vec![protobuf::executor_metric::Metric::AvailableMemory( + self.available_memory_size, + )] + .into_iter() + .map(|m| protobuf::ExecutorMetric { metric: Some(m) }) + .collect(), + } + } +} + +impl From for ExecutorState { + fn from(input: protobuf::ExecutorState) -> Self { + let mut ret = Self { + available_memory_size: u64::MAX, + }; + for metric in input.metrics { + if let Some(protobuf::executor_metric::Metric::AvailableMemory( + available_memory_size, + )) = metric.metric + { + ret.available_memory_size = available_memory_size + } + } + ret + } +} + /// Summary of executed partition #[derive(Debug, Copy, Clone, Default)] pub struct PartitionStats { diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index f1d46556cfde..c3550c32fcc2 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -26,10 +26,10 @@ use crate::error::{BallistaError, Result}; use crate::execution_plans::{ DistributedQueryExec, ShuffleWriterExec, UnresolvedShuffleExec, }; -use crate::memory_stream::MemoryStream; use crate::serde::scheduler::PartitionStats; use crate::config::BallistaConfig; +use arrow::chunk::Chunk; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::datatypes::SchemaRef; @@ -41,18 +41,19 @@ use datafusion::arrow::{ datatypes::{DataType, Field}, io::ipc::read::FileReader, io::ipc::write::FileWriter, - record_batch::RecordBatch, }; use datafusion::error::DataFusionError; use datafusion::execution::context::{ ExecutionConfig, ExecutionContext, ExecutionContextState, QueryPlanner, }; -use datafusion::logical_plan::{LogicalPlan, Operator}; +use datafusion::field_util::SchemaExt; +use datafusion::logical_plan::{LogicalPlan, Operator, TableScan}; use datafusion::physical_optimizer::coalesce_batches::CoalesceBatches; use datafusion::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::common::batch_byte_size; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{BinaryExpr, Column, Literal}; use datafusion::physical_plan::file_format::{CsvExec, ParquetExec}; @@ -60,10 +61,11 @@ use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::HashAggregateExec; use datafusion::physical_plan::hash_join::HashJoinExec; use datafusion::physical_plan::projection::ProjectionExec; -use datafusion::physical_plan::sort::SortExec; +use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ metrics, AggregateExpr, ExecutionPlan, Metric, PhysicalExpr, RecordBatchStream, }; +use datafusion::record_batch::RecordBatch; use futures::{future, Stream, StreamExt}; use std::time::Instant; @@ -94,17 +96,14 @@ pub async fn write_stream_to_disk( while let Some(result) = stream.next().await { let batch = result?; - let batch_size_bytes: usize = batch - .columns() - .iter() - .map(|array| estimated_bytes_size(array.as_ref())) - .sum(); + let batch_size_bytes: usize = batch_byte_size(&batch); num_batches += 1; num_rows += batch.num_rows(); num_bytes += batch_size_bytes; let timer = disk_write_metric.timer(); - writer.write(&batch, None)?; + let chunk = Chunk::new(batch.columns().to_vec()); + writer.write(&chunk, None)?; timer.done(); } let timer = disk_write_metric.timer(); @@ -257,7 +256,8 @@ pub fn create_df_ctx_with_ballista_query_planner( scheduler_url, config.clone(), ))) - .with_target_partitions(config.default_shuffle_partitions()); + .with_target_partitions(config.default_shuffle_partitions()) + .with_information_schema(true); ExecutionContext::with_config(config) } diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index 8943c2a60927..a30f1a25d02f 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -29,8 +29,8 @@ edition = "2018" snmalloc = ["snmalloc-rs"] [dependencies] -arrow-format = { version = "0.3", features = ["flight-data", "flight-service"] } -arrow = { package = "arrow2", version="0.8", features = ["io_ipc"] } +arrow-format = { version = "0.4", features = ["flight-data", "flight-service"] } +arrow = { package = "arrow2", version="0.9", features = ["io_ipc"] } anyhow = "1" async-trait = "0.1.36" ballista-core = { path = "../core", version = "0.6.0" } @@ -45,6 +45,7 @@ tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } tokio-stream = { version = "0.1", features = ["net"] } tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } +hyper = "0.14.4" [dev-dependencies] diff --git a/ballista/rust/executor/executor_config_spec.toml b/ballista/rust/executor/executor_config_spec.toml index 6f170c85e823..1dd3de99012c 100644 --- a/ballista/rust/executor/executor_config_spec.toml +++ b/ballista/rust/executor/executor_config_spec.toml @@ -54,6 +54,12 @@ type = "u16" default = "50051" doc = "bind port" +[[param]] +name = "bind_grpc_port" +type = "u16" +default = "50052" +doc = "bind grpc service port" + [[param]] name = "work_dir" type = "String" @@ -65,3 +71,10 @@ name = "concurrent_tasks" type = "usize" default = "4" doc = "Max concurrent tasks." + +[[param]] +abbr = "s" +name = "task_scheduling_policy" +type = "ballista_core::config::TaskSchedulingPolicy" +doc = "The task scheduing policy for the scheduler, see TaskSchedulingPolicy::variants() for options. Default: PullStaged" +default = "ballista_core::config::TaskSchedulingPolicy::PullStaged" diff --git a/ballista/rust/executor/src/collect.rs b/ballista/rust/executor/src/collect.rs index c3fadaed6645..34bbf2721a23 100644 --- a/ballista/rust/executor/src/collect.rs +++ b/ballista/rust/executor/src/collect.rs @@ -23,13 +23,13 @@ use std::task::{Context, Poll}; use std::{any::Any, pin::Pin}; use async_trait::async_trait; -use datafusion::arrow::{ - datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, -}; +use datafusion::arrow::{datatypes::SchemaRef, error::Result as ArrowResult}; use datafusion::error::DataFusionError; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; +use datafusion::record_batch::RecordBatch; use datafusion::{error::Result, physical_plan::RecordBatchStream}; use futures::stream::SelectAll; use futures::Stream; @@ -75,11 +75,12 @@ impl ExecutionPlan for CollectExec { async fn execute( &self, partition: usize, - ) -> Result>> { + runtime: Arc, + ) -> Result { assert_eq!(0, partition); let num_partitions = self.plan.output_partitioning().partition_count(); - let futures = (0..num_partitions).map(|i| self.plan.execute(i)); + let futures = (0..num_partitions).map(|i| self.plan.execute(i, runtime.clone())); let streams = futures::future::join_all(futures) .await .into_iter() diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs index 4d12dfc1c755..69bc8380c2c2 100644 --- a/ballista/rust/executor/src/execution_loop.rs +++ b/ballista/rust/executor/src/execution_loop.rs @@ -26,12 +26,11 @@ use tonic::transport::Channel; use ballista_core::serde::protobuf::ExecutorRegistration; use ballista_core::serde::protobuf::{ - self, scheduler_grpc_client::SchedulerGrpcClient, task_status, FailedTask, - PartitionId, PollWorkParams, PollWorkResult, ShuffleWritePartition, TaskDefinition, - TaskStatus, + scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult, + TaskDefinition, TaskStatus, }; -use protobuf::CompletedTask; +use crate::as_task_status; use crate::executor::Executor; use ballista_core::error::BallistaError; use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning; @@ -144,37 +143,6 @@ async fn run_received_tasks( Ok(()) } -fn as_task_status( - execution_result: ballista_core::error::Result>, - executor_id: String, - task_id: PartitionId, -) -> TaskStatus { - match execution_result { - Ok(partitions) => { - info!("Task {:?} finished", task_id); - - TaskStatus { - partition_id: Some(task_id), - status: Some(task_status::Status::Completed(CompletedTask { - executor_id, - partitions, - })), - } - } - Err(e) => { - let error_msg = e.to_string(); - info!("Task {:?} failed: {}", task_id, error_msg); - - TaskStatus { - partition_id: Some(task_id), - status: Some(task_status::Status::Failed(FailedTask { - error: format!("Task failed due to Tokio error: {}", error_msg), - })), - } - } - } -} - async fn sample_tasks_status( task_status_receiver: &mut Receiver, ) -> Vec { diff --git a/ballista/rust/executor/src/executor.rs b/ballista/rust/executor/src/executor.rs index d073d60f7209..6bf1aeb4f182 100644 --- a/ballista/rust/executor/src/executor.rs +++ b/ballista/rust/executor/src/executor.rs @@ -22,21 +22,38 @@ use std::sync::Arc; use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::serde::protobuf; +use ballista_core::serde::scheduler::ExecutorSpecification; use datafusion::error::DataFusionError; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; +use datafusion::prelude::ExecutionConfig; /// Ballista executor pub struct Executor { /// Directory for storing partial results work_dir: String, + + /// Specification like total task slots + pub specification: ExecutorSpecification, } impl Executor { /// Create a new executor instance pub fn new(work_dir: &str) -> Self { + Executor::new_with_specification( + work_dir, + ExecutorSpecification { task_slots: 4 }, + ) + } + + pub fn new_with_specification( + work_dir: &str, + specification: ExecutorSpecification, + ) -> Self { Self { work_dir: work_dir.to_owned(), + specification, } } } @@ -71,7 +88,10 @@ impl Executor { )) }?; - let partitions = exec.execute_shuffle_write(part).await?; + let config = ExecutionConfig::new().with_temp_file_path(self.work_dir.clone()); + let runtime = Arc::new(RuntimeEnv::new(config.runtime)?); + + let partitions = exec.execute_shuffle_write(part, runtime).await?; println!( "=== [{}/{}/{}] Physical plan with metrics ===\n{}\n", diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs new file mode 100644 index 000000000000..3f220ea83a39 --- /dev/null +++ b/ballista/rust/executor/src/executor_server.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::convert::TryInto; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::mpsc; + +use log::{debug, info}; +use tonic::transport::{Channel, Server}; +use tonic::{Request, Response, Status}; + +use ballista_core::error::BallistaError; +use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning; +use ballista_core::serde::protobuf::executor_grpc_server::{ + ExecutorGrpc, ExecutorGrpcServer, +}; +use ballista_core::serde::protobuf::executor_registration::OptionalHost; +use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; +use ballista_core::serde::protobuf::{ + ExecutorRegistration, LaunchTaskParams, LaunchTaskResult, RegisterExecutorParams, + SendHeartBeatParams, StopExecutorParams, StopExecutorResult, TaskDefinition, + UpdateTaskStatusParams, +}; +use ballista_core::serde::scheduler::{ExecutorSpecification, ExecutorState}; +use datafusion::physical_plan::ExecutionPlan; + +use crate::as_task_status; +use crate::executor::Executor; + +pub async fn startup( + mut scheduler: SchedulerGrpcClient, + executor: Arc, + executor_meta: ExecutorRegistration, +) { + // TODO make the buffer size configurable + let (tx_task, rx_task) = mpsc::channel::(1000); + + let executor_server = ExecutorServer::new( + scheduler.clone(), + executor.clone(), + executor_meta.clone(), + ExecutorEnv { tx_task }, + ); + + // 1. Start executor grpc service + { + let executor_meta = executor_meta.clone(); + let addr = format!( + "{}:{}", + executor_meta + .optional_host + .map(|h| match h { + OptionalHost::Host(host) => host, + }) + .unwrap_or_else(|| String::from("127.0.0.1")), + executor_meta.grpc_port + ); + let addr = addr.parse().unwrap(); + info!("Setup executor grpc service for {:?}", addr); + + let server = ExecutorGrpcServer::new(executor_server.clone()); + let grpc_server_future = Server::builder().add_service(server).serve(addr); + tokio::spawn(async move { grpc_server_future.await }); + } + + let executor_server = Arc::new(executor_server); + + // 2. Do executor registration + match register_executor(&mut scheduler, &executor_meta, &executor.specification).await + { + Ok(_) => { + info!("Executor registration succeed"); + } + Err(error) => { + panic!("Executor registration failed due to: {}", error); + } + }; + + // 3. Start Heartbeater + { + let heartbeater = Heartbeater::new(executor_server.clone()); + heartbeater.start().await; + } + + // 4. Start TaskRunnerPool + { + let task_runner_pool = TaskRunnerPool::new(executor_server.clone()); + task_runner_pool.start(rx_task).await; + } +} + +#[allow(clippy::clone_on_copy)] +async fn register_executor( + scheduler: &mut SchedulerGrpcClient, + executor_meta: &ExecutorRegistration, + specification: &ExecutorSpecification, +) -> Result<(), BallistaError> { + let result = scheduler + .register_executor(RegisterExecutorParams { + metadata: Some(executor_meta.clone()), + specification: Some(specification.clone().into()), + }) + .await?; + if result.into_inner().success { + Ok(()) + } else { + Err(BallistaError::General( + "Executor registration failed!!!".to_owned(), + )) + } +} + +#[derive(Clone)] +pub struct ExecutorServer { + _start_time: u128, + executor: Arc, + executor_meta: ExecutorRegistration, + scheduler: SchedulerGrpcClient, + executor_env: ExecutorEnv, +} + +#[derive(Clone)] +struct ExecutorEnv { + tx_task: mpsc::Sender, +} + +unsafe impl Sync for ExecutorEnv {} + +impl ExecutorServer { + fn new( + scheduler: SchedulerGrpcClient, + executor: Arc, + executor_meta: ExecutorRegistration, + executor_env: ExecutorEnv, + ) -> Self { + Self { + _start_time: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(), + executor, + executor_meta, + scheduler, + executor_env, + } + } + + async fn heartbeat(&self) { + // TODO Error handling + self.scheduler + .clone() + .send_heart_beat(SendHeartBeatParams { + metadata: Some(self.executor_meta.clone()), + state: Some(self.get_executor_state().await.into()), + }) + .await + .unwrap(); + } + + async fn run_task(&self, task: TaskDefinition) -> Result<(), BallistaError> { + let task_id = task.task_id.unwrap(); + let task_id_log = format!( + "{}/{}/{}", + task_id.job_id, task_id.stage_id, task_id.partition_id + ); + info!("Start to run task {}", task_id_log); + + let plan: Arc = (&task.plan.unwrap()).try_into().unwrap(); + let shuffle_output_partitioning = + parse_protobuf_hash_partitioning(task.output_partitioning.as_ref())?; + + let execution_result = self + .executor + .execute_shuffle_write( + task_id.job_id.clone(), + task_id.stage_id as usize, + task_id.partition_id as usize, + plan, + shuffle_output_partitioning, + ) + .await; + info!("Done with task {}", task_id_log); + debug!("Statistics: {:?}", execution_result); + + // TODO use another channel to update the status of a task set + self.scheduler + .clone() + .update_task_status(UpdateTaskStatusParams { + metadata: Some(self.executor_meta.clone()), + task_status: vec![as_task_status( + execution_result, + self.executor_meta.id.clone(), + task_id, + )], + }) + .await?; + + Ok(()) + } + + // TODO with real state + async fn get_executor_state(&self) -> ExecutorState { + ExecutorState { + available_memory_size: u64::MAX, + } + } +} + +struct Heartbeater { + executor_server: Arc, +} + +impl Heartbeater { + fn new(executor_server: Arc) -> Self { + Self { executor_server } + } + + async fn start(&self) { + let executor_server = self.executor_server.clone(); + tokio::spawn(async move { + info!("Starting heartbeater to send heartbeat the scheduler periodically"); + loop { + executor_server.heartbeat().await; + tokio::time::sleep(Duration::from_millis(60000)).await; + } + }); + } +} + +struct TaskRunnerPool { + executor_server: Arc, +} + +impl TaskRunnerPool { + fn new(executor_server: Arc) -> Self { + Self { executor_server } + } + + async fn start(&self, mut rx_task: mpsc::Receiver) { + let executor_server = self.executor_server.clone(); + tokio::spawn(async move { + info!("Starting the task runner pool"); + loop { + let task = rx_task.recv().await.unwrap(); + info!("Received task {:?}", task); + + let server = executor_server.clone(); + tokio::spawn(async move { + server.run_task(task).await.unwrap(); + }); + } + }); + } +} + +#[tonic::async_trait] +impl ExecutorGrpc for ExecutorServer { + async fn launch_task( + &self, + request: Request, + ) -> Result, Status> { + let tasks = request.into_inner().task; + let task_sender = self.executor_env.tx_task.clone(); + for task in tasks { + task_sender.send(task).await.unwrap(); + } + Ok(Response::new(LaunchTaskResult { success: true })) + } + + async fn stop_executor( + &self, + _request: Request, + ) -> Result, Status> { + todo!() + } +} diff --git a/ballista/rust/executor/src/flight_service.rs b/ballista/rust/executor/src/flight_service.rs index 79666332a7f4..a936768006e7 100644 --- a/ballista/rust/executor/src/flight_service.rs +++ b/ballista/rust/executor/src/flight_service.rs @@ -17,6 +17,8 @@ //! Implementation of the Apache Arrow Flight protocol that wraps an executor. +use arrow::array::ArrayRef; +use arrow::chunk::Chunk; use std::fs::File; use std::pin::Pin; use std::sync::Arc; @@ -34,7 +36,6 @@ use arrow_format::flight::data::{ use arrow_format::flight::service::flight_service_server::FlightService; use datafusion::arrow::{ error::ArrowError, io::ipc::read::FileReader, io::ipc::write::WriteOptions, - record_batch::RecordBatch, }; use futures::{Stream, StreamExt}; use log::{info, warn}; @@ -175,11 +176,11 @@ impl FlightService for BallistaFlightService { /// Convert a single RecordBatch into an iterator of FlightData (containing /// dictionaries and batches) fn create_flight_iter( - batch: &RecordBatch, + chunk: &Chunk, options: &WriteOptions, ) -> Box>> { let (flight_dictionaries, flight_batch) = - arrow::io::flight::serialize_batch(batch, &[], options); + arrow::io::flight::serialize_batch(chunk, &[], options); Box::new( flight_dictionaries .into_iter() @@ -201,14 +202,13 @@ async fn stream_flight_data(path: String, tx: FlightDataSender) -> Result<(), St let reader = FileReader::new(&mut file, file_meta, None); let options = WriteOptions::default(); - let schema_flight_data = - arrow::io::flight::serialize_schema(reader.schema().as_ref(), &[]); + let schema_flight_data = arrow::io::flight::serialize_schema(reader.schema(), None); send_response(&tx, Ok(schema_flight_data)).await?; let mut row_count = 0; for batch in reader { if let Ok(x) = &batch { - row_count += x.num_rows(); + row_count += x.len(); } let batch_flight_data: Vec<_> = batch .map(|b| create_flight_iter(&b, &options).collect()) diff --git a/ballista/rust/executor/src/lib.rs b/ballista/rust/executor/src/lib.rs index 714698b357f6..a2711da08cc4 100644 --- a/ballista/rust/executor/src/lib.rs +++ b/ballista/rust/executor/src/lib.rs @@ -20,7 +20,46 @@ pub mod collect; pub mod execution_loop; pub mod executor; +pub mod executor_server; pub mod flight_service; mod standalone; pub use standalone::new_standalone_executor; + +use log::info; + +use ballista_core::serde::protobuf::{ + task_status, CompletedTask, FailedTask, PartitionId, ShuffleWritePartition, + TaskStatus, +}; + +pub fn as_task_status( + execution_result: ballista_core::error::Result>, + executor_id: String, + task_id: PartitionId, +) -> TaskStatus { + match execution_result { + Ok(partitions) => { + info!("Task {:?} finished", task_id); + + TaskStatus { + partition_id: Some(task_id), + status: Some(task_status::Status::Completed(CompletedTask { + executor_id, + partitions, + })), + } + } + Err(e) => { + let error_msg = e.to_string(); + info!("Task {:?} failed: {}", task_id, error_msg); + + TaskStatus { + partition_id: Some(task_id), + status: Some(task_status::Status::Failed(FailedTask { + error: format!("Task failed due to Tokio error: {}", error_msg), + })), + } + } + } +} diff --git a/ballista/rust/executor/src/main.rs b/ballista/rust/executor/src/main.rs index af1659a307d0..877b5eb3ce30 100644 --- a/ballista/rust/executor/src/main.rs +++ b/ballista/rust/executor/src/main.rs @@ -21,16 +21,18 @@ use std::sync::Arc; use anyhow::{Context, Result}; use arrow_format::flight::service::flight_service_server::FlightServiceServer; -use ballista_executor::execution_loop; +use ballista_executor::{execution_loop, executor_server}; use log::info; use tempfile::TempDir; use tonic::transport::Server; use uuid::Uuid; +use ballista_core::config::TaskSchedulingPolicy; use ballista_core::serde::protobuf::{ executor_registration, scheduler_grpc_client::SchedulerGrpcClient, ExecutorRegistration, }; +use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::{print_version, BALLISTA_VERSION}; use ballista_executor::executor::Executor; use ballista_executor::flight_service::BallistaFlightService; @@ -67,6 +69,7 @@ async fn main() -> Result<()> { let external_host = opt.external_host; let bind_host = opt.bind_host; let port = opt.bind_port; + let grpc_port = opt.bind_grpc_port; let addr = format!("{}:{}", bind_host, port); let addr = addr @@ -94,32 +97,54 @@ async fn main() -> Result<()> { .clone() .map(executor_registration::OptionalHost::Host), port: port as u32, + grpc_port: grpc_port as u32, }; + let executor_specification = ExecutorSpecification { + task_slots: opt.concurrent_tasks as u32, + }; + let executor = Arc::new(Executor::new_with_specification( + &work_dir, + executor_specification, + )); let scheduler = SchedulerGrpcClient::connect(scheduler_url) .await .context("Could not connect to scheduler")?; - let executor = Arc::new(Executor::new(&work_dir)); - - let service = BallistaFlightService::new(executor.clone()); + let scheduler_policy = opt.task_scheduling_policy; + match scheduler_policy { + TaskSchedulingPolicy::PushStaged => { + tokio::spawn(executor_server::startup( + scheduler, + executor.clone(), + executor_meta, + )); + } + _ => { + tokio::spawn(execution_loop::poll_loop( + scheduler, + executor.clone(), + executor_meta, + opt.concurrent_tasks, + )); + } + } - let server = FlightServiceServer::new(service); - info!( - "Ballista v{} Rust Executor listening on {:?}", - BALLISTA_VERSION, addr - ); - let server_future = tokio::spawn(Server::builder().add_service(server).serve(addr)); - tokio::spawn(execution_loop::poll_loop( - scheduler, - executor, - executor_meta, - opt.concurrent_tasks, - )); + // Arrow flight service + { + let service = BallistaFlightService::new(executor.clone()); + let server = FlightServiceServer::new(service); + info!( + "Ballista v{} Rust Executor listening on {:?}", + BALLISTA_VERSION, addr + ); + let server_future = + tokio::spawn(Server::builder().add_service(server).serve(addr)); + server_future + .await + .context("Tokio error")? + .context("Could not start executor server")?; + } - server_future - .await - .context("Tokio error")? - .context("Could not start executor server")?; Ok(()) } diff --git a/ballista/rust/executor/src/standalone.rs b/ballista/rust/executor/src/standalone.rs index 89f98082e9f7..da74ca94fcdb 100644 --- a/ballista/rust/executor/src/standalone.rs +++ b/ballista/rust/executor/src/standalone.rs @@ -62,6 +62,8 @@ pub async fn new_standalone_executor( id: Uuid::new_v4().to_string(), // assign this executor a unique ID optional_host: Some(OptionalHost::Host("localhost".to_string())), port: addr.port() as u32, + // TODO Make it configurable + grpc_port: 50020, }; tokio::spawn(execution_loop::poll_loop( scheduler, diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 0bacccf031d8..10b3723712da 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -33,11 +33,11 @@ sled = ["sled_package", "tokio-stream"] [dependencies] anyhow = "1" ballista-core = { path = "../core", version = "0.6.0" } -clap = "2" +clap = { version = "3", features = ["derive", "cargo"] } configure_me = "0.4.0" datafusion = { path = "../../../datafusion", version = "6.0.0" } env_logger = "0.9" -etcd-client = { version = "0.7", optional = true } +etcd-client = { version = "0.8", optional = true } futures = "0.3" http = "0.2" http-body = "0.4" diff --git a/ballista/rust/scheduler/scheduler_config_spec.toml b/ballista/rust/scheduler/scheduler_config_spec.toml index 81e77d31b0a0..cf03fc08a72a 100644 --- a/ballista/rust/scheduler/scheduler_config_spec.toml +++ b/ballista/rust/scheduler/scheduler_config_spec.toml @@ -57,4 +57,11 @@ abbr = "p" name = "bind_port" type = "u16" default = "50050" -doc = "bind port. Default: 50050" \ No newline at end of file +doc = "bind port. Default: 50050" + +[[param]] +abbr = "s" +name = "scheduler_policy" +type = "ballista_core::config::TaskSchedulingPolicy" +doc = "The scheduing policy for the scheduler, see TaskSchedulingPolicy::variants() for options. Default: PullStaged" +default = "ballista_core::config::TaskSchedulingPolicy::PullStaged" \ No newline at end of file diff --git a/ballista/rust/scheduler/src/lib.rs b/ballista/rust/scheduler/src/lib.rs index 107ea28ff68b..cc508d39c18e 100644 --- a/ballista/rust/scheduler/src/lib.rs +++ b/ballista/rust/scheduler/src/lib.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#![doc = include_str!("../README.md")] +#![doc = include_str ! ("../README.md")] pub mod api; pub mod planner; @@ -41,31 +41,43 @@ pub mod externalscaler { include!(concat!(env!("OUT_DIR"), "/externalscaler.rs")); } +use std::collections::{HashMap, HashSet}; +use std::fmt; use std::{convert::TryInto, sync::Arc}; -use std::{fmt, net::IpAddr}; use ballista_core::serde::protobuf::{ execute_query_params::Query, executor_registration::OptionalHost, job_status, scheduler_grpc_server::SchedulerGrpc, task_status, ExecuteQueryParams, ExecuteQueryResult, FailedJob, FileType, GetFileMetadataParams, GetFileMetadataResult, GetJobStatusParams, GetJobStatusResult, JobStatus, - PartitionId, PollWorkParams, PollWorkResult, QueuedJob, RunningJob, TaskDefinition, - TaskStatus, + LaunchTaskParams, PartitionId, PollWorkParams, PollWorkResult, QueuedJob, + RegisterExecutorParams, RegisterExecutorResult, RunningJob, SendHeartBeatParams, + SendHeartBeatResult, TaskDefinition, TaskStatus, UpdateTaskStatusParams, + UpdateTaskStatusResult, +}; +use ballista_core::serde::scheduler::{ + ExecutorData, ExecutorMeta, ExecutorSpecification, }; -use ballista_core::serde::scheduler::ExecutorMeta; -use clap::arg_enum; +use clap::ArgEnum; use datafusion::physical_plan::ExecutionPlan; + #[cfg(feature = "sled")] extern crate sled_package as sled; // an enum used to configure the backend // needs to be visible to code generated by configure_me -arg_enum! { - #[derive(Debug, serde::Deserialize)] - pub enum ConfigBackend { - Etcd, - Standalone +#[derive(Debug, Clone, ArgEnum, serde::Deserialize)] +pub enum ConfigBackend { + Etcd, + Standalone, +} + +impl std::str::FromStr for ConfigBackend { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + ArgEnum::from_str(s, true) } } @@ -81,29 +93,51 @@ use crate::externalscaler::{ }; use crate::planner::DistributedPlanner; -use log::{debug, error, info, warn}; +use log::{debug, error, info, trace, warn}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use tonic::{Request, Response, Status}; use self::state::{ConfigBackendClient, SchedulerState}; -use ballista_core::config::BallistaConfig; +use anyhow::Context; +use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy}; +use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; +use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto; use datafusion::prelude::{ExecutionConfig, ExecutionContext}; -use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use tokio::sync::{mpsc, RwLock}; +use tonic::transport::Channel; #[derive(Clone)] pub struct SchedulerServer { - caller_ip: IpAddr, pub(crate) state: Arc, start_time: u128, + policy: TaskSchedulingPolicy, + scheduler_env: Option, + executors_client: Arc>>>, +} + +#[derive(Clone)] +pub struct SchedulerEnv { + pub tx_job: mpsc::Sender, } impl SchedulerServer { - pub fn new( + pub fn new(config: Arc, namespace: String) -> Self { + SchedulerServer::new_with_policy( + config, + namespace, + TaskSchedulingPolicy::PullStaged, + None, + ) + } + + pub fn new_with_policy( config: Arc, namespace: String, - caller_ip: IpAddr, + policy: TaskSchedulingPolicy, + scheduler_env: Option, ) -> Self { let state = Arc::new(SchedulerState::new(config, namespace)); let state_clone = state.clone(); @@ -112,13 +146,178 @@ impl SchedulerServer { tokio::spawn(async move { state_clone.synchronize_job_status_loop().await }); Self { - caller_ip, state, start_time: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_millis(), + policy, + scheduler_env, + executors_client: Arc::new(RwLock::new(HashMap::new())), + } + } + + async fn schedule_job(&self, job_id: String) -> Result<(), BallistaError> { + let alive_executors = self + .state + .get_alive_executors_metadata_within_one_minute() + .await?; + let alive_executors: HashMap = alive_executors + .into_iter() + .map(|e| (e.id.clone(), e)) + .collect(); + let available_executors = self.state.get_available_executors_data().await?; + let mut available_executors: Vec = available_executors + .into_iter() + .filter(|e| alive_executors.contains_key(&e.executor_id)) + .collect(); + + // In case of there's no enough resources, reschedule the tasks of the job + if available_executors.is_empty() { + let tx_job = self.scheduler_env.as_ref().unwrap().tx_job.clone(); + // TODO Maybe it's better to use an exclusive runtime for this kind task scheduling + tokio::spawn(async move { + warn!("Not enough available executors for task running"); + tokio::time::sleep(Duration::from_millis(100)).await; + tx_job.send(job_id).await.unwrap(); + }); + return Ok(()); + } + + let (tasks_assigment, num_tasks) = + self.fetch_tasks(&mut available_executors, &job_id).await?; + if num_tasks > 0 { + for (idx_executor, tasks) in tasks_assigment.into_iter().enumerate() { + if !tasks.is_empty() { + let executor_data = &available_executors[idx_executor]; + debug!( + "Start to launch tasks {:?} to executor {:?}", + tasks, executor_data.executor_id + ); + let mut client = { + let clients = self.executors_client.read().await; + info!("Size of executor clients: {:?}", clients.len()); + clients.get(&executor_data.executor_id).unwrap().clone() + }; + // Update the resources first + self.state.save_executor_data(executor_data.clone()).await?; + // TODO check whether launching task is successful or not + client.launch_task(LaunchTaskParams { task: tasks }).await?; + } else { + // Since the task assignment policy is round robin, + // if find tasks for one executor is empty, just break fast + break; + } + } + return Ok(()); + } + + Ok(()) + } + + async fn fetch_tasks( + &self, + available_executors: &mut Vec, + job_id: &str, + ) -> Result<(Vec>, usize), BallistaError> { + let mut ret: Vec> = + Vec::with_capacity(available_executors.len()); + for _idx in 0..available_executors.len() { + ret.push(Vec::new()); + } + let mut num_tasks = 0; + loop { + info!("Go inside fetching task loop"); + let mut has_tasks = true; + for (idx, executor) in available_executors.iter_mut().enumerate() { + if executor.available_task_slots == 0 { + break; + } + let plan = self + .state + .assign_next_schedulable_job_task(&executor.executor_id, job_id) + .await + .map_err(|e| { + let msg = format!("Error finding next assignable task: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + if let Some((task, _plan)) = &plan { + let partition_id = task.partition_id.as_ref().unwrap(); + info!( + "Sending new task to {}: {}/{}/{}", + executor.executor_id, + partition_id.job_id, + partition_id.stage_id, + partition_id.partition_id + ); + } + match plan { + Some((status, plan)) => { + let plan_clone = plan.clone(); + let output_partitioning = if let Some(shuffle_writer) = + plan_clone.as_any().downcast_ref::() + { + shuffle_writer.shuffle_output_partitioning() + } else { + return Err(BallistaError::General(format!( + "Task root plan was not a ShuffleWriterExec: {:?}", + plan_clone + ))); + }; + + ret[idx].push(TaskDefinition { + plan: Some(plan.try_into().unwrap()), + task_id: status.partition_id, + output_partitioning: hash_partitioning_to_proto( + output_partitioning, + ) + .map_err(|_| Status::internal("TBD".to_string()))?, + }); + executor.available_task_slots -= 1; + num_tasks += 1; + } + _ => { + // Indicate there's no more tasks to be scheduled + has_tasks = false; + break; + } + } + } + if !has_tasks { + break; + } + let has_executors = + available_executors.get(0).unwrap().available_task_slots > 0; + if !has_executors { + break; + } } + Ok((ret, num_tasks)) + } +} + +pub struct TaskScheduler { + scheduler_server: Arc, +} + +impl TaskScheduler { + pub fn new(scheduler_server: Arc) -> Self { + Self { scheduler_server } + } + + pub fn start(&self, mut rx_job: mpsc::Receiver) { + let scheduler_server = self.scheduler_server.clone(); + tokio::spawn(async move { + info!("Starting the task scheduler"); + loop { + let job_id = rx_job.recv().await.unwrap(); + info!("Fetch job {:?} to be scheduled", job_id.clone()); + + let server = scheduler_server.clone(); + server.schedule_job(job_id).await.unwrap(); + } + }); } } @@ -177,6 +376,13 @@ impl SchedulerGrpc for SchedulerServer { &self, request: Request, ) -> std::result::Result, tonic::Status> { + if let TaskSchedulingPolicy::PushStaged = self.policy { + error!("Poll work interface is not supported for push-based task scheduling"); + return Err(tonic::Status::failed_precondition( + "Bad request because poll work is not supported for push-based task scheduling", + )); + } + let remote_addr = request.remote_addr(); if let PollWorkParams { metadata: Some(metadata), can_accept_task, @@ -191,8 +397,9 @@ impl SchedulerGrpc for SchedulerServer { .map(|h| match h { OptionalHost::Host(host) => host, }) - .unwrap_or_else(|| self.caller_ip.to_string()), + .unwrap_or_else(|| remote_addr.unwrap().ip().to_string()), port: metadata.port as u16, + grpc_port: metadata.grpc_port as u16, }; let mut lock = self.state.lock().await.map_err(|e| { let msg = format!("Could not lock the state: {}", e); @@ -274,6 +481,195 @@ impl SchedulerGrpc for SchedulerServer { } } + async fn register_executor( + &self, + request: Request, + ) -> Result, Status> { + let remote_addr = request.remote_addr(); + if let RegisterExecutorParams { + metadata: Some(metadata), + specification: Some(specification), + } = request.into_inner() + { + info!("Received register executor request for {:?}", metadata); + let metadata: ExecutorMeta = ExecutorMeta { + id: metadata.id, + host: metadata + .optional_host + .map(|h| match h { + OptionalHost::Host(host) => host, + }) + .unwrap_or_else(|| remote_addr.unwrap().ip().to_string()), + port: metadata.port as u16, + grpc_port: metadata.grpc_port as u16, + }; + // Check whether the executor starts the grpc service + { + let executor_url = + format!("http://{}:{}", metadata.host, metadata.grpc_port); + info!("Connect to executor {:?}", executor_url); + let executor_client = ExecutorGrpcClient::connect(executor_url) + .await + .context("Could not connect to executor") + .map_err(|e| tonic::Status::internal(format!("{:?}", e)))?; + let mut clients = self.executors_client.write().await; + // TODO check duplicated registration + clients.insert(metadata.id.clone(), executor_client); + info!("Size of executor clients: {:?}", clients.len()); + } + let mut lock = self.state.lock().await.map_err(|e| { + let msg = format!("Could not lock the state: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + self.state + .save_executor_metadata(metadata.clone()) + .await + .map_err(|e| { + let msg = format!("Could not save executor metadata: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + let executor_spec: ExecutorSpecification = specification.into(); + let executor_data = ExecutorData { + executor_id: metadata.id.clone(), + total_task_slots: executor_spec.task_slots, + available_task_slots: executor_spec.task_slots, + }; + self.state + .save_executor_data(executor_data) + .await + .map_err(|e| { + let msg = format!("Could not save executor data: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + lock.unlock().await; + Ok(Response::new(RegisterExecutorResult { success: true })) + } else { + warn!("Received invalid register executor request"); + Err(tonic::Status::invalid_argument( + "Missing metadata in request", + )) + } + } + + async fn send_heart_beat( + &self, + request: Request, + ) -> Result, Status> { + let remote_addr = request.remote_addr(); + if let SendHeartBeatParams { + metadata: Some(metadata), + state: Some(state), + } = request.into_inner() + { + debug!("Received heart beat request for {:?}", metadata); + trace!("Related executor state is {:?}", state); + let metadata: ExecutorMeta = ExecutorMeta { + id: metadata.id, + host: metadata + .optional_host + .map(|h| match h { + OptionalHost::Host(host) => host, + }) + .unwrap_or_else(|| remote_addr.unwrap().ip().to_string()), + port: metadata.port as u16, + grpc_port: metadata.grpc_port as u16, + }; + { + let mut lock = self.state.lock().await.map_err(|e| { + let msg = format!("Could not lock the state: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + self.state + .save_executor_state(metadata, Some(state)) + .await + .map_err(|e| { + let msg = format!("Could not save executor metadata: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + lock.unlock().await; + } + Ok(Response::new(SendHeartBeatResult { reregister: false })) + } else { + warn!("Received invalid executor heart beat request"); + Err(tonic::Status::invalid_argument( + "Missing metadata or metrics in request", + )) + } + } + + async fn update_task_status( + &self, + request: Request, + ) -> Result, Status> { + if let UpdateTaskStatusParams { + metadata: Some(metadata), + task_status, + } = request.into_inner() + { + debug!("Received task status update request for {:?}", metadata); + trace!("Related task status is {:?}", task_status); + let mut jobs = HashSet::new(); + { + let mut lock = self.state.lock().await.map_err(|e| { + let msg = format!("Could not lock the state: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + let num_tasks = task_status.len(); + for task_status in task_status { + self.state + .save_task_status(&task_status) + .await + .map_err(|e| { + let msg = format!("Could not save task status: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + if task_status.partition_id.is_some() { + jobs.insert(task_status.partition_id.unwrap().job_id.clone()); + } + } + let mut executor_data = self + .state + .get_executor_data(&metadata.id) + .await + .map_err(|e| { + let msg = format!( + "Could not get metadata data for id {:?}: {}", + &metadata.id, e + ); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + executor_data.available_task_slots += num_tasks as u32; + self.state + .save_executor_data(executor_data) + .await + .map_err(|e| { + let msg = format!("Could not save metadata data: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + lock.unlock().await; + } + let tx_job = self.scheduler_env.as_ref().unwrap().tx_job.clone(); + for job_id in jobs { + tx_job.send(job_id).await.unwrap(); + } + Ok(Response::new(UpdateTaskStatusResult { success: true })) + } else { + warn!("Received invalid task status update request"); + Err(tonic::Status::invalid_argument( + "Missing metadata or task status in request", + )) + } + } + async fn get_file_metadata( &self, request: Request, @@ -386,6 +782,12 @@ impl SchedulerGrpc for SchedulerServer { let state = self.state.clone(); let job_id_spawn = job_id.clone(); + let tx_job: Option> = match self.policy { + TaskSchedulingPolicy::PullStaged => None, + TaskSchedulingPolicy::PushStaged => { + Some(self.scheduler_env.as_ref().unwrap().tx_job.clone()) + } + }; tokio::spawn(async move { // create physical plan using DataFusion let datafusion_ctx = create_datafusion_context(&config); @@ -499,6 +901,11 @@ impl SchedulerGrpc for SchedulerServer { )); } } + + if let Some(tx_job) = tx_job { + // Send job_id to the scheduler channel + tx_job.send(job_id_spawn).await.unwrap(); + } }); Ok(Response::new(ExecuteQueryResult { job_id })) @@ -533,10 +940,7 @@ pub fn create_datafusion_context(config: &BallistaConfig) -> ExecutionContext { #[cfg(all(test, feature = "sled"))] mod test { - use std::{ - net::{IpAddr, Ipv4Addr}, - sync::Arc, - }; + use std::sync::Arc; use tonic::Request; @@ -554,16 +958,13 @@ mod test { async fn test_poll_work() -> Result<(), BallistaError> { let state = Arc::new(StandaloneClient::try_new_temporary()?); let namespace = "default"; - let scheduler = SchedulerServer::new( - state.clone(), - namespace.to_owned(), - IpAddr::V4(Ipv4Addr::LOCALHOST), - ); + let scheduler = SchedulerServer::new(state.clone(), namespace.to_owned()); let state = SchedulerState::new(state, namespace.to_string()); let exec_meta = ExecutorRegistration { id: "abc".to_owned(), optional_host: Some(OptionalHost::Host("".to_owned())), port: 0, + grpc_port: 0, }; let request: Request = Request::new(PollWorkParams { metadata: Some(exec_meta.clone()), diff --git a/ballista/rust/scheduler/src/main.rs b/ballista/rust/scheduler/src/main.rs index 7b79eb1b39ac..5da5bbef6197 100644 --- a/ballista/rust/scheduler/src/main.rs +++ b/ballista/rust/scheduler/src/main.rs @@ -23,6 +23,7 @@ use futures::future::{self, Either, TryFutureExt}; use hyper::{server::conn::AddrStream, service::make_service_fn, Server}; use std::convert::Infallible; use std::{net::SocketAddr, sync::Arc}; +use tonic::transport::server::Connected; use tonic::transport::Server as TonicServer; use tower::Service; @@ -35,9 +36,14 @@ use ballista_scheduler::api::{get_routes, EitherBody, Error}; use ballista_scheduler::state::EtcdClient; #[cfg(feature = "sled")] use ballista_scheduler::state::StandaloneClient; -use ballista_scheduler::{state::ConfigBackendClient, ConfigBackend, SchedulerServer}; +use ballista_scheduler::{ + state::ConfigBackendClient, ConfigBackend, SchedulerEnv, SchedulerServer, + TaskScheduler, +}; +use ballista_core::config::TaskSchedulingPolicy; use log::info; +use tokio::sync::mpsc; #[macro_use] extern crate configure_me; @@ -51,25 +57,43 @@ mod config { "/scheduler_configure_me_config.rs" )); } + use config::prelude::*; async fn start_server( config_backend: Arc, namespace: String, addr: SocketAddr, + policy: TaskSchedulingPolicy, ) -> Result<()> { info!( "Ballista v{} Scheduler listening on {:?}", BALLISTA_VERSION, addr ); - - Ok(Server::bind(&addr) - .serve(make_service_fn(move |request: &AddrStream| { - let scheduler_server = SchedulerServer::new( + //should only call SchedulerServer::new() once in the process + info!( + "Starting Scheduler grpc server with task scheduling policy of {:?}", + policy + ); + let scheduler_server = match policy { + TaskSchedulingPolicy::PushStaged => { + // TODO make the buffer size configurable + let (tx_job, rx_job) = mpsc::channel::(10000); + let scheduler_server = SchedulerServer::new_with_policy( config_backend.clone(), namespace.clone(), - request.remote_addr().ip(), + policy, + Some(SchedulerEnv { tx_job }), ); + let task_scheduler = TaskScheduler::new(Arc::new(scheduler_server.clone())); + task_scheduler.start(rx_job); + scheduler_server + } + _ => SchedulerServer::new(config_backend.clone(), namespace.clone()), + }; + + Ok(Server::bind(&addr) + .serve(make_service_fn(move |request: &AddrStream| { let scheduler_grpc_server = SchedulerGrpcServer::new(scheduler_server.clone()); @@ -79,10 +103,16 @@ async fn start_server( .add_service(scheduler_grpc_server) .add_service(keda_scaler) .into_service(); - let mut warp = warp::service(get_routes(scheduler_server)); + let mut warp = warp::service(get_routes(scheduler_server.clone())); + let connect_info = request.connect_info(); future::ok::<_, Infallible>(tower::service_fn( move |req: hyper::Request| { + // Set the connect info from hyper to tonic + let (mut parts, body) = req.into_parts(); + parts.extensions.insert(connect_info.clone()); + let req = http::Request::from_parts(parts, body); + let header = req.headers().get(hyper::header::ACCEPT); if header.is_some() && header.unwrap().eq("application/json") { return Either::Left( @@ -158,6 +188,8 @@ async fn main() -> Result<()> { ) } }; - start_server(client, namespace, addr).await?; + + let policy: TaskSchedulingPolicy = opt.scheduler_policy; + start_server(client, namespace, addr, policy).await?; Ok(()) } diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index efc7eb607e59..fdd143500b5f 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -254,7 +254,7 @@ mod test { use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::HashJoinExec; - use datafusion::physical_plan::sort::SortExec; + use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec, }; diff --git a/ballista/rust/scheduler/src/standalone.rs b/ballista/rust/scheduler/src/standalone.rs index 6ab5bd62a8f0..55239d8b5a5e 100644 --- a/ballista/rust/scheduler/src/standalone.rs +++ b/ballista/rust/scheduler/src/standalone.rs @@ -20,10 +20,7 @@ use ballista_core::{ BALLISTA_VERSION, }; use log::info; -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::Arc, -}; +use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpListener; use tonic::transport::Server; @@ -35,7 +32,6 @@ pub async fn new_standalone_scheduler() -> Result { let server = SchedulerGrpcServer::new(SchedulerServer::new( Arc::new(client), "ballista".to_string(), - IpAddr::V4(Ipv4Addr::LOCALHOST), )); // Let the OS assign a random, free port let listener = TcpListener::bind("localhost:0").await?; diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index ef6de8312702..45d915a06d8f 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -31,7 +31,7 @@ use ballista_core::serde::protobuf::{ ExecutorMetadata, FailedJob, FailedTask, JobStatus, PhysicalPlanNode, RunningJob, RunningTask, TaskStatus, }; -use ballista_core::serde::scheduler::PartitionStats; +use ballista_core::serde::scheduler::{ExecutorData, PartitionStats}; use ballista_core::{error::BallistaError, serde::scheduler::ExecutorMeta}; use ballista_core::{error::Result, execution_plans::UnresolvedShuffleExec}; @@ -118,6 +118,13 @@ impl SchedulerState { Ok(result) } + pub async fn get_alive_executors_metadata_within_one_minute( + &self, + ) -> Result> { + self.get_alive_executors_metadata(Duration::from_secs(60)) + .await + } + pub async fn get_alive_executors_metadata( &self, last_seen_threshold: Duration, @@ -133,6 +140,14 @@ impl SchedulerState { } pub async fn save_executor_metadata(&self, meta: ExecutorMeta) -> Result<()> { + self.save_executor_state(meta, None).await + } + + pub async fn save_executor_state( + &self, + meta: ExecutorMeta, + state: Option, + ) -> Result<()> { let key = get_executor_key(&self.namespace, &meta.id); let meta: ExecutorMetadata = meta.into(); let timestamp = SystemTime::now() @@ -142,11 +157,57 @@ impl SchedulerState { let heartbeat = ExecutorHeartbeat { meta: Some(meta), timestamp, + state, }; let value: Vec = encode_protobuf(&heartbeat)?; self.config_client.put(key, value).await } + pub async fn save_executor_data(&self, executor_data: ExecutorData) -> Result<()> { + let key = get_executor_data_key(&self.namespace, &executor_data.executor_id); + let executor_data: protobuf::ExecutorData = executor_data.into(); + let value: Vec = encode_protobuf(&executor_data)?; + self.config_client.put(key, value).await + } + + pub async fn get_executors_data(&self) -> Result> { + let mut result = vec![]; + + let entries = self + .config_client + .get_from_prefix(&get_executors_data_prefix(&self.namespace)) + .await?; + for (_key, entry) in entries { + let executor_data: protobuf::ExecutorData = decode_protobuf(&entry)?; + result.push(executor_data.into()); + } + Ok(result) + } + + pub async fn get_available_executors_data(&self) -> Result> { + let mut res = self + .get_executors_data() + .await? + .into_iter() + .filter_map(|exec| (exec.available_task_slots > 0).then(|| exec)) + .collect::>(); + res.sort_by(|a, b| Ord::cmp(&b.available_task_slots, &a.available_task_slots)); + Ok(res) + } + + pub async fn get_executor_data(&self, executor_id: &str) -> Result { + let key = get_executor_data_key(&self.namespace, executor_id); + let value = &self.config_client.get(&key).await?; + if value.is_empty() { + return Err(BallistaError::General(format!( + "No executor data found for {}", + key + ))); + } + let value: protobuf::ExecutorData = decode_protobuf(value)?; + Ok(value.into()) + } + pub async fn save_job_metadata( &self, job_id: &str, @@ -233,6 +294,18 @@ impl SchedulerState { Ok((&value).try_into()?) } + pub async fn get_job_tasks( + &self, + job_id: &str, + ) -> Result> { + self.config_client + .get_from_prefix(&get_task_prefix_for_job(&self.namespace, job_id)) + .await? + .into_iter() + .map(|(key, bytes)| Ok((key, decode_protobuf(&bytes)?))) + .collect() + } + pub async fn get_all_tasks(&self) -> Result> { self.config_client .get_from_prefix(&get_task_prefix(&self.namespace)) @@ -281,6 +354,42 @@ impl SchedulerState { executor_id: &str, ) -> Result)>> { let tasks = self.get_all_tasks().await?; + self.assign_next_schedulable_task_inner(executor_id, tasks) + .await + } + + pub async fn assign_next_schedulable_job_task( + &self, + executor_id: &str, + job_id: &str, + ) -> Result)>> { + let job_tasks = self.get_job_tasks(job_id).await?; + self.assign_next_schedulable_task_inner(executor_id, job_tasks) + .await + } + + async fn assign_next_schedulable_task_inner( + &self, + executor_id: &str, + tasks: HashMap, + ) -> Result)>> { + match self.get_next_schedulable_task(tasks).await? { + Some((status, plan)) => { + let mut status = status.clone(); + status.status = Some(task_status::Status::Running(RunningTask { + executor_id: executor_id.to_owned(), + })); + self.save_task_status(&status).await?; + Ok(Some((status, plan))) + } + _ => Ok(None), + } + } + + async fn get_next_schedulable_task( + &self, + tasks: HashMap, + ) -> Result)>> { // TODO: Make the duration a configurable parameter let executors = self .get_alive_executors_metadata(Duration::from_secs(60)) @@ -320,34 +429,36 @@ impl SchedulerState { .await?; if task_is_dead { continue 'tasks; - } else if let Some(task_status::Status::Completed( - CompletedTask { + } + match &referenced_task.status { + Some(task_status::Status::Completed(CompletedTask { executor_id, partitions, - }, - )) = &referenced_task.status - { - debug!("Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}", - shuffle_input_partition_id, - partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::>().join("\n\t") - ); - let stage_shuffle_partition_locations = partition_locations - .entry(unresolved_shuffle.stage_id) - .or_insert_with(HashMap::new); - let executor_meta = executors - .iter() - .find(|exec| exec.id == *executor_id) - .unwrap() - .clone(); - - for shuffle_write_partition in partitions { - let temp = stage_shuffle_partition_locations - .entry(shuffle_write_partition.partition_id as usize) - .or_insert_with(Vec::new); - let executor_meta = executor_meta.clone(); - let partition_location = - ballista_core::serde::scheduler::PartitionLocation { - partition_id: + })) => { + debug!("Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}", + shuffle_input_partition_id, + partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::>().join("\n\t") + ); + let stage_shuffle_partition_locations = + partition_locations + .entry(unresolved_shuffle.stage_id) + .or_insert_with(HashMap::new); + let executor_meta = executors + .iter() + .find(|exec| exec.id == *executor_id) + .unwrap() + .clone(); + + for shuffle_write_partition in partitions { + let temp = stage_shuffle_partition_locations + .entry( + shuffle_write_partition.partition_id as usize, + ) + .or_insert_with(Vec::new); + let executor_meta = executor_meta.clone(); + let partition_location = + ballista_core::serde::scheduler::PartitionLocation { + partition_id: ballista_core::serde::scheduler::PartitionId { job_id: partition.job_id.clone(), stage_id: unresolved_shuffle.stage_id, @@ -355,29 +466,43 @@ impl SchedulerState { .partition_id as usize, }, - executor_meta, - partition_stats: PartitionStats::new( - Some(shuffle_write_partition.num_rows), - Some(shuffle_write_partition.num_batches), - Some(shuffle_write_partition.num_bytes), - ), - path: shuffle_write_partition.path.clone(), - }; + executor_meta, + partition_stats: PartitionStats::new( + Some(shuffle_write_partition.num_rows), + Some(shuffle_write_partition.num_batches), + Some(shuffle_write_partition.num_bytes), + ), + path: shuffle_write_partition.path.clone(), + }; + debug!( + "Scheduler storing stage {} output partition {} path: {}", + unresolved_shuffle.stage_id, + partition_location.partition_id.partition_id, + partition_location.path + ); + temp.push(partition_location); + } + } + Some(task_status::Status::Failed(FailedTask { error })) => { + // A task should fail when its referenced_task fails + let mut status = status.clone(); + let err_msg = error.to_string(); + status.status = + Some(task_status::Status::Failed(FailedTask { + error: err_msg, + })); + self.save_task_status(&status).await?; + continue 'tasks; + } + _ => { debug!( - "Scheduler storing stage {} output partition {} path: {}", + "Stage {} input partition {} has not completed yet", unresolved_shuffle.stage_id, - partition_location.partition_id.partition_id, - partition_location.path - ); - temp.push(partition_location); + shuffle_input_partition_id, + ); + continue 'tasks; } - } else { - debug!( - "Stage {} input partition {} has not completed yet", - unresolved_shuffle.stage_id, shuffle_input_partition_id, - ); - continue 'tasks; - } + }; } } @@ -385,12 +510,7 @@ impl SchedulerState { remove_unresolved_shuffles(plan.as_ref(), &partition_locations)?; // If we get here, there are no more unresolved shuffled and the task can be run - let mut status = status.clone(); - status.status = Some(task_status::Status::Running(RunningTask { - executor_id: executor_id.to_owned(), - })); - self.save_task_status(&status).await?; - return Ok(Some((status, plan))); + return Ok(Some((status.clone(), plan))); } } Ok(None) @@ -583,6 +703,14 @@ fn get_executor_key(namespace: &str, id: &str) -> String { format!("{}/{}", get_executors_prefix(namespace), id) } +fn get_executors_data_prefix(namespace: &str) -> String { + format!("/ballista/{}/resources/executors", namespace) +} + +fn get_executor_data_key(namespace: &str, id: &str) -> String { + format!("{}/{}", get_executors_data_prefix(namespace), id) +} + fn get_job_prefix(namespace: &str) -> String { format!("/ballista/{}/jobs", namespace) } @@ -670,6 +798,7 @@ mod test { id: "123".to_owned(), host: "localhost".to_owned(), port: 123, + grpc_port: 124, }; state.save_executor_metadata(meta.clone()).await?; let result: Vec<_> = state diff --git a/ballista/rust/scheduler/src/test_utils.rs b/ballista/rust/scheduler/src/test_utils.rs index b9d7ee42f48b..18c4710885f0 100644 --- a/ballista/rust/scheduler/src/test_utils.rs +++ b/ballista/rust/scheduler/src/test_utils.rs @@ -19,6 +19,7 @@ use ballista_core::error::Result; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::field_util::SchemaExt; use datafusion::prelude::CsvReadOptions; pub const TPCH_TABLES: &[&str] = &[ diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore index 6320cd248dd8..1269488f7fb1 100644 --- a/benchmarks/.gitignore +++ b/benchmarks/.gitignore @@ -1 +1 @@ -data \ No newline at end of file +data diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index db863d68f335..f9a8504c7a75 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -25,14 +25,14 @@ homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" license = "Apache-2.0" publish = false -rust-version = "1.57" +rust-version = "1.58" [features] simd = ["datafusion/simd"] snmalloc = ["snmalloc-rs"] [dependencies] -arrow = { package = "arrow2", version="0.8", features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "io_print", "ahash", "compute_merge_sort", "compute", "regex"] } +arrow = { package = "arrow2", version="0.9", features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "io_print", "ahash", "compute_merge_sort", "compute", "regex"] } datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client" } structopt = { version = "0.3", default-features = false } diff --git a/benchmarks/README.md b/benchmarks/README.md index e6c17430d6e2..6cad607f3db9 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -123,7 +123,7 @@ To run the benchmarks: ```bash cd $ARROW_HOME/benchmarks -cargo run --release benchmark ballista --host localhost --port 50050 --query 1 --path $(pwd)/data --format tbl +cargo run --release --bin tpch benchmark ballista --host localhost --port 50050 --query 1 --path $(pwd)/data --format tbl ``` ## Running the Ballista Benchmarks on docker-compose diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs index 12eb9835d876..b2f18c7c4bb1 100644 --- a/benchmarks/src/bin/nyctaxi.rs +++ b/benchmarks/src/bin/nyctaxi.rs @@ -17,6 +17,8 @@ //! Apache Arrow Rust Benchmarks +use arrow::array::ArrayRef; +use arrow::chunk::Chunk; use std::collections::HashMap; use std::path::PathBuf; use std::process; @@ -28,6 +30,7 @@ use datafusion::arrow::io::print; use datafusion::error::Result; use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::field_util::SchemaExt; use datafusion::physical_plan::collect; use datafusion::prelude::CsvReadOptions; use structopt::StructOpt; @@ -116,15 +119,21 @@ async fn datafusion_sql_benchmarks( } async fn execute_sql(ctx: &mut ExecutionContext, sql: &str, debug: bool) -> Result<()> { + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); let plan = ctx.create_logical_plan(sql)?; let plan = ctx.optimize(&plan)?; if debug { println!("Optimized logical plan:\n{:?}", plan); } let physical_plan = ctx.create_physical_plan(&plan).await?; - let result = collect(physical_plan).await?; + let result = collect(physical_plan, runtime).await?; if debug { - print::print(&result); + let fields = result + .first() + .map(|b| b.schema().field_names()) + .unwrap_or(vec![]); + let chunks: Vec> = result.iter().map(|rb| rb.into()).collect(); + println!("{}", print::write(&chunks, &fields)); } Ok(()) } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 9d3302055121..4494bb77977c 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -17,6 +17,8 @@ //! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. +use arrow::array::ArrayRef; +use arrow::chunk::Chunk; use futures::future::join_all; use rand::prelude::*; use std::ops::Div; @@ -45,14 +47,14 @@ use datafusion::{ datasource::file_format::{csv::CsvFormat, FileFormat}, }; use datafusion::{ - arrow::record_batch::RecordBatch, datasource::file_format::parquet::ParquetFormat, + datasource::file_format::parquet::ParquetFormat, record_batch::RecordBatch, }; use arrow::io::parquet::write::{Compression, Version, WriteOptions}; -use arrow::io::print::print; use ballista::prelude::{ BallistaConfig, BallistaContext, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, }; +use datafusion::field_util::SchemaExt; use structopt::StructOpt; #[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))] @@ -262,6 +264,7 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result Result Result<()> { millis.push(elapsed as f64); println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed); if opt.debug { - print(&batches); + println!("{}", datafusion::arrow_print::write(&batches)); } } @@ -441,7 +444,7 @@ async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> { &client_id, &i, query_id, elapsed ); if opt.debug { - print(&batches); + println!("{}", datafusion::arrow_print::write(&batches)); } } }); @@ -543,13 +546,19 @@ async fn execute_query( displayable(physical_plan.as_ref()).indent() ); } - let result = collect(physical_plan.clone()).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let result = collect(physical_plan.clone(), runtime).await?; if debug { println!( "=== Physical plan with metrics ===\n{}\n", DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()).indent() ); - print::print(&result); + let fields = result + .first() + .map(|b| b.schema().field_names()) + .unwrap_or(vec![]); + let chunks: Vec> = result.iter().map(|rb| rb.into()).collect(); + println!("{}", print::write(&chunks, &fields)); } Ok(result) } @@ -785,6 +794,7 @@ mod tests { use arrow::array::get_display; use datafusion::arrow::array::*; + use datafusion::field_util::FieldExt; use datafusion::logical_plan::Expr; use datafusion::logical_plan::Expr::Cast; diff --git a/ci/docker/linux-apt-lint.dockerfile b/ci/docker/linux-apt-lint.dockerfile index bce5527d54aa..a5c063c74c5d 100644 --- a/ci/docker/linux-apt-lint.dockerfile +++ b/ci/docker/linux-apt-lint.dockerfile @@ -46,7 +46,7 @@ COPY ci/scripts/install_iwyu.sh /arrow/ci/scripts/ RUN arrow/ci/scripts/install_iwyu.sh /tmp/iwyu /usr/local ${clang_tools} # Rust linter -ARG rust=nightly-2021-10-23 +ARG rust=nightly-2022-01-17 RUN curl https://sh.rustup.rs -sSf | \ sh -s -- --default-toolchain stable -y ENV PATH /root/.cargo/bin:$PATH diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index f212de3223cc..285c8388be36 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -24,12 +24,12 @@ keywords = [ "arrow", "datafusion", "ballista", "query", "sql" ] license = "Apache-2.0" homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" -rust-version = "1.57" +rust-version = "1.58" [dependencies] -clap = "2.33" +clap = { version = "3", features = ["derive", "cargo"] } rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } datafusion = { path = "../datafusion", version = "6.0.0" } -arrow = { package = "arrow2", version="0.8", features = ["io_print"] } +arrow = { package = "arrow2", version="0.9", features = ["io_print"] } ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion-cli/Dockerfile b/datafusion-cli/Dockerfile index fed14188fded..cef7afa8c9da 100644 --- a/datafusion-cli/Dockerfile +++ b/datafusion-cli/Dockerfile @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -FROM rust:1.57 as builder +FROM rust:1.58 as builder COPY ./datafusion /usr/src/datafusion diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index 4c7c65bf537c..fa37059039a2 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -23,8 +23,9 @@ use crate::print_format::PrintFormat; use crate::print_options::{self, PrintOptions}; use datafusion::arrow::array::{ArrayRef, Utf8Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; +use datafusion::field_util::SchemaExt; +use datafusion::record_batch::RecordBatch; use std::str::FromStr; use std::sync::Arc; use std::time::Instant; diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 73e1b60ec42f..acc340db8222 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -24,9 +24,8 @@ use crate::{ print_format::{all_print_formats, PrintFormat}, print_options::PrintOptions, }; -use clap::SubCommand; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; +use datafusion::record_batch::RecordBatch; use rustyline::config::Config; use rustyline::error::ReadlineError; use rustyline::Editor; diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index c460a1d2f064..7839d4f69bcb 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -16,11 +16,14 @@ // under the License. //! Functions that are query-able and searchable via the `\h` command -use arrow::array::Utf8Array; +use arrow::array::{ArrayRef, Utf8Array}; +use arrow::chunk::Chunk; use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; use datafusion::arrow::io::print; use datafusion::error::{DataFusionError, Result}; +use datafusion::field_util::SchemaExt; +use datafusion::physical_plan::ColumnarValue::Array; +use datafusion::record_batch::RecordBatch; use std::fmt; use std::str::FromStr; use std::sync::Arc; @@ -187,14 +190,14 @@ impl fmt::Display for Function { pub fn display_all_functions() -> Result<()> { println!("Available help:"); - let array = StringArray::from_slice( + let array: ArrayRef = Arc::new(StringArray::from_slice( ALL_FUNCTIONS .iter() .map(|f| format!("{}", f)) .collect::>(), - ); + )); let schema = Schema::new(vec![Field::new("Function", DataType::Utf8, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?; - print::print(&[batch]); + let batch = Chunk::try_new(vec![array])?; + println!("{}", print::write(&[batch], &schema.field_names())); Ok(()) } diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 577a3110dc11..4cb9e9ddef14 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -40,32 +40,32 @@ pub async fn main() -> Result<()> { Parquet files as well as querying directly against in-memory data.", ) .arg( - Arg::with_name("data-path") + Arg::new("data-path") .help("Path to your data, default to current directory") - .short("p") + .short('p') .long("data-path") .validator(is_valid_data_dir) .takes_value(true), ) .arg( - Arg::with_name("batch-size") + Arg::new("batch-size") .help("The batch size of each query, or use DataFusion default") - .short("c") + .short('c') .long("batch-size") .validator(is_valid_batch_size) .takes_value(true), ) .arg( - Arg::with_name("file") + Arg::new("file") .help("Execute commands from file(s), then exit") - .short("f") + .short('f') .long("file") - .multiple(true) + .multiple_occurrences(true) .validator(is_valid_file) .takes_value(true), ) .arg( - Arg::with_name("format") + Arg::new("format") .help("Output format") .long("format") .default_value("table") @@ -81,21 +81,21 @@ pub async fn main() -> Result<()> { .takes_value(true), ) .arg( - Arg::with_name("host") + Arg::new("host") .help("Ballista scheduler host") .long("host") .takes_value(true), ) .arg( - Arg::with_name("port") + Arg::new("port") .help("Ballista scheduler port") .long("port") .takes_value(true), ) .arg( - Arg::with_name("quiet") + Arg::new("quiet") .help("Reduce printing other than the results and work quietly") - .short("q") + .short('q') .long("quiet") .takes_value(false), ) @@ -154,23 +154,23 @@ pub async fn main() -> Result<()> { Ok(()) } -fn is_valid_file(dir: String) -> std::result::Result<(), String> { - if Path::new(&dir).is_file() { +fn is_valid_file(dir: &str) -> std::result::Result<(), String> { + if Path::new(dir).is_file() { Ok(()) } else { Err(format!("Invalid file '{}'", dir)) } } -fn is_valid_data_dir(dir: String) -> std::result::Result<(), String> { - if Path::new(&dir).is_dir() { +fn is_valid_data_dir(dir: &str) -> std::result::Result<(), String> { + if Path::new(dir).is_dir() { Ok(()) } else { Err(format!("Invalid data directory '{}'", dir)) } } -fn is_valid_batch_size(size: String) -> std::result::Result<(), String> { +fn is_valid_batch_size(size: &str) -> std::result::Result<(), String> { match size.parse::() { Ok(size) if size > 0 => Ok(()), _ => Err(format!("Invalid batch size '{}'", size)), diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 9ea811c3a92b..fa8bf2384cf3 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -18,8 +18,9 @@ //! Print format variants use arrow::io::json::write::{JsonArray, JsonFormat, LineDelimited}; use datafusion::arrow::io::{csv::write, print}; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; +use datafusion::field_util::SchemaExt; +use datafusion::record_batch::RecordBatch; use std::fmt; use std::str::FromStr; @@ -59,7 +60,7 @@ impl FromStr for PrintFormat { } impl fmt::Display for PrintFormat { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Self::Csv => write!(f, "csv"), Self::Tsv => write!(f, "tsv"), @@ -77,10 +78,14 @@ fn print_batches_to_json(batches: &[RecordBatch]) -> Result Result println!("{}", print_batches_with_sep(batches, b',')?), Self::Tsv => println!("{}", print_batches_with_sep(batches, b'\t')?), - Self::Table => print::print(batches), + Self::Table => println!("{}", datafusion::arrow_print::write(batches)), Self::Json => { println!("{}", print_batches_to_json::(batches)?) } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 5e3792634a4e..bebd49831a5a 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -16,8 +16,8 @@ // under the License. use crate::print_format::PrintFormat; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::Result; +use datafusion::record_batch::RecordBatch; use std::time::Instant; #[derive(Debug, Clone)] diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 1474e6a75e06..b6724ae173f0 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" keywords = [ "arrow", "query", "sql" ] edition = "2021" publish = false -rust-version = "1.57" +rust-version = "1.58" [[example]] name = "avro_sql" @@ -34,8 +34,8 @@ path = "examples/avro_sql.rs" required-features = ["datafusion/avro"] [dev-dependencies] -arrow-format = { version = "0.3", features = ["flight-service", "flight-data"] } -arrow = { package = "arrow2", version="0.8", features = ["io_ipc", "io_flight"] } +arrow-format = { version = "0.4", features = ["flight-service", "flight-data"] } +arrow = { package = "arrow2", version="0.9", features = ["io_ipc", "io_flight"] } datafusion = { path = "../datafusion" } prost = "0.9" tonic = "0.6" diff --git a/datafusion-examples/examples/dataframe_in_memory.rs b/datafusion-examples/examples/dataframe_in_memory.rs index 0990881c139b..b00bfdabe368 100644 --- a/datafusion-examples/examples/dataframe_in_memory.rs +++ b/datafusion-examples/examples/dataframe_in_memory.rs @@ -19,10 +19,11 @@ use std::sync::Arc; use datafusion::arrow::array::{Int32Array, Utf8Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::arrow::record_batch::RecordBatch; +use datafusion::record_batch::RecordBatch; use datafusion::datasource::MemTable; use datafusion::error::Result; +use datafusion::field_util::SchemaExt; use datafusion::prelude::*; /// This example demonstrates how to use the DataFrame API against in-memory data. diff --git a/datafusion-examples/examples/flight_client.rs b/datafusion-examples/examples/flight_client.rs index 536aba30e610..5b8304c163c8 100644 --- a/datafusion-examples/examples/flight_client.rs +++ b/datafusion-examples/examples/flight_client.rs @@ -21,6 +21,8 @@ use arrow::io::flight::deserialize_schemas; use arrow_format::flight::data::{flight_descriptor, FlightDescriptor, Ticket}; use arrow_format::flight::service::flight_service_client::FlightServiceClient; use datafusion::arrow_print; +use datafusion::field_util::SchemaExt; +use datafusion::record_batch::RecordBatch; use std::collections::HashMap; /// This example shows how to wrap DataFusion with `FlightService` to support looking up schema information for @@ -64,13 +66,13 @@ async fn main() -> Result<(), Box> { let mut results = vec![]; let dictionaries_by_field = HashMap::new(); while let Some(flight_data) = stream.message().await? { - let record_batch = arrow::io::flight::deserialize_batch( + let chunk = arrow::io::flight::deserialize_batch( &flight_data, - schema.clone(), + schema.fields(), &ipc_schema, &dictionaries_by_field, )?; - results.push(record_batch); + results.push(RecordBatch::new_with_chunk(&schema, chunk)); } // print the results diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight_server.rs index 9a7b8a6bed21..b616cfb7bd29 100644 --- a/datafusion-examples/examples/flight_server.rs +++ b/datafusion-examples/examples/flight_server.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::chunk::Chunk; use std::pin::Pin; use std::sync::Arc; @@ -77,7 +78,7 @@ impl FlightService for FlightServiceImpl { .unwrap(); let schema_result = - arrow::io::flight::serialize_schema_to_result(schema.as_ref(), &[]); + arrow::io::flight::serialize_schema_to_result(schema.as_ref(), None); Ok(Response::new(schema_result)) } @@ -115,17 +116,20 @@ impl FlightService for FlightServiceImpl { // add an initial FlightData message that sends schema let options = WriteOptions::default(); - let schema_flight_data = - arrow::io::flight::serialize_schema(&df.schema().clone().into(), &[]); + let schema_flight_data = arrow::io::flight::serialize_schema( + &df.schema().clone().into(), + None, + ); let mut flights: Vec> = vec![Ok(schema_flight_data)]; let mut batches: Vec> = results - .iter() + .into_iter() .flat_map(|batch| { + let chunk = Chunk::new(batch.columns().to_vec()); let (flight_dictionaries, flight_batch) = - arrow::io::flight::serialize_batch(batch, &[], &options); + arrow::io::flight::serialize_batch(&chunk, &[], &options); flight_dictionaries .into_iter() .chain(std::iter::once(flight_batch)) diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 527ff84c0272..15c85bc6dd1f 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -17,11 +17,11 @@ /// In this example we will declare a single-type, single return type UDAF that computes the geometric mean. /// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean -use datafusion::arrow::{ - array::Float32Array, array::Float64Array, datatypes::DataType, - record_batch::RecordBatch, -}; +use datafusion::arrow::{array::Float32Array, array::Float64Array, datatypes::DataType}; +use datafusion::record_batch::RecordBatch; +use arrow::array::ArrayRef; +use datafusion::field_util::SchemaExt; use datafusion::physical_plan::functions::Volatility; use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator}; use datafusion::{prelude::*, scalar::ScalarValue}; @@ -65,20 +65,6 @@ impl GeometricMean { pub fn new() -> Self { GeometricMean { n: 0, prod: 1.0 } } -} - -// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions -// to use them. -impl Accumulator for GeometricMean { - // this function serializes our state to `ScalarValue`, which DataFusion uses - // to pass this state between execution stages. - // Note that this can be arbitrary data. - fn state(&self) -> Result> { - Ok(vec![ - ScalarValue::from(self.prod), - ScalarValue::from(self.n), - ]) - } // this function receives one entry per argument of this accumulator. // DataFusion calls this function on every row, and expects this function to update the accumulator's state. @@ -113,6 +99,20 @@ impl Accumulator for GeometricMean { }; Ok(()) } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } // DataFusion expects this function to return the final value of this aggregator. // in this case, this is the formula of the geometric mean @@ -121,9 +121,37 @@ impl Accumulator for GeometricMean { Ok(ScalarValue::from(value)) } + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + }; + (0..values[0].len()).try_for_each(|index| { + let v = values + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + self.update(&v) + }) + } + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, // that can be used to perform these operations on arrays instead of single values. // By default, these methods call `update` and `merge` row by row + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + (0..states[0].len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + self.merge(&v) + }) + } } #[tokio::main] diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 35ad4f491985..e30bd394a08e 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. +use datafusion::field_util::SchemaExt; +use datafusion::prelude::*; +use datafusion::record_batch::RecordBatch; use datafusion::{ arrow::{ array::{ArrayRef, Float32Array, Float64Array}, datatypes::DataType, - record_batch::RecordBatch, }, physical_plan::functions::Volatility, }; - -use datafusion::prelude::*; use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use std::sync::Arc; diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 5a79041bbb85..d80bc090a5ec 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -31,7 +31,7 @@ include = [ "Cargo.toml", ] edition = "2021" -rust-version = "1.57" +rust-version = "1.58" [lib] name = "datafusion" @@ -55,7 +55,7 @@ avro = ["arrow/io_avro", "arrow/io_avro_async", "arrow/io_avro_compression", "nu [dependencies] ahash = { version = "0.7", default-features = false } hashbrown = { version = "0.11", features = ["raw"] } -parquet = { package = "parquet2", version = "0.8", default_features = false, features = ["stream"] } +parquet = { package = "parquet2", version = "0.9", default_features = false, features = ["stream"] } sqlparser = "0.13" paste = "^1.0" num_cpus = "1.13.0" @@ -66,9 +66,9 @@ pin-project-lite= "^0.2.7" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs"] } tokio-stream = "0.1" log = "^0.4" -md-5 = { version = "^0.9.1", optional = true } -sha2 = { version = "^0.9.1", optional = true } -blake2 = { version = "^0.9.2", optional = true } +md-5 = { version = "^0.10.0", optional = true } +sha2 = { version = "^0.10.1", optional = true } +blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } ordered-float = "2.0" unicode-segmentation = { version = "^1.7.1", optional = true } @@ -77,19 +77,20 @@ lazy_static = { version = "^1.4.0" } smallvec = { version = "1.6", features = ["union"] } rand = "0.8" num-traits = { version = "0.2", optional = true } -pyo3 = { version = "0.14", optional = true } +pyo3 = { version = "0.15", optional = true } +tempfile = "3" + avro-schema = { version = "0.2", optional = true } # used to print arrow arrays in a nice columnar format comfy-table = { version = "5.0", default-features = false } [dependencies.arrow] package = "arrow2" -version="0.8" +version="0.9" features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "ahash", "compute"] [dev-dependencies] criterion = "0.3" -tempfile = "3" doc-comment = "0.3" parquet-format-async-temp = "0" diff --git a/datafusion/benches/data_utils/mod.rs b/datafusion/benches/data_utils/mod.rs index 335d4465c627..7d2885e380ae 100644 --- a/datafusion/benches/data_utils/mod.rs +++ b/datafusion/benches/data_utils/mod.rs @@ -17,9 +17,11 @@ //! This module provides the in-memory table for more realistic benchmarking. -use arrow::{array::*, datatypes::*, record_batch::RecordBatch}; +use arrow::{array::*, datatypes::*}; use datafusion::datasource::MemTable; use datafusion::error::Result; +use datafusion::field_util::SchemaExt; +use datafusion::record_batch::RecordBatch; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; diff --git a/datafusion/benches/filter_query_sql.rs b/datafusion/benches/filter_query_sql.rs index dfcde1409c86..994010a025af 100644 --- a/datafusion/benches/filter_query_sql.rs +++ b/datafusion/benches/filter_query_sql.rs @@ -18,10 +18,11 @@ use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, }; use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::field_util::SchemaExt; use datafusion::prelude::ExecutionContext; +use datafusion::record_batch::RecordBatch; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; use std::sync::Arc; diff --git a/datafusion/benches/math_query_sql.rs b/datafusion/benches/math_query_sql.rs index fcc11dac94f4..0f6a697a808d 100644 --- a/datafusion/benches/math_query_sql.rs +++ b/datafusion/benches/math_query_sql.rs @@ -26,12 +26,13 @@ use tokio::runtime::Runtime; use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, }; use datafusion::error::Result; +use datafusion::record_batch::RecordBatch; use datafusion::datasource::MemTable; use datafusion::execution::context::ExecutionContext; +use datafusion::field_util::SchemaExt; fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); diff --git a/datafusion/benches/physical_plan.rs b/datafusion/benches/physical_plan.rs index 6c608f4c537f..15ea20b41e76 100644 --- a/datafusion/benches/physical_plan.rs +++ b/datafusion/benches/physical_plan.rs @@ -23,17 +23,16 @@ extern crate datafusion; use std::sync::Arc; -use arrow::{ - array::{ArrayRef, Int64Array, Utf8Array}, - record_batch::RecordBatch, -}; +use arrow::array::{ArrayRef, Int64Array, Utf8Array}; +use datafusion::record_batch::RecordBatch; use tokio::runtime::Runtime; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::{ collect, expressions::{col, PhysicalSortExpr}, memory::MemoryExec, - sort_preserving_merge::SortPreservingMergeExec, }; // Initialise the operator using the provided record batches and the sort key @@ -55,10 +54,11 @@ fn sort_preserving_merge_operator(batches: Vec, sort: &[&str]) { None, ) .unwrap(); - let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 8192)); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); let rt = Runtime::new().unwrap(); - rt.block_on(collect(merge)).unwrap(); + let rt_env = Arc::new(RuntimeEnv::default()); + rt.block_on(collect(merge, rt_env)).unwrap(); } // Produces `n` record batches of row size `m`. Each record batch will have diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index 13e757c2bb7a..a6bf75e4760c 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -29,6 +29,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; use datafusion::execution::context::ExecutionContext; +use datafusion::field_util::SchemaExt; use tokio::runtime::Runtime; fn query(ctx: Arc>, sql: &str) { @@ -76,13 +77,14 @@ fn create_context() -> Arc> { let partitions = 16; rt.block_on(async { - let mem_table = MemTable::load(Arc::new(csv), 16 * 1024, Some(partitions)) - .await - .unwrap(); - // create local execution context let mut ctx = ExecutionContext::new(); ctx.state.lock().unwrap().config.target_partitions = 1; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + + let mem_table = MemTable::load(Arc::new(csv), Some(partitions), runtime) + .await + .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) .unwrap(); ctx_holder.lock().unwrap().push(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/src/arrow_print.rs b/datafusion/src/arrow_print.rs index 9232870c5e94..d8bf0b82825f 100644 --- a/datafusion/src/arrow_print.rs +++ b/datafusion/src/arrow_print.rs @@ -20,8 +20,9 @@ // adapted from https://github.com/jorgecarleitao/arrow2/blob/ef7937dfe56033c2cc491482c67587b52cd91554/src/array/display.rs // see: https://github.com/jorgecarleitao/arrow2/issues/771 -use arrow::{array::*, record_batch::RecordBatch}; - +use crate::field_util::{FieldExt, SchemaExt}; +use crate::record_batch::RecordBatch; +use arrow::array::*; use comfy_table::{Cell, Table}; macro_rules! dyn_display { diff --git a/datafusion/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/src/avro_to_arrow/arrow_array_reader.rs index 1a8424ab8448..0fd50e9b2c1f 100644 --- a/datafusion/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/src/avro_to_arrow/arrow_array_reader.rs @@ -17,9 +17,9 @@ //! Avro to Arrow array readers -use crate::arrow::record_batch::RecordBatch; use crate::error::Result; -use crate::physical_plan::coalesce_batches::concat_batches; +use crate::physical_plan::coalesce_batches::concat_chunks; +use crate::record_batch::RecordBatch; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::io::avro::read::Reader as AvroReader; @@ -45,7 +45,7 @@ impl<'a, R: Read> AvroBatchReader { codec, ), avro_schemas, - schema.clone(), + schema.fields.clone(), ); Ok(Self { reader, schema }) } @@ -55,15 +55,15 @@ impl<'a, R: Read> AvroBatchReader { pub fn next_batch(&mut self, batch_size: usize) -> ArrowResult> { if let Some(Ok(batch)) = self.reader.next() { let mut batch = batch; - 'batch: while batch.num_rows() < batch_size { + 'batch: while batch.len() < batch_size { if let Some(Ok(next_batch)) = self.reader.next() { - let num_rows = batch.num_rows() + next_batch.num_rows(); - batch = concat_batches(&self.schema, &[batch, next_batch], num_rows)? + let num_rows = batch.len() + next_batch.len(); + batch = concat_chunks(&self.schema, &[batch, next_batch], num_rows)? } else { break 'batch; } } - Ok(Some(batch)) + Ok(Some(RecordBatch::new_with_chunk(&self.schema, batch))) } else { Ok(None) } @@ -75,6 +75,7 @@ mod test { use crate::arrow::array::Array; use crate::arrow::datatypes::{Field, TimeUnit}; use crate::avro_to_arrow::{Reader, ReaderBuilder}; + use crate::field_util::SchemaExt; use arrow::array::{Int32Array, Int64Array, ListArray}; use arrow::datatypes::DataType; use std::fs::File; diff --git a/datafusion/src/avro_to_arrow/reader.rs b/datafusion/src/avro_to_arrow/reader.rs index 415756eb3cea..7cb640e60560 100644 --- a/datafusion/src/avro_to_arrow/reader.rs +++ b/datafusion/src/avro_to_arrow/reader.rs @@ -17,8 +17,8 @@ use super::arrow_array_reader::AvroBatchReader; use crate::arrow::datatypes::SchemaRef; -use crate::arrow::record_batch::RecordBatch; use crate::error::Result; +use crate::record_batch::RecordBatch; use arrow::error::Result as ArrowResult; use arrow::io::avro::{read, Compression}; use std::io::{Read, Seek, SeekFrom}; @@ -196,6 +196,8 @@ mod tests { use super::*; use crate::arrow::array::*; use crate::arrow::datatypes::{DataType, Field}; + use crate::field_util::SchemaExt; + use crate::record_batch::RecordBatch; use arrow::datatypes::TimeUnit; use std::fs::File; diff --git a/datafusion/src/avro_to_arrow/schema.rs b/datafusion/src/avro_to_arrow/schema.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/datafusion/src/avro_to_arrow/schema.rs @@ -0,0 +1 @@ + diff --git a/datafusion/src/cast.rs b/datafusion/src/cast.rs new file mode 100644 index 000000000000..2ebfa59696d5 --- /dev/null +++ b/datafusion/src/cast.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines helper functions for force Array type downcast + +use arrow::array::*; +use arrow::array::{Array, PrimitiveArray}; +use arrow::types::NativeType; + +/// Force downcast ArrayRef to PrimitiveArray +pub fn as_primitive_array(arr: &dyn Array) -> &PrimitiveArray +where + T: NativeType, +{ + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to primitive array") +} + +macro_rules! array_downcast_fn { + ($name: ident, $arrty: ty, $arrty_str:expr) => { + #[doc = "Force downcast ArrayRef to "] + #[doc = $arrty_str] + pub fn $name(arr: &dyn Array) -> &$arrty { + arr.as_any().downcast_ref::<$arrty>().expect(concat!( + "Unable to downcast to typed array through ", + stringify!($name) + )) + } + }; + + // use recursive macro to generate dynamic doc string for a given array type + ($name: ident, $arrty: ty) => { + array_downcast_fn!($name, $arrty, stringify!($arrty)); + }; +} + +array_downcast_fn!(as_string_array, Utf8Array); diff --git a/datafusion/src/catalog/catalog.rs b/datafusion/src/catalog/catalog.rs index 30fea1f45f2f..7dbfa5a80c3e 100644 --- a/datafusion/src/catalog/catalog.rs +++ b/datafusion/src/catalog/catalog.rs @@ -59,6 +59,12 @@ impl MemoryCatalogList { } } +impl Default for MemoryCatalogList { + fn default() -> Self { + Self::new() + } +} + impl CatalogList for MemoryCatalogList { fn as_any(&self) -> &dyn Any { self @@ -84,6 +90,12 @@ impl CatalogList for MemoryCatalogList { } } +impl Default for MemoryCatalogProvider { + fn default() -> Self { + Self::new() + } +} + /// Represents a catalog, comprising a number of named schemas. pub trait CatalogProvider: Sync + Send { /// Returns the catalog provider as [`Any`](std::any::Any) diff --git a/datafusion/src/catalog/information_schema.rs b/datafusion/src/catalog/information_schema.rs index a6585a497477..a36f6cf0f717 100644 --- a/datafusion/src/catalog/information_schema.rs +++ b/datafusion/src/catalog/information_schema.rs @@ -24,13 +24,14 @@ use std::{ sync::{Arc, Weak}, }; +use crate::record_batch::RecordBatch; use arrow::{ array::*, datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, }; use crate::datasource::{MemTable, TableProvider, TableType}; +use crate::field_util::{FieldExt, SchemaExt}; use super::{ catalog::{CatalogList, CatalogProvider}, diff --git a/datafusion/src/catalog/mod.rs b/datafusion/src/catalog/mod.rs index 10591f07e378..478cdefc0cb7 100644 --- a/datafusion/src/catalog/mod.rs +++ b/datafusion/src/catalog/mod.rs @@ -18,6 +18,7 @@ //! This module contains interfaces and default implementations //! of table namespacing concepts, including catalogs and schemas. +#![allow(clippy::module_inception)] pub mod catalog; pub mod information_schema; pub mod schema; diff --git a/datafusion/src/catalog/schema.rs b/datafusion/src/catalog/schema.rs index cf754f63d357..1379eb1894eb 100644 --- a/datafusion/src/catalog/schema.rs +++ b/datafusion/src/catalog/schema.rs @@ -79,6 +79,12 @@ impl MemorySchemaProvider { } } +impl Default for MemorySchemaProvider { + fn default() -> Self { + Self::new() + } +} + impl SchemaProvider for MemorySchemaProvider { fn as_any(&self) -> &dyn Any { self @@ -128,6 +134,7 @@ mod tests { use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider}; use crate::datasource::empty::EmptyTable; + use crate::field_util::SchemaExt; #[tokio::test] async fn test_mem_provider() { diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index c8c5dcc1c5e6..9c4a4e4aeb4d 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -17,11 +17,11 @@ //! DataFrame API for building and executing query plans. -use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use crate::logical_plan::{ DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning, }; +use crate::record_batch::RecordBatch; use std::sync::Arc; use crate::physical_plan::SendableRecordBatchStream; diff --git a/datafusion/src/datasource/datasource.rs b/datafusion/src/datasource/datasource.rs index 823b40807e93..1b59c857fb07 100644 --- a/datafusion/src/datasource/datasource.rs +++ b/datafusion/src/datasource/datasource.rs @@ -77,7 +77,6 @@ pub trait TableProvider: Sync + Send { async fn scan( &self, projection: &Option>, - batch_size: usize, filters: &[Expr], // limit can be used to reduce the amount scanned // from the datasource as a performance optimization. diff --git a/datafusion/src/datasource/empty.rs b/datafusion/src/datasource/empty.rs index 380c5a7ac5d1..5622d15a0d67 100644 --- a/datafusion/src/datasource/empty.rs +++ b/datafusion/src/datasource/empty.rs @@ -26,6 +26,7 @@ use async_trait::async_trait; use crate::datasource::TableProvider; use crate::error::Result; use crate::logical_plan::Expr; +use crate::physical_plan::project_schema; use crate::physical_plan::{empty::EmptyExec, ExecutionPlan}; /// A table with a schema but no data. @@ -53,21 +54,11 @@ impl TableProvider for EmptyTable { async fn scan( &self, projection: &Option>, - _batch_size: usize, _filters: &[Expr], _limit: Option, ) -> Result> { // even though there is no data, projections apply - let projection = match projection.clone() { - Some(p) => p, - None => (0..self.schema.fields().len()).collect(), - }; - let projected_schema = Schema::new( - projection - .iter() - .map(|i| self.schema.field(*i).clone()) - .collect(), - ); - Ok(Arc::new(EmptyExec::new(false, Arc::new(projected_schema)))) + let projected_schema = project_schema(&self.schema, projection.as_ref())?; + Ok(Arc::new(EmptyExec::new(false, projected_schema))) } } diff --git a/datafusion/src/datasource/file_format/avro.rs b/datafusion/src/datasource/file_format/avro.rs index 1f7e50663889..bd83e75c4f74 100644 --- a/datafusion/src/datasource/file_format/avro.rs +++ b/datafusion/src/datasource/file_format/avro.rs @@ -20,7 +20,6 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::Schema; use arrow::{self, datatypes::SchemaRef}; use async_trait::async_trait; use futures::StreamExt; @@ -30,7 +29,7 @@ use crate::avro_to_arrow::read_avro_schema_from_reader; use crate::datasource::object_store::{ObjectReader, ObjectReaderStream}; use crate::error::Result; use crate::logical_plan::Expr; -use crate::physical_plan::file_format::{AvroExec, PhysicalPlanConfig}; +use crate::physical_plan::file_format::{AvroExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; @@ -46,13 +45,12 @@ impl FileFormat for AvroFormat { async fn infer_schema(&self, mut readers: ObjectReaderStream) -> Result { let mut schemas = vec![]; - while let Some(obj_reader) = readers.next().await { + if let Some(obj_reader) = readers.next().await { let mut reader = obj_reader?.sync_reader()?; let schema = read_avro_schema_from_reader(&mut reader)?; schemas.push(schema); } - let merged_schema = Schema::try_merge(schemas)?; - Ok(Arc::new(merged_schema)) + Ok(Arc::new(schemas.first().unwrap().clone())) } async fn infer_stats(&self, _reader: Arc) -> Result { @@ -61,7 +59,7 @@ impl FileFormat for AvroFormat { async fn create_physical_plan( &self, - conf: PhysicalPlanConfig, + conf: FileScanConfig, _filters: &[Expr], ) -> Result> { let exec = AvroExec::new(conf); @@ -81,6 +79,8 @@ mod tests { }; use super::*; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::field_util::{FieldExt, SchemaExt}; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, UInt64Array, }; @@ -88,9 +88,10 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); let projection = None; - let exec = get_exec("alltypes_plain.avro", &projection, 2, None).await?; - let stream = exec.execute(0).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches = stream .map(|batch| { @@ -108,9 +109,10 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.avro", &projection, 1024, Some(1)).await?; - let batches = collect(exec).await?; + let exec = get_exec("alltypes_plain.avro", &projection, Some(1)).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -120,8 +122,9 @@ mod tests { #[tokio::test] async fn read_alltypes_plain_avro() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.avro", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; let x: Vec = exec .schema() @@ -146,7 +149,7 @@ mod tests { x ); - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -170,10 +173,11 @@ mod tests { #[tokio::test] async fn read_bool_alltypes_plain_avro() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![1]); - let exec = get_exec("alltypes_plain.avro", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -198,10 +202,11 @@ mod tests { #[tokio::test] async fn read_i32_alltypes_plain_avro() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec("alltypes_plain.avro", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -223,10 +228,11 @@ mod tests { #[tokio::test] async fn read_i96_alltypes_plain_avro() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![10]); - let exec = get_exec("alltypes_plain.avro", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -248,10 +254,11 @@ mod tests { #[tokio::test] async fn read_f32_alltypes_plain_avro() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![6]); - let exec = get_exec("alltypes_plain.avro", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -276,10 +283,11 @@ mod tests { #[tokio::test] async fn read_f64_alltypes_plain_avro() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![7]); - let exec = get_exec("alltypes_plain.avro", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -304,10 +312,11 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_avro() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![9]); - let exec = get_exec("alltypes_plain.avro", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -333,7 +342,6 @@ mod tests { async fn get_exec( file_name: &str, projection: &Option>, - batch_size: usize, limit: Option, ) -> Result> { let testdata = crate::test_util::arrow_test_data(); @@ -350,13 +358,12 @@ mod tests { let file_groups = vec![vec![local_unpartitioned_file(filename.to_owned())]]; let exec = format .create_physical_plan( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema, file_groups, statistics, projection: projection.clone(), - batch_size, limit, table_partition_cols: vec![], }, diff --git a/datafusion/src/datasource/file_format/csv.rs b/datafusion/src/datasource/file_format/csv.rs index a65a1914e30c..c8897c2f011e 100644 --- a/datafusion/src/datasource/file_format/csv.rs +++ b/datafusion/src/datasource/file_format/csv.rs @@ -29,8 +29,9 @@ use futures::StreamExt; use super::FileFormat; use crate::datasource::object_store::{ObjectReader, ObjectReaderStream}; use crate::error::Result; +use crate::field_util::SchemaExt; use crate::logical_plan::Expr; -use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; +use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; @@ -102,25 +103,18 @@ impl FileFormat for CsvFormat { .has_headers(self.has_header) .from_reader(obj_reader?.sync_reader()?); - let schema = csv::read::infer_schema( + let (fields, records_read) = csv::read::infer_schema( &mut reader, Some(records_to_read), self.has_header, &csv::read::infer, )?; - // if records_read == 0 { - // continue; - // } - // schemas.push(schema.clone()); - // records_to_read -= records_read; - // if records_to_read == 0 { - // break; - // } - // - // FIXME: return recods_read from infer_schema - schemas.push(schema.clone()); - records_to_read -= records_to_read; + if records_read == 0 { + continue; + } + schemas.push(Schema::new(fields)); + records_to_read -= records_read; if records_to_read == 0 { break; } @@ -136,7 +130,7 @@ impl FileFormat for CsvFormat { async fn create_physical_plan( &self, - conf: PhysicalPlanConfig, + conf: FileScanConfig, _filters: &[Expr], ) -> Result> { let exec = CsvExec::new(conf, self.has_header, self.delimiter); @@ -147,9 +141,11 @@ impl FileFormat for CsvFormat { #[cfg(test)] mod tests { use super::*; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::field_util::{FieldExt, SchemaExt}; use crate::{ datasource::{ - file_format::PhysicalPlanConfig, + file_format::FileScanConfig, object_store::local::{ local_object_reader, local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, @@ -161,10 +157,11 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]); - let exec = get_exec("aggregate_test_100.csv", &projection, 2, None).await?; - let stream = exec.execute(0).await?; + let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches: i32 = stream .map(|batch| { @@ -186,9 +183,10 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0, 1, 2, 3]); - let exec = get_exec("aggregate_test_100.csv", &projection, 1024, Some(1)).await?; - let batches = collect(exec).await?; + let exec = get_exec("aggregate_test_100.csv", &projection, Some(1)).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -199,7 +197,7 @@ mod tests { #[tokio::test] async fn infer_schema() -> Result<()> { let projection = None; - let exec = get_exec("aggregate_test_100.csv", &projection, 1024, None).await?; + let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; let x: Vec = exec .schema() @@ -231,10 +229,11 @@ mod tests { #[tokio::test] async fn read_char_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec("aggregate_test_100.csv", &projection, 1024, None).await?; + let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; - let batches = collect(exec).await.expect("Collect batches"); + let batches = collect(exec, runtime).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -258,7 +257,6 @@ mod tests { async fn get_exec( file_name: &str, projection: &Option>, - batch_size: usize, limit: Option, ) -> Result> { let testdata = crate::test_util::arrow_test_data(); @@ -275,13 +273,12 @@ mod tests { let file_groups = vec![vec![local_unpartitioned_file(filename.to_owned())]]; let exec = format .create_physical_plan( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema, file_groups, statistics, projection: projection.clone(), - batch_size, limit, table_partition_cols: vec![], }, diff --git a/datafusion/src/datasource/file_format/json.rs b/datafusion/src/datasource/file_format/json.rs index 45c3d3af1195..5220e6f30fe7 100644 --- a/datafusion/src/datasource/file_format/json.rs +++ b/datafusion/src/datasource/file_format/json.rs @@ -27,9 +27,10 @@ use async_trait::async_trait; use futures::StreamExt; use super::FileFormat; -use super::PhysicalPlanConfig; +use super::FileScanConfig; use crate::datasource::object_store::{ObjectReader, ObjectReaderStream}; use crate::error::Result; +use crate::field_util::SchemaExt; use crate::logical_plan::Expr; use crate::physical_plan::file_format::NdJsonExec; use crate::physical_plan::ExecutionPlan; @@ -77,7 +78,7 @@ impl FileFormat for JsonFormat { async fn create_physical_plan( &self, - conf: PhysicalPlanConfig, + conf: FileScanConfig, _filters: &[Expr], ) -> Result> { let exec = NdJsonExec::new(conf); @@ -90,9 +91,11 @@ mod tests { use arrow::array::Int64Array; use super::*; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::field_util::FieldExt; use crate::{ datasource::{ - file_format::PhysicalPlanConfig, + file_format::FileScanConfig, object_store::local::{ local_object_reader, local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, @@ -103,9 +106,10 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); let projection = None; - let exec = get_exec(&projection, 2, None).await?; - let stream = exec.execute(0).await?; + let exec = get_exec(&projection, None).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches: i32 = stream .map(|batch| { @@ -127,9 +131,10 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec(&projection, 1024, Some(1)).await?; - let batches = collect(exec).await?; + let exec = get_exec(&projection, Some(1)).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -140,7 +145,7 @@ mod tests { #[tokio::test] async fn infer_schema() -> Result<()> { let projection = None; - let exec = get_exec(&projection, 1024, None).await?; + let exec = get_exec(&projection, None).await?; let x: Vec = exec .schema() @@ -155,10 +160,11 @@ mod tests { #[tokio::test] async fn read_int_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec(&projection, 1024, None).await?; + let exec = get_exec(&projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -184,7 +190,6 @@ mod tests { async fn get_exec( projection: &Option>, - batch_size: usize, limit: Option, ) -> Result> { let filename = "tests/jsons/2.json"; @@ -200,13 +205,12 @@ mod tests { let file_groups = vec![vec![local_unpartitioned_file(filename.to_owned())]]; let exec = format .create_physical_plan( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema, file_groups, statistics, projection: projection.clone(), - batch_size, limit, table_partition_cols: vec![], }, diff --git a/datafusion/src/datasource/file_format/mod.rs b/datafusion/src/datasource/file_format/mod.rs index 54491615fc4c..21da2e1e6a27 100644 --- a/datafusion/src/datasource/file_format/mod.rs +++ b/datafusion/src/datasource/file_format/mod.rs @@ -29,7 +29,7 @@ use std::sync::Arc; use crate::arrow::datatypes::SchemaRef; use crate::error::Result; use crate::logical_plan::Expr; -use crate::physical_plan::file_format::PhysicalPlanConfig; +use crate::physical_plan::file_format::FileScanConfig; use crate::physical_plan::{ExecutionPlan, Statistics}; use async_trait::async_trait; @@ -59,7 +59,7 @@ pub trait FileFormat: Send + Sync + fmt::Debug { /// according to this file format. async fn create_physical_plan( &self, - conf: PhysicalPlanConfig, + conf: FileScanConfig, filters: &[Expr], ) -> Result>; } diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs index c74155ba3469..9af9e607dc31 100644 --- a/datafusion/src/datasource/file_format/parquet.rs +++ b/datafusion/src/datasource/file_format/parquet.rs @@ -17,15 +17,16 @@ //! Parquet format abstractions +use arrow::array::{BooleanArray, MutableArray, MutableUtf8Array}; use std::any::{type_name, Any}; use std::sync::Arc; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use futures::stream::StreamExt; use arrow::io::parquet::read::{get_schema, read_metadata}; +use futures::TryStreamExt; use parquet::statistics::{ BinaryStatistics as ParquetBinaryStatistics, BooleanStatistics as ParquetBooleanStatistics, @@ -33,19 +34,19 @@ use parquet::statistics::{ }; use super::FileFormat; -use super::PhysicalPlanConfig; +use super::FileScanConfig; use crate::arrow::datatypes::{DataType, Field}; use crate::datasource::object_store::{ObjectReader, ObjectReaderStream}; use crate::datasource::{create_max_min_accs, get_col_stats}; use crate::error::DataFusionError; use crate::error::Result; +use crate::field_util::SchemaExt; use crate::logical_plan::combine_filters; use crate::logical_plan::Expr; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::file_format::ParquetExec; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::{Accumulator, Statistics}; -use crate::scalar::ScalarValue; /// The default file exetension of parquet files pub const DEFAULT_PARQUET_EXTENSION: &str = ".parquet"; @@ -83,16 +84,14 @@ impl FileFormat for ParquetFormat { self } - async fn infer_schema(&self, mut readers: ObjectReaderStream) -> Result { - // We currently get the schema information from the first file rather than do - // schema merging and this is a limitation. - // See https://issues.apache.org/jira/browse/ARROW-11017 - let first_file = readers - .next() - .await - .ok_or_else(|| DataFusionError::Plan("No data file found".to_owned()))??; - let schema = fetch_schema(first_file)?; - Ok(Arc::new(schema)) + async fn infer_schema(&self, readers: ObjectReaderStream) -> Result { + let merged_schema = readers + .try_fold(Schema::empty(), |acc, reader| async { + let next_schema = fetch_schema(reader); + Schema::try_merge([acc, next_schema?]) + }) + .await?; + Ok(Arc::new(merged_schema)) } async fn infer_stats(&self, reader: Arc) -> Result { @@ -102,7 +101,7 @@ impl FileFormat for ParquetFormat { async fn create_physical_plan( &self, - conf: PhysicalPlanConfig, + conf: FileScanConfig, filters: &[Expr], ) -> Result> { // If enable pruning then combine the filters to build the predicate. @@ -128,7 +127,7 @@ fn summarize_min_max( use arrow::io::parquet::read::PhysicalType; macro_rules! update_primitive_min_max { - ($DT:ident, $PRIMITIVE_TYPE:ident) => {{ + ($DT:ident, $PRIMITIVE_TYPE:ident, $ARRAY_TYPE:ident) => {{ if let DataType::$DT = fields[i].data_type() { let stats = stats .as_any() @@ -141,7 +140,9 @@ fn summarize_min_max( })?; if let Some(max_value) = &mut max_values[i] { if let Some(v) = stats.max_value { - match max_value.update(&[ScalarValue::$DT(Some(v))]) { + match max_value.update_batch(&[Arc::new( + arrow::array::$ARRAY_TYPE::from_slice(vec![v]), + )]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -151,7 +152,9 @@ fn summarize_min_max( } if let Some(min_value) = &mut min_values[i] { if let Some(v) = stats.min_value { - match min_value.update(&[ScalarValue::$DT(Some(v))]) { + match min_value.update_batch(&[Arc::new( + arrow::array::$ARRAY_TYPE::from_slice(vec![v]), + )]) { Ok(_) => {} Err(_) => { min_values[i] = None; @@ -176,7 +179,9 @@ fn summarize_min_max( })?; if let Some(max_value) = &mut max_values[i] { if let Some(v) = stats.max_value { - match max_value.update(&[ScalarValue::Boolean(Some(v))]) { + match max_value + .update_batch(&[Arc::new(BooleanArray::from_slice(vec![v]))]) + { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -186,7 +191,9 @@ fn summarize_min_max( } if let Some(min_value) = &mut min_values[i] { if let Some(v) = stats.min_value { - match min_value.update(&[ScalarValue::Boolean(Some(v))]) { + match min_value + .update_batch(&[Arc::new(BooleanArray::from_slice(vec![v]))]) + { Ok(_) => {} Err(_) => { min_values[i] = None; @@ -197,18 +204,18 @@ fn summarize_min_max( } } PhysicalType::Int32 => { - update_primitive_min_max!(Int32, i32); + update_primitive_min_max!(Int32, i32, Int32Array); } PhysicalType::Int64 => { - update_primitive_min_max!(Int64, i64); + update_primitive_min_max!(Int64, i64, Int64Array); } // 96 bit ints not supported PhysicalType::Int96 => {} PhysicalType::Float => { - update_primitive_min_max!(Float32, f32); + update_primitive_min_max!(Float32, f32, Float32Array); } PhysicalType::Double => { - update_primitive_min_max!(Float64, f64); + update_primitive_min_max!(Float64, f64, Float64Array); } PhysicalType::ByteArray => { if let DataType::Utf8 = fields[i].data_type() { @@ -222,9 +229,9 @@ fn summarize_min_max( })?; if let Some(max_value) = &mut max_values[i] { if let Some(v) = &stats.max_value { - match max_value.update(&[ScalarValue::Utf8( - std::str::from_utf8(&*v).map(|s| s.to_string()).ok(), - )]) { + let mut a = MutableUtf8Array::::with_capacity(1); + a.push(std::str::from_utf8(&*v).map(|s| s.to_string()).ok()); + match max_value.update_batch(&[a.as_arc()]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -234,9 +241,9 @@ fn summarize_min_max( } if let Some(min_value) = &mut min_values[i] { if let Some(v) = &stats.min_value { - match min_value.update(&[ScalarValue::Utf8( - std::str::from_utf8(&*v).map(|s| s.to_string()).ok(), - )]) { + let mut a = MutableUtf8Array::::with_capacity(1); + a.push(std::str::from_utf8(&*v).map(|s| s.to_string()).ok()); + match min_value.update_batch(&[a.as_arc()]) { Ok(_) => {} Err(_) => { min_values[i] = None; @@ -331,6 +338,9 @@ mod tests { }; use super::*; + use crate::field_util::FieldExt; + + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, }; @@ -339,9 +349,10 @@ mod tests { #[tokio::test] /// Parquet2 lacks the ability to set batch size for parquet reader async fn read_small_batches() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); let projection = None; - let exec = get_exec("alltypes_plain.parquet", &projection, 2, None).await?; - let stream = exec.execute(0).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches = stream .map(|batch| { @@ -362,14 +373,15 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, Some(1)).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, Some(1)).await?; // note: even if the limit is set, the executor rounds up to the batch size assert_eq!(exec.statistics().num_rows, Some(8)); assert_eq!(exec.statistics().total_byte_size, Some(671)); assert!(exec.statistics().is_exact); - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -379,8 +391,9 @@ mod tests { #[tokio::test] async fn read_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; let x: Vec = exec .schema() @@ -404,7 +417,7 @@ mod tests { y ); - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -415,10 +428,11 @@ mod tests { #[tokio::test] async fn read_bool_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![1]); - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -443,10 +457,11 @@ mod tests { #[tokio::test] async fn read_i32_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -468,10 +483,11 @@ mod tests { #[tokio::test] async fn read_i96_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![10]); - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -493,10 +509,11 @@ mod tests { #[tokio::test] async fn read_f32_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![6]); - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -521,10 +538,11 @@ mod tests { #[tokio::test] async fn read_f64_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![7]); - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -549,10 +567,11 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![9]); - let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; + let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -578,7 +597,6 @@ mod tests { async fn get_exec( file_name: &str, projection: &Option>, - batch_size: usize, limit: Option, ) -> Result> { let testdata = crate::test_util::parquet_test_data(); @@ -595,13 +613,12 @@ mod tests { let file_groups = vec![vec![local_unpartitioned_file(filename.clone())]]; let exec = format .create_physical_plan( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema, file_groups, statistics, projection: projection.clone(), - batch_size, limit, table_partition_cols: vec![], }, diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs index abee565af260..0d52966f1065 100644 --- a/datafusion/src/datasource/listing/helpers.rs +++ b/datafusion/src/datasource/listing/helpers.rs @@ -19,10 +19,10 @@ use std::sync::Arc; +use crate::record_batch::RecordBatch; use arrow::{ array::*, datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, }; use chrono::{TimeZone, Utc}; use futures::{ @@ -43,6 +43,7 @@ use crate::datasource::{ object_store::{FileMeta, ObjectStore, SizedFile}, MemTable, PartitionedFile, PartitionedFileStream, }; +use crate::field_util::SchemaExt; const FILE_SIZE_COLUMN_NAME: &str = "_df_part_file_size_"; const FILE_PATH_COLUMN_NAME: &str = "_df_part_file_path_"; diff --git a/datafusion/src/datasource/listing/table.rs b/datafusion/src/datasource/listing/table.rs index 22e3f750370c..b3a7122cf1ae 100644 --- a/datafusion/src/datasource/listing/table.rs +++ b/datafusion/src/datasource/listing/table.rs @@ -28,8 +28,8 @@ use crate::{ logical_plan::Expr, physical_plan::{ empty::EmptyExec, - file_format::{PhysicalPlanConfig, DEFAULT_PARTITION_COLUMN_DATATYPE}, - ExecutionPlan, Statistics, + file_format::{FileScanConfig, DEFAULT_PARTITION_COLUMN_DATATYPE}, + project_schema, ExecutionPlan, Statistics, }, }; @@ -37,6 +37,7 @@ use crate::datasource::{ datasource::TableProviderFilterPushDown, file_format::FileFormat, get_statistics_with_limit, object_store::ObjectStore, PartitionedFile, TableProvider, }; +use crate::field_util::SchemaExt; use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; @@ -125,7 +126,7 @@ impl ListingTable { options: ListingOptions, ) -> Self { // Add the partition columns to the file schema - let mut table_fields = file_schema.fields().clone(); + let mut table_fields = file_schema.fields().to_vec(); for part in &options.table_partition_cols { table_fields.push(Field::new( part, @@ -138,7 +139,7 @@ impl ListingTable { object_store, table_path, file_schema, - table_schema: Arc::new(Schema::new(table_fields)), + table_schema: Arc::new(Schema::new(table_fields.to_vec())), options, } } @@ -170,7 +171,6 @@ impl TableProvider for ListingTable { async fn scan( &self, projection: &Option>, - batch_size: usize, filters: &[Expr], limit: Option, ) -> Result> { @@ -180,12 +180,7 @@ impl TableProvider for ListingTable { // if no files need to be read, return an `EmptyExec` if partitioned_file_lists.is_empty() { let schema = self.schema(); - let projected_schema = match &projection { - None => schema, - Some(p) => Arc::new(Schema::new( - p.iter().map(|i| schema.field(*i).clone()).collect(), - )), - }; + let projected_schema = project_schema(&schema, projection.as_ref())?; return Ok(Arc::new(EmptyExec::new(false, projected_schema))); } @@ -193,13 +188,12 @@ impl TableProvider for ListingTable { self.options .format .create_physical_plan( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::clone(&self.object_store), file_schema: Arc::clone(&self.file_schema), file_groups: partitioned_file_lists, statistics, projection: projection.clone(), - batch_size, limit, table_partition_cols: self.options.table_partition_cols.clone(), }, @@ -289,7 +283,7 @@ mod tests { let table = load_table("alltypes_plain.parquet").await?; let projection = None; let exec = table - .scan(&projection, 1024, &[], None) + .scan(&projection, &[], None) .await .expect("Scan table"); @@ -313,7 +307,7 @@ mod tests { .await?; let table = ListingTable::new(Arc::new(LocalFileSystem {}), filename, schema, opt); - let exec = table.scan(&None, 1024, &[], None).await?; + let exec = table.scan(&None, &[], None).await?; assert_eq!(exec.statistics().num_rows, Some(8)); assert_eq!(exec.statistics().total_byte_size, Some(671)); @@ -345,7 +339,7 @@ mod tests { let filter = Expr::not_eq(col("p1"), lit("v1")); let scan = table - .scan(&None, 1024, &[filter], None) + .scan(&None, &[filter], None) .await .expect("Empty execution plan"); diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs index 57a71c33d584..4b1e09e68e71 100644 --- a/datafusion/src/datasource/memory.rs +++ b/datafusion/src/datasource/memory.rs @@ -23,12 +23,14 @@ use futures::StreamExt; use std::any::Any; use std::sync::Arc; +use crate::record_batch::RecordBatch; use arrow::datatypes::{Field, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; use async_trait::async_trait; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::{FieldExt, SchemaExt}; use crate::logical_plan::Expr; use crate::physical_plan::common; use crate::physical_plan::memory::MemoryExec; @@ -80,18 +82,19 @@ impl MemTable { /// Create a mem table by reading from another data source pub async fn load( t: Arc, - batch_size: usize, output_partitions: Option, + runtime: Arc, ) -> Result { let schema = t.schema(); - let exec = t.scan(&None, batch_size, &[], None).await?; + let exec = t.scan(&None, &[], None).await?; let partition_count = exec.output_partitioning().partition_count(); let tasks = (0..partition_count) .map(|part_i| { + let runtime1 = runtime.clone(); let exec = exec.clone(); tokio::spawn(async move { - let stream = exec.execute(part_i).await?; + let stream = exec.execute(part_i, runtime1.clone()).await?; common::collect(stream).await }) }) @@ -118,7 +121,7 @@ impl MemTable { let mut output_partitions = vec![]; for i in 0..exec.output_partitioning().partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i).await?; + let mut stream = exec.execute(i, runtime.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -145,7 +148,6 @@ impl TableProvider for MemTable { async fn scan( &self, projection: &Option>, - _batch_size: usize, _filters: &[Expr], _limit: Option, ) -> Result> { @@ -160,13 +162,16 @@ impl TableProvider for MemTable { #[cfg(test)] mod tests { use super::*; + use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; + use arrow::error::ArrowError; use futures::StreamExt; - use std::collections::HashMap; + use std::collections::BTreeMap; #[tokio::test] async fn test_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -187,8 +192,8 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; // scan with projection - let exec = provider.scan(&Some(vec![2, 1]), 1024, &[], None).await?; - let mut it = exec.execute(0).await?; + let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?; + let mut it = exec.execute(0, runtime).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -200,6 +205,7 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -217,8 +223,8 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; - let exec = provider.scan(&None, 1024, &[], None).await?; - let mut it = exec.execute(0).await?; + let exec = provider.scan(&None, &[], None).await?; + let mut it = exec.execute(0, runtime).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); @@ -247,11 +253,14 @@ mod tests { let projection: Vec = vec![0, 4]; - match provider.scan(&Some(projection), 1024, &[], None).await { - Err(DataFusionError::Internal(e)) => { - assert_eq!("\"Projection index out of range\"", format!("{:?}", e)) + match provider.scan(&Some(projection), &[], None).await { + Err(DataFusionError::ArrowError(ArrowError::InvalidArgumentError(e))) => { + assert_eq!( + "\"project index 4 out of bounds, max field 3\"", + format!("{:?}", e) + ) } - _ => panic!("Scan should failed on invalid projection"), + res => panic!("Scan should failed on invalid projection, got {:?}", res), }; Ok(()) @@ -325,18 +334,17 @@ mod tests { #[tokio::test] async fn test_merged_schema() -> Result<()> { - let mut metadata = HashMap::new(); + let runtime = Arc::new(RuntimeEnv::default()); + let mut metadata = BTreeMap::new(); metadata.insert("foo".to_string(), "bar".to_string()); - let schema1 = Schema::new_from( - vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ], - // test for comparing metadata - metadata, - ); + // test for comparing metadata + let schema1 = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]) + .with_metadata(metadata); let schema2 = Schema::new(vec![ // test for comparing nullability @@ -368,8 +376,8 @@ mod tests { let provider = MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?; - let exec = provider.scan(&None, 1024, &[], None).await?; - let mut it = exec.execute(0).await?; + let exec = provider.scan(&None, &[], None).await?; + let mut it = exec.execute(0, runtime).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); diff --git a/datafusion/src/datasource/mod.rs b/datafusion/src/datasource/mod.rs index 9f4f77f7ea28..f7119cebcb95 100644 --- a/datafusion/src/datasource/mod.rs +++ b/datafusion/src/datasource/mod.rs @@ -17,6 +17,7 @@ //! DataFusion data sources +#![allow(clippy::module_inception)] pub mod datasource; pub mod empty; pub mod file_format; @@ -31,6 +32,7 @@ pub use self::memory::MemTable; use self::object_store::{FileMeta, SizedFile}; use crate::arrow::datatypes::{Schema, SchemaRef}; use crate::error::Result; +use crate::field_util::SchemaExt; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; use crate::scalar::ScalarValue; @@ -70,7 +72,7 @@ pub async fn get_statistics_with_limit( if let Some(max_value) = &mut max_values[i] { if let Some(file_max) = cs.max_value.clone() { - match max_value.update(&[file_max]) { + match max_value.update_batch(&[file_max.to_array()]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -81,7 +83,7 @@ pub async fn get_statistics_with_limit( if let Some(min_value) = &mut min_values[i] { if let Some(file_min) = cs.min_value.clone() { - match min_value.update(&[file_min]) { + match min_value.update_batch(&[file_min.to_array()]) { Ok(_) => {} Err(_) => { min_values[i] = None; diff --git a/datafusion/src/datasource/object_store/local.rs b/datafusion/src/datasource/object_store/local.rs index 5d254496e542..902fa43c04db 100644 --- a/datafusion/src/datasource/object_store/local.rs +++ b/datafusion/src/datasource/object_store/local.rs @@ -28,8 +28,7 @@ use crate::datasource::object_store::{ FileMeta, FileMetaStream, ListEntryStream, ObjectReader, ObjectStore, ReadSeek, }; use crate::datasource::PartitionedFile; -use crate::error::DataFusionError; -use crate::error::Result; +use crate::error::{DataFusionError, Result}; use super::{ObjectReaderStream, SizedFile}; diff --git a/datafusion/src/datasource/object_store/mod.rs b/datafusion/src/datasource/object_store/mod.rs index 43f27102c5ec..c581b171a57b 100644 --- a/datafusion/src/datasource/object_store/mod.rs +++ b/datafusion/src/datasource/object_store/mod.rs @@ -125,7 +125,7 @@ pub type ListEntryStream = /// Stream readers opened on a given object store pub type ObjectReaderStream = - Pin>> + Send + Sync + 'static>>; + Pin>> + Send + Sync>>; /// A ObjectStore abstracts access to an underlying file/object storage. /// It maps strings (e.g. URLs, filesystem paths, etc) to sources of bytes @@ -174,7 +174,7 @@ pub struct ObjectStoreRegistry { } impl fmt::Debug for ObjectStoreRegistry { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("ObjectStoreRegistry") .field( "schemes", @@ -189,6 +189,12 @@ impl fmt::Debug for ObjectStoreRegistry { } } +impl Default for ObjectStoreRegistry { + fn default() -> Self { + Self::new() + } +} + impl ObjectStoreRegistry { /// Create the registry that object stores can registered into. /// ['LocalFileSystem'] store is registered in by default to support read local files natively. diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs index b5676669df00..fbad9a97d37b 100644 --- a/datafusion/src/error.rs +++ b/datafusion/src/error.rs @@ -29,6 +29,9 @@ use sqlparser::parser::ParserError; /// Result type for operations that could result in an [DataFusionError] pub type Result = result::Result; +/// Error type for generic operations that could result in DataFusionError::External +pub type GenericError = Box; + /// DataFusion error #[derive(Debug)] #[allow(missing_docs)] @@ -56,13 +59,12 @@ pub enum DataFusionError { /// Error returned during execution of the query. /// Examples include files not found, errors in parsing certain types. Execution(String), -} - -impl DataFusionError { - /// Wraps this [DataFusionError] as an [arrow::error::ArrowError]. - pub fn into_arrow_external_error(self) -> ArrowError { - ArrowError::from_external_error(Box::new(self)) - } + /// This error is thrown when a consumer cannot acquire memory from the Memory Manager + /// we can just cancel the execution of the partition. + ResourcesExhausted(String), + /// Errors originating from outside DataFusion's core codebase. + /// For example, a custom S3Error from the crate datafusion-objectstore-s3 + External(GenericError), } impl From for DataFusionError { @@ -77,6 +79,16 @@ impl From for DataFusionError { } } +impl From for ArrowError { + fn from(e: DataFusionError) -> Self { + match e { + DataFusionError::ArrowError(e) => e, + DataFusionError::External(e) => ArrowError::External("".to_string(), e), + other => ArrowError::External("".to_string(), Box::new(other)), + } + } +} + impl From for DataFusionError { fn from(e: ParquetError) -> Self { DataFusionError::ParquetError(e) @@ -89,8 +101,14 @@ impl From for DataFusionError { } } +impl From for DataFusionError { + fn from(err: GenericError) -> Self { + DataFusionError::External(err) + } +} + impl Display for DataFusionError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match *self { DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {}", desc), DataFusionError::ParquetError(ref desc) => { @@ -113,8 +131,58 @@ impl Display for DataFusionError { DataFusionError::Execution(ref desc) => { write!(f, "Execution error: {}", desc) } + DataFusionError::ResourcesExhausted(ref desc) => { + write!(f, "Resources exhausted: {}", desc) + } + DataFusionError::External(ref desc) => { + write!(f, "External error: {}", desc) + } } } } impl error::Error for DataFusionError {} + +#[cfg(test)] +mod test { + use crate::error::DataFusionError; + use arrow::error::ArrowError; + + #[test] + fn arrow_error_to_datafusion() { + let res = return_arrow_error().unwrap_err(); + assert_eq!( + res.to_string(), + "External error: Error during planning: foo" + ); + } + + #[test] + fn datafusion_error_to_arrow() { + let res = return_datafusion_error().unwrap_err(); + assert_eq!( + res.to_string(), + "Arrow error: Invalid argument error: Schema error: bar" + ); + } + + /// Model what happens when implementing SendableRecrordBatchStream: + /// DataFusion code needs to return an ArrowError + #[allow(clippy::try_err)] + fn return_arrow_error() -> arrow::error::Result<()> { + // Expect the '?' to work + let _foo = Err(DataFusionError::Plan("foo".to_string()))?; + Ok(()) + } + + /// Model what happens when using arrow kernels in DataFusion + /// code: need to turn an ArrowError into a DataFusionError + #[allow(clippy::try_err)] + fn return_datafusion_error() -> crate::error::Result<()> { + // Expect the '?' to work + let _bar = Err(ArrowError::InvalidArgumentError( + "Schema error: bar".to_string(), + ))?; + Ok(()) + } +} diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 89ea4380e1c0..2e70962c8360 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -39,7 +39,6 @@ use crate::{ }, }; use log::debug; -use std::fs; use std::path::Path; use std::string::String; use std::sync::Arc; @@ -47,17 +46,18 @@ use std::{ collections::{HashMap, HashSet}, sync::Mutex, }; +use std::{fs, path::PathBuf}; use futures::{StreamExt, TryStreamExt}; use tokio::task::{self, JoinHandle}; +use crate::record_batch::RecordBatch; use arrow::datatypes::SchemaRef; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::io::csv; use arrow::io::parquet; use arrow::io::parquet::write::FallibleStreamingIterator; use arrow::io::parquet::write::WriteOptions; -use arrow::record_batch::RecordBatch; use crate::catalog::{ catalog::{CatalogProvider, MemoryCatalogProvider}, @@ -82,6 +82,8 @@ use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use crate::physical_optimizer::repartition::Repartition; +use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use crate::field_util::{FieldExt, SchemaExt}; use crate::logical_plan::plan::Explain; use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::physical_plan::planner::DefaultPhysicalPlanner; @@ -97,7 +99,12 @@ use crate::{dataframe::DataFrame, physical_plan::udaf::AggregateUDF}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use super::options::{AvroReadOptions, CsvReadOptions}; +use super::{ + disk_manager::DiskManagerConfig, + memory_manager::MemoryManagerConfig, + options::{AvroReadOptions, CsvReadOptions}, + DiskManager, MemoryManager, +}; /// ExecutionContext is the main interface for executing queries with DataFusion. The context /// provides the following functionality: @@ -145,6 +152,12 @@ pub struct ExecutionContext { pub state: Arc>, } +impl Default for ExecutionContext { + fn default() -> Self { + Self::new() + } +} + impl ExecutionContext { /// Creates a new execution context using a default configuration. pub fn new() -> Self { @@ -176,6 +189,8 @@ impl ExecutionContext { .register_catalog(config.default_catalog.clone(), default_catalog); } + let runtime_env = Arc::new(RuntimeEnv::new(config.runtime.clone()).unwrap()); + Self { state: Arc::new(Mutex::new(ExecutionContextState { catalog_list, @@ -185,10 +200,16 @@ impl ExecutionContext { config, execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), + runtime_env, })), } } + /// Return the [RuntimeEnv] used to run queries with this [ExecutionContext] + pub fn runtime_env(&self) -> Arc { + self.state.lock().unwrap().runtime_env.clone() + } + /// Creates a dataframe that will execute a SQL query. /// /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` @@ -570,6 +591,7 @@ impl ExecutionContext { .unwrap() .object_store_registry .get_by_uri(uri) + .map_err(DataFusionError::from) } /// Registers a table using a custom `TableProvider` so that @@ -711,6 +733,7 @@ impl ExecutionContext { let path = path.as_ref(); // create directory to contain the CSV files (one per partition) let fs_path = Path::new(path); + let runtime = self.runtime_env(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -722,16 +745,24 @@ impl ExecutionContext { let mut writer = csv::write::WriterBuilder::new() .from_path(path) .map_err(ArrowError::from)?; - - csv::write::write_header(&mut writer, plan.schema().as_ref())?; + let mut field_names = vec![]; + let schema = plan.schema(); + for f in schema.fields() { + field_names.push(f.name()); + } + csv::write::write_header(&mut writer, &field_names)?; let options = csv::write::SerializeOptions::default(); - let stream = plan.execute(i).await?; + let stream = plan.execute(i, runtime.clone()).await?; let handle: JoinHandle> = task::spawn(async move { stream .map(|batch| { - csv::write::write_batch(&mut writer, &batch?, &options) + csv::write::write_chunk( + &mut writer, + &batch?.into(), + &options, + ) }) .try_collect() .await @@ -759,6 +790,7 @@ impl ExecutionContext { let path = path.as_ref(); // create directory to contain the Parquet files (one per partition) let fs_path = Path::new(path); + let runtime = self.runtime_env(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -769,7 +801,7 @@ impl ExecutionContext { let path = fs_path.join(&filename); let mut file = fs::File::create(path)?; - let stream = plan.execute(i).await?; + let stream = plan.execute(i, runtime.clone()).await?; let handle: JoinHandle> = task::spawn(async move { let parquet_schema = parquet::write::to_parquet_schema(&schema)?; @@ -777,13 +809,13 @@ impl ExecutionContext { let row_groups = stream.map(|batch: ArrowResult| { // map each record batch to a row group - batch.map(|batch| { + let r = batch.map(|batch| { let batch_cols = batch.columns().to_vec(); // column chunk in row group let pages = batch_cols .into_iter() - .zip(a.columns().to_vec().into_iter()) + .zip(a.columns().iter().cloned()) .map(move |(array, descriptor)| { parquet::write::array_to_pages( array.as_ref(), @@ -809,13 +841,14 @@ impl ExecutionContext { }) }); parquet::write::DynIter::new(pages) - }) + }); + async { r } }); Ok(parquet::write::stream::write_stream( &mut file, row_groups, - schema.as_ref(), + schema.as_ref().clone(), parquet_schema, options, None, @@ -913,8 +946,6 @@ impl QueryPlanner for DefaultQueryPlanner { pub struct ExecutionConfig { /// Number of partitions for query execution. Increasing partitions can increase concurrency. pub target_partitions: usize, - /// Default batch size when reading data sources - pub batch_size: usize, /// Responsible for optimizing a logical plan optimizers: Vec>, /// Responsible for optimizing a physical execution plan @@ -941,13 +972,14 @@ pub struct ExecutionConfig { pub repartition_windows: bool, /// Should Datafusion parquet reader using the predicate to prune data parquet_pruning: bool, + /// Runtime configurations such as memory threshold and local disk for spill + pub runtime: RuntimeConfig, } impl Default for ExecutionConfig { fn default() -> Self { Self { target_partitions: num_cpus::get(), - batch_size: 8192, optimizers: vec![ // Simplify expressions first to maximize the chance // of applying other optimizations @@ -975,6 +1007,7 @@ impl Default for ExecutionConfig { repartition_aggregations: true, repartition_windows: true, parquet_pruning: true, + runtime: RuntimeConfig::default(), } } } @@ -997,7 +1030,7 @@ impl ExecutionConfig { pub fn with_batch_size(mut self, n: usize) -> Self { // batch size must be greater than zero assert!(n > 0); - self.batch_size = n; + self.runtime.batch_size = n; self } @@ -1092,6 +1125,54 @@ impl ExecutionConfig { self.parquet_pruning = enabled; self } + + /// Customize runtime config + pub fn with_runtime_config(mut self, config: RuntimeConfig) -> Self { + self.runtime = config; + self + } + + /// Use an an existing [MemoryManager] + pub fn with_existing_memory_manager(mut self, existing: Arc) -> Self { + self.runtime = self + .runtime + .with_memory_manager(MemoryManagerConfig::new_existing(existing)); + self + } + + /// Specify the total memory to use while running the DataFusion + /// plan to `max_memory * memory_fraction` in bytes. + /// + /// Note DataFusion does not yet respect this limit in all cases. + pub fn with_memory_limit( + mut self, + max_memory: usize, + memory_fraction: f64, + ) -> Result { + self.runtime = + self.runtime + .with_memory_manager(MemoryManagerConfig::try_new_limit( + max_memory, + memory_fraction, + )?); + Ok(self) + } + + /// Use an an existing [DiskManager] + pub fn with_existing_disk_manager(mut self, existing: Arc) -> Self { + self.runtime = self + .runtime + .with_disk_manager(DiskManagerConfig::new_existing(existing)); + self + } + + /// Use the specified path to create any needed temporary files + pub fn with_temp_file_path(mut self, path: impl Into) -> Self { + self.runtime = self + .runtime + .with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])); + self + } } /// Holds per-execution properties and data (such as starting timestamps, etc). @@ -1103,6 +1184,27 @@ pub struct ExecutionProps { pub(crate) query_execution_start_time: DateTime, } +impl Default for ExecutionProps { + fn default() -> Self { + Self::new() + } +} + +impl ExecutionProps { + /// Creates a new execution props + pub fn new() -> Self { + ExecutionProps { + query_execution_start_time: chrono::Utc::now(), + } + } + + /// Marks the execution of query started timestamp + pub fn start_execution(&mut self) -> &Self { + self.query_execution_start_time = chrono::Utc::now(); + &*self + } +} + /// Execution context for registering data sources and executing queries #[derive(Clone)] pub struct ExecutionContextState { @@ -1120,20 +1222,13 @@ pub struct ExecutionContextState { pub execution_props: ExecutionProps, /// Object Store that are registered with the context pub object_store_registry: Arc, + /// Runtime environment + pub runtime_env: Arc, } -impl ExecutionProps { - /// Creates a new execution props - pub fn new() -> Self { - ExecutionProps { - query_execution_start_time: chrono::Utc::now(), - } - } - - /// Marks the execution of query started timestamp - pub fn start_execution(&mut self) -> &Self { - self.query_execution_start_time = chrono::Utc::now(); - &*self +impl Default for ExecutionContextState { + fn default() -> Self { + Self::new() } } @@ -1148,6 +1243,7 @@ impl ExecutionContextState { config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), + runtime_env: Arc::new(RuntimeEnv::default()), } } @@ -1231,11 +1327,14 @@ impl FunctionRegistry for ExecutionContextState { #[cfg(test)] mod tests { use super::*; + use crate::execution::context::QueryPlanner; + use crate::field_util::{FieldExt, SchemaExt}; use crate::logical_plan::plan::Projection; use crate::logical_plan::TableScan; use crate::logical_plan::{binary_expr, lit, Operator}; use crate::physical_plan::functions::{make_scalar_function, Volatility}; use crate::physical_plan::{collect, collect_partitioned}; + use crate::record_batch::RecordBatch; use crate::test; use crate::variable::VarType; use crate::{ @@ -1248,14 +1347,15 @@ mod tests { physical_plan::expressions::AvgAccumulator, }; use arrow::array::*; + use arrow::chunk::Chunk; use arrow::compute::arithmetics::basic::add; use arrow::datatypes::*; use arrow::io::parquet::write::{ to_parquet_schema, write_file, Compression, Encoding, RowGroupIterator, Version, WriteOptions, }; - use arrow::record_batch::RecordBatch; use async_trait::async_trait; + use std::collections::BTreeMap; use std::fs::File; use std::sync::Weak; use std::thread::{self, JoinHandle}; @@ -1263,6 +1363,40 @@ mod tests { use tempfile::TempDir; use test::*; + #[tokio::test] + async fn shared_memory_and_disk_manager() { + // Demonstrate the ability to share DiskManager and + // MemoryManager between two different executions. + let ctx1 = ExecutionContext::new(); + + // configure with same memory / disk manager + let memory_manager = ctx1.runtime_env().memory_manager.clone(); + let disk_manager = ctx1.runtime_env().disk_manager.clone(); + let config = ExecutionConfig::new() + .with_existing_memory_manager(memory_manager.clone()) + .with_existing_disk_manager(disk_manager.clone()); + + let ctx2 = ExecutionContext::with_config(config); + + assert!(std::ptr::eq( + Arc::as_ptr(&memory_manager), + Arc::as_ptr(&ctx1.runtime_env().memory_manager) + )); + assert!(std::ptr::eq( + Arc::as_ptr(&memory_manager), + Arc::as_ptr(&ctx2.runtime_env().memory_manager) + )); + + assert!(std::ptr::eq( + Arc::as_ptr(&disk_manager), + Arc::as_ptr(&ctx1.runtime_env().disk_manager) + )); + assert!(std::ptr::eq( + Arc::as_ptr(&disk_manager), + Arc::as_ptr(&ctx2.runtime_env().disk_manager) + )); + } + #[test] fn optimize_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); @@ -1413,7 +1547,8 @@ mod tests { let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - let results = collect_partitioned(physical_plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect_partitioned(physical_plan, runtime).await?; // note that the order of partitions is not deterministic let mut num_rows = 0; @@ -1461,6 +1596,7 @@ mod tests { let tmp_dir = TempDir::new()?; let partition_count = 4; let ctx = create_ctx(&tmp_dir, partition_count).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -1490,9 +1626,9 @@ mod tests { let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); + assert_eq!("c2", physical_plan.schema().field(0).name()); - let batches = collect(physical_plan).await?; + let batches = collect(physical_plan, runtime).await?; assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); Ok(()) @@ -1566,9 +1702,10 @@ mod tests { let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("b", physical_plan.schema().field(0).name().as_str()); + assert_eq!("b", physical_plan.schema().field(0).name()); - let batches = collect(physical_plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let batches = collect(physical_plan, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(4, batches[0].num_rows()); @@ -2815,11 +2952,8 @@ mod tests { let plan = ctx.optimize(&plan)?; let physical_plan = ctx.create_physical_plan(&Arc::new(plan)).await?; - assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); - assert_eq!( - "total_salary", - physical_plan.schema().field(1).name().as_str() - ); + assert_eq!("c1", physical_plan.schema().field(0).name()); + assert_eq!("total_salary", physical_plan.schema().field(1).name()); Ok(()) } @@ -3346,7 +3480,8 @@ mod tests { let plan = ctx.optimize(&plan)?; let plan = ctx.create_physical_plan(&plan).await?; - let result = collect(plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let result = collect(plan, runtime).await?; let expected = vec![ "+-----+-----+-----------------+", @@ -3633,7 +3768,6 @@ mod tests { async fn scan( &self, _: &Option>, - _: usize, _: &[Expr], _: Option, ) -> Result> { @@ -4139,7 +4273,7 @@ mod tests { let table_dir = tmp_dir.path().join("parquet_test"); let table_path = Path::new(&table_dir); - let mut non_empty_metadata: HashMap = HashMap::new(); + let mut non_empty_metadata: BTreeMap = BTreeMap::new(); non_empty_metadata.insert("testing".to_string(), "metadata".to_string()); let fields = vec![ @@ -4147,7 +4281,9 @@ mod tests { Field::new("name", DataType::Utf8, true), ]; let schemas = vec![ - Arc::new(Schema::new_from(fields.clone(), non_empty_metadata.clone())), + Arc::new( + Schema::new(fields.clone()).with_metadata(non_empty_metadata.clone()), + ), Arc::new(Schema::new(fields.clone())), ]; @@ -4166,12 +4302,9 @@ mod tests { // create mock record batch let ids = Arc::new(Int32Array::from_slice(vec![i as i32])); let names = Arc::new(Utf8Array::::from_slice(vec!["test"])); - let rec_batch = - RecordBatch::try_new(schema.clone(), vec![ids, names]).unwrap(); - let schema_ref = schema.as_ref(); let parquet_schema = to_parquet_schema(schema_ref).unwrap(); - let iter = vec![Ok(rec_batch)]; + let iter = vec![Ok(Chunk::new(vec![ids as ArrayRef, names as ArrayRef]))]; let row_groups = RowGroupIterator::try_new( iter.into_iter(), schema_ref, @@ -4269,7 +4402,7 @@ mod tests { let logical_plan = ctx.create_logical_plan(sql)?; let logical_plan = ctx.optimize(&logical_plan)?; let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - ctx.write_csv(physical_plan, out_dir.to_string()).await + ctx.write_csv(physical_plan, out_dir).await } /// Execute SQL and write results to partitioned parquet files @@ -4289,8 +4422,7 @@ mod tests { version: parquet::write::Version::V1, }); - ctx.write_parquet(physical_plan, out_dir.to_string(), options) - .await + ctx.write_parquet(physical_plan, out_dir, options).await } /// Generate CSV partitions within the supplied directory diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 4cf427d1be2b..f097ca9bf3a3 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -25,11 +25,11 @@ use crate::logical_plan::{ col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, }; +use crate::record_batch::RecordBatch; use crate::{ dataframe::*, physical_plan::{collect, collect_partitioned}, }; -use arrow::record_batch::RecordBatch; use crate::physical_plan::{ execute_stream, execute_stream_partitioned, ExecutionPlan, SendableRecordBatchStream, @@ -161,7 +161,8 @@ impl DataFrame for DataFrameImpl { /// execute it, collecting all resulting batches into memory async fn collect(&self) -> Result> { let plan = self.create_physical_plan().await?; - Ok(collect(plan).await?) + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + Ok(collect(plan, runtime).await?) } /// Print results. @@ -182,7 +183,8 @@ impl DataFrame for DataFrameImpl { /// execute it, returning a stream over a single partition async fn execute_stream(&self) -> Result { let plan = self.create_physical_plan().await?; - execute_stream(plan).await + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + execute_stream(plan, runtime).await } /// Convert the logical plan represented by this DataFrame into a physical plan and @@ -190,14 +192,16 @@ impl DataFrame for DataFrameImpl { /// partitioning async fn collect_partitioned(&self) -> Result>> { let plan = self.create_physical_plan().await?; - Ok(collect_partitioned(plan).await?) + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + Ok(collect_partitioned(plan, runtime).await?) } /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, returning a stream for each partition async fn execute_stream_partitioned(&self) -> Result> { let plan = self.create_physical_plan().await?; - Ok(execute_stream_partitioned(plan).await?) + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + Ok(execute_stream_partitioned(plan, runtime).await?) } /// Returns the schema from the logical plan diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs new file mode 100644 index 000000000000..79b70f1f8b9a --- /dev/null +++ b/datafusion/src/execution/disk_manager.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages files generated during query execution, files are +//! hashed among the directories listed in RuntimeConfig::local_dirs. + +use crate::error::{DataFusionError, Result}; +use log::{debug, info}; +use rand::{thread_rng, Rng}; +use std::path::PathBuf; +use std::sync::Arc; +use tempfile::{Builder, NamedTempFile, TempDir}; + +/// Configuration for temporary disk access +#[derive(Debug, Clone)] +pub enum DiskManagerConfig { + /// Use the provided [DiskManager] instance + Existing(Arc), + + /// Create a new [DiskManager] that creates temporary files within + /// a temporary directory chosen by the OS + NewOs, + + /// Create a new [DiskManager] that creates temporary files within + /// the specified directories + NewSpecified(Vec), +} + +impl Default for DiskManagerConfig { + fn default() -> Self { + Self::NewOs + } +} + +impl DiskManagerConfig { + /// Create temporary files in a temporary directory chosen by the OS + pub fn new() -> Self { + Self::default() + } + + /// Create temporary files using the provided disk manager + pub fn new_existing(existing: Arc) -> Self { + Self::Existing(existing) + } + + /// Create temporary files in the specified directories + pub fn new_specified(paths: Vec) -> Self { + Self::NewSpecified(paths) + } +} + +/// Manages files generated during query execution, e.g. spill files generated +/// while processing dataset larger than available memory. +#[derive(Debug)] +pub struct DiskManager { + local_dirs: Vec, +} + +impl DiskManager { + /// Create a DiskManager given the configuration + pub fn try_new(config: DiskManagerConfig) -> Result> { + match config { + DiskManagerConfig::Existing(manager) => Ok(manager), + DiskManagerConfig::NewOs => { + let tempdir = tempfile::tempdir().map_err(DataFusionError::IoError)?; + + debug!( + "Created directory {:?} as DataFusion working directory", + tempdir + ); + Ok(Arc::new(Self { + local_dirs: vec![tempdir], + })) + } + DiskManagerConfig::NewSpecified(conf_dirs) => { + let local_dirs = create_local_dirs(conf_dirs)?; + info!( + "Created local dirs {:?} as DataFusion working directory", + local_dirs + ); + Ok(Arc::new(Self { local_dirs })) + } + } + } + + /// Return a temporary file from a randomized choice in the configured locations + pub fn create_tmp_file(&self) -> Result { + create_tmp_file(&self.local_dirs) + } +} + +/// Setup local dirs by creating one new dir in each of the given dirs +fn create_local_dirs(local_dirs: Vec) -> Result> { + local_dirs + .iter() + .map(|root| { + Builder::new() + .prefix("datafusion-") + .tempdir_in(root) + .map_err(DataFusionError::IoError) + }) + .collect() +} + +fn create_tmp_file(local_dirs: &[TempDir]) -> Result { + let dir_index = thread_rng().gen_range(0..local_dirs.len()); + let dir = local_dirs.get(dir_index).ok_or_else(|| { + DataFusionError::Internal("No directories available to DiskManager".into()) + })?; + + Builder::new() + .tempfile_in(dir) + .map_err(DataFusionError::IoError) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use tempfile::TempDir; + + #[test] + fn file_in_right_dir() -> Result<()> { + let local_dir1 = TempDir::new()?; + let local_dir2 = TempDir::new()?; + let local_dir3 = TempDir::new()?; + let local_dirs = vec![local_dir1.path(), local_dir2.path(), local_dir3.path()]; + let config = DiskManagerConfig::new_specified( + local_dirs.iter().map(|p| p.into()).collect(), + ); + + let dm = DiskManager::try_new(config)?; + let actual = dm.create_tmp_file()?; + + // the file should be in one of the specified local directories + let found = local_dirs.iter().any(|p| { + actual + .path() + .ancestors() + .any(|candidate_path| *p == candidate_path) + }); + + assert!( + found, + "Can't find {:?} in specified local dirs: {:?}", + actual, local_dirs + ); + + Ok(()) + } +} diff --git a/datafusion/src/execution/memory_manager.rs b/datafusion/src/execution/memory_manager.rs new file mode 100644 index 000000000000..32f79750a70d --- /dev/null +++ b/datafusion/src/execution/memory_manager.rs @@ -0,0 +1,601 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages all available memory during query execution + +use crate::error::{DataFusionError, Result}; +use async_trait::async_trait; +use hashbrown::HashMap; +use log::info; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Condvar, Mutex, Weak}; + +static CONSUMER_ID: AtomicUsize = AtomicUsize::new(0); + +#[derive(Debug, Clone)] +/// Configuration information for memory management +pub enum MemoryManagerConfig { + /// Use the existing [MemoryManager] + Existing(Arc), + + /// Create a new [MemoryManager] that will use up to some + /// fraction of total system memory. + New { + /// Max execution memory allowed for DataFusion. Defaults to + /// `usize::MAX`, which will not attempt to limit the memory + /// used during plan execution. + max_memory: usize, + + /// The fraction of `max_memory` that the memory manager will + /// use for execution. + /// + /// The purpose of this config is to set aside memory for + /// untracked data structures, and imprecise size estimation + /// during memory acquisition. Defaults to 0.7 + memory_fraction: f64, + }, +} + +impl Default for MemoryManagerConfig { + fn default() -> Self { + Self::New { + max_memory: usize::MAX, + memory_fraction: 0.7, + } + } +} + +impl MemoryManagerConfig { + /// Create a new memory [MemoryManager] with no limit on the + /// memory used + pub fn new() -> Self { + Default::default() + } + + /// Create a configuration based on an existing [MemoryManager] + pub fn new_existing(existing: Arc) -> Self { + Self::Existing(existing) + } + + /// Create a new [MemoryManager] with a `max_memory` and `fraction` + pub fn try_new_limit(max_memory: usize, memory_fraction: f64) -> Result { + if max_memory == 0 { + return Err(DataFusionError::Plan(format!( + "invalid max_memory. Expected greater than 0, got {}", + max_memory + ))); + } + if !(memory_fraction > 0f64 && memory_fraction <= 1f64) { + return Err(DataFusionError::Plan(format!( + "invalid fraction. Expected greater than 0 and less than 1.0, got {}", + memory_fraction + ))); + } + + Ok(Self::New { + max_memory, + memory_fraction, + }) + } + + /// return the maximum size of the memory, in bytes, this config will allow + fn pool_size(&self) -> usize { + match self { + MemoryManagerConfig::Existing(existing) => existing.pool_size, + MemoryManagerConfig::New { + max_memory, + memory_fraction, + } => (*max_memory as f64 * *memory_fraction) as usize, + } + } +} + +fn next_id() -> usize { + CONSUMER_ID.fetch_add(1, Ordering::SeqCst) +} + +/// Type of the memory consumer +pub enum ConsumerType { + /// consumers that can grow its memory usage by requesting more from the memory manager or + /// shrinks its memory usage when we can no more assign available memory to it. + /// Examples are spillable sorter, spillable hashmap, etc. + Requesting, + /// consumers that are not spillable, counting in for only tracking purpose. + Tracking, +} + +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +/// Id that uniquely identifies a Memory Consumer +pub struct MemoryConsumerId { + /// partition the consumer belongs to + pub partition_id: usize, + /// unique id + pub id: usize, +} + +impl MemoryConsumerId { + /// Auto incremented new Id + pub fn new(partition_id: usize) -> Self { + let id = next_id(); + Self { partition_id, id } + } +} + +impl Display for MemoryConsumerId { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}:{}", self.partition_id, self.id) + } +} + +#[async_trait] +/// A memory consumer that either takes up memory (of type `ConsumerType::Tracking`) +/// or grows/shrinks memory usage based on available memory (of type `ConsumerType::Requesting`). +pub trait MemoryConsumer: Send + Sync { + /// Display name of the consumer + fn name(&self) -> String; + + /// Unique id of the consumer + fn id(&self) -> &MemoryConsumerId; + + /// Ptr to MemoryManager + fn memory_manager(&self) -> Arc; + + /// Partition that the consumer belongs to + fn partition_id(&self) -> usize { + self.id().partition_id + } + + /// Type of the consumer + fn type_(&self) -> &ConsumerType; + + /// Grow memory by `required` to buffer more data in memory, + /// this may trigger spill before grow when the memory threshold is + /// reached for this consumer. + async fn try_grow(&self, required: usize) -> Result<()> { + let current = self.mem_used(); + info!( + "trying to acquire {} whiling holding {} from consumer {}", + human_readable_size(required), + human_readable_size(current), + self.id(), + ); + + let can_grow_directly = self + .memory_manager() + .can_grow_directly(required, current) + .await; + if !can_grow_directly { + info!( + "Failed to grow memory of {} directly from consumer {}, spilling first ...", + human_readable_size(required), + self.id() + ); + let freed = self.spill().await?; + self.memory_manager() + .record_free_then_acquire(freed, required); + } + Ok(()) + } + + /// Spill in-memory buffers to disk, free memory, return the previous used + async fn spill(&self) -> Result; + + /// Current memory used by this consumer + fn mem_used(&self) -> usize; +} + +impl Debug for dyn MemoryConsumer { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "{}[{}]: {}", + self.name(), + self.id(), + human_readable_size(self.mem_used()) + ) + } +} + +impl Display for dyn MemoryConsumer { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}[{}]", self.name(), self.id(),) + } +} + +/* +The memory management architecture is the following: + +1. User designates max execution memory by setting RuntimeConfig.max_memory and RuntimeConfig.memory_fraction (float64 between 0..1). + The actual max memory DataFusion could use `pool_size = max_memory * memory_fraction`. +2. The entities that take up memory during its execution are called 'Memory Consumers'. Operators or others are encouraged to + register themselves to the memory manager and report its usage through `mem_used()`. +3. There are two kinds of consumers: + - 'Requesting' consumers that would acquire memory during its execution and release memory through `spill` if no more memory is available. + - 'Tracking' consumers that exist for reporting purposes to provide a more accurate memory usage estimation for memory consumers. +4. Requesting and tracking consumers share the pool. Each controlling consumer could acquire a maximum of + (pool_size - all_tracking_used) / active_num_controlling_consumers. + + Memory Space for the DataFusion Lib / Process of `pool_size` + ┌──────────────────────────────────────────────z─────────────────────────────┐ + │ z │ + │ z │ + │ Requesting z Tracking │ + │ Memory Consumers z Memory Consumers │ + │ z │ + │ z │ + └──────────────────────────────────────────────z─────────────────────────────┘ +*/ + +/// Manage memory usage during physical plan execution +#[derive(Debug)] +pub struct MemoryManager { + requesters: Arc>>>, + trackers: Arc>>>, + pool_size: usize, + requesters_total: Arc>, + cv: Condvar, +} + +impl MemoryManager { + /// Create new memory manager based on the configuration + #[allow(clippy::mutex_atomic)] + pub fn new(config: MemoryManagerConfig) -> Arc { + let pool_size = config.pool_size(); + + match config { + MemoryManagerConfig::Existing(manager) => manager, + MemoryManagerConfig::New { .. } => { + info!( + "Creating memory manager with initial size {}", + human_readable_size(pool_size) + ); + + Arc::new(Self { + requesters: Arc::new(Mutex::new(HashMap::new())), + trackers: Arc::new(Mutex::new(HashMap::new())), + pool_size, + requesters_total: Arc::new(Mutex::new(0)), + cv: Condvar::new(), + }) + } + } + } + + fn get_tracker_total(&self) -> usize { + let trackers = self.trackers.lock().unwrap(); + if trackers.len() > 0 { + trackers.values().fold(0usize, |acc, y| match y.upgrade() { + None => acc, + Some(t) => acc + t.mem_used(), + }) + } else { + 0 + } + } + + /// Register a new memory consumer for memory usage tracking + pub(crate) fn register_consumer(&self, consumer: &Arc) { + let id = consumer.id().clone(); + match consumer.type_() { + ConsumerType::Requesting => { + let mut requesters = self.requesters.lock().unwrap(); + requesters.insert(id, Arc::downgrade(consumer)); + } + ConsumerType::Tracking => { + let mut trackers = self.trackers.lock().unwrap(); + trackers.insert(id, Arc::downgrade(consumer)); + } + } + } + + fn max_mem_for_requesters(&self) -> usize { + let trk_total = self.get_tracker_total(); + self.pool_size - trk_total + } + + /// Grow memory attempt from a consumer, return if we could grant that much to it + async fn can_grow_directly(&self, required: usize, current: usize) -> bool { + let num_rqt = self.requesters.lock().unwrap().len(); + let mut rqt_current_used = self.requesters_total.lock().unwrap(); + let mut rqt_max = self.max_mem_for_requesters(); + + let granted; + loop { + let remaining = rqt_max - *rqt_current_used; + let max_per_rqt = rqt_max / num_rqt; + let min_per_rqt = max_per_rqt / 2; + + if required + current >= max_per_rqt { + granted = false; + break; + } + + if remaining >= required { + granted = true; + *rqt_current_used += required; + break; + } else if current < min_per_rqt { + // if we cannot acquire at lease 1/2n memory, just wait for others + // to spill instead spill self frequently with limited total mem + rqt_current_used = self.cv.wait(rqt_current_used).unwrap(); + } else { + granted = false; + break; + } + + rqt_max = self.max_mem_for_requesters(); + } + + granted + } + + fn record_free_then_acquire(&self, freed: usize, acquired: usize) { + let mut requesters_total = self.requesters_total.lock().unwrap(); + *requesters_total -= freed; + *requesters_total += acquired; + self.cv.notify_all() + } + + /// Drop a memory consumer from memory usage tracking + pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId) { + // find in requesters first + { + let mut requesters = self.requesters.lock().unwrap(); + if requesters.remove(id).is_some() { + return; + } + } + let mut trackers = self.trackers.lock().unwrap(); + trackers.remove(id); + } +} + +impl Display for MemoryManager { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let requesters = + self.requesters + .lock() + .unwrap() + .values() + .fold(vec![], |mut acc, consumer| match consumer.upgrade() { + None => acc, + Some(c) => { + acc.push(format!("{}", c)); + acc + } + }); + let tracker_mem = self.get_tracker_total(); + write!(f, + "MemoryManager usage statistics: total {}, tracker used {}, total {} requesters detail: \n {},", + human_readable_size(self.pool_size), + human_readable_size(tracker_mem), + &requesters.len(), + requesters.join("\n")) + } +} + +const TB: u64 = 1 << 40; +const GB: u64 = 1 << 30; +const MB: u64 = 1 << 20; +const KB: u64 = 1 << 10; + +fn human_readable_size(size: usize) -> String { + let size = size as u64; + let (value, unit) = { + if size >= 2 * TB { + (size as f64 / TB as f64, "TB") + } else if size >= 2 * GB { + (size as f64 / GB as f64, "GB") + } else if size >= 2 * MB { + (size as f64 / MB as f64, "MB") + } else if size >= 2 * KB { + (size as f64 / KB as f64, "KB") + } else { + (size as f64, "B") + } + }; + format!("{:.1} {}", value, unit) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use async_trait::async_trait; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + struct DummyRequester { + id: MemoryConsumerId, + runtime: Arc, + spills: AtomicUsize, + mem_used: AtomicUsize, + } + + impl DummyRequester { + fn new(partition: usize, runtime: Arc) -> Self { + Self { + id: MemoryConsumerId::new(partition), + runtime, + spills: AtomicUsize::new(0), + mem_used: AtomicUsize::new(0), + } + } + + async fn do_with_mem(&self, grow: usize) -> Result<()> { + self.try_grow(grow).await?; + self.mem_used.fetch_add(grow, Ordering::SeqCst); + Ok(()) + } + + fn get_spills(&self) -> usize { + self.spills.load(Ordering::SeqCst) + } + } + + #[async_trait] + impl MemoryConsumer for DummyRequester { + fn name(&self) -> String { + "dummy".to_owned() + } + + fn id(&self) -> &MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Requesting + } + + async fn spill(&self) -> Result { + self.spills.fetch_add(1, Ordering::SeqCst); + let used = self.mem_used.swap(0, Ordering::SeqCst); + Ok(used) + } + + fn mem_used(&self) -> usize { + self.mem_used.load(Ordering::SeqCst) + } + } + + struct DummyTracker { + id: MemoryConsumerId, + runtime: Arc, + mem_used: usize, + } + + impl DummyTracker { + fn new(partition: usize, runtime: Arc, mem_used: usize) -> Self { + Self { + id: MemoryConsumerId::new(partition), + runtime, + mem_used, + } + } + } + + #[async_trait] + impl MemoryConsumer for DummyTracker { + fn name(&self) -> String { + "dummy".to_owned() + } + + fn id(&self) -> &MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Tracking + } + + async fn spill(&self) -> Result { + Ok(0) + } + + fn mem_used(&self) -> usize { + self.mem_used + } + } + + #[tokio::test] + async fn basic_functionalities() { + let config = RuntimeConfig::new() + .with_memory_manager(MemoryManagerConfig::try_new_limit(100, 1.0).unwrap()); + let runtime = Arc::new(RuntimeEnv::new(config).unwrap()); + + let tracker1 = Arc::new(DummyTracker::new(0, runtime.clone(), 5)); + runtime.register_consumer(&(tracker1.clone() as Arc)); + assert_eq!(runtime.memory_manager.get_tracker_total(), 5); + + let tracker2 = Arc::new(DummyTracker::new(0, runtime.clone(), 10)); + runtime.register_consumer(&(tracker2.clone() as Arc)); + assert_eq!(runtime.memory_manager.get_tracker_total(), 15); + + let tracker3 = Arc::new(DummyTracker::new(0, runtime.clone(), 15)); + runtime.register_consumer(&(tracker3.clone() as Arc)); + assert_eq!(runtime.memory_manager.get_tracker_total(), 30); + + runtime.drop_consumer(tracker2.id()); + assert_eq!(runtime.memory_manager.get_tracker_total(), 20); + + let requester1 = Arc::new(DummyRequester::new(0, runtime.clone())); + runtime.register_consumer(&(requester1.clone() as Arc)); + + // first requester entered, should be able to use any of the remaining 80 + requester1.do_with_mem(40).await.unwrap(); + requester1.do_with_mem(10).await.unwrap(); + assert_eq!(requester1.get_spills(), 0); + assert_eq!(requester1.mem_used(), 50); + assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 50); + + let requester2 = Arc::new(DummyRequester::new(0, runtime.clone())); + runtime.register_consumer(&(requester2.clone() as Arc)); + + requester2.do_with_mem(20).await.unwrap(); + requester2.do_with_mem(30).await.unwrap(); + assert_eq!(requester2.get_spills(), 1); + assert_eq!(requester2.mem_used(), 30); + + requester1.do_with_mem(10).await.unwrap(); + assert_eq!(requester1.get_spills(), 1); + assert_eq!(requester1.mem_used(), 10); + + assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 40); + } + + #[tokio::test] + #[should_panic(expected = "invalid max_memory. Expected greater than 0, got 0")] + async fn test_try_new_with_limit_0() { + MemoryManagerConfig::try_new_limit(0, 1.0).unwrap(); + } + + #[tokio::test] + #[should_panic( + expected = "invalid fraction. Expected greater than 0 and less than 1.0, got -9.6" + )] + async fn test_try_new_with_limit_neg_fraction() { + MemoryManagerConfig::try_new_limit(100, -9.6).unwrap(); + } + + #[tokio::test] + #[should_panic( + expected = "invalid fraction. Expected greater than 0 and less than 1.0, got 9.6" + )] + async fn test_try_new_with_limit_too_large() { + MemoryManagerConfig::try_new_limit(100, 9.6).unwrap(); + } + + #[tokio::test] + async fn test_try_new_with_limit_pool_size() { + let config = MemoryManagerConfig::try_new_limit(100, 0.5).unwrap(); + assert_eq!(config.pool_size(), 50); + + let config = MemoryManagerConfig::try_new_limit(100000, 0.1).unwrap(); + assert_eq!(config.pool_size(), 10000); + } +} diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index e353a3160b8d..e3b42ae254a9 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -19,4 +19,10 @@ pub mod context; pub mod dataframe_impl; +pub(crate) mod disk_manager; +pub mod memory_manager; pub mod options; +pub mod runtime_env; + +pub use disk_manager::DiskManager; +pub use memory_manager::{MemoryConsumer, MemoryConsumerId, MemoryManager}; diff --git a/datafusion/src/execution/options.rs b/datafusion/src/execution/options.rs index c6b5ff646ea3..219e2fd89700 100644 --- a/datafusion/src/execution/options.rs +++ b/datafusion/src/execution/options.rs @@ -46,6 +46,12 @@ pub struct CsvReadOptions<'a> { pub file_extension: &'a str, } +impl<'a> Default for CsvReadOptions<'a> { + fn default() -> Self { + Self::new() + } +} + impl<'a> CsvReadOptions<'a> { /// Create a CSV read option with default presets pub fn new() -> Self { diff --git a/datafusion/src/execution/runtime_env.rs b/datafusion/src/execution/runtime_env.rs new file mode 100644 index 000000000000..cdcd1f71b4f5 --- /dev/null +++ b/datafusion/src/execution/runtime_env.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution runtime environment that tracks memory, disk and various configurations +//! that are used during physical plan execution. + +use crate::{ + error::Result, + execution::{ + disk_manager::{DiskManager, DiskManagerConfig}, + memory_manager::{ + MemoryConsumer, MemoryConsumerId, MemoryManager, MemoryManagerConfig, + }, + }, +}; + +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +#[derive(Clone)] +/// Execution runtime environment. This structure is passed to the +/// physical plans when they are run. +pub struct RuntimeEnv { + /// Default batch size while creating new batches + pub batch_size: usize, + /// Runtime memory management + pub memory_manager: Arc, + /// Manage temporary files during query execution + pub disk_manager: Arc, +} + +impl Debug for RuntimeEnv { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "RuntimeEnv") + } +} + +impl RuntimeEnv { + /// Create env based on configuration + pub fn new(config: RuntimeConfig) -> Result { + let RuntimeConfig { + batch_size, + memory_manager, + disk_manager, + } = config; + + Ok(Self { + batch_size, + memory_manager: MemoryManager::new(memory_manager), + disk_manager: DiskManager::try_new(disk_manager)?, + }) + } + + /// Get execution batch size based on config + pub fn batch_size(&self) -> usize { + self.batch_size + } + + /// Register the consumer to get it tracked + pub fn register_consumer(&self, memory_consumer: &Arc) { + self.memory_manager.register_consumer(memory_consumer); + } + + /// Drop the consumer from get tracked + pub fn drop_consumer(&self, id: &MemoryConsumerId) { + self.memory_manager.drop_consumer(id) + } +} + +impl Default for RuntimeEnv { + fn default() -> Self { + RuntimeEnv::new(RuntimeConfig::new()).unwrap() + } +} + +#[derive(Clone)] +/// Execution runtime configuration +pub struct RuntimeConfig { + /// Default batch size while creating new batches, it's especially useful + /// for buffer-in-memory batches since creating tiny batches would results + /// in too much metadata memory consumption. + pub batch_size: usize, + /// DiskManager to manage temporary disk file usage + pub disk_manager: DiskManagerConfig, + /// MemoryManager to limit access to memory + pub memory_manager: MemoryManagerConfig, +} + +impl RuntimeConfig { + /// New with default values + pub fn new() -> Self { + Default::default() + } + + /// Customize batch size + pub fn with_batch_size(mut self, n: usize) -> Self { + // batch size must be greater than zero + assert!(n > 0); + self.batch_size = n; + self + } + + /// Customize disk manager + pub fn with_disk_manager(mut self, disk_manager: DiskManagerConfig) -> Self { + self.disk_manager = disk_manager; + self + } + + /// Customize memory manager + pub fn with_memory_manager(mut self, memory_manager: MemoryManagerConfig) -> Self { + self.memory_manager = memory_manager; + self + } +} + +impl Default for RuntimeConfig { + fn default() -> Self { + Self { + batch_size: 8192, + disk_manager: DiskManagerConfig::default(), + memory_manager: MemoryManagerConfig::default(), + } + } +} diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index 301925227722..2dfccb73092d 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -18,8 +18,10 @@ //! Utility functions for complex field access use arrow::array::{ArrayRef, StructArray}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Metadata, Schema}; +use arrow::error::ArrowError; use std::borrow::Borrow; +use std::collections::BTreeMap; use crate::error::{DataFusionError, Result}; use crate::scalar::ScalarValue; @@ -109,3 +111,380 @@ pub fn struct_array_from(pairs: Vec<(Field, ArrayRef)>) -> StructArray { let values = pairs.iter().map(|v| v.1.clone()).collect(); StructArray::from_data(DataType::Struct(fields), values, None) } + +/// Imitate arrow-rs Schema behavior by extending arrow2 Schema +pub trait SchemaExt { + /// Creates a new [`Schema`] from a sequence of [`Field`] values. + /// + /// # Example + /// + /// ``` + /// use arrow::datatypes::{Field, DataType, Schema}; + /// use datafusion::field_util::SchemaExt; + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let schema = Schema::new(vec![field_a, field_b]); + /// ``` + fn new(fields: Vec) -> Self; + + /// Creates a new [`Schema`] from a sequence of [`Field`] values and [`arrow::datatypes::Metadata`] + /// + /// # Example + /// + /// ``` + /// use std::collections::BTreeMap; + /// use arrow::datatypes::{Field, DataType, Schema}; + /// use datafusion::field_util::SchemaExt; + /// + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let schema_metadata: BTreeMap = + /// vec![("baz".to_string(), "barf".to_string())] + /// .into_iter() + /// .collect(); + /// let schema = Schema::new_with_metadata(vec![field_a, field_b], schema_metadata); + /// ``` + fn new_with_metadata(fields: Vec, metadata: Metadata) -> Self; + + /// Creates an empty [`Schema`]. + fn empty() -> Self; + + /// Look up a column by name and return a immutable reference to the column along with + /// its index. + fn column_with_name(&self, name: &str) -> Option<(usize, &Field)>; + + /// Returns the first [`Field`] named `name`. + fn field_with_name(&self, name: &str) -> Result<&Field>; + + /// Find the index of the column with the given name. + fn index_of(&self, name: &str) -> Result; + + /// Returns the [`Field`] at position `i`. + /// # Panics + /// Panics iff `i` is larger than the number of fields in this [`Schema`]. + fn field(&self, index: usize) -> &Field; + + /// Returns all [`Field`]s in this schema. + fn fields(&self) -> &[Field]; + + /// Returns an immutable reference to the Map of custom metadata key-value pairs. + fn metadata(&self) -> &BTreeMap; + + /// Merge schema into self if it is compatible. Struct fields will be merged recursively. + /// + /// Example: + /// + /// ``` + /// use arrow::datatypes::*; + /// use datafusion::field_util::SchemaExt; + /// + /// let merged = Schema::try_merge(vec![ + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, false), + /// Field::new("c2", DataType::Utf8, false), + /// ]), + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ]).unwrap(); + /// + /// assert_eq!( + /// merged, + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ); + /// ``` + fn try_merge(schemas: impl IntoIterator) -> Result + where + Self: Sized; + + /// Return the field names + fn field_names(&self) -> Vec; + + /// Returns a new schema with only the specified columns in the new schema + /// This carries metadata from the parent schema over as well + fn project(&self, indices: &[usize]) -> Result; +} + +impl SchemaExt for Schema { + fn new(fields: Vec) -> Self { + Self::from(fields) + } + + fn new_with_metadata(fields: Vec, metadata: Metadata) -> Self { + Self::new(fields).with_metadata(metadata) + } + + fn empty() -> Self { + Self::from(vec![]) + } + + fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { + self.fields.iter().enumerate().find(|(_, f)| f.name == name) + } + + fn field_with_name(&self, name: &str) -> Result<&Field> { + Ok(&self.fields[self.index_of(name)?]) + } + + fn index_of(&self, name: &str) -> Result { + self.column_with_name(name).map(|(i, _f)| i).ok_or_else(|| { + DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!( + "Unable to get field named \"{}\". Valid fields: {:?}", + name, + self.field_names() + ))) + }) + } + + fn field(&self, index: usize) -> &Field { + &self.fields[index] + } + + #[inline] + fn fields(&self) -> &[Field] { + &self.fields + } + + #[inline] + fn metadata(&self) -> &BTreeMap { + &self.metadata + } + + fn try_merge(schemas: impl IntoIterator) -> Result { + schemas + .into_iter() + .try_fold(Self::empty(), |mut merged, schema| { + let Schema { metadata, fields } = schema; + for (key, value) in metadata.into_iter() { + // merge metadata + if let Some(old_val) = merged.metadata.get(&key) { + if old_val != &value { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema due to conflicting metadata." + .to_string(), + ), + )); + } + } + merged.metadata.insert(key, value); + } + // merge fields + for field in fields.into_iter() { + let mut new_field = true; + for merged_field in &mut merged.fields { + if field.name() != merged_field.name() { + continue; + } + new_field = false; + merged_field.try_merge(&field)? + } + // found a new field, add to field list + if new_field { + merged.fields.push(field); + } + } + Ok(merged) + }) + } + + fn field_names(&self) -> Vec { + self.fields.iter().map(|f| f.name.to_string()).collect() + } + + fn project(&self, indices: &[usize]) -> Result { + let new_fields = indices + .iter() + .map(|i| { + self.fields.get(*i).cloned().ok_or_else(|| { + DataFusionError::ArrowError(ArrowError::InvalidArgumentError( + format!( + "project index {} out of bounds, max field {}", + i, + self.fields().len() + ), + )) + }) + }) + .collect::>>()?; + Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) + } +} + +/// Imitate arrow-rs Field behavior by extending arrow2 Field +pub trait FieldExt { + /// The field name + fn name(&self) -> &str; + + /// Whether the field is nullable + fn is_nullable(&self) -> bool; + + /// Returns the field metadata + fn metadata(&self) -> &BTreeMap; + + /// Merge field into self if it is compatible. Struct will be merged recursively. + /// NOTE: `self` may be updated to unexpected state in case of merge failure. + /// + /// Example: + /// + /// ``` + /// use arrow2::datatypes::*; + /// + /// let mut field = Field::new("c1", DataType::Int64, false); + /// assert!(field.try_merge(&Field::new("c1", DataType::Int64, true)).is_ok()); + /// assert!(field.is_nullable()); + /// ``` + fn try_merge(&mut self, from: &Field) -> Result<()>; + + /// Sets the `Field`'s optional custom metadata. + /// The metadata is set as `None` for empty map. + fn set_metadata(&mut self, metadata: Option>); +} + +impl FieldExt for Field { + #[inline] + fn name(&self) -> &str { + &self.name + } + + #[inline] + fn is_nullable(&self) -> bool { + self.is_nullable + } + + #[inline] + fn metadata(&self) -> &BTreeMap { + &self.metadata + } + + fn try_merge(&mut self, from: &Field) -> Result<()> { + // merge metadata + for (key, from_value) in from.metadata() { + if let Some(self_value) = self.metadata.get(key) { + if self_value != from_value { + return Err(DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!( + "Fail to merge field due to conflicting metadata data value for key {}", + key + )))); + } + } else { + self.metadata.insert(key.clone(), from_value.clone()); + } + } + + match &mut self.data_type { + DataType::Struct(nested_fields) => match &from.data_type { + DataType::Struct(from_nested_fields) => { + for from_field in from_nested_fields { + let mut is_new_field = true; + for self_field in nested_fields.iter_mut() { + if self_field.name != from_field.name { + continue; + } + is_new_field = false; + self_field.try_merge(from_field)?; + } + if is_new_field { + nested_fields.push(from_field.clone()); + } + } + } + _ => { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema Field due to conflicting datatype" + .to_string(), + ), + )); + } + }, + DataType::Union(nested_fields, _, _) => match &from.data_type { + DataType::Union(from_nested_fields, _, _) => { + for from_field in from_nested_fields { + let mut is_new_field = true; + for self_field in nested_fields.iter_mut() { + if from_field == self_field { + is_new_field = false; + break; + } + } + if is_new_field { + nested_fields.push(from_field.clone()); + } + } + } + _ => { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema Field due to conflicting datatype" + .to_string(), + ), + )); + } + }, + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Binary + | DataType::LargeBinary + | DataType::Interval(_) + | DataType::LargeList(_) + | DataType::List(_) + | DataType::Dictionary(_, _, _) + | DataType::FixedSizeList(_, _) + | DataType::FixedSizeBinary(_) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Extension(_, _, _) + | DataType::Map(_, _) + | DataType::Decimal(_, _) => { + if self.data_type != from.data_type { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError( + "Fail to merge schema Field due to conflicting datatype" + .to_string(), + ), + )); + } + } + } + if from.is_nullable { + self.is_nullable = from.is_nullable; + } + + Ok(()) + } + + #[inline] + fn set_metadata(&mut self, metadata: Option>) { + if let Some(v) = metadata { + if !v.is_empty() { + self.metadata = v; + } + } + } +} diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index dd735b7621db..6b839807f9db 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -15,15 +15,6 @@ // specific language governing permissions and limitations // under the License. #![warn(missing_docs, clippy::needless_borrow)] -// Clippy lints, some should be disabled incrementally -#![allow( - clippy::float_cmp, - clippy::from_over_into, - clippy::module_inception, - clippy::new_without_default, - clippy::type_complexity, - clippy::upper_case_acronyms -)] //! [DataFusion](https://github.com/apache/arrow-datafusion) //! is an extensible query execution framework that uses @@ -39,7 +30,7 @@ //! ```rust //! # use datafusion::prelude::*; //! # use datafusion::error::Result; -//! # use datafusion::arrow::record_batch::RecordBatch; +//! # use datafusion::record_batch::RecordBatch; //! //! # #[tokio::main] //! # async fn main() -> Result<()> { @@ -77,7 +68,7 @@ //! ``` //! # use datafusion::prelude::*; //! # use datafusion::error::Result; -//! # use datafusion::arrow::record_batch::RecordBatch; +//! # use datafusion::record_batch::RecordBatch; //! //! # #[tokio::main] //! # async fn main() -> Result<()> { @@ -233,10 +224,13 @@ pub mod arrow_print; mod arrow_temporal_util; pub mod field_util; +pub mod record_batch; #[cfg(feature = "pyarrow")] mod pyarrow; +#[cfg(test)] +mod cast; #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index fc609390bcc0..549db89035eb 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -25,17 +25,16 @@ use crate::datasource::{ MemTable, TableProvider, }; use crate::error::{DataFusionError, Result}; +use crate::field_util::SchemaExt; use crate::logical_plan::plan::{ Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort, TableScan, ToStringifiedPlan, Union, Window, }; use crate::optimizer::utils; use crate::prelude::*; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; -use arrow::{ - datatypes::{DataType, Schema, SchemaRef}, - record_batch::RecordBatch, -}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use std::convert::TryFrom; use std::iter; use std::{ @@ -1013,14 +1012,13 @@ pub(crate) fn expand_wildcard( let columns_to_skip = using_columns .into_iter() // For each USING JOIN condition, only expand to one column in projection - .map(|cols| { + .flat_map(|cols| { let mut cols = cols.into_iter().collect::>(); // sort join columns to make sure we consistently keep the same // qualified column cols.sort(); cols.into_iter().skip(1) }) - .flatten() .collect::>(); if columns_to_skip.is_empty() { diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index e8698b8b4f34..b89b2399e67a 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -25,6 +25,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Column; +use crate::field_util::{FieldExt, SchemaExt}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::fmt::{Display, Formatter}; @@ -70,8 +71,8 @@ impl DFSchema { // deterministic let mut qualified_names = qualified_names .iter() - .map(|(l, r)| (l.to_owned(), r.to_owned())) - .collect::>(); + .map(|(l, r)| (l.as_str(), r.to_owned())) + .collect::>(); qualified_names.sort_by(|a, b| { let a = format!("{}.{}", a.0, a.1); let b = format!("{}.{}", b.0, b.1); @@ -297,19 +298,16 @@ impl DFSchema { } } -impl Into for DFSchema { - /// Convert a schema into a DFSchema - fn into(self) -> Schema { +impl From for Schema { + /// Convert DFSchema into a Schema + fn from(df_schema: DFSchema) -> Self { Schema::new( - self.fields + df_schema + .fields .into_iter() .map(|f| { if f.qualifier().is_some() { - Field::new( - f.name().as_str(), - f.data_type().to_owned(), - f.is_nullable(), - ) + Field::new(f.name(), f.data_type().to_owned(), f.is_nullable()) } else { f.field } @@ -319,10 +317,10 @@ impl Into for DFSchema { } } -impl Into for &DFSchema { - /// Convert a schema into a DFSchema - fn into(self) -> Schema { - Schema::new(self.fields.iter().map(|f| f.field.clone()).collect()) +impl From<&DFSchema> for Schema { + /// Convert DFSchema reference into a Schema + fn from(df_schema: &DFSchema) -> Self { + Schema::new(df_schema.fields.iter().map(|f| f.field.clone()).collect()) } } @@ -340,9 +338,9 @@ impl TryFrom for DFSchema { } } -impl Into for DFSchema { - fn into(self) -> SchemaRef { - SchemaRef::new(self.into()) +impl From for SchemaRef { + fn from(df_schema: DFSchema) -> Self { + SchemaRef::new(df_schema.into()) } } @@ -388,7 +386,7 @@ impl ToDFSchema for Vec { } impl Display for DFSchema { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!( f, "{}", @@ -441,7 +439,7 @@ impl DFField { } /// Returns an immutable reference to the `DFField`'s unqualified name - pub fn name(&self) -> &String { + pub fn name(&self) -> &str { self.field.name() } @@ -537,8 +535,8 @@ mod tests { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let arrow_schema: Schema = schema.into(); let expected = - "[Field { name: \"c0\", data_type: Boolean, nullable: true, metadata: {} }, \ - Field { name: \"c1\", data_type: Boolean, nullable: true, metadata: {} }]"; + "[Field { name: \"c0\", data_type: Boolean, is_nullable: true, metadata: {} }, \ + Field { name: \"c1\", data_type: Boolean, is_nullable: true, metadata: {} }]"; assert_eq!(expected, format!("{:?}", arrow_schema.fields)); Ok(()) } diff --git a/datafusion/src/logical_plan/display.rs b/datafusion/src/logical_plan/display.rs index 8178ef4484b2..5f28eea9775c 100644 --- a/datafusion/src/logical_plan/display.rs +++ b/datafusion/src/logical_plan/display.rs @@ -17,6 +17,7 @@ //! This module provides logic for displaying LogicalPlans in various styles use super::{LogicalPlan, PlanVisitor}; +use crate::field_util::{FieldExt, SchemaExt}; use arrow::datatypes::Schema; use std::fmt; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 5a55f398cdab..4d81472da9dc 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -23,7 +23,7 @@ pub use super::Operator; use arrow::{compute::cast::can_cast_types, datatypes::DataType}; use crate::error::{DataFusionError, Result}; -use crate::field_util::get_indexed_field; +use crate::field_util::{get_indexed_field, FieldExt}; use crate::logical_plan::{ plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, }; @@ -172,7 +172,7 @@ impl FromStr for Column { } impl fmt::Display for Column { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self.relation { Some(r) => write!(f, "#{}.{}", r, self.name), None => write!(f, "#{}", self.name), @@ -402,13 +402,16 @@ impl Expr { /// the expression is incorrectly typed (e.g. `[utf8] + [bool]`). pub fn get_type(&self, schema: &DFSchema) -> Result { match self { - Expr::Alias(expr, _) => expr.get_type(schema), + Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { + expr.get_type(schema) + } Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()), Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), - Expr::Cast { data_type, .. } => Ok(data_type.clone()), - Expr::TryCast { data_type, .. } => Ok(data_type.clone()), + Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { + Ok(data_type.clone()) + } Expr::ScalarUDF { fun, args } => { let data_types = args .iter() @@ -444,10 +447,11 @@ impl Expr { .collect::>>()?; Ok((fun.return_type)(&data_types)?.as_ref().clone()) } - Expr::Not(_) => Ok(DataType::Boolean), - Expr::Negative(expr) => expr.get_type(schema), - Expr::IsNull(_) => Ok(DataType::Boolean), - Expr::IsNotNull(_) => Ok(DataType::Boolean), + Expr::Not(_) + | Expr::IsNull(_) + | Expr::Between { .. } + | Expr::InList { .. } + | Expr::IsNotNull(_) => Ok(DataType::Boolean), Expr::BinaryExpr { ref left, ref right, @@ -457,9 +461,6 @@ impl Expr { op, &right.get_type(schema)?, ), - Expr::Sort { ref expr, .. } => expr.get_type(schema), - Expr::Between { .. } => Ok(DataType::Boolean), - Expr::InList { .. } => Ok(DataType::Boolean), Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -479,10 +480,14 @@ impl Expr { /// This happens when the expression refers to a column that does not exist in the schema. pub fn nullable(&self, input_schema: &DFSchema) -> Result { match self { - Expr::Alias(expr, _) => expr.nullable(input_schema), + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::Negative(expr) + | Expr::Sort { expr, .. } + | Expr::Between { expr, .. } + | Expr::InList { expr, .. } => expr.nullable(input_schema), Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()), Expr::Literal(value) => Ok(value.is_null()), - Expr::ScalarVariable(_) => Ok(true), Expr::Case { when_then_expr, else_expr, @@ -502,24 +507,19 @@ impl Expr { } } Expr::Cast { expr, .. } => expr.nullable(input_schema), - Expr::TryCast { .. } => Ok(true), - Expr::ScalarFunction { .. } => Ok(true), - Expr::ScalarUDF { .. } => Ok(true), - Expr::WindowFunction { .. } => Ok(true), - Expr::AggregateFunction { .. } => Ok(true), - Expr::AggregateUDF { .. } => Ok(true), - Expr::Not(expr) => expr.nullable(input_schema), - Expr::Negative(expr) => expr.nullable(input_schema), - Expr::IsNull(_) => Ok(false), - Expr::IsNotNull(_) => Ok(false), + Expr::ScalarVariable(_) + | Expr::TryCast { .. } + | Expr::ScalarFunction { .. } + | Expr::ScalarUDF { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::AggregateUDF { .. } => Ok(true), + Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false), Expr::BinaryExpr { ref left, ref right, .. } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Sort { ref expr, .. } => expr.nullable(input_schema), - Expr::Between { ref expr, .. } => expr.nullable(input_schema), - Expr::InList { ref expr, .. } => expr.nullable(input_schema), Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -725,18 +725,23 @@ impl Expr { // recurse (and cover all expression types) let visitor = match self { - Expr::Alias(expr, _) => expr.accept(visitor), - Expr::Column(_) => Ok(visitor), - Expr::ScalarVariable(..) => Ok(visitor), - Expr::Literal(..) => Ok(visitor), + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast { expr, .. } + | Expr::TryCast { expr, .. } + | Expr::Sort { expr, .. } + | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), + Expr::Column(_) + | Expr::ScalarVariable(_) + | Expr::Literal(_) + | Expr::Wildcard => Ok(visitor), Expr::BinaryExpr { left, right, .. } => { let visitor = left.accept(visitor)?; right.accept(visitor) } - Expr::Not(expr) => expr.accept(visitor), - Expr::IsNotNull(expr) => expr.accept(visitor), - Expr::IsNull(expr) => expr.accept(visitor), - Expr::Negative(expr) => expr.accept(visitor), Expr::Between { expr, low, high, .. } => { @@ -767,13 +772,10 @@ impl Expr { Ok(visitor) } } - Expr::Cast { expr, .. } => expr.accept(visitor), - Expr::TryCast { expr, .. } => expr.accept(visitor), - Expr::Sort { expr, .. } => expr.accept(visitor), - Expr::ScalarFunction { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::ScalarUDF { args, .. } => args + Expr::ScalarFunction { args, .. } + | Expr::ScalarUDF { args, .. } + | Expr::AggregateFunction { args, .. } + | Expr::AggregateUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), Expr::WindowFunction { @@ -793,19 +795,11 @@ impl Expr { .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; Ok(visitor) } - Expr::AggregateFunction { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::AggregateUDF { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), Expr::InList { expr, list, .. } => { let visitor = expr.accept(visitor)?; list.iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)) } - Expr::Wildcard => Ok(visitor), - Expr::GetIndexedField { ref expr, .. } => expr.accept(visitor), }?; visitor.post_visit(self) @@ -2454,7 +2448,7 @@ mod tests { #[test] fn test_partial_ord() { - // Test validates that partial ord is defined for Expr, not + // Test validates that partial ord is defined for Expr using hashes, not // intended to exhaustively test all possibilities let exp1 = col("a") + lit(1); let exp2 = col("a") + lit(2); diff --git a/datafusion/src/logical_plan/extension.rs b/datafusion/src/logical_plan/extension.rs index 43bf96ffb072..7e8361713fe7 100644 --- a/datafusion/src/logical_plan/extension.rs +++ b/datafusion/src/logical_plan/extension.rs @@ -53,7 +53,7 @@ pub trait UserDefinedLogicalNode: fmt::Debug { self.schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.name().to_string()) .collect() } diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 952572f4dea3..2a001c148ec8 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -22,6 +22,7 @@ use super::expr::{Column, Expr}; use super::extension::UserDefinedLogicalNode; use crate::datasource::TableProvider; use crate::error::DataFusionError; +use crate::field_util::SchemaExt; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -528,8 +529,7 @@ impl LogicalPlan { { self.using_columns.push( on.iter() - .map(|entry| [&entry.0, &entry.1]) - .flatten() + .flat_map(|entry| [&entry.0, &entry.1]) .cloned() .collect::>(), ); @@ -1027,7 +1027,7 @@ pub enum PlanType { } impl fmt::Display for PlanType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { PlanType::InitialLogicalPlan => write!(f, "initial_logical_plan"), PlanType::OptimizedLogicalPlan { optimizer_name } => { diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 233d112197fe..947073409d05 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -69,6 +69,12 @@ impl OptimizerRule for CommonSubexprEliminate { } } +impl Default for CommonSubexprEliminate { + fn default() -> Self { + Self::new() + } +} + impl CommonSubexprEliminate { #[allow(missing_docs)] pub fn new() -> Self { diff --git a/datafusion/src/optimizer/eliminate_limit.rs b/datafusion/src/optimizer/eliminate_limit.rs index 1f74ae2ef50f..c1fc2068d325 100644 --- a/datafusion/src/optimizer/eliminate_limit.rs +++ b/datafusion/src/optimizer/eliminate_limit.rs @@ -25,6 +25,7 @@ use super::utils; use crate::execution::context::ExecutionProps; /// Optimization rule that replaces LIMIT 0 with an [LogicalPlan::EmptyRelation] +#[derive(Default)] pub struct EliminateLimit; impl EliminateLimit { diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index c55a5cdbc6d3..d104e4435f53 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -18,7 +18,7 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection}; use crate::logical_plan::{ - and, replace_col, Column, CrossJoin, Limit, LogicalPlan, TableScan, + and, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan, TableScan, }; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; @@ -50,10 +50,11 @@ use std::{ /// Projection: #a AS b /// Filter: #a Gt Int64(10) <--- changed from #b to #a /// -/// This performs a single pass trought the plan. When it passes trought a filter, it stores that filter, +/// This performs a single pass through the plan. When it passes through a filter, it stores that filter, /// and when it reaches a node that does not commute with it, it adds the filter to that place. -/// When it passes through a projection, it re-writes the filter's expression taking into accoun that projection. +/// When it passes through a projection, it re-writes the filter's expression taking into account that projection. /// When multiple filters would have been written, it `AND` their expressions into a single expression. +#[derive(Default)] pub struct FilterPushDown {} #[derive(Debug, Clone, Default)] @@ -82,83 +83,6 @@ fn get_predicates<'a>( .unzip() } -// returns 3 (potentially overlaping) sets of predicates: -// * pushable to left: its columns are all on the left -// * pushable to right: its columns is all on the right -// * keep: the set of columns is not in only either left or right -// Note that a predicate can be both pushed to the left and to the right. -fn get_join_predicates<'a>( - state: &'a State, - left: &DFSchema, - right: &DFSchema, -) -> ( - Vec<&'a HashSet>, - Vec<&'a HashSet>, - Predicates<'a>, -) { - let left_columns = &left - .fields() - .iter() - .map(|f| { - [ - f.qualified_column(), - // we need to push down filter using unqualified column as well - f.unqualified_column(), - ] - }) - .flatten() - .collect::>(); - let right_columns = &right - .fields() - .iter() - .map(|f| { - [ - f.qualified_column(), - // we need to push down filter using unqualified column as well - f.unqualified_column(), - ] - }) - .flatten() - .collect::>(); - - let filters = state - .filters - .iter() - .map(|(predicate, columns)| { - ( - (predicate, columns), - ( - columns, - left_columns.intersection(columns).collect::>(), - right_columns.intersection(columns).collect::>(), - ), - ) - }) - .collect::>(); - - let pushable_to_left = filters - .iter() - .filter(|(_, (columns, left, _))| left.len() == columns.len()) - .map(|((_, b), _)| *b) - .collect(); - let pushable_to_right = filters - .iter() - .filter(|(_, (columns, _, right))| right.len() == columns.len()) - .map(|((_, b), _)| *b) - .collect(); - let keep = filters - .iter() - .filter(|(_, (columns, left, right))| { - // predicates whose columns are not in only one side of the join need to remain - let all_in_left = left.len() == columns.len(); - let all_in_right = right.len() == columns.len(); - !all_in_left && !all_in_right - }) - .map(|((a, b), _)| (a, b)) - .unzip(); - (pushable_to_left, pushable_to_right, keep) -} - /// Optimizes the plan fn push_down(state: &State, plan: &LogicalPlan) -> Result { let new_inputs = plan @@ -203,11 +127,11 @@ fn remove_filters( // keeps all filters from `filters` that are in `predicate_columns` fn keep_filters( filters: &[(Expr, HashSet)], - predicate_columns: &[&HashSet], + relevant_predicates: &Predicates, ) -> Vec<(Expr, HashSet)> { filters .iter() - .filter(|(_, columns)| predicate_columns.contains(&columns)) + .filter(|(expr, _)| relevant_predicates.0.contains(&expr)) .cloned() .collect::>() } @@ -252,33 +176,121 @@ fn split_members<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) { } } +// For a given JOIN logical plan, determine whether each side of the join is preserved. +// We say a join side is preserved if the join returns all or a subset of the rows from +// the relevant side, such that each row of the output table directly maps to a row of +// the preserved input table. If a table is not preserved, it can provide extra null rows. +// That is, there may be rows in the output table that don't directly map to a row in the +// input table. +// +// For example: +// - In an inner join, both sides are preserved, because each row of the output +// maps directly to a row from each side. +// - In a left join, the left side is preserved and the right is not, because +// there may be rows in the output that don't directly map to a row in the +// right input (due to nulls filling where there is no match on the right). +// +// This is important because we can always push down post-join filters to a preserved +// side of the join, assuming the filter only references columns from that side. For the +// non-preserved side it can be more tricky. +// +// Returns a tuple of booleans - (left_preserved, right_preserved). +fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) { + match plan { + LogicalPlan::Join(Join { join_type, .. }) => match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (true, false), + JoinType::Right => (false, true), + JoinType::Full => (false, false), + // No columns from the right side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::Semi | JoinType::Anti => (true, false), + }, + LogicalPlan::CrossJoin(_) => (true, true), + _ => unreachable!("lr_is_preserved only valid for JOIN nodes"), + } +} + +// Determine which predicates in state can be pushed down to a given side of a join. +// To determine this, we need to know the schema of the relevant join side and whether +// or not the side's rows are preserved when joining. If the side is not preserved, we +// do not push down anything. Otherwise we can push down predicates where all of the +// relevant columns are contained on the relevant join side's schema. +fn get_pushable_join_predicates<'a>( + state: &'a State, + schema: &DFSchema, + preserved: bool, +) -> Predicates<'a> { + if !preserved { + return (vec![], vec![]); + } + + let schema_columns = schema + .fields() + .iter() + .flat_map(|f| { + [ + f.qualified_column(), + // we need to push down filter using unqualified column as well + f.unqualified_column(), + ] + }) + .collect::>(); + + state + .filters + .iter() + .filter(|(_, columns)| { + let all_columns_in_schema = schema_columns + .intersection(columns) + .collect::>() + .len() + == columns.len(); + all_columns_in_schema + }) + .map(|(a, b)| (a, b)) + .unzip() +} + fn optimize_join( mut state: State, plan: &LogicalPlan, left: &LogicalPlan, right: &LogicalPlan, ) -> Result { - let (pushable_to_left, pushable_to_right, keep) = - get_join_predicates(&state, left.schema(), right.schema()); + let (left_preserved, right_preserved) = lr_is_preserved(plan); + let to_left = get_pushable_join_predicates(&state, left.schema(), left_preserved); + let to_right = get_pushable_join_predicates(&state, right.schema(), right_preserved); + + let to_keep: Predicates = state + .filters + .iter() + .filter(|(expr, _)| { + let pushed_to_left = to_left.0.contains(&expr); + let pushed_to_right = to_right.0.contains(&expr); + !pushed_to_left && !pushed_to_right + }) + .map(|(a, b)| (a, b)) + .unzip(); let mut left_state = state.clone(); - left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); + left_state.filters = keep_filters(&left_state.filters, &to_left); let left = optimize(left, left_state)?; let mut right_state = state.clone(); - right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); + right_state.filters = keep_filters(&right_state.filters, &to_right); let right = optimize(right, right_state)?; // create a new Join with the new `left` and `right` let expr = plan.expressions(); let plan = utils::from_plan(plan, &expr, &[left, right])?; - if keep.0.is_empty() { + if to_keep.0.is_empty() { Ok(plan) } else { // wrap the join on the filter whose predicates must be kept - let plan = add_filter(plan, &keep.0); - state.filters = remove_filters(&state.filters, &keep.1); + let plan = add_filter(plan, &to_keep.0); + state.filters = remove_filters(&state.filters, &to_keep.1); Ok(plan) } @@ -399,63 +411,68 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { optimize_join(state, plan, left, right) } LogicalPlan::Join(Join { - left, right, on, .. + left, + right, + on, + join_type, + .. }) => { - // duplicate filters for joined columns so filters can be pushed down to both sides. - // Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - let join_side_filters = state - .filters - .iter() - .filter_map(|(predicate, columns)| { - let mut join_cols_to_replace = HashMap::new(); - for col in columns.iter() { - for (l, r) in on { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; + if *join_type == JoinType::Inner { + // For inner joins, duplicate filters for joined columns so filters can be pushed down + // to both sides. Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + let join_side_filters = state + .filters + .iter() + .filter_map(|(predicate, columns)| { + let mut join_cols_to_replace = HashMap::new(); + for col in columns.iter() { + for (l, r) in on { + if col == l { + join_cols_to_replace.insert(col, r); + break; + } else if col == r { + join_cols_to_replace.insert(col, l); + break; + } } } - } - if join_cols_to_replace.is_empty() { - return None; - } - - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; - - let join_side_columns = columns - .clone() - .into_iter() - // replace keys in join_cols_to_replace with values in resulting column - // set - .filter(|c| !join_cols_to_replace.contains_key(c)) - .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone())) - .collect(); - - Some(Ok((join_side_predicate, join_side_columns))) - }) - .collect::>>()?; - state.filters.extend(join_side_filters); + if join_cols_to_replace.is_empty() { + return None; + } + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } + }; + + let join_side_columns = columns + .clone() + .into_iter() + // replace keys in join_cols_to_replace with values in resulting column + // set + .filter(|c| !join_cols_to_replace.contains_key(c)) + .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone())) + .collect(); + + Some(Ok((join_side_predicate, join_side_columns))) + }) + .collect::>>()?; + state.filters.extend(join_side_filters); + } optimize_join(state, plan, left, right) } LogicalPlan::TableScan(TableScan { @@ -556,6 +573,7 @@ fn rewrite(expr: &Expr, projection: &HashMap) -> Result { mod tests { use super::*; use crate::datasource::TableProvider; + use crate::field_util::SchemaExt; use crate::logical_plan::{lit, sum, DFSchema, Expr, LogicalPlanBuilder, Operator}; use crate::physical_plan::ExecutionPlan; use crate::test::*; @@ -1138,6 +1156,170 @@ mod tests { Ok(()) } + /// post-join predicates on the right side of a left join are not duplicated + /// TODO: In this case we can sometimes convert the join to an INNER join + #[test] + fn filter_using_left_join() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join_using( + &right, + JoinType::Left, + vec![Column::from_name("a".to_string())], + )? + .filter(col("test2.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Filter: #test2.a <= Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + // filter not duplicated nor pushed down - i.e. noop + let expected = "\ + Filter: #test2.a <= Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// post-join predicates on the left side of a right join are not duplicated + /// TODO: In this case we can sometimes convert the join to an INNER join + #[test] + fn filter_using_right_join() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join_using( + &right, + JoinType::Right, + vec![Column::from_name("a".to_string())], + )? + .filter(col("test.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Filter: #test.a <= Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + // filter not duplicated nor pushed down - i.e. noop + let expected = "\ + Filter: #test.a <= Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// post-left-join predicate on a column common to both sides is only pushed to the left side + /// i.e. - not duplicated to the right side + #[test] + fn filter_using_left_join_on_common() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join_using( + &right, + JoinType::Left, + vec![Column::from_name("a".to_string())], + )? + .filter(col("a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Filter: #test.a <= Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + // filter sent to left side of the join, not the right + let expected = "\ + Join: Using #test.a = #test2.a\ + \n Filter: #test.a <= Int64(1)\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// post-right-join predicate on a column common to both sides is only pushed to the right side + /// i.e. - not duplicated to the left side. + #[test] + fn filter_using_right_join_on_common() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join_using( + &right, + JoinType::Right, + vec![Column::from_name("a".to_string())], + )? + .filter(col("test2.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Filter: #test2.a <= Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + // filter sent to right side of join, not duplicated to the left + let expected = "\ + Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n Filter: #test2.a <= Int64(1)\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + struct PushDownProvider { pub filter_support: TableProviderFilterPushDown, } @@ -1157,7 +1339,6 @@ mod tests { async fn scan( &self, _: &Option>, - _: usize, _: &[Expr], _: Option, ) -> Result> { diff --git a/datafusion/src/optimizer/limit_push_down.rs b/datafusion/src/optimizer/limit_push_down.rs index 15d5093abed9..4fa6e27869e4 100644 --- a/datafusion/src/optimizer/limit_push_down.rs +++ b/datafusion/src/optimizer/limit_push_down.rs @@ -28,6 +28,7 @@ use std::sync::Arc; /// Optimization rule that tries pushes down LIMIT n /// where applicable to reduce the amount of scanned / processed data +#[derive(Default)] pub struct LimitPushDown {} impl LimitPushDown { diff --git a/datafusion/src/optimizer/mod.rs b/datafusion/src/optimizer/mod.rs index c5cab97926df..984cbee90947 100644 --- a/datafusion/src/optimizer/mod.rs +++ b/datafusion/src/optimizer/mod.rs @@ -18,6 +18,7 @@ //! This module contains a query optimizer that operates against a logical plan and applies //! some simple rules to a logical plan, such as "Projection Push Down" and "Type Coercion". +#![allow(clippy::module_inception)] pub mod common_subexpr_eliminate; pub mod eliminate_limit; pub mod filter_push_down; diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index f92ab653fd4e..d2f482f6caf6 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -20,6 +20,7 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::ExecutionProps; +use crate::field_util::SchemaExt; use crate::logical_plan::plan::{ Aggregate, Analyze, Join, Projection, TableScan, Window, }; @@ -31,7 +32,6 @@ use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::sql::utils::find_sort_exprs; use arrow::datatypes::{Field, Schema}; -use arrow::error::Result as ArrowResult; use std::{ collections::{BTreeSet, HashSet}, sync::Arc, @@ -39,6 +39,7 @@ use std::{ /// Optimizer that removes unused projections and aggregations from plans /// This reduces both scans and +#[derive(Default)] pub struct ProjectionPushDown {} impl OptimizerRule for ProjectionPushDown { @@ -87,7 +88,7 @@ fn get_projected_schema( .iter() .filter(|c| c.relation.is_none() || c.relation.as_ref() == table_name) .map(|c| schema.index_of(&c.name)) - .filter_map(ArrowResult::ok) + .filter_map(Result::ok) .collect(); if projection.is_empty() { diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 2f448ea73c04..4583a6730536 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -17,12 +17,13 @@ //! Simplify expressions optimizer rule +use crate::record_batch::RecordBatch; use arrow::array::new_null_array; use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; use crate::error::DataFusionError; use crate::execution::context::{ExecutionContextState, ExecutionProps}; +use crate::field_util::SchemaExt; use crate::logical_plan::{lit, DFSchemaRef, Expr}; use crate::logical_plan::{DFSchema, ExprRewriter, LogicalPlan, RewriteRecursion}; use crate::optimizer::optimizer::OptimizerRule; @@ -43,6 +44,7 @@ use crate::{error::Result, logical_plan::Operator}; /// is optimized to /// `Filter: b > 2` /// +#[derive(Default)] pub struct SimplifyExpressions {} /// returns true if `needle` is found in a chain of search_op @@ -332,29 +334,28 @@ impl ConstEvaluator { // at plan time match expr { // Has no runtime cost, but needed during planning - Expr::Alias(..) => false, - Expr::AggregateFunction { .. } => false, - Expr::AggregateUDF { .. } => false, - Expr::ScalarVariable(_) => false, - Expr::Column(_) => false, + Expr::Alias(..) + | Expr::AggregateFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::ScalarVariable(_) + | Expr::Column(_) + | Expr::WindowFunction { .. } + | Expr::Sort { .. } + | Expr::Wildcard => false, Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility), - Expr::WindowFunction { .. } => false, - Expr::Sort { .. } => false, - Expr::Wildcard => false, - - Expr::Literal(_) => true, - Expr::BinaryExpr { .. } => true, - Expr::Not(_) => true, - Expr::IsNotNull(_) => true, - Expr::IsNull(_) => true, - Expr::Negative(_) => true, - Expr::Between { .. } => true, - Expr::Case { .. } => true, - Expr::Cast { .. } => true, - Expr::TryCast { .. } => true, - Expr::InList { .. } => true, - Expr::GetIndexedField { .. } => true, + Expr::Literal(_) + | Expr::BinaryExpr { .. } + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::Negative(_) + | Expr::Between { .. } + | Expr::Case { .. } + | Expr::Cast { .. } + | Expr::TryCast { .. } + | Expr::InList { .. } + | Expr::GetIndexedField { .. } => true, } } diff --git a/datafusion/src/optimizer/single_distinct_to_groupby.rs b/datafusion/src/optimizer/single_distinct_to_groupby.rs index 9bddec997db6..02a24e214495 100644 --- a/datafusion/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/src/optimizer/single_distinct_to_groupby.rs @@ -40,6 +40,7 @@ use std::sync::Arc; /// ) /// GROUP BY k /// ``` +#[derive(Default)] pub struct SingleDistinctToGroupBy {} const SINGLE_DISTINCT_ALIAS: &str = "alias1"; diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 541ac67a6117..f7ab836b398c 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -63,26 +63,26 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { Expr::ScalarVariable(var_names) => { self.accum.insert(Column::from_name(var_names.join("."))); } - Expr::Alias(_, _) => {} - Expr::Literal(_) => {} - Expr::BinaryExpr { .. } => {} - Expr::Not(_) => {} - Expr::IsNotNull(_) => {} - Expr::IsNull(_) => {} - Expr::Negative(_) => {} - Expr::Between { .. } => {} - Expr::Case { .. } => {} - Expr::Cast { .. } => {} - Expr::TryCast { .. } => {} - Expr::Sort { .. } => {} - Expr::ScalarFunction { .. } => {} - Expr::ScalarUDF { .. } => {} - Expr::WindowFunction { .. } => {} - Expr::AggregateFunction { .. } => {} - Expr::AggregateUDF { .. } => {} - Expr::InList { .. } => {} - Expr::Wildcard => {} - Expr::GetIndexedField { .. } => {} + Expr::Alias(_, _) + | Expr::Literal(_) + | Expr::BinaryExpr { .. } + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::Negative(_) + | Expr::Between { .. } + | Expr::Case { .. } + | Expr::Cast { .. } + | Expr::TryCast { .. } + | Expr::Sort { .. } + | Expr::ScalarFunction { .. } + | Expr::ScalarUDF { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::InList { .. } + | Expr::Wildcard + | Expr::GetIndexedField { .. } => {} } Ok(Recursion::Continue(self)) } @@ -281,10 +281,19 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::BinaryExpr { left, right, .. } => { Ok(vec![left.as_ref().to_owned(), right.as_ref().to_owned()]) } - Expr::IsNull(e) => Ok(vec![e.as_ref().to_owned()]), - Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]), - Expr::ScalarFunction { args, .. } => Ok(args.clone()), - Expr::ScalarUDF { args, .. } => Ok(args.clone()), + Expr::IsNull(expr) + | Expr::IsNotNull(expr) + | Expr::Cast { expr, .. } + | Expr::TryCast { expr, .. } + | Expr::Alias(expr, ..) + | Expr::Not(expr) + | Expr::Negative(expr) + | Expr::Sort { expr, .. } + | Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), + Expr::ScalarFunction { args, .. } + | Expr::ScalarUDF { args, .. } + | Expr::AggregateFunction { args, .. } + | Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::WindowFunction { args, partition_by, @@ -299,8 +308,6 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { expr_list.extend(order_by.clone()); Ok(expr_list) } - Expr::AggregateFunction { args, .. } => Ok(args.clone()), - Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::Case { expr, when_then_expr, @@ -322,15 +329,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { } Ok(expr_list) } - Expr::Cast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), - Expr::TryCast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), - Expr::Column(_) => Ok(vec![]), - Expr::Alias(expr, ..) => Ok(vec![expr.as_ref().to_owned()]), - Expr::Literal(_) => Ok(vec![]), - Expr::ScalarVariable(_) => Ok(vec![]), - Expr::Not(expr) => Ok(vec![expr.as_ref().to_owned()]), - Expr::Negative(expr) => Ok(vec![expr.as_ref().to_owned()]), - Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), + Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => Ok(vec![]), Expr::Between { expr, low, high, .. } => Ok(vec![ @@ -348,7 +347,6 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), - Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), } } @@ -473,9 +471,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } Expr::Not(_) => Ok(Expr::Not(Box::new(expressions[0].clone()))), Expr::Negative(_) => Ok(Expr::Negative(Box::new(expressions[0].clone()))), - Expr::Column(_) => Ok(expr.clone()), - Expr::Literal(_) => Ok(expr.clone()), - Expr::ScalarVariable(_) => Ok(expr.clone()), + Expr::Column(_) + | Expr::Literal(_) + | Expr::InList { .. } + | Expr::ScalarVariable(_) => Ok(expr.clone()), Expr::Sort { asc, nulls_first, .. } => Ok(Expr::Sort { @@ -504,7 +503,6 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { Ok(expr) } } - Expr::InList { .. } => Ok(expr.clone()), Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), diff --git a/datafusion/src/physical_optimizer/aggregate_statistics.rs b/datafusion/src/physical_optimizer/aggregate_statistics.rs index 8d59fd2571b7..13ffac95ed72 100644 --- a/datafusion/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/src/physical_optimizer/aggregate_statistics.rs @@ -32,8 +32,10 @@ use crate::scalar::ScalarValue; use super::optimizer::PhysicalOptimizerRule; use super::utils::optimize_children; use crate::error::Result; +use crate::field_util::SchemaExt; /// Optimizer that uses available statistics for aggregate functions +#[derive(Default)] pub struct AggregateStatistics {} impl AggregateStatistics { @@ -253,11 +255,12 @@ mod tests { use super::*; use std::sync::Arc; + use crate::record_batch::RecordBatch; use arrow::array::{Int32Array, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; use crate::error::Result; + use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::Operator; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; @@ -294,6 +297,7 @@ mod tests { nulls: bool, ) -> Result<()> { let conf = ExecutionConfig::new(); + let runtime = Arc::new(RuntimeEnv::default()); let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; let (col, count) = match nulls { @@ -303,8 +307,8 @@ mod tests { // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); - let result = common::collect(optimized.execute(0).await?).await?; - assert_eq!(result[0].schema(), &Arc::new(Schema::new(vec![col]))); + let result = common::collect(optimized.execute(0, runtime).await?).await?; + assert_eq!(result[0].schema().as_ref(), &Schema::new(vec![col])); assert_eq!( result[0] .column(0) diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 9af8911062df..98e65a2b1281 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -29,6 +29,7 @@ use crate::{ use std::sync::Arc; /// Optimizer that introduces CoalesceBatchesExec to avoid overhead with small batches +#[derive(Default)] pub struct CoalesceBatches {} impl CoalesceBatches { @@ -74,7 +75,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { // we should do that once https://issues.apache.org/jira/browse/ARROW-11059 is // implemented. For now, we choose half the configured batch size to avoid copies // when a small number of rows are removed from a batch - let target_batch_size = config.batch_size / 2; + let target_batch_size = config.runtime.batch_size / 2; Arc::new(CoalesceBatchesExec::new(plan.clone(), target_batch_size)) } else { plan.clone() diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs index 0d1c39fd8acb..e81cd8d86eb7 100644 --- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs +++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs @@ -31,6 +31,7 @@ use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use super::utils::optimize_children; use crate::error::Result; +use crate::field_util::{FieldExt, SchemaExt}; /// BuildProbeOrder reorders the build and probe phase of /// hash joins. This uses the amount of rows that a datasource has. @@ -38,6 +39,7 @@ use crate::error::Result; /// is the smallest. /// If the information is not available, the order stays the same, /// so that it could be optimized manually in a query. +#[derive(Default)] pub struct HashBuildProbeOrder {} impl HashBuildProbeOrder { diff --git a/datafusion/src/physical_optimizer/merge_exec.rs b/datafusion/src/physical_optimizer/merge_exec.rs index 0127313bb1eb..58823a665b16 100644 --- a/datafusion/src/physical_optimizer/merge_exec.rs +++ b/datafusion/src/physical_optimizer/merge_exec.rs @@ -26,6 +26,7 @@ use crate::{ use std::sync::Arc; /// Introduces CoalescePartitionsExec +#[derive(Default)] pub struct AddCoalescePartitionsExec {} impl AddCoalescePartitionsExec { diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index cecafa0b2eee..fe577a644905 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -31,13 +31,16 @@ use std::convert::TryFrom; use std::{collections::HashSet, sync::Arc}; +use crate::record_batch::RecordBatch; use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, compute::cast, datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, }; +use crate::field_util::{FieldExt, SchemaExt}; +use crate::physical_plan::expressions::cast::cast_with_error; +use crate::prelude::lit; use crate::{ error::{DataFusionError, Result}, execution::context::ExecutionContextState, @@ -76,6 +79,12 @@ pub trait PruningStatistics { /// return the number of containers (e.g. row groups) being /// pruned with these statistics fn num_containers(&self) -> usize; + + /// return the number of null values for the named column as an + /// `Option`. + /// + /// Note: the returned array must contain `num_containers()` rows. + fn null_counts(&self, column: &Column) -> Option; } /// Evaluates filter expressions on statistics in order to @@ -201,7 +210,7 @@ impl PruningPredicate { struct RequiredStatColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema - /// * Statistics type (e.g. Min or Max) + /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for /// pruning predicate evaluation columns: Vec<(Column, StatisticsType, Field)>, @@ -282,6 +291,22 @@ impl RequiredStatColumns { ) -> Result { self.stat_column_expr(column, column_expr, field, StatisticsType::Max, "max") } + + /// rewrite col --> col_null_count + fn null_count_column_expr( + &mut self, + column: &Column, + column_expr: &Expr, + field: &Field, + ) -> Result { + self.stat_column_expr( + column, + column_expr, + field, + StatisticsType::NullCount, + "null_count", + ) + } } impl From> for RequiredStatColumns { @@ -330,6 +355,7 @@ fn build_statistics_record_batch( let array = match statistics_type { StatisticsType::Min => statistics.min_values(column), StatisticsType::Max => statistics.max_values(column), + StatisticsType::NullCount => statistics.null_counts(column), }; let array = array .unwrap_or_else(|| new_null_array(data_type.clone(), num_containers).into()); @@ -345,7 +371,8 @@ fn build_statistics_record_batch( // cast statistics array to required data type (e.g. parquet // provides timestamp statistics as "Int64") let array = - cast::cast(array.as_ref(), data_type, cast::CastOptions::default())?.into(); + cast_with_error(array.as_ref(), data_type, cast::CastOptions::default())? + .into(); fields.push(stat_field.clone()); arrays.push(array); @@ -585,6 +612,32 @@ fn build_single_column_expr( } } +/// Given an expression reference to `expr`, if `expr` is a column expression, +/// returns a pruning expression in terms of IsNull that will evaluate to true +/// if the column may contain null, and false if definitely does not +/// contain null. +fn build_is_null_column_expr( + expr: &Expr, + schema: &Schema, + required_columns: &mut RequiredStatColumns, +) -> Option { + match expr { + Expr::Column(ref col) => { + let field = schema.field_with_name(&col.name).ok()?; + + let null_count_field = &Field::new(field.name(), DataType::UInt64, false); + required_columns + .null_count_column_expr(col, expr, null_count_field) + .map(|null_count_column_expr| { + // IsNull(column) => null_count > 0 + null_count_column_expr.gt(lit::(0)) + }) + .ok() + } + _ => None, + } +} + /// Translate logical filter expression into pruning predicate /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. @@ -605,6 +658,11 @@ fn build_predicate_expression( // predicate expression can only be a binary expression let (left, op, right) = match expr { Expr::BinaryExpr { left, op, right } => (left, *op, right), + Expr::IsNull(expr) => { + let expr = build_is_null_column_expr(expr, schema, required_columns) + .unwrap_or(unhandled); + return Ok(expr); + } Expr::Column(col) => { let expr = build_single_column_expr(col, schema, required_columns, false) .unwrap_or(unhandled); @@ -705,19 +763,20 @@ fn build_statistics_expr(expr_builder: &mut PruningExpressionBuilder) -> Result< enum StatisticsType { Min, Max, + NullCount, } #[cfg(test)] mod tests { - use std::collections::HashMap; - use super::*; + use crate::logical_plan::{col, lit}; use crate::{assert_batches_eq, physical_optimizer::pruning::StatisticsType}; use arrow::{ array::*, datatypes::{DataType, TimeUnit}, }; + use std::collections::HashMap; #[derive(Debug)] /// Test for container stats @@ -815,6 +874,10 @@ mod tests { .map(|container_stats| container_stats.len()) .unwrap_or(0) } + + fn null_counts(&self, _column: &Column) -> Option { + None + } } /// Returns the specified min/max container values @@ -836,6 +899,10 @@ mod tests { fn num_containers(&self) -> usize { self.num_containers } + + fn null_counts(&self, _column: &Column) -> Option { + None + } } #[test] diff --git a/datafusion/src/physical_optimizer/repartition.rs b/datafusion/src/physical_optimizer/repartition.rs index 8ac9dadd9548..461d19445ea4 100644 --- a/datafusion/src/physical_optimizer/repartition.rs +++ b/datafusion/src/physical_optimizer/repartition.rs @@ -26,6 +26,7 @@ use crate::physical_plan::{Distribution, Partitioning::*}; use crate::{error::Result, execution::context::ExecutionConfig}; /// Optimizer that introduces repartition to introduce more parallelism in the plan +#[derive(Default)] pub struct Repartition {} impl Repartition { @@ -110,7 +111,8 @@ mod tests { use super::*; use crate::datasource::PartitionedFile; - use crate::physical_plan::file_format::{ParquetExec, PhysicalPlanConfig}; + use crate::field_util::SchemaExt; + use crate::physical_plan::file_format::{FileScanConfig, ParquetExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::Statistics; use crate::test::object_store::TestObjectStore; @@ -121,13 +123,12 @@ mod tests { let parquet_project = ProjectionExec::try_new( vec![], Arc::new(ParquetExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: TestObjectStore::new_arc(&[("x", 100)]), file_schema, file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], statistics: Statistics::default(), projection: None, - batch_size: 2048, limit: None, table_partition_cols: vec![], }, @@ -160,7 +161,7 @@ mod tests { Arc::new(ProjectionExec::try_new( vec![], Arc::new(ParquetExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: TestObjectStore::new_arc(&[("x", 100)]), file_schema, file_groups: vec![vec![PartitionedFile::new( @@ -169,7 +170,6 @@ mod tests { )]], statistics: Statistics::default(), projection: None, - batch_size: 2048, limit: None, table_partition_cols: vec![], }, diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 888de9aeb8bc..ac87c25b6101 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -32,11 +32,11 @@ use super::{ }; use crate::error::{DataFusionError, Result}; use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; -use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use expressions::{ - avg_return_type, stddev_return_type, sum_return_type, variance_return_type, + avg_return_type, correlation_return_type, covariance_return_type, stddev_return_type, + sum_return_type, variance_return_type, }; use std::{fmt, str::FromStr, sync::Arc}; @@ -74,6 +74,12 @@ pub enum AggregateFunction { Stddev, /// Standard Deviation (Population) StddevPop, + /// Covariance (Sample) + Covariance, + /// Covariance (Population) + CovariancePop, + /// Correlation + Correlation, } impl fmt::Display for AggregateFunction { @@ -100,6 +106,10 @@ impl FromStr for AggregateFunction { "stddev" => AggregateFunction::Stddev, "stddev_samp" => AggregateFunction::Stddev, "stddev_pop" => AggregateFunction::StddevPop, + "covar" => AggregateFunction::Covariance, + "covar_samp" => AggregateFunction::Covariance, + "covar_pop" => AggregateFunction::CovariancePop, + "corr" => AggregateFunction::Correlation, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -134,6 +144,11 @@ pub fn return_type( AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]), + AggregateFunction::Covariance => covariance_return_type(&coerced_data_types[0]), + AggregateFunction::CovariancePop => { + covariance_return_type(&coerced_data_types[0]) + } + AggregateFunction::Correlation => correlation_return_type(&coerced_data_types[0]), AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), @@ -182,14 +197,12 @@ pub fn create_aggregate_expr( name, return_type, )), - (AggregateFunction::Count, true) => { - Arc::new(distinct_expressions::DistinctCount::new( - coerced_exprs_types, - coerced_phy_exprs, - name, - return_type, - )) - } + (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( + coerced_exprs_types, + coerced_phy_exprs, + name, + return_type, + )), (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( coerced_phy_exprs[0].clone(), name, @@ -207,11 +220,18 @@ pub fn create_aggregate_expr( coerced_exprs_types[0].clone(), )) } - (AggregateFunction::ArrayAgg, _) => Arc::new(expressions::ArrayAgg::new( + (AggregateFunction::ArrayAgg, false) => Arc::new(expressions::ArrayAgg::new( coerced_phy_exprs[0].clone(), name, coerced_exprs_types[0].clone(), )), + (AggregateFunction::ArrayAgg, true) => { + Arc::new(expressions::DistinctArrayAgg::new( + coerced_phy_exprs[0].clone(), + name, + coerced_exprs_types[0].clone(), + )) + } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( coerced_phy_exprs[0].clone(), name, @@ -254,6 +274,30 @@ pub fn create_aggregate_expr( "VAR_POP(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::Covariance, false) => Arc::new(expressions::Covariance::new( + coerced_phy_exprs[0].clone(), + coerced_phy_exprs[1].clone(), + name, + return_type, + )), + (AggregateFunction::Covariance, true) => { + return Err(DataFusionError::NotImplemented( + "COVAR(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::CovariancePop, false) => { + Arc::new(expressions::CovariancePop::new( + coerced_phy_exprs[0].clone(), + coerced_phy_exprs[1].clone(), + name, + return_type, + )) + } + (AggregateFunction::CovariancePop, true) => { + return Err(DataFusionError::NotImplemented( + "COVAR_POP(DISTINCT) aggregations are not available".to_string(), + )); + } (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( coerced_phy_exprs[0].clone(), name, @@ -274,6 +318,19 @@ pub fn create_aggregate_expr( "STDDEV_POP(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::Correlation, false) => { + Arc::new(expressions::Correlation::new( + coerced_phy_exprs[0].clone(), + coerced_phy_exprs[1].clone(), + name, + return_type, + )) + } + (AggregateFunction::Correlation, true) => { + return Err(DataFusionError::NotImplemented( + "CORR(DISTINCT) aggregations are not available".to_string(), + )); + } }) } @@ -326,6 +383,12 @@ pub fn signature(fun: &AggregateFunction) -> Signature { | AggregateFunction::StddevPop => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::Covariance | AggregateFunction::CovariancePop => { + Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) + } + AggregateFunction::Correlation => { + Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) + } } } @@ -333,8 +396,10 @@ pub fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance, + ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg, + DistinctCount, Max, Min, Stddev, Sum, Variance, }; #[test] @@ -401,6 +466,49 @@ mod tests { } _ => {} }; + + let result_distinct = create_aggregate_expr( + &fun, + true, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Count => { + assert!(result_distinct.as_any().is::()); + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, true), + result_distinct.field().unwrap() + ); + } + AggregateFunction::ApproxDistinct => { + assert!(result_distinct.as_any().is::()); + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, false), + result_distinct.field().unwrap() + ); + } + AggregateFunction::ArrayAgg => { + assert!(result_distinct.as_any().is::()); + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new( + "c1", + DataType::List(Box::new(Field::new( + "item", + data_type.clone(), + true + ))), + false + ), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; } } Ok(()) @@ -669,6 +777,147 @@ mod tests { Ok(()) } + #[test] + fn test_covar_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Covariance]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = Schema::new(vec![ + Field::new("c1", data_type.clone(), true), + Field::new("c2", data_type.clone(), true), + ]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema) + .unwrap(), + ), + Arc::new( + expressions::Column::new_with_schema("c2", &input_schema) + .unwrap(), + ), + ]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..2], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Covariance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_covar_pop_expr() -> Result<()> { + let funcs = vec![AggregateFunction::CovariancePop]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = Schema::new(vec![ + Field::new("c1", data_type.clone(), true), + Field::new("c2", data_type.clone(), true), + ]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema) + .unwrap(), + ), + Arc::new( + expressions::Column::new_with_schema("c2", &input_schema) + .unwrap(), + ), + ]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..2], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Covariance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_corr_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Correlation]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = Schema::new(vec![ + Field::new("c1", data_type.clone(), true), + Field::new("c2", data_type.clone(), true), + ]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema) + .unwrap(), + ), + Arc::new( + expressions::Column::new_with_schema("c2", &input_schema) + .unwrap(), + ), + ]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..2], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Covariance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; diff --git a/datafusion/src/physical_plan/analyze.rs b/datafusion/src/physical_plan/analyze.rs index 5cfd8421f7ca..538fb86f7f93 100644 --- a/datafusion/src/physical_plan/analyze.rs +++ b/datafusion/src/physical_plan/analyze.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use std::{any::Any, time::Instant}; +use crate::record_batch::RecordBatch; use crate::{ error::{DataFusionError, Result}, physical_plan::{ @@ -27,10 +28,11 @@ use crate::{ Partitioning, Statistics, }, }; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::datatypes::SchemaRef; use futures::StreamExt; use super::{stream::RecordBatchReceiverStream, Distribution, SendableRecordBatchStream}; +use crate::execution::runtime_env::RuntimeEnv; use arrow::array::MutableUtf8Array; use async_trait::async_trait; @@ -100,7 +102,11 @@ impl ExecutionPlan for AnalyzeExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "AnalyzeExec invalid partition. Expected 0, got {}", @@ -120,7 +126,7 @@ impl ExecutionPlan for AnalyzeExec { let (tx, rx) = tokio::sync::mpsc::channel(input_partitions); let captured_input = self.input.clone(); - let mut input_stream = captured_input.execute(0).await?; + let mut input_stream = captured_input.execute(0, runtime).await?; let captured_schema = self.schema.clone(); let verbose = self.verbose; @@ -220,6 +226,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use futures::FutureExt; + use crate::field_util::SchemaExt; use crate::{ physical_plan::collect, test::{ @@ -232,6 +239,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -239,7 +247,7 @@ mod tests { let refs = blocking_exec.refs(); let analyze_exec = Arc::new(AnalyzeExec::new(true, blocking_exec, schema)); - let fut = collect(analyze_exec); + let fut = collect(analyze_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/coalesce_batches.rs b/datafusion/src/physical_plan/coalesce_batches.rs index 2a4d799fe271..3f17d4d50d92 100644 --- a/datafusion/src/physical_plan/coalesce_batches.rs +++ b/datafusion/src/physical_plan/coalesce_batches.rs @@ -18,6 +18,8 @@ //! CoalesceBatchesExec combines small batches into larger batches for more efficient use of //! vectorized processing by upstream operators. +use arrow::array::ArrayRef; +use arrow::chunk::Chunk; use std::any::Any; use std::pin::Pin; use std::sync::Arc; @@ -29,10 +31,12 @@ use crate::physical_plan::{ SendableRecordBatchStream, }; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; +use crate::record_batch::RecordBatch; use arrow::compute::concatenate::concatenate; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; use async_trait::async_trait; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -111,9 +115,13 @@ impl ExecutionPlan for CoalesceBatchesExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { Ok(Box::pin(CoalesceBatchesStream { - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, schema: self.input.schema(), target_batch_size: self.target_batch_size, buffer: Vec::new(), @@ -288,9 +296,38 @@ pub fn concat_batches( RecordBatch::try_new(schema.clone(), arrays) } +/// Concatenates an array of `arrow::chunk::Chunk` into one +pub fn concat_chunks( + schema: &SchemaRef, + batches: &[Chunk], + row_count: usize, +) -> ArrowResult> { + if batches.is_empty() { + return Ok(Chunk::new(vec![])); + } + let mut arrays = Vec::with_capacity(schema.fields().len()); + for i in 0..schema.fields().len() { + let array = concatenate( + &batches + .iter() + .map(|batch| batch.columns()[i].as_ref()) + .collect::>(), + )? + .into(); + arrays.push(array); + } + debug!( + "Combined {} batches containing {} rows", + batches.len(), + row_count + ); + Chunk::try_new(arrays) +} + #[cfg(test)] mod tests { use super::*; + use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec}; use arrow::array::UInt32Array; use arrow::datatypes::{DataType, Field, Schema}; @@ -352,9 +389,10 @@ mod tests { // execute and collect results let output_partition_count = exec.output_partitioning().partition_count(); let mut output_partitions = Vec::with_capacity(output_partition_count); + let runtime = Arc::new(RuntimeEnv::default()); for i in 0..output_partition_count { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i).await?; + let mut stream = exec.execute(i, runtime.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs b/datafusion/src/physical_plan/coalesce_partitions.rs index 089c6b4617aa..e3bf890cee59 100644 --- a/datafusion/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/src/physical_plan/coalesce_partitions.rs @@ -27,7 +27,7 @@ use futures::Stream; use async_trait::async_trait; -use arrow::record_batch::RecordBatch; +use crate::record_batch::RecordBatch; use arrow::{datatypes::SchemaRef, error::Result as ArrowResult}; use super::common::AbortOnDropMany; @@ -37,6 +37,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; use super::SendableRecordBatchStream; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::common::spawn_execution; use pin_project_lite::pin_project; @@ -97,7 +98,11 @@ impl ExecutionPlan for CoalescePartitionsExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // CoalescePartitionsExec produces a single partition if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -113,7 +118,7 @@ impl ExecutionPlan for CoalescePartitionsExec { )), 1 => { // bypass any threading / metrics if there is a single partition - self.input.execute(0).await + self.input.execute(0, runtime).await } _ => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -136,6 +141,7 @@ impl ExecutionPlan for CoalescePartitionsExec { self.input.clone(), sender.clone(), part_i, + runtime.clone(), )); } @@ -207,7 +213,8 @@ mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; - use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; + use crate::field_util::SchemaExt; + use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::{collect, common}; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending}; @@ -215,19 +222,19 @@ mod tests { #[tokio::test] async fn merge() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let num_partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", num_partitions)?; let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: schema, file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -244,7 +251,7 @@ mod tests { assert_eq!(merge.output_partitioning().partition_count(), 1); // the result should contain 4 batches (one per input partition) - let iter = merge.execute(0).await?; + let iter = merge.execute(0, runtime).await?; let batches = common::collect(iter).await?; assert_eq!(batches.len(), num_partitions); @@ -257,6 +264,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -265,7 +273,7 @@ mod tests { let coaelesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec); + let fut = collect(coaelesce_partitions_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index 75672fd4fe99..92168b9dff8f 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -21,7 +21,8 @@ use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_avg_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, + is_avg_support_arg_type, is_correlation_support_arg_type, + is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type, try_cast, }; use crate::physical_plan::functions::{Signature, TypeSignature}; @@ -105,6 +106,24 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Covariance => { + if !is_covariance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::CovariancePop => { + if !is_covariance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } AggregateFunction::Stddev => { if !is_stddev_support_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( @@ -123,6 +142,15 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Correlation => { + if !is_correlation_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } diff --git a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs new file mode 100644 index 000000000000..cfb9828d710b --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs @@ -0,0 +1,625 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Coercion rules for matching argument types for binary operators + +use crate::arrow::datatypes::DataType; +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::Operator; +use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128}; + +/// Coercion rules for all binary operators. Returns the output type +/// of applying `op` to an argument of `lhs_type` and `rhs_type`. +pub(crate) fn coerce_types( + lhs_type: &DataType, + op: &Operator, + rhs_type: &DataType, +) -> Result { + // This result MUST be compatible with `binary_coerce` + let result = match op { + Operator::And | Operator::Or => match (lhs_type, rhs_type) { + // logical binary boolean operators can only be evaluated in bools + (DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean), + _ => None, + }, + // logical equality operators have their own rules, and always return a boolean + Operator::Eq | Operator::NotEq => comparison_eq_coercion(lhs_type, rhs_type), + // order-comparison operators have their own rules + Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => { + comparison_order_coercion(lhs_type, rhs_type) + } + // "like" operators operate on strings and always return a boolean + Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type), + // for math expressions, the final value of the coercion is also the return type + // because coercion favours higher information types + Operator::Plus + | Operator::Minus + | Operator::Modulo + | Operator::Divide + | Operator::Multiply => mathematics_numerical_coercion(op, lhs_type, rhs_type), + Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch => string_coercion(lhs_type, rhs_type), + Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => { + eq_coercion(lhs_type, rhs_type) + } + }; + + // re-write the error message of failed coercions to include the operator's information + match result { + None => Err(DataFusionError::Plan( + format!( + "'{:?} {} {:?}' can't be evaluated because there isn't a common type to coerce the types to", + lhs_type, op, rhs_type + ), + )), + Some(t) => Ok(t) + } +} + +fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { + // same type => equality is possible + return Some(lhs_type.clone()); + } + comparison_binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| temporal_coercion(lhs_type, rhs_type)) +} + +fn comparison_order_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { + // same type => all good + return Some(lhs_type.clone()); + } + comparison_binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| temporal_coercion(lhs_type, rhs_type)) +} + +fn comparison_binary_numeric_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + return None; + }; + + // same type => all good + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + // these are ordered from most informative to least informative so + // that the coercion removes the least amount of information + match (lhs_type, rhs_type) { + // support decimal data type for comparison operation + (Decimal(p1, s1), Decimal(p2, s2)) => Some(Decimal(*p1.max(p2), *s1.max(s2))), + (Decimal(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), + (_, Decimal(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), + (Float64, _) | (_, Float64) => Some(Float64), + (_, Float32) | (Float32, _) => Some(Float32), + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + (UInt64, _) | (_, UInt64) => Some(UInt64), + (UInt32, _) | (_, UInt32) => Some(UInt32), + (UInt16, _) | (_, UInt16) => Some(UInt16), + (UInt8, _) | (_, UInt8) => Some(UInt8), + _ => None, + } +} + +fn get_comparison_common_decimal_type( + decimal_type: &DataType, + other_type: &DataType, +) -> Option { + let other_decimal_type = &match other_type { + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 + DataType::Int8 => DataType::Decimal(3, 0), + DataType::Int16 => DataType::Decimal(5, 0), + DataType::Int32 => DataType::Decimal(10, 0), + DataType::Int64 => DataType::Decimal(20, 0), + DataType::Float32 => DataType::Decimal(14, 7), + DataType::Float64 => DataType::Decimal(30, 15), + _ => { + return None; + } + }; + match (decimal_type, &other_decimal_type) { + (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => { + let new_precision = p1.max(p2); + let new_scale = s1.max(s2); + Some(DataType::Decimal(*new_precision, *new_scale)) + } + _ => None, + } +} + +// Convert the numeric data type to the decimal data type. +// Now, we just support the signed integer type and floating-point type. +fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { + match numeric_type { + DataType::Int8 => Some(DataType::Decimal(3, 0)), + DataType::Int16 => Some(DataType::Decimal(5, 0)), + DataType::Int32 => Some(DataType::Decimal(10, 0)), + DataType::Int64 => Some(DataType::Decimal(20, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + DataType::Float32 => Some(DataType::Decimal(14, 7)), + DataType::Float64 => Some(DataType::Decimal(30, 15)), + _ => None, + } +} + +fn mathematics_numerical_coercion( + mathematics_op: &Operator, + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + // error on any non-numeric type + if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + return None; + }; + + // same type => all good + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + // these are ordered from most informative to least informative so + // that the coercion removes the least amount of information + match (lhs_type, rhs_type) { + (Decimal(_, _), Decimal(_, _)) => { + coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type) + } + (Decimal(_, _), _) => { + let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type); + match converted_decimal_type { + None => None, + Some(right_decimal_type) => coercion_decimal_mathematics_type( + mathematics_op, + lhs_type, + &right_decimal_type, + ), + } + } + (_, Decimal(_, _)) => { + let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type); + match converted_decimal_type { + None => None, + Some(left_decimal_type) => coercion_decimal_mathematics_type( + mathematics_op, + &left_decimal_type, + rhs_type, + ), + } + } + (Float64, _) | (_, Float64) => Some(Float64), + (_, Float32) | (Float32, _) => Some(Float32), + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + (UInt64, _) | (_, UInt64) => Some(UInt64), + (UInt32, _) | (_, UInt32) => Some(UInt32), + (UInt16, _) | (_, UInt16) => Some(UInt16), + (UInt8, _) | (_, UInt8) => Some(UInt8), + _ => None, + } +} + +fn create_decimal_type(precision: usize, scale: usize) -> DataType { + DataType::Decimal( + MAX_PRECISION_FOR_DECIMAL128.min(precision), + MAX_SCALE_FOR_DECIMAL128.min(scale), + ) +} + +fn coercion_decimal_mathematics_type( + mathematics_op: &Operator, + left_decimal_type: &DataType, + right_decimal_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + match (left_decimal_type, right_decimal_type) { + // The coercion rule from spark + // https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35 + (Decimal(p1, s1), Decimal(p2, s2)) => { + match mathematics_op { + Operator::Plus | Operator::Minus => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // max(s1, s2) + max(p1-s1, p2-s2) + 1 + let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Multiply => { + // s1 + s2 + let result_scale = *s1 + *s2; + // p1 + p2 + 1 + let result_precision = *p1 + *p2 + 1; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Divide => { + // max(6, s1 + p2 + 1) + let result_scale = 6.max(*s1 + *p2 + 1); + // p1 - s1 + s2 + max(6, s1 + p2 + 1) + let result_precision = result_scale + *p1 - *s1 + *s2; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Modulo => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // min(p1-s1, p2-s2) + max(s1, s2) + let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2); + Some(create_decimal_type(result_precision, result_scale)) + } + _ => unreachable!(), + } + } + _ => unreachable!(), + } +} + +/// Determine if a DataType is signed numeric or not +pub fn is_signed_numeric(dt: &DataType) -> bool { + matches!( + dt, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + ) +} + +/// Determine if a DataType is numeric or not +pub fn is_numeric(dt: &DataType) -> bool { + is_signed_numeric(dt) + || match dt { + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + true + } + _ => false, + } +} + +/// Coercion rules for dictionary values (aka the type of the dictionary itself) +fn dictionary_value_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + numerical_coercion(lhs_type, rhs_type).or_else(|| string_coercion(lhs_type, rhs_type)) +} + +/// Coercion rules for Dictionaries: the type that both lhs and rhs +/// can be casted to for the purpose of a computation. +/// +/// It would likely be preferable to cast primitive values to +/// dictionaries, and thus avoid unpacking dictionary as well as doing +/// faster comparisons. However, the arrow compute kernels (e.g. eq) +/// don't have DictionaryArray support yet, so fall back to unpacking +/// the dictionaries +fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + match (lhs_type, rhs_type) { + ( + DataType::Dictionary(_lhs_index_type, lhs_value_type, _), + DataType::Dictionary(_rhs_index_type, rhs_value_type, _), + ) => dictionary_value_coercion(lhs_value_type, rhs_value_type), + (DataType::Dictionary(_index_type, value_type, _), _) => { + dictionary_value_coercion(value_type, rhs_type) + } + (_, DataType::Dictionary(_index_type, value_type, _)) => { + dictionary_value_coercion(lhs_type, value_type) + } + _ => None, + } +} + +/// Coercion rules for Strings: the type that both lhs and rhs can be +/// casted to for the purpose of a string computation +fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Utf8, Utf8) => Some(Utf8), + (LargeUtf8, Utf8) => Some(LargeUtf8), + (Utf8, LargeUtf8) => Some(LargeUtf8), + (LargeUtf8, LargeUtf8) => Some(LargeUtf8), + _ => None, + } +} + +/// coercion rules for like operations. +/// This is a union of string coercion rules and dictionary coercion rules +fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + string_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) +} + +/// Coercion rules for Temporal columns: the type that both lhs and rhs can be +/// casted to for the purpose of a date computation +fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + use arrow::datatypes::TimeUnit; + match (lhs_type, rhs_type) { + (Utf8, Date32) => Some(Date32), + (Date32, Utf8) => Some(Date32), + (Utf8, Date64) => Some(Date64), + (Date64, Utf8) => Some(Date64), + (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { + let tz = match (lhs_tz, rhs_tz) { + // can't cast across timezones + (Some(lhs_tz), Some(rhs_tz)) => { + if lhs_tz != rhs_tz { + return None; + } else { + Some(lhs_tz.clone()) + } + } + (Some(lhs_tz), None) => Some(lhs_tz.clone()), + (None, Some(rhs_tz)) => Some(rhs_tz.clone()), + (None, None) => None, + }; + + let unit = match (lhs_unit, rhs_unit) { + (TimeUnit::Second, TimeUnit::Millisecond) => TimeUnit::Second, + (TimeUnit::Second, TimeUnit::Microsecond) => TimeUnit::Second, + (TimeUnit::Second, TimeUnit::Nanosecond) => TimeUnit::Second, + (TimeUnit::Millisecond, TimeUnit::Second) => TimeUnit::Second, + (TimeUnit::Millisecond, TimeUnit::Microsecond) => TimeUnit::Millisecond, + (TimeUnit::Millisecond, TimeUnit::Nanosecond) => TimeUnit::Millisecond, + (TimeUnit::Microsecond, TimeUnit::Second) => TimeUnit::Second, + (TimeUnit::Microsecond, TimeUnit::Millisecond) => TimeUnit::Millisecond, + (TimeUnit::Microsecond, TimeUnit::Nanosecond) => TimeUnit::Microsecond, + (TimeUnit::Nanosecond, TimeUnit::Second) => TimeUnit::Second, + (TimeUnit::Nanosecond, TimeUnit::Millisecond) => TimeUnit::Millisecond, + (TimeUnit::Nanosecond, TimeUnit::Microsecond) => TimeUnit::Microsecond, + (l, r) => { + assert_eq!(l, r); + *l + } + }; + + Some(Timestamp(unit, tz)) + } + _ => None, + } +} + +pub(crate) fn is_dictionary(t: &DataType) -> bool { + matches!(t, DataType::Dictionary(_, _, _)) +} + +/// Coercion rule for numerical types: The type that both lhs and rhs +/// can be casted to for numerical calculation, while maintaining +/// maximum precision +fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + // error on any non-numeric type + if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + return None; + }; + + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { + // same type => all good + return Some(lhs_type.clone()); + } + + // these are ordered from most informative to least informative so + // that the coercion removes the least amount of information + match (lhs_type, rhs_type) { + (Float64, _) | (_, Float64) => Some(Float64), + (_, Float32) | (Float32, _) => Some(Float32), + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + (UInt64, _) | (_, UInt64) => Some(UInt64), + (UInt32, _) | (_, UInt32) => Some(UInt32), + (UInt16, _) | (_, UInt16) => Some(UInt16), + (UInt8, _) | (_, UInt8) => Some(UInt8), + _ => None, + } +} + +/// coercion rules for equality operations. This is a superset of all numerical coercion rules. +fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { + // same type => equality is possible + return Some(lhs_type.clone()); + } + numerical_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| temporal_coercion(lhs_type, rhs_type)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::datatypes::DataType; + use crate::error::{DataFusionError, Result}; + use crate::logical_plan::Operator; + use arrow::datatypes::IntegerType; + + #[test] + + fn test_coercion_error() -> Result<()> { + let result_type = + coerce_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8); + + if let Err(DataFusionError::Plan(e)) = result_type { + assert_eq!(e, "'Float32 + Utf8' can't be evaluated because there isn't a common type to coerce the types to"); + Ok(()) + } else { + Err(DataFusionError::Internal( + "Coercion should have returned an DataFusionError::Internal".to_string(), + )) + } + } + + #[test] + fn test_decimal_binary_comparison_coercion() -> Result<()> { + let input_decimal = DataType::Decimal(20, 3); + let input_types = [ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Decimal(38, 10), + ]; + let result_types = [ + DataType::Decimal(20, 3), + DataType::Decimal(20, 3), + DataType::Decimal(20, 3), + DataType::Decimal(20, 3), + DataType::Decimal(20, 7), + DataType::Decimal(30, 15), + DataType::Decimal(38, 10), + ]; + let comparison_op_types = [ + Operator::NotEq, + Operator::Eq, + Operator::Gt, + Operator::GtEq, + Operator::Lt, + Operator::LtEq, + ]; + for (i, input_type) in input_types.iter().enumerate() { + let expect_type = &result_types[i]; + for op in comparison_op_types { + let result_type = coerce_types(&input_decimal, &op, input_type)?; + assert_eq!(expect_type, &result_type); + } + } + // negative test + let result_type = coerce_types(&input_decimal, &Operator::Eq, &DataType::Boolean); + assert!(result_type.is_err()); + Ok(()) + } + + #[test] + fn test_decimal_mathematics_op_type() { + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(), + DataType::Decimal(3, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(), + DataType::Decimal(5, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(), + DataType::Decimal(10, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(), + DataType::Decimal(20, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(), + DataType::Decimal(14, 7) + ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(), + DataType::Decimal(30, 15) + ); + + let op = Operator::Plus; + let left_decimal_type = DataType::Decimal(10, 3); + let right_decimal_type = DataType::Decimal(20, 4); + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(21, 4), result.unwrap()); + let op = Operator::Minus; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(21, 4), result.unwrap()); + let op = Operator::Multiply; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(31, 7), result.unwrap()); + let op = Operator::Divide; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(35, 24), result.unwrap()); + let op = Operator::Modulo; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(11, 4), result.unwrap()); + } + + #[test] + fn test_dictionary_type_coersion() { + use DataType::*; + + // TODO: In the future, this would ideally return Dictionary types and avoid unpacking + let lhs_type = Dictionary(IntegerType::Int8, Box::new(Int32), false); + let rhs_type = Dictionary(IntegerType::Int8, Box::new(Int16), false); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32)); + + let lhs_type = Dictionary(IntegerType::Int8, Box::new(Utf8), false); + let rhs_type = Dictionary(IntegerType::Int8, Box::new(Int16), false); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None); + + let lhs_type = Dictionary(IntegerType::Int8, Box::new(Utf8), false); + let rhs_type = Utf8; + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + + let lhs_type = Utf8; + let rhs_type = Dictionary(IntegerType::Int8, Box::new(Utf8), false); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + } +} diff --git a/datafusion/src/physical_plan/coercion_rule/mod.rs b/datafusion/src/physical_plan/coercion_rule/mod.rs index 1aeabda793b1..83e091aa5e91 100644 --- a/datafusion/src/physical_plan/coercion_rule/mod.rs +++ b/datafusion/src/physical_plan/coercion_rule/mod.rs @@ -18,5 +18,7 @@ //! Define coercion rules for different Expr type. //! //! Aggregate function rule +//! Binary operation rule pub(crate) mod aggregate_rule; +pub(crate) mod binary_rule; diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index 94d53438e736..733b1ee92fbf 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -19,18 +19,23 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; +use crate::physical_plan::metrics::BaselineMetrics; use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::record_batch::RecordBatch; use arrow::compute::aggregate::estimated_bytes_size; use arrow::compute::concatenate; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; +use arrow::io::ipc::write::{FileWriter, WriteOptions}; use futures::channel::mpsc; use futures::{Future, SinkExt, Stream, StreamExt, TryStreamExt}; use pin_project_lite::pin_project; use std::fs; -use std::fs::metadata; +use std::fs::{metadata, File}; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::task::JoinHandle; @@ -40,15 +45,21 @@ pub struct SizedRecordBatchStream { schema: SchemaRef, batches: Vec>, index: usize, + baseline_metrics: BaselineMetrics, } impl SizedRecordBatchStream { /// Create a new RecordBatchIterator - pub fn new(schema: SchemaRef, batches: Vec>) -> Self { + pub fn new( + schema: SchemaRef, + batches: Vec>, + baseline_metrics: BaselineMetrics, + ) -> Self { SizedRecordBatchStream { schema, index: 0, batches, + baseline_metrics, } } } @@ -60,12 +71,13 @@ impl Stream for SizedRecordBatchStream { mut self: std::pin::Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll> { - Poll::Ready(if self.index < self.batches.len() { + let poll = Poll::Ready(if self.index < self.batches.len() { self.index += 1; Some(Ok(self.batches[self.index - 1].as_ref().clone())) } else { None - }) + }); + self.baseline_metrics.record_poll(poll) } } @@ -165,9 +177,10 @@ pub(crate) fn spawn_execution( input: Arc, mut output: mpsc::Sender>, partition: usize, + runtime: Arc, ) -> JoinHandle<()> { tokio::spawn(async move { - let mut stream = match input.execute(partition).await { + let mut stream = match input.execute(partition, runtime).await { Err(e) => { // If send fails, plan being torn // down, no place to send the error @@ -197,12 +210,7 @@ pub fn compute_record_batch_statistics( ) -> Statistics { let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); - let total_byte_size = batches - .iter() - .flatten() - .flat_map(RecordBatch::columns) - .map(|a| estimated_bytes_size(a.as_ref())) - .sum(); + let total_byte_size = batches.iter().flatten().map(batch_byte_size).sum(); let projection = match projection { Some(p) => p, @@ -278,10 +286,10 @@ impl Drop for AbortOnDropMany { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, }; #[test] @@ -379,3 +387,65 @@ mod tests { Ok(()) } } + +/// Write in Arrow IPC format. +pub struct IPCWriter { + /// path + pub path: PathBuf, + /// Inner writer + pub writer: FileWriter, + /// bathes written + pub num_batches: u64, + /// rows written + pub num_rows: u64, + /// bytes written + pub num_bytes: u64, +} + +impl IPCWriter { + /// Create new writer + pub fn new(path: &Path, schema: &Schema) -> Result { + let file = File::create(path).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to create partition file at {:?}: {:?}", + path, e + )) + })?; + Ok(Self { + num_batches: 0, + num_rows: 0, + num_bytes: 0, + path: path.into(), + writer: FileWriter::try_new(file, schema, None, WriteOptions::default())?, + }) + } + + /// Write one single batch + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(&batch.into(), None)?; + self.num_batches += 1; + self.num_rows += batch.num_rows() as u64; + let num_bytes: usize = batch_byte_size(batch); + self.num_bytes += num_bytes as u64; + Ok(()) + } + + /// Finish the writer + pub fn finish(&mut self) -> Result<()> { + self.writer.finish().map_err(DataFusionError::ArrowError) + } + + /// Path write to + pub fn path(&self) -> &Path { + &self.path + } +} + +/// Returns the total number of bytes of memory occupied physically by this batch. +pub fn batch_byte_size(batch: &RecordBatch) -> usize { + batch + .columns() + .iter() + .map(|a| estimated_bytes_size(a.as_ref())) + .sum() +} diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index 7c6d7e4d7d59..ed750eafeb62 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -22,9 +22,9 @@ use futures::{lock::Mutex, StreamExt}; use std::{any::Any, sync::Arc, task::Poll}; use crate::physical_plan::memory::MemoryStream; +use crate::record_batch::RecordBatch; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; use futures::{Stream, TryStreamExt}; @@ -43,6 +43,8 @@ use super::{ coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; use log::debug; /// Data of the left side @@ -137,7 +139,11 @@ impl ExecutionPlan for CrossJoinExec { self.right.output_partitioning() } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // we only want to compute the build side once let left_data = { let mut build_side = self.build_side.lock().await; @@ -149,7 +155,7 @@ impl ExecutionPlan for CrossJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0).await?; + let stream = merge.execute(0, runtime.clone()).await?; // Load all batches and count the rows let (batches, num_rows) = stream @@ -174,7 +180,7 @@ impl ExecutionPlan for CrossJoinExec { } }; - let stream = self.right.execute(partition).await?; + let stream = self.right.execute(partition, runtime.clone()).await?; if left_data.num_rows() == 0 { return Ok(Box::pin(MemoryStream::try_new( @@ -327,8 +333,7 @@ fn build_batch( let scalar = ScalarValue::try_from_array(arr, left_index)?; Ok(scalar.to_array_of_size(batch.num_rows())) }) - .collect::>>() - .map_err(|x| x.into_arrow_external_error())?; + .collect::>>()?; RecordBatch::try_new( Arc::new(schema.clone()), diff --git a/datafusion/src/physical_plan/crypto_expressions.rs b/datafusion/src/physical_plan/crypto_expressions.rs index c3e802d850d2..36e3dc08872c 100644 --- a/datafusion/src/physical_plan/crypto_expressions.rs +++ b/datafusion/src/physical_plan/crypto_expressions.rs @@ -25,7 +25,7 @@ use arrow::{ array::{Array, BinaryArray, Offset, Utf8Array}, datatypes::DataType, }; -use blake2::{Blake2b, Blake2s, Digest}; +use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; use md5::Md5; use sha2::{Sha224, Sha256, Sha384, Sha512}; @@ -112,8 +112,8 @@ impl DigestAlgorithm { Self::Sha256 => digest_to_scalar!(Sha256, value), Self::Sha384 => digest_to_scalar!(Sha384, value), Self::Sha512 => digest_to_scalar!(Sha512, value), - Self::Blake2b => digest_to_scalar!(Blake2b, value), - Self::Blake2s => digest_to_scalar!(Blake2s, value), + Self::Blake2b => digest_to_scalar!(Blake2b512, value), + Self::Blake2s => digest_to_scalar!(Blake2s256, value), Self::Blake3 => ScalarValue::Binary(value.as_ref().map(|v| { let mut digest = Blake3::default(); digest.update(v.as_bytes()); @@ -143,8 +143,8 @@ impl DigestAlgorithm { Self::Sha256 => digest_to_array!(Sha256, input_value), Self::Sha384 => digest_to_array!(Sha384, input_value), Self::Sha512 => digest_to_array!(Sha512, input_value), - Self::Blake2b => digest_to_array!(Blake2b, input_value), - Self::Blake2s => digest_to_array!(Blake2s, input_value), + Self::Blake2b => digest_to_array!(Blake2b512, input_value), + Self::Blake2s => digest_to_array!(Blake2s256, input_value), Self::Blake3 => { let binary_array: BinaryArray = input_value .iter() diff --git a/datafusion/src/physical_plan/datetime_expressions.rs b/datafusion/src/physical_plan/datetime_expressions.rs index 2879378c6331..28cf1a570179 100644 --- a/datafusion/src/physical_plan/datetime_expressions.rs +++ b/datafusion/src/physical_plan/datetime_expressions.rs @@ -33,9 +33,9 @@ use arrow::{ }; use arrow::{compute::temporal, temporal_conversions::timestamp_ns_to_datetime}; use chrono::prelude::{DateTime, Utc}; -use chrono::Datelike; use chrono::Duration; use chrono::Timelike; +use chrono::{Datelike, NaiveDateTime}; use std::borrow::Borrow; /// given a function `op` that maps a `&str` to a Result of an arrow native type, @@ -113,7 +113,9 @@ where let s = PrimitiveScalar::::new(data_type, Some((op)(s)?)); ColumnarValue::Scalar(s.try_into()?) } - None => ColumnarValue::Scalar(ScalarValue::new_null(data_type)), + None => ColumnarValue::Scalar( + PrimitiveScalar::::new(data_type, None).try_into()?, + ), }), other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", @@ -186,6 +188,10 @@ pub fn make_now( } } +fn quarter_month(date: &NaiveDateTime) -> u32 { + 1 + 3 * ((date.month() - 1) / 3) +} + fn date_trunc_single(granularity: &str, value: i64) -> Result { let value = timestamp_ns_to_datetime(value).with_nanosecond(0); let value = match granularity { @@ -208,6 +214,12 @@ fn date_trunc_single(granularity: &str, value: i64) -> Result { .and_then(|d| d.with_minute(0)) .and_then(|d| d.with_hour(0)) .and_then(|d| d.with_day0(0)), + "quarter" => value + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) + .and_then(|d| d.with_month(quarter_month(&d))), "year" => value .and_then(|d| d.with_second(0)) .and_then(|d| d.with_minute(0)) @@ -380,6 +392,7 @@ mod tests { "year", "2020-01-01T00:00:00.000000Z", ), + // week ( "2021-01-01T13:42:29.190855Z", "week", @@ -390,13 +403,49 @@ mod tests { "week", "2019-12-30T00:00:00.000000Z", ), + // quarter + ( + "2020-01-01T13:42:29.190855Z", + "quarter", + "2020-01-01T00:00:00.000000Z", + ), + ( + "2020-02-01T13:42:29.190855Z", + "quarter", + "2020-01-01T00:00:00.000000Z", + ), + ( + "2020-03-01T13:42:29.190855Z", + "quarter", + "2020-01-01T00:00:00.000000Z", + ), + ( + "2020-04-01T13:42:29.190855Z", + "quarter", + "2020-04-01T00:00:00.000000Z", + ), + ( + "2020-08-01T13:42:29.190855Z", + "quarter", + "2020-07-01T00:00:00.000000Z", + ), + ( + "2020-11-01T13:42:29.190855Z", + "quarter", + "2020-10-01T00:00:00.000000Z", + ), + ( + "2020-12-01T13:42:29.190855Z", + "quarter", + "2020-10-01T00:00:00.000000Z", + ), ]; cases.iter().for_each(|(original, granularity, expected)| { - let original = string_to_timestamp_nanos(original).unwrap(); - let expected = string_to_timestamp_nanos(expected).unwrap(); - let result = date_trunc_single(granularity, original).unwrap(); - assert_eq!(result, expected); + let left = string_to_timestamp_nanos(original).unwrap(); + let right = string_to_timestamp_nanos(expected).unwrap(); + let result = date_trunc_single(granularity, left).unwrap(); + assert_eq!(result, right, "{} = {}", original, expected); }); } diff --git a/datafusion/src/physical_plan/empty.rs b/datafusion/src/physical_plan/empty.rs index a8dead391ec8..9d78f52b3cd7 100644 --- a/datafusion/src/physical_plan/empty.rs +++ b/datafusion/src/physical_plan/empty.rs @@ -25,12 +25,14 @@ use crate::physical_plan::{ memory::MemoryStream, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, }; +use crate::record_batch::RecordBatch; use arrow::array::NullArray; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; use super::{common, SendableRecordBatchStream, Statistics}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; use async_trait::async_trait; /// Execution plan for empty relation (produces no rows) @@ -110,7 +112,11 @@ impl ExecutionPlan for EmptyExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -153,13 +159,14 @@ mod tests { #[tokio::test] async fn empty() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(false, schema.clone()); assert_eq!(empty.schema(), schema); // we should have no results - let iter = empty.execute(0).await?; + let iter = empty.execute(0, runtime).await?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); @@ -184,21 +191,23 @@ mod tests { #[tokio::test] async fn invalid_execute() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(false, schema); // ask for the wrong partition - assert!(empty.execute(1).await.is_err()); - assert!(empty.execute(20).await.is_err()); + assert!(empty.execute(1, runtime.clone()).await.is_err()); + assert!(empty.execute(20, runtime.clone()).await.is_err()); Ok(()) } #[tokio::test] async fn produce_one_row() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(true, schema); - let iter = empty.execute(0).await?; + let iter = empty.execute(0, runtime).await?; let batches = common::collect(iter).await?; // should have one item diff --git a/datafusion/src/physical_plan/explain.rs b/datafusion/src/physical_plan/explain.rs index 712780a4e340..eb1a3e09e5eb 100644 --- a/datafusion/src/physical_plan/explain.rs +++ b/datafusion/src/physical_plan/explain.rs @@ -20,6 +20,9 @@ use std::any::Any; use std::sync::Arc; +use super::SendableRecordBatchStream; +use crate::execution::runtime_env::RuntimeEnv; +use crate::record_batch::RecordBatch; use crate::{ error::{DataFusionError, Result}, logical_plan::StringifiedPlan, @@ -28,9 +31,9 @@ use crate::{ Statistics, }, }; -use arrow::{array::*, datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow::{array::*, datatypes::SchemaRef}; -use super::SendableRecordBatchStream; +use crate::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; use async_trait::async_trait; /// Explain execution plan operator. This operator contains the string @@ -101,7 +104,11 @@ impl ExecutionPlan for ExplainExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "ExplainExec invalid partition {}", @@ -140,9 +147,13 @@ impl ExecutionPlan for ExplainExec { vec![type_builder.into_arc(), plan_builder.into_arc()], )?; + let metrics = ExecutionPlanMetricsSet::new(); + let baseline_metrics = BaselineMetrics::new(&metrics, partition); + Ok(Box::pin(SizedRecordBatchStream::new( self.schema.clone(), vec![Arc::new(record_batch)], + baseline_metrics, ))) } diff --git a/datafusion/src/physical_plan/expressions/approx_distinct.rs b/datafusion/src/physical_plan/expressions/approx_distinct.rs index 0e4ba9c398ba..677cd7395ca0 100644 --- a/datafusion/src/physical_plan/expressions/approx_distinct.rs +++ b/datafusion/src/physical_plan/expressions/approx_distinct.rs @@ -210,23 +210,6 @@ impl TryFrom<&ScalarValue> for HyperLogLog { macro_rules! default_accumulator_impl { () => { - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - self.update_batch( - values - .iter() - .map(|s| s.to_array() as ArrayRef) - .collect::>() - .as_slice(), - ) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - assert_eq!(1, states.len(), "expect only 1 element in the states"); - let other = HyperLogLog::try_from(&states[0])?; - self.hll.merge(&other); - Ok(()) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { assert_eq!(1, states.len(), "expect only 1 element in the states"); let binary_array = states[0] diff --git a/datafusion/src/physical_plan/expressions/array_agg.rs b/datafusion/src/physical_plan/expressions/array_agg.rs index c86a08ba8aa3..be49408bdf16 100644 --- a/datafusion/src/physical_plan/expressions/array_agg.rs +++ b/datafusion/src/physical_plan/expressions/array_agg.rs @@ -18,9 +18,10 @@ //! Defines physical expressions that can evaluated at runtime during query execution use super::format_state_name; -use crate::error::Result; +use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; +use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; use std::any::Any; use std::sync::Arc; @@ -94,7 +95,7 @@ impl AggregateExpr for ArrayAgg { #[derive(Debug)] pub(crate) struct ArrayAggAccumulator { - array: Vec, + values: Vec, datatype: DataType, } @@ -102,45 +103,52 @@ impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type pub fn try_new(datatype: &DataType) -> Result { Ok(Self { - array: vec![], + values: vec![], datatype: datatype.clone(), }) } } impl Accumulator for ArrayAggAccumulator { - fn state(&self) -> Result> { - Ok(vec![ScalarValue::List( - Some(Box::new(self.array.clone())), - Box::new(self.datatype.clone()), - )]) - } - - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let value = &values[0]; - self.array.push(value.clone()); - - Ok(()) + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + }; + assert!(values.len() == 1, "array_agg can only take 1 param!"); + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let scalar = ScalarValue::try_from_array(arr, index)?; + self.values.push(scalar); + Ok(()) + }) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); }; - - assert!(states.len() == 1, "states length should be 1!"); - match &states[0] { - ScalarValue::List(Some(array), _) => { - self.array.extend((&**array).clone()); + assert!(states.len() == 1, "array_agg states must be singleton!"); + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let scalar = ScalarValue::try_from_array(arr, index)?; + if let ScalarValue::List(Some(values), _) = scalar { + self.values.extend(*values); + Ok(()) + } else { + Err(DataFusionError::Internal( + "array_agg state must be list!".into(), + )) } - _ => unreachable!(), - } - Ok(()) + }) + } + + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) } fn evaluate(&self) -> Result { Ok(ScalarValue::List( - Some(Box::new(self.array.clone())), + Some(Box::new(self.values.clone())), Box::new(self.datatype.clone()), )) } @@ -149,13 +157,14 @@ impl Accumulator for ArrayAggAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; use crate::physical_plan::expressions::tests::aggregate; + use crate::record_batch::RecordBatch; use crate::{error::Result, generic_test_op}; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; #[test] fn array_agg_i32() -> Result<()> { @@ -244,8 +253,7 @@ mod tests { )))), ); - let array: ArrayRef = - ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap().into(); + let array: ArrayRef = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); generic_test_op!( array, diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 25b16af4aae5..5ee63e68b181 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -168,15 +168,6 @@ impl Accumulator for AvgAccumulator { Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let values = &values[0]; - - self.count += (!values.is_null()) as u64; - self.sum = sum::sum(&self.sum, values)?; - - Ok(()) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; @@ -185,20 +176,6 @@ impl Accumulator for AvgAccumulator { Ok(()) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - let count = &states[0]; - // counts are summed - if let ScalarValue::UInt64(Some(c)) = count { - self.count += c - } else { - unreachable!() - }; - - // sums are summed - self.sum = sum::sum(&self.sum, &states[1])?; - Ok(()) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = states[0].as_any().downcast_ref::().unwrap(); // counts are summed @@ -235,10 +212,11 @@ impl Accumulator for AvgAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; + use crate::record_batch::RecordBatch; use crate::{error::Result, generic_test_op}; use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; #[test] fn test_avg_return_data_type() -> Result<()> { diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index c345495ca08a..d902544a96df 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -17,20 +17,18 @@ use std::{any::Any, convert::TryInto, sync::Arc}; +use crate::record_batch::RecordBatch; use arrow::array::*; use arrow::compute; use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; +use crate::physical_plan::coercion_rule::binary_rule::coerce_types; use crate::physical_plan::expressions::try_cast; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use crate::scalar::ScalarValue; -use super::coercion::{ - eq_coercion, like_coercion, numerical_coercion, order_coercion, string_coercion, -}; use arrow::scalar::Scalar; use arrow::types::NativeType; @@ -430,56 +428,6 @@ fn evaluate_inverse_scalar( } } -/// Coercion rules for all binary operators. Returns the output type -/// of applying `op` to an argument of `lhs_type` and `rhs_type`. -fn common_binary_type( - lhs_type: &DataType, - op: &Operator, - rhs_type: &DataType, -) -> Result { - // This result MUST be compatible with `binary_coerce` - let result = match op { - Operator::And | Operator::Or => match (lhs_type, rhs_type) { - // logical binary boolean operators can only be evaluated in bools - (DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean), - _ => None, - }, - // logical equality operators have their own rules, and always return a boolean - Operator::Eq | Operator::NotEq => eq_coercion(lhs_type, rhs_type), - // "like" operators operate on strings and always return a boolean - Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type), - // order-comparison operators have their own rules - Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => { - order_coercion(lhs_type, rhs_type) - } - // for math expressions, the final value of the coercion is also the return type - // because coercion favours higher information types - Operator::Plus - | Operator::Minus - | Operator::Modulo - | Operator::Divide - | Operator::Multiply => numerical_coercion(lhs_type, rhs_type), - Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch => string_coercion(lhs_type, rhs_type), - Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => { - eq_coercion(lhs_type, rhs_type) - } - }; - - // re-write the error message of failed coercions to include the operator's information - match result { - None => Err(DataFusionError::Plan( - format!( - "'{:?} {} {:?}' can't be evaluated because there isn't a common type to coerce the types to", - lhs_type, op, rhs_type - ), - )), - Some(t) => Ok(t) - } -} - /// Returns the return type of a binary operator or an error when the binary operator cannot /// perform the computation between the argument's types, even after type coercion. /// @@ -491,7 +439,7 @@ pub fn binary_operator_data_type( ) -> Result { // validate that it is possible to perform the operation on incoming types. // (or the return datatype cannot be inferred) - let common_type = common_binary_type(lhs_type, op, rhs_type)?; + let result_type = coerce_types(lhs_type, op, rhs_type)?; match op { // operators that return a boolean @@ -516,7 +464,7 @@ pub fn binary_operator_data_type( | Operator::Minus | Operator::Divide | Operator::Multiply - | Operator::Modulo => Ok(common_type), + | Operator::Modulo => Ok(result_type), } } @@ -747,11 +695,11 @@ fn binary_cast( let lhs_type = &lhs.data_type(input_schema)?; let rhs_type = &rhs.data_type(input_schema)?; - let cast_type = common_binary_type(lhs_type, op, rhs_type)?; + let result_type = coerce_types(lhs_type, op, rhs_type)?; Ok(( - try_cast(lhs, input_schema, cast_type.clone())?, - try_cast(rhs, input_schema, cast_type)?, + try_cast(lhs, input_schema, result_type.clone())?, + try_cast(rhs, input_schema, result_type)?, )) } @@ -775,7 +723,275 @@ mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::{col, lit}; + use arrow::datatypes::{Field, SchemaRef}; + use arrow::error::ArrowError; + + // TODO add iter for decimal array + // TODO move this to arrow-rs + // https://github.com/apache/arrow-rs/issues/1083 + pub(super) fn eq_decimal_scalar( + left: &Int128Array, + right: i128, + ) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) == right))?; + } + } + Ok(bool_builder.into()) + } + + pub(super) fn eq_decimal( + left: &Int128Array, + right: &Int128Array, + ) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) == right.value(i)))?; + } + } + Ok(bool_builder.into()) + } + + fn neq_decimal_scalar(left: &Int128Array, right: i128) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) != right))?; + } + } + Ok(bool_builder.into()) + } + + fn neq_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) != right.value(i)))?; + } + } + Ok(bool_builder.into()) + } + + fn lt_decimal_scalar(left: &Int128Array, right: i128) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) < right))?; + } + } + Ok(bool_builder.into()) + } + + fn lt_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) < right.value(i)))?; + } + } + Ok(bool_builder.into()) + } + + fn lt_eq_decimal_scalar(left: &Int128Array, right: i128) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) <= right))?; + } + } + Ok(bool_builder.into()) + } + + fn lt_eq_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) <= right.value(i)))?; + } + } + Ok(bool_builder.into()) + } + + fn gt_decimal_scalar(left: &Int128Array, right: i128) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) > right))?; + } + } + Ok(bool_builder.into()) + } + + fn gt_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) > right.value(i)))?; + } + } + Ok(bool_builder.into()) + } + + fn gt_eq_decimal_scalar(left: &Int128Array, right: i128) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) >= right))?; + } + } + Ok(bool_builder.into()) + } + + fn gt_eq_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + bool_builder.push(None); + } else { + bool_builder.try_push(Some(left.value(i) >= right.value(i)))?; + } + } + Ok(bool_builder.into()) + } + + fn is_distinct_from_decimal( + left: &Int128Array, + right: &Int128Array, + ) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + match (left.is_null(i), right.is_null(i)) { + (true, true) => bool_builder.try_push(Some(false))?, + (true, false) | (false, true) => bool_builder.try_push(Some(true))?, + (_, _) => bool_builder.try_push(Some(left.value(i) != right.value(i)))?, + } + } + Ok(bool_builder.into()) + } + + fn is_not_distinct_from_decimal( + left: &Int128Array, + right: &Int128Array, + ) -> Result { + let mut bool_builder = MutableBooleanArray::with_capacity(left.len()); + for i in 0..left.len() { + match (left.is_null(i), right.is_null(i)) { + (true, true) => bool_builder.try_push(Some(true))?, + (true, false) | (false, true) => bool_builder.try_push(Some(false))?, + (_, _) => bool_builder.try_push(Some(left.value(i) == right.value(i)))?, + } + } + Ok(bool_builder.into()) + } + + fn add_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut decimal_builder = Int128Vec::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.push(None); + } else { + decimal_builder.try_push(Some(left.value(i) + right.value(i)))?; + } + } + Ok(decimal_builder.into()) + } + + fn subtract_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut decimal_builder = Int128Vec::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.push(None); + } else { + decimal_builder.try_push(Some(left.value(i) - right.value(i)))?; + } + } + Ok(decimal_builder.into()) + } + + fn multiply_decimal( + left: &Int128Array, + right: &Int128Array, + scale: u32, + ) -> Result { + let mut decimal_builder = Int128Vec::with_capacity(left.len()); + let divide = 10_i128.pow(scale); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.push(None); + } else { + decimal_builder + .try_push(Some(left.value(i) * right.value(i) / divide))?; + } + } + Ok(decimal_builder.into()) + } + + fn divide_decimal( + left: &Int128Array, + right: &Int128Array, + scale: i32, + ) -> Result { + let mut decimal_builder = Int128Vec::with_capacity(left.len()); + let mul = 10_f64.powi(scale); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.push(None); + } else if right.value(i) == 0 { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError("Cannot divide by zero".to_string()), + )); + } else { + let l_value = left.value(i) as f64; + let r_value = right.value(i) as f64; + let result = ((l_value / r_value) * mul) as i128; + decimal_builder.try_push(Some(result))?; + } + } + Ok(decimal_builder.into()) + } + + fn modulus_decimal(left: &Int128Array, right: &Int128Array) -> Result { + let mut decimal_builder = Int128Vec::with_capacity(left.len()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.push(None); + } else if right.value(i) == 0 { + return Err(DataFusionError::ArrowError( + ArrowError::InvalidArgumentError("Cannot divide by zero".to_string()), + )); + } else { + decimal_builder.try_push(Some(left.value(i) % right.value(i)))?; + } + } + Ok(decimal_builder.into()) + } // Create a binary expression without coercion. Used here when we do not want to coerce the expressions // to valid types. Usage can result in an execution (after plan) error. @@ -783,8 +999,9 @@ mod tests { l: Arc, op: Operator, r: Arc, + input_schema: &Schema, ) -> Arc { - Arc::new(BinaryExpr::new(l, op, r)) + binary(l, op, r, input_schema).unwrap() } #[test] @@ -797,7 +1014,12 @@ mod tests { let b = Int32Array::from_slice(&[1, 2, 4, 8, 16]); // expression: "a < b" - let lt = binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?); + let lt = binary_simple( + col("a", &schema)?, + Operator::Lt, + col("b", &schema)?, + &schema, + ); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; @@ -827,9 +1049,20 @@ mod tests { // expression: "a < b OR a == b" let expr = binary_simple( - binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?), + binary_simple( + col("a", &schema)?, + Operator::Lt, + col("b", &schema)?, + &schema, + ), Operator::Or, - binary_simple(col("a", &schema)?, Operator::Eq, col("b", &schema)?), + binary_simple( + col("a", &schema)?, + Operator::Eq, + col("b", &schema)?, + &schema, + ), + &schema, ); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; @@ -1139,7 +1372,8 @@ mod tests { op: Operator, expected: PrimitiveArray, ) -> Result<()> { - let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?); + let arithmetic_op = + binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -1148,18 +1382,19 @@ mod tests { } fn apply_logic_op( - schema: Arc, - left: BooleanArray, - right: BooleanArray, + schema: &Arc, + left: &ArrayRef, + right: &ArrayRef, op: Operator, - expected: BooleanArray, + expected: ArrayRef, ) -> Result<()> { - let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?); - let data: Vec = vec![Arc::new(left), Arc::new(right)]; - let batch = RecordBatch::try_new(schema, data)?; + let arithmetic_op = + binary_simple(col("a", schema)?, op, col("b", schema)?, schema); + let data: Vec = vec![left.clone(), right.clone()]; + let batch = RecordBatch::try_new(schema.clone(), data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); - assert_eq!(expected, result.as_ref()); + assert_eq!(expected, result); Ok(()) } @@ -1185,14 +1420,14 @@ mod tests { // Test `scalar arr` produces expected fn apply_logic_op_scalar_arr( schema: &SchemaRef, - scalar: bool, + scalar: &ScalarValue, arr: &ArrayRef, op: Operator, expected: &BooleanArray, ) -> Result<()> { - let scalar = lit(scalar.into()); + let scalar = lit(scalar.clone()); - let arithmetic_op = binary_simple(scalar, op, col("a", schema)?); + let arithmetic_op = binary_simple(scalar, op, col("a", schema)?, schema); let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.as_ref(), expected as &dyn Array); @@ -1204,13 +1439,13 @@ mod tests { fn apply_logic_op_arr_scalar( schema: &SchemaRef, arr: &ArrayRef, - scalar: bool, + scalar: &ScalarValue, op: Operator, expected: &BooleanArray, ) -> Result<()> { - let scalar = lit(scalar.into()); + let scalar = lit(scalar.clone()); - let arithmetic_op = binary_simple(col("a", schema)?, op, scalar); + let arithmetic_op = binary_simple(col("a", schema)?, op, scalar, schema); let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.as_ref(), expected as &dyn Array); @@ -1224,7 +1459,7 @@ mod tests { Field::new("a", DataType::Boolean, true), Field::new("b", DataType::Boolean, true), ]); - let a = BooleanArray::from(vec![ + let a = Arc::new(BooleanArray::from(vec![ Some(true), Some(false), None, @@ -1234,8 +1469,8 @@ mod tests { Some(true), Some(false), None, - ]); - let b = BooleanArray::from(vec![ + ])) as ArrayRef; + let b = Arc::new(BooleanArray::from(vec![ Some(true), Some(true), Some(true), @@ -1245,7 +1480,7 @@ mod tests { None, None, None, - ]); + ])) as ArrayRef; let expected = BooleanArray::from(vec![ Some(true), @@ -1258,7 +1493,7 @@ mod tests { Some(false), None, ]); - apply_logic_op(Arc::new(schema), a, b, Operator::And, expected)?; + apply_logic_op(&Arc::new(schema), &a, &b, Operator::And, Arc::new(expected))?; Ok(()) } @@ -1269,7 +1504,7 @@ mod tests { Field::new("a", DataType::Boolean, true), Field::new("b", DataType::Boolean, true), ]); - let a = BooleanArray::from(vec![ + let a = Arc::new(BooleanArray::from(vec![ Some(true), Some(false), None, @@ -1279,8 +1514,8 @@ mod tests { Some(true), Some(false), None, - ]); - let b = BooleanArray::from(vec![ + ])) as ArrayRef; + let b = Arc::new(BooleanArray::from(vec![ Some(true), Some(true), Some(true), @@ -1290,7 +1525,7 @@ mod tests { None, None, None, - ]); + ])) as ArrayRef; let expected = BooleanArray::from(vec![ Some(true), @@ -1303,7 +1538,7 @@ mod tests { None, None, ]); - apply_logic_op(Arc::new(schema), a, b, Operator::Or, expected)?; + apply_logic_op(&Arc::new(schema), &a, &b, Operator::Or, Arc::new(expected))?; Ok(()) } @@ -1312,12 +1547,12 @@ mod tests { /// /// a: [true, true, true, NULL, NULL, NULL, false, false, false] /// b: [true, NULL, false, true, NULL, false, true, NULL, false] - fn bool_test_arrays() -> (SchemaRef, BooleanArray, BooleanArray) { + fn bool_test_arrays() -> (SchemaRef, ArrayRef, ArrayRef) { let schema = Schema::new(vec![ Field::new("a", DataType::Boolean, false), Field::new("b", DataType::Boolean, false), ]); - let a = [ + let a: BooleanArray = [ Some(true), Some(true), Some(true), @@ -1330,7 +1565,7 @@ mod tests { ] .iter() .collect(); - let b = [ + let b: BooleanArray = [ Some(true), None, Some(false), @@ -1343,7 +1578,7 @@ mod tests { ] .iter() .collect(); - (Arc::new(schema), a, b) + (Arc::new(schema), Arc::new(a), Arc::new(b)) } /// Returns (schema, BooleanArray) with [true, NULL, false] @@ -1356,7 +1591,7 @@ mod tests { #[test] fn eq_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = vec![ + let expected = BooleanArray::from_iter(vec![ Some(true), None, Some(false), @@ -1366,28 +1601,54 @@ mod tests { Some(false), None, Some(true), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::Eq, expected).unwrap(); + ]); + apply_logic_op(&schema, &a, &b, Operator::Eq, Arc::new(expected)).unwrap(); } #[test] fn eq_op_bool_scalar() { let (schema, a) = scalar_bool_test_array(); let expected = [Some(true), None, Some(false)].iter().collect(); - apply_logic_op_scalar_arr(&schema, true, &a, Operator::Eq, &expected).unwrap(); - apply_logic_op_arr_scalar(&schema, &a, true, Operator::Eq, &expected).unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(true), + &a, + Operator::Eq, + &expected, + ) + .unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(true), + Operator::Eq, + &expected, + ) + .unwrap(); let expected = [Some(false), None, Some(true)].iter().collect(); - apply_logic_op_scalar_arr(&schema, false, &a, Operator::Eq, &expected).unwrap(); - apply_logic_op_arr_scalar(&schema, &a, false, Operator::Eq, &expected).unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(false), + &a, + Operator::Eq, + &expected, + ) + .unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(false), + Operator::Eq, + &expected, + ) + .unwrap(); } #[test] fn neq_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = [ + let expected = BooleanArray::from_iter([ Some(false), None, Some(true), @@ -1397,30 +1658,54 @@ mod tests { Some(true), None, Some(false), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::NotEq, expected).unwrap(); + ]); + apply_logic_op(&schema, &a, &b, Operator::NotEq, Arc::new(expected)).unwrap(); } #[test] fn neq_op_bool_scalar() { let (schema, a) = scalar_bool_test_array(); let expected = [Some(false), None, Some(true)].iter().collect(); - apply_logic_op_scalar_arr(&schema, true, &a, Operator::NotEq, &expected).unwrap(); - apply_logic_op_arr_scalar(&schema, &a, true, Operator::NotEq, &expected).unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(true), + &a, + Operator::NotEq, + &expected, + ) + .unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(true), + Operator::NotEq, + &expected, + ) + .unwrap(); let expected = [Some(true), None, Some(false)].iter().collect(); - apply_logic_op_scalar_arr(&schema, false, &a, Operator::NotEq, &expected) - .unwrap(); - apply_logic_op_arr_scalar(&schema, &a, false, Operator::NotEq, &expected) - .unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(false), + &a, + Operator::NotEq, + &expected, + ) + .unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(false), + Operator::NotEq, + &expected, + ) + .unwrap(); } #[test] fn lt_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = [ + let expected = BooleanArray::from_iter([ Some(false), None, Some(false), @@ -1430,32 +1715,58 @@ mod tests { Some(true), None, Some(false), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::Lt, expected).unwrap(); + ]); + apply_logic_op(&schema, &a, &b, Operator::Lt, Arc::new(expected)).unwrap(); } #[test] fn lt_op_bool_scalar() { let (schema, a) = scalar_bool_test_array(); let expected = [Some(false), None, Some(false)].iter().collect(); - apply_logic_op_scalar_arr(&schema, true, &a, Operator::Lt, &expected).unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(true), + &a, + Operator::Lt, + &expected, + ) + .unwrap(); let expected = [Some(false), None, Some(true)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, true, Operator::Lt, &expected).unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(true), + Operator::Lt, + &expected, + ) + .unwrap(); let expected = [Some(true), None, Some(false)].iter().collect(); - apply_logic_op_scalar_arr(&schema, false, &a, Operator::Lt, &expected).unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(false), + &a, + Operator::Lt, + &expected, + ) + .unwrap(); let expected = [Some(false), None, Some(false)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, false, Operator::Lt, &expected).unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(false), + Operator::Lt, + &expected, + ) + .unwrap(); } #[test] fn lt_eq_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = [ + let expected = BooleanArray::from_iter([ Some(true), None, Some(false), @@ -1465,32 +1776,58 @@ mod tests { Some(true), None, Some(true), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::LtEq, expected).unwrap(); + ]); + apply_logic_op(&schema, &a, &b, Operator::LtEq, Arc::new(expected)).unwrap(); } #[test] fn lt_eq_op_bool_scalar() { let (schema, a) = scalar_bool_test_array(); let expected = [Some(true), None, Some(false)].iter().collect(); - apply_logic_op_scalar_arr(&schema, true, &a, Operator::LtEq, &expected).unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(true), + &a, + Operator::LtEq, + &expected, + ) + .unwrap(); let expected = [Some(true), None, Some(true)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, true, Operator::LtEq, &expected).unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(true), + Operator::LtEq, + &expected, + ) + .unwrap(); let expected = [Some(true), None, Some(true)].iter().collect(); - apply_logic_op_scalar_arr(&schema, false, &a, Operator::LtEq, &expected).unwrap(); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(false), + &a, + Operator::LtEq, + &expected, + ) + .unwrap(); let expected = [Some(false), None, Some(true)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, false, Operator::LtEq, &expected).unwrap(); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(false), + Operator::LtEq, + &expected, + ) + .unwrap(); } #[test] fn gt_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = [ + let expected = BooleanArray::from_iter([ Some(false), None, Some(true), @@ -1500,32 +1837,58 @@ mod tests { Some(false), None, Some(false), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::Gt, expected).unwrap(); + ]); + apply_logic_op(&schema, &a, &b, Operator::Gt, Arc::new(expected)).unwrap(); } #[test] fn gt_op_bool_scalar() { let (schema, a) = scalar_bool_test_array(); - let expected = [Some(false), None, Some(true)].iter().collect(); - apply_logic_op_scalar_arr(&schema, true, &a, Operator::Gt, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(false), None, Some(true)]); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(true), + &a, + Operator::Gt, + &expected, + ) + .unwrap(); - let expected = [Some(false), None, Some(false)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, true, Operator::Gt, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(false), None, Some(false)]); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(true), + Operator::Gt, + &expected, + ) + .unwrap(); - let expected = [Some(false), None, Some(false)].iter().collect(); - apply_logic_op_scalar_arr(&schema, false, &a, Operator::Gt, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(false), None, Some(false)]); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(false), + &a, + Operator::Gt, + &expected, + ) + .unwrap(); - let expected = [Some(true), None, Some(false)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, false, Operator::Gt, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(true), None, Some(false)]); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(false), + Operator::Gt, + &expected, + ) + .unwrap(); } #[test] fn gt_eq_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = [ + let expected = BooleanArray::from_iter([ Some(true), None, Some(true), @@ -1535,32 +1898,58 @@ mod tests { Some(false), None, Some(true), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::GtEq, expected).unwrap(); + ]); + apply_logic_op(&schema, &a, &b, Operator::GtEq, Arc::new(expected)).unwrap(); } #[test] fn gt_eq_op_bool_scalar() { let (schema, a) = scalar_bool_test_array(); - let expected = [Some(true), None, Some(true)].iter().collect(); - apply_logic_op_scalar_arr(&schema, true, &a, Operator::GtEq, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(true), None, Some(true)]); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(true), + &a, + Operator::GtEq, + &expected, + ) + .unwrap(); - let expected = [Some(true), None, Some(false)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, true, Operator::GtEq, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(true), None, Some(false)]); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(true), + Operator::GtEq, + &expected, + ) + .unwrap(); - let expected = [Some(false), None, Some(true)].iter().collect(); - apply_logic_op_scalar_arr(&schema, false, &a, Operator::GtEq, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(false), None, Some(true)]); + apply_logic_op_scalar_arr( + &schema, + &ScalarValue::from(false), + &a, + Operator::GtEq, + &expected, + ) + .unwrap(); - let expected = [Some(true), None, Some(true)].iter().collect(); - apply_logic_op_arr_scalar(&schema, &a, false, Operator::GtEq, &expected).unwrap(); + let expected = BooleanArray::from_iter([Some(true), None, Some(true)]); + apply_logic_op_arr_scalar( + &schema, + &a, + &ScalarValue::from(false), + Operator::GtEq, + &expected, + ) + .unwrap(); } #[test] fn is_distinct_from_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = [ + let expected = BooleanArray::from_iter([ Some(false), Some(true), Some(true), @@ -1570,16 +1959,21 @@ mod tests { Some(true), Some(true), Some(false), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::IsDistinctFrom, expected).unwrap(); + ]); + apply_logic_op( + &schema, + &a, + &b, + Operator::IsDistinctFrom, + Arc::new(expected), + ) + .unwrap(); } #[test] fn is_not_distinct_from_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = [ + let expected = BooleanArray::from_iter([ Some(true), Some(false), Some(false), @@ -1589,25 +1983,15 @@ mod tests { Some(false), Some(false), Some(true), - ] - .iter() - .collect(); - apply_logic_op(schema, a, b, Operator::IsNotDistinctFrom, expected).unwrap(); - } - - #[test] - fn test_coersion_error() -> Result<()> { - let expr = - common_binary_type(&DataType::Float32, &Operator::Plus, &DataType::Utf8); - - if let Err(DataFusionError::Plan(e)) = expr { - assert_eq!(e, "'Float32 + Utf8' can't be evaluated because there isn't a common type to coerce the types to"); - Ok(()) - } else { - Err(DataFusionError::Internal( - "Coercion should have returned an DataFusionError::Internal".to_string(), - )) - } + ]); + apply_logic_op( + &schema, + &a, + &b, + Operator::IsNotDistinctFrom, + Arc::new(expected), + ) + .unwrap(); } #[test] @@ -1628,7 +2012,7 @@ mod tests { let expr = (0..tree_depth) .into_iter() .map(|_| col("a", schema.as_ref()).unwrap()) - .reduce(|l, r| binary_simple(l, Operator::Plus, r)) + .reduce(|l, r| binary_simple(l, Operator::Plus, r, schema)) .unwrap(); let result = expr @@ -1642,4 +2026,582 @@ mod tests { .collect(); assert_eq!(result.as_ref(), &expected as &dyn Array); } + + fn create_decimal_array( + array: &[Option], + _precision: usize, + _scale: usize, + ) -> Result { + let mut decimal_builder = Int128Vec::with_capacity(array.len()); + for value in array { + match value { + None => { + decimal_builder.push(None); + } + Some(v) => { + decimal_builder.try_push(Some(*v))?; + } + } + } + Ok(decimal_builder.into()) + } + + #[test] + fn comparison_decimal_op_test() -> Result<()> { + let value_i128: i128 = 123; + let decimal_array = create_decimal_array( + &[ + Some(value_i128), + None, + Some(value_i128 - 1), + Some(value_i128 + 1), + ], + 25, + 3, + )?; + // eq: array = i128 + let result = eq_decimal_scalar(&decimal_array, value_i128)?; + assert_eq!( + BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]), + result + ); + // neq: array != i128 + let result = neq_decimal_scalar(&decimal_array, value_i128)?; + assert_eq!( + BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]), + result + ); + // lt: array < i128 + let result = lt_decimal_scalar(&decimal_array, value_i128)?; + assert_eq!( + BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]), + result + ); + // lt_eq: array <= i128 + let result = lt_eq_decimal_scalar(&decimal_array, value_i128)?; + assert_eq!( + BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]), + result + ); + // gt: array > i128 + let result = gt_decimal_scalar(&decimal_array, value_i128)?; + assert_eq!( + BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]), + result + ); + // gt_eq: array >= i128 + let result = gt_eq_decimal_scalar(&decimal_array, value_i128)?; + assert_eq!( + BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]), + result + ); + + let left_decimal_array = decimal_array; + let right_decimal_array = create_decimal_array( + &[ + Some(value_i128 - 1), + Some(value_i128), + Some(value_i128 + 1), + Some(value_i128 + 1), + ], + 25, + 3, + )?; + // eq: left == right + let result = eq_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]), + result + ); + // neq: left != right + let result = neq_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]), + result + ); + // lt: left < right + let result = lt_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]), + result + ); + // lt_eq: left <= right + let result = lt_eq_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]), + result + ); + // gt: left > right + let result = gt_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]), + result + ); + // gt_eq: left >= right + let result = gt_eq_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]), + result + ); + // is_distinct: left distinct right + let result = is_distinct_from_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(true), Some(true), Some(true), Some(false)]), + result + ); + // is_distinct: left distinct right + let result = + is_not_distinct_from_decimal(&left_decimal_array, &right_decimal_array)?; + assert_eq!( + BooleanArray::from(vec![Some(false), Some(false), Some(false), Some(true)]), + result + ); + Ok(()) + } + + #[test] + fn comparison_decimal_expr_test() -> Result<()> { + let decimal_scalar = ScalarValue::Decimal128(Some(123_456), 10, 3); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + // scalar == array + apply_logic_op_scalar_arr( + &schema, + &decimal_scalar, + &(Arc::new(Int64Array::from(vec![Some(124), None])) as ArrayRef), + Operator::Eq, + &BooleanArray::from(vec![Some(false), None]), + ) + .unwrap(); + + // array != scalar + apply_logic_op_arr_scalar( + &schema, + &(Arc::new(Int64Array::from(vec![Some(123), None, Some(1)])) as ArrayRef), + &decimal_scalar, + Operator::NotEq, + &BooleanArray::from(vec![Some(true), None, Some(true)]), + ) + .unwrap(); + + // array < scalar + apply_logic_op_arr_scalar( + &schema, + &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef), + &decimal_scalar, + Operator::Lt, + &BooleanArray::from(vec![Some(true), None, Some(false)]), + ) + .unwrap(); + + // array > scalar + apply_logic_op_arr_scalar( + &schema, + &(Arc::new(Int64Array::from(vec![Some(123), None, Some(124)])) as ArrayRef), + &decimal_scalar, + Operator::Gt, + &BooleanArray::from(vec![Some(false), None, Some(true)]), + ) + .unwrap(); + + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)])); + // array == scalar + apply_logic_op_arr_scalar( + &schema, + &(Arc::new(Float64Array::from(vec![Some(123.456), None, Some(123.457)])) + as ArrayRef), + &decimal_scalar, + Operator::Eq, + &BooleanArray::from(vec![Some(true), None, Some(false)]), + ) + .unwrap(); + + // array <= scalar + apply_logic_op_arr_scalar( + &schema, + &(Arc::new(Float64Array::from(vec![ + Some(123.456), + None, + Some(123.457), + Some(123.45), + ])) as ArrayRef), + &decimal_scalar, + Operator::LtEq, + &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]), + ) + .unwrap(); + // array >= scalar + apply_logic_op_arr_scalar( + &schema, + &(Arc::new(Float64Array::from(vec![ + Some(123.456), + None, + Some(123.457), + Some(123.45), + ])) as ArrayRef), + &decimal_scalar, + Operator::GtEq, + &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]), + ) + .unwrap(); + + // compare decimal array with other array type + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal(10, 0), true), + ])); + + let value: i64 = 123; + + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value as i128), + None, + Some((value - 1) as i128), + Some((value + 1) as i128), + ], + 10, + 0, + )?) as ArrayRef; + + let int64_array = Arc::new(Int64Array::from(vec![ + Some(value), + Some(value - 1), + Some(value), + Some(value + 1), + ])) as ArrayRef; + + // eq: int64array == decimal array + apply_logic_op( + &schema, + &int64_array, + &decimal_array, + Operator::Eq, + Arc::new(BooleanArray::from_iter(vec![ + Some(true), + None, + Some(false), + Some(true), + ])), + ) + .unwrap(); + // neq: int64array != decimal array + apply_logic_op( + &schema, + &int64_array, + &decimal_array, + Operator::NotEq, + Arc::new(BooleanArray::from_iter(vec![ + Some(false), + None, + Some(true), + Some(false), + ])), + ) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + + let value: i128 = 123; + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value as i128), // 1.23 + None, + Some((value - 1) as i128), // 1.22 + Some((value + 1) as i128), // 1.24 + ], + 10, + 2, + )?) as ArrayRef; + let float64_array = Arc::new(Float64Array::from(vec![ + Some(1.23), + Some(1.22), + Some(1.23), + Some(1.24), + ])) as ArrayRef; + // lt: float64array < decimal array + apply_logic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Lt, + Arc::new(BooleanArray::from_iter(vec![ + Some(false), + None, + Some(false), + Some(false), + ])), + ) + .unwrap(); + // lt_eq: float64array <= decimal array + apply_logic_op( + &schema, + &float64_array, + &decimal_array, + Operator::LtEq, + Arc::new(BooleanArray::from_iter(vec![ + Some(true), + None, + Some(false), + Some(true), + ])), + ) + .unwrap(); + // gt: float64array > decimal array + apply_logic_op( + &schema, + &float64_array, + &decimal_array, + Operator::Gt, + Arc::new(BooleanArray::from_iter(vec![ + Some(false), + None, + Some(true), + Some(false), + ])), + ) + .unwrap(); + apply_logic_op( + &schema, + &float64_array, + &decimal_array, + Operator::GtEq, + Arc::new(BooleanArray::from_iter(vec![ + Some(true), + None, + Some(true), + Some(true), + ])), + ) + .unwrap(); + // is distinct: float64array is distinct decimal array + // TODO: now we do not refactor the `is distinct or is not distinct` rule of coercion. + // traced by https://github.com/apache/arrow-datafusion/issues/1590 + // the decimal array will be casted to float64array + apply_logic_op( + &schema, + &float64_array, + &decimal_array, + Operator::IsDistinctFrom, + Arc::new(BooleanArray::from_iter(vec![ + Some(false), + Some(true), + Some(true), + Some(false), + ])), + ) + .unwrap(); + // is not distinct + apply_logic_op( + &schema, + &float64_array, + &decimal_array, + Operator::IsNotDistinctFrom, + Arc::new(BooleanArray::from_iter(vec![ + Some(true), + Some(false), + Some(false), + Some(true), + ])), + ) + .unwrap(); + + Ok(()) + } + + #[test] + fn arithmetic_decimal_op_test() -> Result<()> { + let value_i128: i128 = 123; + let left_decimal_array = create_decimal_array( + &[ + Some(value_i128), + None, + Some(value_i128 - 1), + Some(value_i128 + 1), + ], + 25, + 3, + )?; + let right_decimal_array = create_decimal_array( + &[ + Some(value_i128), + Some(value_i128), + Some(value_i128), + Some(value_i128), + ], + 25, + 3, + )?; + // add + let result = add_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = + create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 3)?; + assert_eq!(expect, result); + // subtract + let result = subtract_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3)?; + assert_eq!(expect, result); + // multiply + let result = multiply_decimal(&left_decimal_array, &right_decimal_array, 3)?; + let expect = create_decimal_array(&[Some(15), None, Some(15), Some(15)], 25, 3)?; + assert_eq!(expect, result); + // divide + let left_decimal_array = create_decimal_array( + &[Some(1234567), None, Some(1234567), Some(1234567)], + 25, + 3, + )?; + let right_decimal_array = + create_decimal_array(&[Some(10), Some(100), Some(55), Some(-123)], 25, 3)?; + let result = divide_decimal(&left_decimal_array, &right_decimal_array, 3)?; + let expect = create_decimal_array( + &[Some(123456700), None, Some(22446672), Some(-10037130)], + 25, + 3, + )?; + assert_eq!(expect, result); + // modulus + let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16)], 25, 3)?; + assert_eq!(expect, result); + + Ok(()) + } + + fn apply_arithmetic_op( + schema: &SchemaRef, + left: &ArrayRef, + right: &ArrayRef, + op: Operator, + expected: ArrayRef, + ) -> Result<()> { + let arithmetic_op = + binary_simple(col("a", schema)?, op, col("b", schema)?, schema); + let data: Vec = vec![left.clone(), right.clone()]; + let batch = RecordBatch::try_new(schema.clone(), data)?; + let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + + assert_eq!(result.as_ref(), expected.as_ref()); + Ok(()) + } + + #[test] + fn arithmetic_decimal_expr_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let value: i128 = 123; + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value as i128), // 1.23 + None, + Some((value - 1) as i128), // 1.22 + Some((value + 1) as i128), // 1.24 + ], + 10, + 2, + )?) as ArrayRef; + let int32_array = Arc::new(Int32Array::from(vec![ + Some(123), + Some(122), + Some(123), + Some(124), + ])) as ArrayRef; + + // add: Int32array add decimal array + let expect = Arc::new(create_decimal_array( + &[Some(12423), None, Some(12422), Some(12524)], + 13, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Plus, + expect, + ) + .unwrap(); + + // subtract: decimal array subtract int32 array + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, true), + Field::new("a", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[Some(-12177), None, Some(-12178), Some(-12276)], + 13, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Minus, + expect, + ) + .unwrap(); + + // multiply: decimal array multiply int32 array + let expect = Arc::new(create_decimal_array( + &[Some(15129), None, Some(15006), Some(15376)], + 21, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Multiply, + expect, + ) + .unwrap(); + // divide: int32 array divide decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[ + Some(10000000000000), + None, + Some(10081967213114), + Some(10000000000000), + ], + 23, + 11, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Divide, + expect, + ) + .unwrap(); + // modulus: int32 array modulus decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[Some(000), None, Some(100), Some(000)], + 10, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Modulo, + expect, + ) + .unwrap(); + + Ok(()) + } } diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs index 25136e8cb853..196da07e11b6 100644 --- a/datafusion/src/physical_plan/expressions/case.rs +++ b/datafusion/src/physical_plan/expressions/case.rs @@ -17,15 +17,18 @@ use std::{any::Any, sync::Arc}; +use crate::record_batch::RecordBatch; use arrow::array::*; use arrow::compute::comparison; use arrow::compute::if_then_else; use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::expressions::try_cast; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; +type WhenThen = (Arc, Arc); + /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -48,7 +51,7 @@ pub struct CaseExpr { /// Optional base expression that can be compared to literal values in the "when" expressions expr: Option>, /// One or more when/then expressions - when_then_expr: Vec<(Arc, Arc)>, + when_then_expr: Vec, /// Optional "else" expression else_expr: Option>, } @@ -73,7 +76,7 @@ impl CaseExpr { /// Create a new CASE WHEN expression pub fn try_new( expr: Option>, - when_then_expr: &[(Arc, Arc)], + when_then_expr: &[WhenThen], else_expr: Option>, ) -> Result { if when_then_expr.is_empty() { @@ -95,7 +98,7 @@ impl CaseExpr { } /// One or more when/then expressions - pub fn when_then_expr(&self) -> &[(Arc, Arc)] { + pub fn when_then_expr(&self) -> &[WhenThen] { &self.when_then_expr } @@ -121,7 +124,10 @@ impl CaseExpr { // start with the else condition, or nulls let mut current_value = if let Some(e) = &self.else_expr { - e.evaluate(batch)?.into_array(batch.num_rows()) + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(e.clone(), &*batch.schema(), return_type) + .unwrap_or_else(|_| e.clone()); + expr.evaluate(batch)?.into_array(batch.num_rows()) } else { new_null_array(return_type, batch.num_rows()).into() }; @@ -172,7 +178,9 @@ impl CaseExpr { // start with the else condition, or nulls let mut current_value = if let Some(e) = &self.else_expr { - e.evaluate(batch)?.into_array(batch.num_rows()) + let expr = try_cast(e.clone(), &*batch.schema(), return_type) + .unwrap_or_else(|_| e.clone()); + expr.evaluate(batch)?.into_array(batch.num_rows()) } else { new_null_array(return_type, batch.num_rows()).into() }; @@ -257,7 +265,7 @@ impl PhysicalExpr for CaseExpr { /// Create a CASE expression pub fn case( expr: Option>, - when_thens: &[(Arc, Arc)], + when_thens: &[WhenThen], else_expr: Option>, ) -> Result> { Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) @@ -266,6 +274,7 @@ pub fn case( #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::{ error::Result, logical_plan::Operator, @@ -407,6 +416,35 @@ mod tests { Ok(()) } + #[test] + fn case_with_type_cast() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123.3 ELSE 999 END + let when = binary( + col("a", schema)?, + Operator::Eq, + lit(ScalarValue::Utf8(Some("foo".to_string()))), + batch.schema(), + )?; + let then = lit(ScalarValue::Float64(Some(123.3))); + let else_value = lit(ScalarValue::Int32(Some(999))); + + let expr = case(None, &[(when, then)], Some(else_value))?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to Float64Array"); + + let expected = + &Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]); + + assert_eq!(expected, result); + + Ok(()) + } fn case_test_batch() -> Result { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = Utf8Array::::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); diff --git a/datafusion/src/physical_plan/expressions/cast.rs b/datafusion/src/physical_plan/expressions/cast.rs index 789ab582a7a0..2e4a9158eeca 100644 --- a/datafusion/src/physical_plan/expressions/cast.rs +++ b/datafusion/src/physical_plan/expressions/cast.rs @@ -22,12 +22,19 @@ use std::sync::Arc; use super::ColumnarValue; use crate::error::{DataFusionError, Result}; use crate::physical_plan::PhysicalExpr; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; use arrow::array::{Array, Int32Array}; use arrow::compute::cast; +use arrow::compute::cast::CastOptions; use arrow::compute::take; use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; + +/// provide Datafusion default cast options +pub const DEFAULT_DATAFUSION_CAST_OPTIONS: CastOptions = CastOptions { + wrapped: false, + partial: false, +}; /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast #[derive(Debug)] @@ -36,12 +43,22 @@ pub struct CastExpr { expr: Arc, /// The data type to cast to cast_type: DataType, + /// Cast options + cast_options: CastOptions, } impl CastExpr { /// Create a new CastExpr - pub fn new(expr: Arc, cast_type: DataType) -> Self { - Self { expr, cast_type } + pub fn new( + expr: Arc, + cast_type: DataType, + cast_options: CastOptions, + ) -> Self { + Self { + expr, + cast_type, + cast_options, + } } /// The expression to cast @@ -77,12 +94,16 @@ impl PhysicalExpr for CastExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - cast_column(&value, &self.cast_type) + cast_column(&value, &self.cast_type, self.cast_options) } } -fn cast_with_error(array: &dyn Array, cast_type: &DataType) -> Result> { - let result = cast::cast(array, cast_type, cast::CastOptions::default())?; +pub fn cast_with_error( + array: &dyn Array, + cast_type: &DataType, + options: CastOptions, +) -> Result> { + let result = cast::cast(array, cast_type, options)?; if result.null_count() != array.null_count() { let casted_valids = result.validity().unwrap(); let failed_casts = match array.validity() { @@ -105,15 +126,20 @@ fn cast_with_error(array: &dyn Array, cast_type: &DataType) -> Result ColumnarValue for cast_type -pub fn cast_column(value: &ColumnarValue, cast_type: &DataType) -> Result { +pub fn cast_column( + value: &ColumnarValue, + cast_type: &DataType, + cast_options: CastOptions, +) -> Result { match value { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array( - cast_with_error(array.as_ref(), cast_type)?.into(), - )), + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::from( + cast_with_error(array.as_ref(), cast_type, cast_options)?, + ))), ColumnarValue::Scalar(scalar) => { let scalar_array = scalar.to_array(); - let cast_array = cast_with_error(scalar_array.as_ref(), cast_type)?.into(); - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; + let cast_array = + cast_with_error(scalar_array.as_ref(), cast_type, cast_options)?; + let cast_scalar = ScalarValue::try_from_array(&Arc::from(cast_array), 0)?; Ok(ColumnarValue::Scalar(cast_scalar)) } } @@ -127,12 +153,13 @@ pub fn cast_with_options( expr: Arc, input_schema: &Schema, cast_type: DataType, + cast_options: CastOptions, ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(expr.clone()) } else if cast::can_cast_types(&expr_type, &cast_type) { - Ok(Arc::new(CastExpr::new(expr, cast_type))) + Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { Err(DataFusionError::Internal(format!( "Unsupported CAST from {:?} to {:?}", @@ -150,18 +177,72 @@ pub fn cast( input_schema: &Schema, cast_type: DataType, ) -> Result> { - cast_with_options(expr, input_schema, cast_type) + cast_with_options( + expr, + input_schema, + cast_type, + DEFAULT_DATAFUSION_CAST_OPTIONS, + ) } #[cfg(test)] mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; use arrow::{array::*, datatypes::*}; type StringArray = Utf8Array; + // runs an end-to-end test of physical type cast + // 1. construct a record batch with a column "a" of type A + // 2. construct a physical expression of CAST(a AS B) + // 3. evaluate the expression + // 4. verify that the resulting expression is of type B + // 5. verify that the resulting values are downcastable and correct + macro_rules! generic_decimal_to_other_test_cast { + ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $A_TYPE, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new($DECIMAL_ARRAY)], + )?; + // verify that we can construct the expression + let expression = + cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; + + // verify that its display is correct + assert_eq!( + format!("CAST(a@0 AS {:?})", $TYPE), + format!("{}", expression) + ); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema)?, $TYPE); + + // compute + let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + + // verify that the array's data_type is correct + assert_eq!(*result.data_type(), $TYPE); + + // verify that the data itself is downcastable + let result = result + .as_any() + .downcast_ref::<$TYPEARRAY>() + .expect("failed to downcast"); + + // verify that the result itself is correct + for (i, x) in $VEC.iter().enumerate() { + match x { + Some(x) => assert_eq!(result.value(i), *x), + None => assert!(!result.is_valid(i)), + } + } + }}; + } + // runs an end-to-end test of physical type cast // 1. construct a record batch with a column "a" of type A // 2. construct a physical expression of CAST(a AS B) @@ -169,14 +250,15 @@ mod tests { // 4. verify that the resulting expression is of type B // 5. verify that the resulting values are downcastable and correct macro_rules! generic_test_cast { - ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ + ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{ let schema = Schema::new(vec![Field::new("a", $A_TYPE, false)]); let a = $A_ARRAY::from_slice($A_VEC); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // verify that we can construct the expression - let expression = cast_with_options(col("a", &schema)?, &schema, $TYPE)?; + let expression = + cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct assert_eq!( @@ -212,6 +294,278 @@ mod tests { }}; } + #[test] + fn test_cast_decimal_to_decimal() -> Result<()> { + let array: Vec = vec![1234, 2222, 3, 4000, 5000]; + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 3), + Int128Array, + DataType::Decimal(20, 6), + vec![ + Some(1_234_000_i128), + Some(2_222_000_i128), + Some(3_000_i128), + Some(4_000_000_i128), + Some(5_000_000_i128), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 3), + Int128Array, + DataType::Decimal(10, 2), + vec![ + Some(123_i128), + Some(222_i128), + Some(0_i128), + Some(400_i128), + Some(500_i128), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + Ok(()) + } + + #[test] + fn test_cast_decimal_to_numeric() -> Result<()> { + let array: Vec = vec![1, 2, 3, 4, 5]; + // decimal to i8 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int8Array, + DataType::Int8, + vec![ + Some(1_i8), + Some(2_i8), + Some(3_i8), + Some(4_i8), + Some(5_i8), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + // decimal to i16 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int16Array, + DataType::Int16, + vec![ + Some(1_i16), + Some(2_i16), + Some(3_i16), + Some(4_i16), + Some(5_i16), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + // decimal to i32 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int32Array, + DataType::Int32, + vec![ + Some(1_i32), + Some(2_i32), + Some(3_i32), + Some(4_i32), + Some(5_i32), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + // decimal to i64 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int64Array, + DataType::Int64, + vec![ + Some(1_i64), + Some(2_i64), + Some(3_i64), + Some(4_i64), + Some(5_i64), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + // decimal to float32 + let array: Vec = vec![1234, 2222, 3, 4000, 5000]; + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 3), + Float32Array, + DataType::Float32, + vec![ + Some(1.234_f32), + Some(2.222_f32), + Some(0.003_f32), + Some(4.0_f32), + Some(5.0_f32), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + // decimal to float64 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(20, 6), + Float64Array, + DataType::Float64, + vec![ + Some(0.001234_f64), + Some(0.002222_f64), + Some(0.000003_f64), + Some(0.004_f64), + Some(0.005_f64), + None, + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + Ok(()) + } + + #[test] + fn test_cast_numeric_to_decimal() -> Result<()> { + // int8 + generic_test_cast!( + Int8Array, + DataType::Int8, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(3, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + // int16 + generic_test_cast!( + Int16Array, + DataType::Int16, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(5, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + // int32 + generic_test_cast!( + Int32Array, + DataType::Int32, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(10, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + // int64 + generic_test_cast!( + Int64Array, + DataType::Int64, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(20, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + // int64 to different scale + generic_test_cast!( + Int64Array, + DataType::Int64, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(20, 2), + vec![ + Some(100_i128), + Some(200_i128), + Some(300_i128), + Some(400_i128), + Some(500_i128), + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + // float32 + generic_test_cast!( + Float32Array, + DataType::Float32, + vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], + Int128Array, + DataType::Decimal(10, 2), + vec![ + Some(150_i128), + Some(250_i128), + Some(300_i128), + Some(112_i128), + Some(550_i128), + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + + // float64 + generic_test_cast!( + Float64Array, + DataType::Float64, + vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], + Int128Array, + DataType::Decimal(20, 4), + vec![ + Some(15000_i128), + Some(25000_i128), + Some(30000_i128), + Some(11234_i128), + Some(55000_i128), + ], + DEFAULT_DATAFUSION_CAST_OPTIONS + ); + Ok(()) + } + #[test] fn test_cast_i32_u32() -> Result<()> { generic_test_cast!( @@ -226,7 +580,8 @@ mod tests { Some(3_u32), Some(4_u32), Some(5_u32) - ] + ], + DEFAULT_DATAFUSION_CAST_OPTIONS ); Ok(()) } @@ -239,7 +594,8 @@ mod tests { &[1, 2, 3, 4, 5], StringArray, DataType::Utf8, - vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")] + vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], + DEFAULT_DATAFUSION_CAST_OPTIONS ); Ok(()) } @@ -255,7 +611,8 @@ mod tests { original, Int64Array, DataType::Timestamp(TimeUnit::Nanosecond, None), - expected + expected, + DEFAULT_DATAFUSION_CAST_OPTIONS ); Ok(()) } @@ -272,7 +629,9 @@ mod tests { #[test] fn invalid_str_cast() { let arr = Utf8Array::::from_slice(&["a", "b", "123", "!", "456"]); - let err = cast_with_error(&arr, &DataType::Int64).unwrap_err(); + let err = + cast_with_error(&arr, &DataType::Int64, DEFAULT_DATAFUSION_CAST_OPTIONS) + .unwrap_err(); assert_eq!( err.to_string(), "Execution error: Could not cast Utf8[a, b, !] to value of type Int64" diff --git a/datafusion/src/physical_plan/expressions/coercion.rs b/datafusion/src/physical_plan/expressions/coercion.rs deleted file mode 100644 index a04f11f263cd..000000000000 --- a/datafusion/src/physical_plan/expressions/coercion.rs +++ /dev/null @@ -1,247 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Coercion rules used to coerce types to match existing expressions' implementations - -use arrow::datatypes::DataType; - -/// Determine if a DataType is signed numeric or not -pub fn is_signed_numeric(dt: &DataType) -> bool { - matches!( - dt, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - ) -} - -/// Determine if a DataType is numeric or not -pub fn is_numeric(dt: &DataType) -> bool { - is_signed_numeric(dt) - || match dt { - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - true - } - _ => false, - } -} - -/// Coercion rules for dictionary values (aka the type of the dictionary itself) -fn dictionary_value_coercion( - lhs_type: &DataType, - rhs_type: &DataType, -) -> Option { - numerical_coercion(lhs_type, rhs_type).or_else(|| string_coercion(lhs_type, rhs_type)) -} - -/// Coercion rules for Dictionaries: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// It would likely be preferable to cast primitive values to -/// dictionaries, and thus avoid unpacking dictionary as well as doing -/// faster comparisons. However, the arrow compute kernels (e.g. eq) -/// don't have DictionaryArray support yet, so fall back to unpacking -/// the dictionaries -pub fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - match (lhs_type, rhs_type) { - ( - DataType::Dictionary(_lhs_index_type, lhs_value_type, _), - DataType::Dictionary(_rhs_index_type, rhs_value_type, _), - ) => dictionary_value_coercion(lhs_value_type, rhs_value_type), - (DataType::Dictionary(_index_type, value_type, _), _) => { - dictionary_value_coercion(value_type, rhs_type) - } - (_, DataType::Dictionary(_index_type, value_type, _)) => { - dictionary_value_coercion(lhs_type, value_type) - } - _ => None, - } -} - -/// Coercion rules for Strings: the type that both lhs and rhs can be -/// casted to for the purpose of a string computation -pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Utf8, Utf8) => Some(Utf8), - (LargeUtf8, Utf8) => Some(LargeUtf8), - (Utf8, LargeUtf8) => Some(LargeUtf8), - (LargeUtf8, LargeUtf8) => Some(LargeUtf8), - _ => None, - } -} - -/// coercion rules for like operations. -/// This is a union of string coercion rules and dictionary coercion rules -pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - string_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type)) -} - -/// Coercion rules for Temporal columns: the type that both lhs and rhs can be -/// casted to for the purpose of a date computation -pub fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - use arrow::datatypes::DataType::*; - use arrow::datatypes::TimeUnit; - match (lhs_type, rhs_type) { - (Utf8, Date32) => Some(Date32), - (Date32, Utf8) => Some(Date32), - (Utf8, Date64) => Some(Date64), - (Date64, Utf8) => Some(Date64), - (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { - let tz = match (lhs_tz, rhs_tz) { - // can't cast across timezones - (Some(lhs_tz), Some(rhs_tz)) => { - if lhs_tz != rhs_tz { - return None; - } else { - Some(lhs_tz.clone()) - } - } - (Some(lhs_tz), None) => Some(lhs_tz.clone()), - (None, Some(rhs_tz)) => Some(rhs_tz.clone()), - (None, None) => None, - }; - - let unit = match (lhs_unit, rhs_unit) { - (TimeUnit::Second, TimeUnit::Millisecond) => TimeUnit::Second, - (TimeUnit::Second, TimeUnit::Microsecond) => TimeUnit::Second, - (TimeUnit::Second, TimeUnit::Nanosecond) => TimeUnit::Second, - (TimeUnit::Millisecond, TimeUnit::Second) => TimeUnit::Second, - (TimeUnit::Millisecond, TimeUnit::Microsecond) => TimeUnit::Millisecond, - (TimeUnit::Millisecond, TimeUnit::Nanosecond) => TimeUnit::Millisecond, - (TimeUnit::Microsecond, TimeUnit::Second) => TimeUnit::Second, - (TimeUnit::Microsecond, TimeUnit::Millisecond) => TimeUnit::Millisecond, - (TimeUnit::Microsecond, TimeUnit::Nanosecond) => TimeUnit::Microsecond, - (TimeUnit::Nanosecond, TimeUnit::Second) => TimeUnit::Second, - (TimeUnit::Nanosecond, TimeUnit::Millisecond) => TimeUnit::Millisecond, - (TimeUnit::Nanosecond, TimeUnit::Microsecond) => TimeUnit::Microsecond, - (l, r) => { - assert_eq!(l, r); - *l - } - }; - - Some(Timestamp(unit, tz)) - } - _ => None, - } -} - -/// Coercion rule for numerical types: The type that both lhs and rhs -/// can be casted to for numerical calculation, while maintaining -/// maximum precision -pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - use arrow::datatypes::DataType::*; - - // error on any non-numeric type - if !is_numeric(lhs_type) || !is_numeric(rhs_type) { - return None; - }; - - // same type => all good - if lhs_type == rhs_type { - return Some(lhs_type.clone()); - } - - // these are ordered from most informative to least informative so - // that the coercion removes the least amount of information - match (lhs_type, rhs_type) { - (Float64, _) | (_, Float64) => Some(Float64), - (_, Float32) | (Float32, _) => Some(Float32), - (Int64, _) | (_, Int64) => Some(Int64), - (Int32, _) | (_, Int32) => Some(Int32), - (Int16, _) | (_, Int16) => Some(Int16), - (Int8, _) | (_, Int8) => Some(Int8), - (UInt64, _) | (_, UInt64) => Some(UInt64), - (UInt32, _) | (_, UInt32) => Some(UInt32), - (UInt16, _) | (_, UInt16) => Some(UInt16), - (UInt8, _) | (_, UInt8) => Some(UInt8), - _ => None, - } -} - -// coercion rules for equality operations. This is a superset of all numerical coercion rules. -pub fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - if lhs_type == rhs_type { - // same type => equality is possible - return Some(lhs_type.clone()); - } - numerical_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type)) - .or_else(|| temporal_coercion(lhs_type, rhs_type)) -} - -// coercion rules that assume an ordered set, such as "less than". -// These are the union of all numerical coercion rules and all string coercion rules -pub fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - if lhs_type == rhs_type { - // same type => all good - return Some(lhs_type.clone()); - } - - numerical_coercion(lhs_type, rhs_type) - .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| dictionary_coercion(lhs_type, rhs_type)) - .or_else(|| temporal_coercion(lhs_type, rhs_type)) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_dictionary_type_coersion() { - use arrow::datatypes::IntegerType; - - // TODO: In the future, this would ideally return Dictionary types and avoid unpacking - let lhs_type = - DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int32), false); - let rhs_type = - DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int16), false); - assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type), - Some(DataType::Int32) - ); - - let lhs_type = - DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); - let rhs_type = - DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Int16), false); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None); - - let lhs_type = - DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); - let rhs_type = DataType::Utf8; - assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type), - Some(DataType::Utf8) - ); - - let lhs_type = DataType::Utf8; - let rhs_type = - DataType::Dictionary(IntegerType::Int8, Box::new(DataType::Utf8), false); - assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type), - Some(DataType::Utf8) - ); - } -} diff --git a/datafusion/src/physical_plan/expressions/column.rs b/datafusion/src/physical_plan/expressions/column.rs index d6eafbb05384..81f7a1d75ab4 100644 --- a/datafusion/src/physical_plan/expressions/column.rs +++ b/datafusion/src/physical_plan/expressions/column.rs @@ -19,12 +19,11 @@ use std::sync::Arc; -use arrow::{ - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; +use crate::record_batch::RecordBatch; +use arrow::datatypes::{DataType, Schema}; use crate::error::Result; +use crate::field_util::{FieldExt, SchemaExt}; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; /// Represents the column at a given index in a RecordBatch diff --git a/datafusion/src/physical_plan/expressions/correlation.rs b/datafusion/src/physical_plan/expressions/correlation.rs new file mode 100644 index 000000000000..9e973b193974 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/correlation.rs @@ -0,0 +1,544 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + expressions::{covariance::CovarianceAccumulator, stddev::StddevAccumulator}, + Accumulator, AggregateExpr, PhysicalExpr, +}; +use crate::scalar::ScalarValue; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; + +use super::{format_state_name, StatsType}; + +/// CORR aggregate expression +#[derive(Debug)] +pub struct Correlation { + name: String, + expr1: Arc, + expr2: Arc, +} + +/// function return type of correlation +pub(crate) fn correlation_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "CORR does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Correlation { + /// Create a new COVAR_POP aggregate function + pub fn new( + expr1: Arc, + expr2: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of correlation just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr1, + expr2, + } + } +} + +impl AggregateExpr for Correlation { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CorrelationAccumulator::try_new()?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2_1"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2_2"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr1.clone(), self.expr2.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute correlation +#[derive(Debug)] +pub struct CorrelationAccumulator { + covar: CovarianceAccumulator, + stddev1: StddevAccumulator, + stddev2: StddevAccumulator, +} + +impl CorrelationAccumulator { + /// Creates a new `CorrelationAccumulator` + pub fn try_new() -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population)?, + stddev1: StddevAccumulator::try_new(StatsType::Population)?, + stddev2: StddevAccumulator::try_new(StatsType::Population)?, + }) + } +} + +impl Accumulator for CorrelationAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.covar.get_count()), + ScalarValue::from(self.covar.get_mean1()), + ScalarValue::from(self.stddev1.get_m2()), + ScalarValue::from(self.covar.get_mean2()), + ScalarValue::from(self.stddev2.get_m2()), + ScalarValue::from(self.covar.get_algo_const()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.covar.update_batch(values)?; + self.stddev1.update_batch(&[values[0].clone()])?; + self.stddev2.update_batch(&[values[1].clone()])?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let states_c = [ + states[0].clone(), + states[1].clone(), + states[3].clone(), + states[5].clone(), + ]; + let states_s1 = [states[0].clone(), states[1].clone(), states[2].clone()]; + let states_s2 = [states[0].clone(), states[3].clone(), states[4].clone()]; + + self.covar.merge_batch(&states_c)?; + self.stddev1.merge_batch(&states_s1)?; + self.stddev2.merge_batch(&states_s2)?; + Ok(()) + } + + fn evaluate(&self) -> Result { + let covar = self.covar.evaluate()?; + let stddev1 = self.stddev1.evaluate()?; + let stddev2 = self.stddev2.evaluate()?; + + if let ScalarValue::Float64(Some(c)) = covar { + if let ScalarValue::Float64(Some(s1)) = stddev1 { + if let ScalarValue::Float64(Some(s2)) = stddev2 { + if s1 == 0_f64 || s2 == 0_f64 { + return Ok(ScalarValue::Float64(Some(0_f64))); + } else { + return Ok(ScalarValue::Float64(Some(c / s1 / s2))); + } + } + } + } + + Ok(ScalarValue::Float64(None)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field_util::SchemaExt; + use crate::physical_plan::expressions::col; + use crate::record_batch::RecordBatch; + use crate::{error::Result, generic_test_op2}; + use arrow::{array::*, datatypes::*}; + + #[test] + fn correlation_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b: ArrayRef = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64, 7_f64])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + Correlation, + ScalarValue::from(0.9819805060619659), + DataType::Float64 + ) + } + + #[test] + fn correlation_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b: ArrayRef = Arc::new(Float64Array::from_slice(&[4_f64, -5_f64, 6_f64])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + Correlation, + ScalarValue::from(0.17066403719657236), + DataType::Float64 + ) + } + + #[test] + fn correlation_f64_4() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); + let b: ArrayRef = Arc::new(Float64Array::from_slice(&[4.1_f64, 5_f64, 6_f64])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + Correlation, + ScalarValue::from(1_f64), + DataType::Float64 + ) + } + + #[test] + fn correlation_f64_6() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(vec![ + 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, + ])); + let b = Arc::new(Float64Array::from_slice(vec![ + 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, + ])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + Correlation, + ScalarValue::from(0.9860135594710389), + DataType::Float64 + ) + } + + #[test] + fn correlation_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3])); + let b: ArrayRef = Arc::new(Int32Array::from_slice(&[4, 5, 6])); + + generic_test_op2!( + a, + b, + DataType::Int32, + DataType::Int32, + Correlation, + ScalarValue::from(1_f64), + DataType::Float64 + ) + } + + #[test] + fn correlation_u32() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[1_u32, 2_u32, 3_u32])); + let b: ArrayRef = Arc::new(UInt32Array::from_slice(&[4_u32, 5_u32, 6_u32])); + generic_test_op2!( + a, + b, + DataType::UInt32, + DataType::UInt32, + Correlation, + ScalarValue::from(1_f64), + DataType::Float64 + ) + } + + #[test] + fn correlation_f32() -> Result<()> { + let a: ArrayRef = Arc::new(Float32Array::from_slice(&[1_f32, 2_f32, 3_f32])); + let b: ArrayRef = Arc::new(Float32Array::from_slice(&[4_f32, 5_f32, 6_f32])); + generic_test_op2!( + a, + b, + DataType::Float32, + DataType::Float32, + Correlation, + ScalarValue::from(1_f64), + DataType::Float64 + ) + } + + #[test] + fn test_correlation_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = correlation_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal(36, 10); + assert!(correlation_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn correlation_i32_with_nulls_1() -> Result<()> { + let a: ArrayRef = + Arc::new(Int32Array::from_iter(vec![Some(1), None, Some(3), Some(3)])); + let b: ArrayRef = + Arc::new(Int32Array::from_iter(vec![Some(4), None, Some(6), Some(3)])); + + generic_test_op2!( + a, + b, + DataType::Int32, + DataType::Int32, + Correlation, + ScalarValue::from(0.1889822365046137), + DataType::Float64 + ) + } + + #[test] + fn correlation_i32_with_nulls_2() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![Some(1), None, Some(3)])); + let b: ArrayRef = + Arc::new(Int32Array::from_iter(vec![Some(4), Some(5), Some(6)])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + + let agg = Arc::new(Correlation::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn correlation_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None])); + let b: ArrayRef = Arc::new(Int32Array::from_iter(vec![None, None])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + + let agg = Arc::new(Correlation::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn correlation_f64_merge_1() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64, 6_f64])); + let c = Arc::new(Float64Array::from_slice(&[1.1_f64, 2.2_f64, 3.3_f64])); + let d = Arc::new(Float64Array::from_slice(&[4.4_f64, 5.5_f64, 9.9_f64])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + ]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; + + let agg1 = Arc::new(Correlation::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(Correlation::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(0.8443707186481967)); + + Ok(()) + } + + #[test] + fn correlation_f64_merge_2() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64, 6_f64])); + let c = Arc::new(Float64Array::from_iter(vec![None])); + let d = Arc::new(Float64Array::from_iter(vec![None])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + ]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; + + let agg1 = Arc::new(Correlation::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(Correlation::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(1_f64)); + + Ok(()) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } + + fn merge( + batch1: &RecordBatch, + batch2: &RecordBatch, + agg1: Arc, + agg2: Arc, + ) -> Result { + let mut accum1 = agg1.create_accumulator()?; + let mut accum2 = agg2.create_accumulator()?; + let expr1 = agg1.expressions(); + let expr2 = agg2.expressions(); + + let values1 = expr1 + .iter() + .map(|e| e.evaluate(batch1)) + .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .collect::>>()?; + let values2 = expr2 + .iter() + .map(|e| e.evaluate(batch2)) + .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .collect::>>()?; + accum1.update_batch(&values1)?; + accum2.update_batch(&values2)?; + let state2 = accum2 + .state()? + .iter() + .map(|v| vec![v.clone()]) + .map(|x| ScalarValue::iter_to_array(x).unwrap()) + .collect::>(); + accum1.merge_batch(&state2)?; + accum1.evaluate() + } +} diff --git a/datafusion/src/physical_plan/expressions/count.rs b/datafusion/src/physical_plan/expressions/count.rs index 255e1767376e..0a5638705096 100644 --- a/datafusion/src/physical_plan/expressions/count.rs +++ b/datafusion/src/physical_plan/expressions/count.rs @@ -113,24 +113,6 @@ impl Accumulator for CountAccumulator { Ok(()) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let value = &values[0]; - if !value.is_null() { - self.count += 1; - } - Ok(()) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - let count = &states[0]; - if let ScalarValue::UInt64(Some(delta)) = count { - self.count += *delta; - } else { - unreachable!() - } - Ok(()) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = states[0].as_any().downcast_ref::().unwrap(); let delta = &compute::aggregate::sum_primitive(counts); @@ -152,10 +134,11 @@ impl Accumulator for CountAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; use crate::physical_plan::expressions::tests::aggregate; + use crate::record_batch::RecordBatch; use crate::{error::Result, generic_test_op}; - use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; #[test] diff --git a/datafusion/src/physical_plan/expressions/covariance.rs b/datafusion/src/physical_plan/expressions/covariance.rs new file mode 100644 index 000000000000..d89d5736129b --- /dev/null +++ b/datafusion/src/physical_plan/expressions/covariance.rs @@ -0,0 +1,725 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::expressions::cast::{ + cast_with_error, DEFAULT_DATAFUSION_CAST_OPTIONS, +}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; +use arrow::array::Float64Array; +use arrow::{ + array::{ArrayRef, UInt64Array}, + datatypes::DataType, + datatypes::Field, +}; + +use super::{format_state_name, StatsType}; + +/// COVAR and COVAR_SAMP aggregate expression +#[derive(Debug)] +pub struct Covariance { + name: String, + expr1: Arc, + expr2: Arc, +} + +/// COVAR_POP aggregate expression +#[derive(Debug)] +pub struct CovariancePop { + name: String, + expr1: Arc, + expr2: Arc, +} + +/// function return type of covariance +pub(crate) fn covariance_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "COVAR does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Covariance { + /// Create a new COVAR aggregate function + pub fn new( + expr1: Arc, + expr2: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of covariance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr1, + expr2, + } + } +} + +impl AggregateExpr for Covariance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr1.clone(), self.expr2.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl CovariancePop { + /// Create a new COVAR_POP aggregate function + pub fn new( + expr1: Arc, + expr2: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of covariance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr1, + expr2, + } + } +} + +impl AggregateExpr for CovariancePop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new( + StatsType::Population, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr1.clone(), self.expr2.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute covariance +/// The algrithm used is an online implementation and numerically stable. It is derived from the following paper +/// for calculating variance: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. +/// +/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online, +/// parallelizable and numerically stable. + +#[derive(Debug)] +pub struct CovarianceAccumulator { + algo_const: f64, + mean1: f64, + mean2: f64, + count: u64, + stats_type: StatsType, +} + +impl CovarianceAccumulator { + /// Creates a new `CovarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + algo_const: 0_f64, + mean1: 0_f64, + mean2: 0_f64, + count: 0_u64, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean1(&self) -> f64 { + self.mean1 + } + + pub fn get_mean2(&self) -> f64 { + self.mean2 + } + + pub fn get_algo_const(&self) -> f64 { + self.algo_const + } +} + +impl Accumulator for CovarianceAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean1), + ScalarValue::from(self.mean2), + ScalarValue::from(self.algo_const), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast_with_error( + values[0].as_ref(), + &DataType::Float64, + DEFAULT_DATAFUSION_CAST_OPTIONS, + )?; + let values2 = &cast_with_error( + values[1].as_ref(), + &DataType::Float64, + DEFAULT_DATAFUSION_CAST_OPTIONS, + )?; + + let mut arr1 = values1 + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .flatten(); + let mut arr2 = values2 + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .flatten(); + + for _i in 0..values1.len() { + let value1 = arr1.next(); + let value2 = arr2.next(); + + if value1 == None || value2 == None { + if value1 == None && value2 == None { + continue; + } else { + return Err(DataFusionError::Internal( + "The two columns are not aligned".to_string(), + )); + } + } + + let new_count = self.count + 1; + let delta1 = value1.unwrap() - self.mean1; + let new_mean1 = delta1 / new_count as f64 + self.mean1; + let delta2 = value2.unwrap() - self.mean2; + let new_mean2 = delta2 / new_count as f64 + self.mean2; + let new_c = delta1 * (value2.unwrap() - new_mean2) + self.algo_const; + + self.count += 1; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = states[0].as_any().downcast_ref::().unwrap(); + let means1 = states[1].as_any().downcast_ref::().unwrap(); + let means2 = states[2].as_any().downcast_ref::().unwrap(); + let cs = states[3].as_any().downcast_ref::().unwrap(); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; + } + let new_count = self.count + c; + let new_mean1 = self.mean1 * self.count as f64 / new_count as f64 + + means1.value(i) * c as f64 / new_count as f64; + let new_mean2 = self.mean2 * self.count as f64 / new_count as f64 + + means2.value(i) * c as f64 / new_count as f64; + let delta1 = self.mean1 - means1.value(i); + let delta2 = self.mean2 - means2.value(i); + let new_c = self.algo_const + + cs.value(i) + + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64; + + self.count = new_count; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + Ok(()) + } + + fn evaluate(&self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + if count <= 1 { + return Err(DataFusionError::Internal( + "At least two values are needed to calculate covariance".to_string(), + )); + } + + if self.count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.algo_const / count as f64))) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field_util::SchemaExt; + use crate::physical_plan::expressions::col; + use crate::record_batch::RecordBatch; + use crate::{error::Result, generic_test_op2}; + use arrow::{array::*, datatypes::*}; + + #[test] + fn covariance_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b: ArrayRef = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64, 6_f64])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + CovariancePop, + ScalarValue::from(0.6666666666666666), + DataType::Float64 + ) + } + + #[test] + fn covariance_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b: ArrayRef = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64, 6_f64])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + Covariance, + ScalarValue::from(1_f64), + DataType::Float64 + ) + } + + #[test] + fn covariance_f64_4() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); + let b: ArrayRef = Arc::new(Float64Array::from_slice(&[4.1_f64, 5_f64, 6_f64])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + Covariance, + ScalarValue::from(0.9033333333333335_f64), + DataType::Float64 + ) + } + + #[test] + fn covariance_f64_5() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); + let b: ArrayRef = Arc::new(Float64Array::from_slice(&[4.1_f64, 5_f64, 6_f64])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + CovariancePop, + ScalarValue::from(0.6022222222222223_f64), + DataType::Float64 + ) + } + + #[test] + fn covariance_f64_6() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(vec![ + 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, + ])); + let b = Arc::new(Float64Array::from_slice(vec![ + 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, + ])); + + generic_test_op2!( + a, + b, + DataType::Float64, + DataType::Float64, + CovariancePop, + ScalarValue::from(0.7616666666666666), + DataType::Float64 + ) + } + + #[test] + fn covariance_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3])); + let b: ArrayRef = Arc::new(Int32Array::from_slice(&[4, 5, 6])); + + generic_test_op2!( + a, + b, + DataType::Int32, + DataType::Int32, + CovariancePop, + ScalarValue::from(0.6666666666666666_f64), + DataType::Float64 + ) + } + + #[test] + fn covariance_u32() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[1_u32, 2_u32, 3_u32])); + let b: ArrayRef = Arc::new(UInt32Array::from_slice(&[4_u32, 5_u32, 6_u32])); + generic_test_op2!( + a, + b, + DataType::UInt32, + DataType::UInt32, + CovariancePop, + ScalarValue::from(0.6666666666666666_f64), + DataType::Float64 + ) + } + + #[test] + fn covariance_f32() -> Result<()> { + let a: ArrayRef = Arc::new(Float32Array::from_slice(&[1_f32, 2_f32, 3_f32])); + let b: ArrayRef = Arc::new(Float32Array::from_slice(&[4_f32, 5_f32, 6_f32])); + generic_test_op2!( + a, + b, + DataType::Float32, + DataType::Float32, + CovariancePop, + ScalarValue::from(0.6666666666666666_f64), + DataType::Float64 + ) + } + + #[test] + fn test_covariance_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = covariance_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal(36, 10); + assert!(covariance_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn covariance_i32_with_nulls_1() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])); + + generic_test_op2!( + a, + b, + DataType::Int32, + DataType::Int32, + CovariancePop, + ScalarValue::from(1_f64), + DataType::Float64 + ) + } + + #[test] + fn covariance_i32_with_nulls_2() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), Some(5), Some(6)])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + + let agg = Arc::new(Covariance::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn covariance_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + + let agg = Arc::new(Covariance::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn covariance_f64_merge_1() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64, 6_f64])); + let c = Arc::new(Float64Array::from_slice(&[1.1_f64, 2.2_f64, 3.3_f64])); + let d = Arc::new(Float64Array::from_slice(&[4.4_f64, 5.5_f64, 6.6_f64])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + ]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; + + let agg1 = Arc::new(CovariancePop::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(CovariancePop::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(0.7616666666666666)); + + Ok(()) + } + + #[test] + fn covariance_f64_merge_2() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64, 6_f64])); + let c = Arc::new(Float64Array::from(vec![None])); + let d = Arc::new(Float64Array::from(vec![None])); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + ]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; + + let agg1 = Arc::new(CovariancePop::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(CovariancePop::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(0.6666666666666666)); + + Ok(()) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } + + fn merge( + batch1: &RecordBatch, + batch2: &RecordBatch, + agg1: Arc, + agg2: Arc, + ) -> Result { + let mut accum1 = agg1.create_accumulator()?; + let mut accum2 = agg2.create_accumulator()?; + let expr1 = agg1.expressions(); + let expr2 = agg2.expressions(); + + let values1 = expr1 + .iter() + .map(|e| e.evaluate(batch1)) + .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .collect::>>()?; + let values2 = expr2 + .iter() + .map(|e| e.evaluate(batch2)) + .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .collect::>>()?; + accum1.update_batch(&values1)?; + accum2.update_batch(&values2)?; + let state2 = accum2 + .state()? + .iter() + .map(|v| vec![v.clone()]) + .map(|x| ScalarValue::iter_to_array(x).unwrap()) + .collect::>(); + accum1.merge_batch(&state2)?; + accum1.evaluate() + } +} diff --git a/datafusion/src/physical_plan/expressions/cume_dist.rs b/datafusion/src/physical_plan/expressions/cume_dist.rs index b70b4fc33967..40d9c7be4bbe 100644 --- a/datafusion/src/physical_plan/expressions/cume_dist.rs +++ b/datafusion/src/physical_plan/expressions/cume_dist.rs @@ -21,10 +21,10 @@ use crate::error::Result; use crate::physical_plan::window_functions::PartitionEvaluator; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use crate::record_batch::RecordBatch; use arrow::array::ArrayRef; use arrow::array::Float64Array; use arrow::datatypes::{DataType, Field}; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::iter; use std::ops::Range; @@ -107,6 +107,7 @@ impl PartitionEvaluator for CumeDistEvaluator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use arrow::{array::*, datatypes::*}; fn test_i32_result( diff --git a/datafusion/src/physical_plan/distinct_expressions.rs b/datafusion/src/physical_plan/expressions/distinct_expressions.rs similarity index 70% rename from datafusion/src/physical_plan/distinct_expressions.rs rename to datafusion/src/physical_plan/expressions/distinct_expressions.rs index 40f6d58dc051..f0a741bb2f0b 100644 --- a/datafusion/src/physical_plan/distinct_expressions.rs +++ b/datafusion/src/physical_plan/expressions/distinct_expressions.rs @@ -22,6 +22,7 @@ use std::fmt::Debug; use std::sync::Arc; use ahash::RandomState; +use arrow::array::ArrayRef; use std::collections::HashSet; use arrow::{ @@ -128,8 +129,7 @@ struct DistinctCountAccumulator { state_data_types: Vec, count_data_type: DataType, } - -impl Accumulator for DistinctCountAccumulator { +impl DistinctCountAccumulator { fn update(&mut self, values: &[ScalarValue]) -> Result<()> { // If a row has a NULL, it is not included in the final count. if !values.iter().any(|v| v.is_null()) { @@ -163,7 +163,33 @@ impl Accumulator for DistinctCountAccumulator { self.update(&row_values) }) } +} +impl Accumulator for DistinctCountAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + }; + (0..values[0].len()).try_for_each(|index| { + let v = values + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + self.update(&v) + }) + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + (0..states[0].len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + self.merge(&v) + }) + } fn state(&self) -> Result> { let mut cols_out = self .state_data_types @@ -205,11 +231,143 @@ impl Accumulator for DistinctCountAccumulator { } } +/// Expression for a ARRAY_AGG(DISTINCT) aggregation. +#[derive(Debug)] +pub struct DistinctArrayAgg { + /// Column name + name: String, + /// The DataType for the input expression + input_data_type: DataType, + /// The input expression + expr: Arc, +} + +impl DistinctArrayAgg { + /// Create a new DistinctArrayAgg aggregate function + pub fn new( + expr: Arc, + name: impl Into, + input_data_type: DataType, + ) -> Self { + let name = name.into(); + Self { + name, + expr, + input_data_type, + } + } +} + +impl AggregateExpr for DistinctArrayAgg { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(DistinctArrayAggAccumulator::try_new( + &self.input_data_type, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + &format_state_name(&self.name, "distinct_array_agg"), + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +struct DistinctArrayAggAccumulator { + values: HashSet, + datatype: DataType, +} + +impl DistinctArrayAggAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: HashSet::new(), + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for DistinctArrayAggAccumulator { + fn state(&self) -> Result> { + Ok(vec![ScalarValue::List( + Some(Box::new(self.values.clone().into_iter().collect())), + Box::new(self.datatype.clone()), + )]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + assert_eq!(values.len(), 1, "batch input should only include 1 column!"); + + let arr = &values[0]; + for i in 0..arr.len() { + self.values.insert(ScalarValue::try_from_array(arr, i)?); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + + for array in states { + for j in 0..array.len() { + self.values.insert(ScalarValue::try_from_array(array, j)?); + } + } + + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::List( + Some(Box::new(self.values.clone().into_iter().collect())), + Box::new(self.datatype.clone()), + )) + } +} + #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::DataType; + use crate::physical_plan::expressions::col; + use crate::physical_plan::expressions::tests::aggregate; + + use crate::field_util::SchemaExt; + use crate::record_batch::RecordBatch; + use arrow::datatypes::{DataType, Schema}; macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ @@ -282,9 +440,20 @@ mod tests { let mut accum = agg.create_accumulator()?; - for row in rows.iter() { - accum.update(row)? - } + let cols = (0..rows[0].len()) + .map(|i| { + rows.iter() + .map(|inner| inner[i].clone()) + .collect::>() + }) + .collect::>(); + + let arrays: Vec = cols + .iter() + .map(|c| ScalarValue::iter_to_array(c.clone())) + .collect::>>()?; + + accum.update_batch(&arrays)?; Ok((accum.state()?, accum.evaluate()?)) } @@ -671,4 +840,143 @@ mod tests { Ok(()) } + + fn check_distinct_array_agg( + input: ArrayRef, + expected: ScalarValue, + datatype: DataType, + ) -> Result<()> { + let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; + + let agg = Arc::new(DistinctArrayAgg::new( + col("a", &schema)?, + "bla".to_string(), + datatype, + )); + let actual = aggregate(&batch, agg)?; + + match (expected, actual) { + (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => { + // workaround lack of Ord of ScalarValue + let cmp = |a: &ScalarValue, b: &ScalarValue| { + a.partial_cmp(b).expect("Can compare ScalarValues") + }; + + e.sort_by(cmp); + a.sort_by(cmp); + // Check that the inputs are the same + assert_eq!(e, a); + } + _ => { + unreachable!() + } + } + + Ok(()) + } + + #[test] + fn distinct_array_agg_i32() -> Result<()> { + let col: ArrayRef = Arc::new(Int32Array::from_slice(vec![1, 2, 7, 4, 5, 2])); + + let out = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + ScalarValue::Int32(Some(7)), + ScalarValue::Int32(Some(4)), + ScalarValue::Int32(Some(5)), + ])), + Box::new(DataType::Int32), + ); + + check_distinct_array_agg(col, out, DataType::Int32) + } + + #[test] + fn distinct_array_agg_nested() -> Result<()> { + // [[1, 2, 3], [4, 5]] + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // [[6], [7, 8]] + let l2 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(7i32), + ScalarValue::from(8i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // [[9]] + let l3 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(9i32)])), + Box::new(DataType::Int32), + )])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let list = ScalarValue::List( + Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // Duplicate l1 in the input array and check that it is deduped in the output. + let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); + + check_distinct_array_agg( + array, + list, + DataType::List(Box::new(Field::new( + "item", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ))), + ) + } } diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index ba16f50127cf..344833e962cf 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -20,13 +20,12 @@ use std::convert::TryInto; use std::{any::Any, sync::Arc}; -use arrow::{ - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; +use crate::record_batch::RecordBatch; +use arrow::datatypes::{DataType, Schema}; use crate::arrow::array::Array; use crate::arrow::compute::concatenate::concatenate; +use crate::field_util::FieldExt; use crate::scalar::ScalarValue; use crate::{ error::DataFusionError, @@ -120,6 +119,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::{col, lit}; use arrow::array::{ Int64Array, MutableListArray, MutableUtf8Array, StructArray, Utf8Array, diff --git a/datafusion/src/physical_plan/expressions/in_list.rs b/datafusion/src/physical_plan/expressions/in_list.rs index 1be5a9c50fcd..1efe749ef4e2 100644 --- a/datafusion/src/physical_plan/expressions/in_list.rs +++ b/datafusion/src/physical_plan/expressions/in_list.rs @@ -20,18 +20,17 @@ use std::any::Any; use std::sync::Arc; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ColumnarValue, PhysicalExpr}; +use crate::record_batch::RecordBatch; +use crate::scalar::ScalarValue; use arrow::{ array::*, bitmap::Bitmap, datatypes::{DataType, Schema}, - record_batch::RecordBatch, types::NativeType, }; -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{ColumnarValue, PhysicalExpr}; -use crate::scalar::ScalarValue; - macro_rules! compare_op_scalar { ($left: expr, $right:expr, $op:expr) => {{ let validity = $left.validity(); @@ -456,6 +455,7 @@ pub fn in_list( #[cfg(test)] mod tests { + use crate::field_util::SchemaExt; use arrow::{array::Utf8Array, datatypes::Field}; type StringArray = Utf8Array; diff --git a/datafusion/src/physical_plan/expressions/is_not_null.rs b/datafusion/src/physical_plan/expressions/is_not_null.rs index fffae683432f..0df066460c4d 100644 --- a/datafusion/src/physical_plan/expressions/is_not_null.rs +++ b/datafusion/src/physical_plan/expressions/is_not_null.rs @@ -19,11 +19,9 @@ use std::{any::Any, sync::Arc}; +use crate::record_batch::RecordBatch; use arrow::compute; -use arrow::{ - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; +use arrow::datatypes::{DataType, Schema}; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use crate::{error::Result, scalar::ScalarValue}; @@ -88,11 +86,12 @@ pub fn is_not_null(arg: Arc) -> Result> #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; + use crate::record_batch::RecordBatch; use arrow::{ array::{BooleanArray, Utf8Array}, datatypes::*, - record_batch::RecordBatch, }; use std::sync::Arc; diff --git a/datafusion/src/physical_plan/expressions/is_null.rs b/datafusion/src/physical_plan/expressions/is_null.rs index f364067bc955..92b202955f9f 100644 --- a/datafusion/src/physical_plan/expressions/is_null.rs +++ b/datafusion/src/physical_plan/expressions/is_null.rs @@ -19,11 +19,9 @@ use std::{any::Any, sync::Arc}; +use crate::record_batch::RecordBatch; use arrow::compute; -use arrow::{ - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; +use arrow::datatypes::{DataType, Schema}; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use crate::{error::Result, scalar::ScalarValue}; @@ -88,11 +86,11 @@ pub fn is_null(arg: Arc) -> Result> { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; use arrow::{ array::{BooleanArray, Utf8Array}, datatypes::*, - record_batch::RecordBatch, }; use std::sync::Arc; diff --git a/datafusion/src/physical_plan/expressions/lead_lag.rs b/datafusion/src/physical_plan/expressions/lead_lag.rs index 02cc5f49a510..a05059623d84 100644 --- a/datafusion/src/physical_plan/expressions/lead_lag.rs +++ b/datafusion/src/physical_plan/expressions/lead_lag.rs @@ -19,13 +19,14 @@ //! at runtime during query execution use crate::error::{DataFusionError, Result}; +use crate::physical_plan::expressions::cast::cast_with_error; use crate::physical_plan::window_functions::PartitionEvaluator; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::borrow::Borrow; use std::ops::Neg; @@ -130,8 +131,7 @@ fn create_empty_array( .map(|scalar| scalar.to_array_of_size(size)) .unwrap_or_else(|| ArrayRef::from(new_null_array(data_type.clone(), size))); if array.data_type() != data_type { - cast::cast(array.borrow(), data_type, cast::CastOptions::default()) - .map_err(DataFusionError::ArrowError) + cast_with_error(array.borrow(), data_type, cast::CastOptions::default()) .map(ArrayRef::from) } else { Ok(array) @@ -188,8 +188,9 @@ impl PartitionEvaluator for WindowShiftEvaluator { mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::Column; - use arrow::record_batch::RecordBatch; + use crate::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { diff --git a/datafusion/src/physical_plan/expressions/literal.rs b/datafusion/src/physical_plan/expressions/literal.rs index 45ecf5c9f9fe..b121306ba9a9 100644 --- a/datafusion/src/physical_plan/expressions/literal.rs +++ b/datafusion/src/physical_plan/expressions/literal.rs @@ -20,10 +20,8 @@ use std::any::Any; use std::sync::Arc; -use arrow::{ - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; +use crate::record_batch::RecordBatch; +use arrow::datatypes::{DataType, Schema}; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; use crate::{error::Result, scalar::ScalarValue}; @@ -80,6 +78,7 @@ pub fn lit(value: ScalarValue) -> Arc { mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use arrow::array::*; use arrow::datatypes::*; diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 1d1ba506acba..1b80ac0a570c 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -439,16 +439,6 @@ impl Accumulator for MaxAccumulator { Ok(()) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let value = &values[0]; - self.max = max(&self.max, value)?; - Ok(()) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - self.update(states) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { self.update_batch(states) } @@ -542,12 +532,6 @@ impl Accumulator for MinAccumulator { Ok(vec![self.min.clone()]) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let value = &values[0]; - self.min = min(&self.min, value)?; - Ok(()) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &min_batch(values)?; @@ -555,10 +539,6 @@ impl Accumulator for MinAccumulator { Ok(()) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - self.update(states) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { self.update_batch(states) } @@ -571,11 +551,12 @@ impl Accumulator for MinAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; use crate::physical_plan::expressions::tests::aggregate; + use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue::Decimal128; use crate::{error::Result, generic_test_op}; - use arrow::record_batch::RecordBatch; #[test] fn min_decimal() -> Result<()> { diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 04127718f961..c83fd492932e 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -19,30 +19,12 @@ use std::sync::Arc; +use super::sorts::SortColumn; use super::ColumnarValue; use crate::error::{DataFusionError, Result}; use crate::physical_plan::PhysicalExpr; -use arrow::array::*; -use arrow::compute::sort::{SortColumn as ArrowSortColumn, SortOptions}; -use arrow::record_batch::RecordBatch; - -/// One column to be used in lexicographical sort -#[derive(Clone, Debug)] -pub struct SortColumn { - /// The array to be sorted - pub values: ArrayRef, - /// The options to sort the array - pub options: Option, -} - -impl<'a> From<&'a SortColumn> for ArrowSortColumn<'a> { - fn from(c: &'a SortColumn) -> Self { - Self { - values: c.values.as_ref(), - options: c.options, - } - } -} +use crate::record_batch::RecordBatch; +use arrow::compute::sort::SortOptions; mod approx_distinct; mod array_agg; @@ -50,8 +32,7 @@ mod average; #[macro_use] mod binary; mod case; -mod cast; -mod coercion; +pub(crate) mod cast; mod column; mod count; mod cume_dist; @@ -63,6 +44,9 @@ mod lead_lag; mod literal; #[macro_use] mod min_max; +mod correlation; +mod covariance; +mod distinct_expressions; mod negative; mod not; mod nth_value; @@ -88,8 +72,15 @@ pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; pub use cast::{cast, cast_column, cast_with_options, CastExpr}; pub use column::{col, Column}; +pub(crate) use correlation::{ + correlation_return_type, is_correlation_support_arg_type, Correlation, +}; pub use count::Count; +pub(crate) use covariance::{ + covariance_return_type, is_covariance_support_arg_type, Covariance, CovariancePop, +}; pub use cume_dist::cume_dist; +pub use distinct_expressions::{DistinctArrayAgg, DistinctCount}; pub use get_indexed_field::GetIndexedFieldExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; @@ -130,7 +121,7 @@ pub struct PhysicalSortExpr { } impl std::fmt::Display for PhysicalSortExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let opts_string = match (self.options.descending, self.options.nulls_first) { (true, true) => "DESC", (true, false) => "DESC NULLS LAST", @@ -189,6 +180,32 @@ mod tests { }}; } + /// macro to perform an aggregation with two inputs and verify the result. + #[macro_export] + macro_rules! generic_test_op2 { + ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![ + Field::new("a", $DATATYPE1, false), + Field::new("b", $DATATYPE2, false), + ]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY1, $ARRAY2])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + col("b", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) + }}; + } + pub fn aggregate( batch: &RecordBatch, agg: Arc, diff --git a/datafusion/src/physical_plan/expressions/negative.rs b/datafusion/src/physical_plan/expressions/negative.rs index a8e4bb113d02..75370ee1a83d 100644 --- a/datafusion/src/physical_plan/expressions/negative.rs +++ b/datafusion/src/physical_plan/expressions/negative.rs @@ -20,18 +20,17 @@ use std::any::Any; use std::sync::Arc; +use crate::record_batch::RecordBatch; use arrow::{ array::*, compute::arithmetics::basic::negate, datatypes::{DataType, Schema}, - record_batch::RecordBatch, }; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::coercion_rule::binary_rule::is_signed_numeric; use crate::physical_plan::{ColumnarValue, PhysicalExpr}; -use super::coercion; - /// Invoke a compute kernel on array(s) macro_rules! compute_op { // invoke unary operator @@ -119,7 +118,7 @@ pub fn negative( input_schema: &Schema, ) -> Result> { let data_type = arg.data_type(input_schema)?; - if !coercion::is_signed_numeric(&data_type) { + if !is_signed_numeric(&data_type) { Err(DataFusionError::Internal( format!( "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed numeric", diff --git a/datafusion/src/physical_plan/expressions/not.rs b/datafusion/src/physical_plan/expressions/not.rs index d0d275e90c21..654856c6cb33 100644 --- a/datafusion/src/physical_plan/expressions/not.rs +++ b/datafusion/src/physical_plan/expressions/not.rs @@ -24,10 +24,10 @@ use std::sync::Arc; use super::ColumnarValue; use crate::error::{DataFusionError, Result}; use crate::physical_plan::PhysicalExpr; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; /// Not expression #[derive(Debug)] @@ -119,6 +119,7 @@ pub fn not( mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; use arrow::datatypes::*; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 9ede495f0e10..125bf82f36f4 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -21,11 +21,11 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::window_functions::PartitionEvaluator; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; use arrow::array::{new_null_array, ArrayRef}; use arrow::compute::window::shift; use arrow::datatypes::{DataType, Field}; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::iter; use std::ops::Range; @@ -203,8 +203,9 @@ impl PartitionEvaluator for NthValueEvaluator { mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::Column; - use arrow::record_batch::RecordBatch; + use crate::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> { diff --git a/datafusion/src/physical_plan/expressions/nullif.rs b/datafusion/src/physical_plan/expressions/nullif.rs index e6be0a8c8e90..95df9ec12330 100644 --- a/datafusion/src/physical_plan/expressions/nullif.rs +++ b/datafusion/src/physical_plan/expressions/nullif.rs @@ -37,12 +37,12 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { Ok(ColumnarValue::Array( - nullif::nullif(lhs.as_ref(), rhs.to_array_of_size(lhs.len()).as_ref())? + nullif::nullif(lhs.as_ref(), rhs.to_array_of_size(lhs.len()).as_ref()) .into(), )) } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => Ok( - ColumnarValue::Array(nullif::nullif(lhs.as_ref(), rhs.as_ref())?.into()), + ColumnarValue::Array(nullif::nullif(lhs.as_ref(), rhs.as_ref()).into()), ), _ => Err(DataFusionError::NotImplemented( "nullif does not support a literal as first argument".to_string(), diff --git a/datafusion/src/physical_plan/expressions/rank.rs b/datafusion/src/physical_plan/expressions/rank.rs index 47b36ebfe676..dffd22420fc2 100644 --- a/datafusion/src/physical_plan/expressions/rank.rs +++ b/datafusion/src/physical_plan/expressions/rank.rs @@ -21,10 +21,10 @@ use crate::error::Result; use crate::physical_plan::window_functions::PartitionEvaluator; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use crate::record_batch::RecordBatch; use arrow::array::ArrayRef; use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::iter; use std::ops::Range; @@ -40,16 +40,16 @@ pub struct Rank { #[derive(Debug, Copy, Clone)] #[allow(clippy::enum_variant_names)] pub(crate) enum RankType { - Rank, - DenseRank, - PercentRank, + Basic, + Dense, + Percent, } /// Create a rank window function pub fn rank(name: String) -> Rank { Rank { name, - rank_type: RankType::Rank, + rank_type: RankType::Basic, } } @@ -57,7 +57,7 @@ pub fn rank(name: String) -> Rank { pub fn dense_rank(name: String) -> Rank { Rank { name, - rank_type: RankType::DenseRank, + rank_type: RankType::Dense, } } @@ -65,7 +65,7 @@ pub fn dense_rank(name: String) -> Rank { pub fn percent_rank(name: String) -> Rank { Rank { name, - rank_type: RankType::PercentRank, + rank_type: RankType::Percent, } } @@ -78,8 +78,8 @@ impl BuiltInWindowFunctionExpr for Rank { fn field(&self) -> Result { let nullable = false; let data_type = match self.rank_type { - RankType::Rank | RankType::DenseRank => DataType::UInt64, - RankType::PercentRank => DataType::Float64, + RankType::Basic | RankType::Dense => DataType::UInt64, + RankType::Percent => DataType::Float64, }; Ok(Field::new(self.name(), data_type, nullable)) } @@ -122,7 +122,7 @@ impl PartitionEvaluator for RankEvaluator { ) -> Result { // see https://www.postgresql.org/docs/current/functions-window.html let result: ArrayRef = match self.rank_type { - RankType::DenseRank => Arc::new(UInt64Array::from_values( + RankType::Dense => Arc::new(UInt64Array::from_values( ranks_in_partition .iter() .zip(1u64..) @@ -131,7 +131,7 @@ impl PartitionEvaluator for RankEvaluator { iter::repeat(rank).take(len) }), )), - RankType::PercentRank => { + RankType::Percent => { // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. let denominator = (partition.end - partition.start) as f64; Arc::new(Float64Array::from_values( @@ -147,7 +147,7 @@ impl PartitionEvaluator for RankEvaluator { .flatten(), )) } - RankType::Rank => Arc::new(UInt64Array::from_values( + RankType::Basic => Arc::new(UInt64Array::from_values( ranks_in_partition .iter() .scan(1_u64, |acc, range| { @@ -166,6 +166,7 @@ impl PartitionEvaluator for RankEvaluator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use arrow::{array::*, datatypes::*}; fn test_with_rank(expr: &Rank, expected: Vec) -> Result<()> { diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs index abcb2df3b913..1566963750ec 100644 --- a/datafusion/src/physical_plan/expressions/row_number.rs +++ b/datafusion/src/physical_plan/expressions/row_number.rs @@ -20,9 +20,9 @@ use crate::error::Result; use crate::physical_plan::window_functions::PartitionEvaluator; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use crate::record_batch::RecordBatch; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -82,7 +82,8 @@ impl PartitionEvaluator for NumRowsEvaluator { mod tests { use super::*; use crate::error::Result; - use arrow::record_batch::RecordBatch; + use crate::field_util::SchemaExt; + use crate::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; #[test] diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 2c8538b28ef4..72106f4cdded 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -25,8 +25,7 @@ use crate::physical_plan::{ expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::scalar::ScalarValue; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use super::{format_state_name, StatsType}; @@ -210,23 +209,27 @@ impl StddevAccumulator { variance: VarianceAccumulator::try_new(s_type)?, }) } + + pub fn get_m2(&self) -> f64 { + self.variance.get_m2() + } } impl Accumulator for StddevAccumulator { fn state(&self) -> Result> { Ok(vec![ ScalarValue::from(self.variance.get_count()), - self.variance.get_mean(), - self.variance.get_m2(), + ScalarValue::from(self.variance.get_mean()), + ScalarValue::from(self.variance.get_m2()), ]) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - self.variance.update(values) + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.update_batch(values) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - self.variance.merge(states) + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.variance.merge_batch(states) } fn evaluate(&self) -> Result { @@ -249,9 +252,10 @@ impl Accumulator for StddevAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; + use crate::record_batch::RecordBatch; use crate::{error::Result, generic_test_op}; - use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; #[test] @@ -407,6 +411,64 @@ mod tests { Ok(()) } + #[test] + fn stddev_f64_merge_1() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = Arc::new(StddevPop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(StddevPop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(std::f64::consts::SQRT_2)); + + Ok(()) + } + + #[test] + fn stddev_f64_merge_2() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, + ])); + let b = Arc::new(Float64Array::from(vec![None])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = Arc::new(StddevPop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(StddevPop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(std::f64::consts::SQRT_2)); + + Ok(()) + } + fn aggregate( batch: &RecordBatch, agg: Arc, @@ -421,4 +483,37 @@ mod tests { accum.update_batch(&values)?; accum.evaluate() } + + fn merge( + batch1: &RecordBatch, + batch2: &RecordBatch, + agg1: Arc, + agg2: Arc, + ) -> Result { + let mut accum1 = agg1.create_accumulator()?; + let mut accum2 = agg2.create_accumulator()?; + let expr1 = agg1.expressions(); + let expr2 = agg2.expressions(); + + let values1 = expr1 + .iter() + .map(|e| e.evaluate(batch1)) + .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .collect::>>()?; + let values2 = expr2 + .iter() + .map(|e| e.evaluate(batch2)) + .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .collect::>>()?; + accum1.update_batch(&values1)?; + accum2.update_batch(&values2)?; + let state2 = accum2 + .state()? + .iter() + .map(|v| vec![v.clone()]) + .map(|x| ScalarValue::iter_to_array(x).unwrap()) + .collect::>(); + accum1.merge_batch(&state2)?; + accum1.evaluate() + } } diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index 12d4b10864c3..207dc2e96603 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -348,23 +348,12 @@ impl Accumulator for SumAccumulator { Ok(vec![self.sum.clone()]) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - // sum(v1, v2, v3) = v1 + v2 + v3 - self.sum = sum(&self.sum, &values[0])?; - Ok(()) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; self.sum = sum(&self.sum, &sum_batch(values)?)?; Ok(()) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - // sum(sum1, sum2) = sum1 + sum2 - self.update(states) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { // sum(sum1, sum2, sum3, ...) = sum1 + sum2 + sum3 + ... self.update_batch(states) @@ -380,10 +369,11 @@ impl Accumulator for SumAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; + use crate::record_batch::RecordBatch; use crate::{error::Result, generic_test_op}; use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; #[test] fn test_sum_return_data_type() -> Result<()> { diff --git a/datafusion/src/physical_plan/expressions/try_cast.rs b/datafusion/src/physical_plan/expressions/try_cast.rs index 453a77c7debd..d47270c8a3a9 100644 --- a/datafusion/src/physical_plan/expressions/try_cast.rs +++ b/datafusion/src/physical_plan/expressions/try_cast.rs @@ -21,11 +21,12 @@ use std::sync::Arc; use super::ColumnarValue; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::expressions::cast::cast_with_error; use crate::physical_plan::PhysicalExpr; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; use arrow::compute; use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; use compute::cast; /// TRY_CAST expression casts an expression to a specific data type and retuns NULL on invalid cast @@ -78,7 +79,7 @@ impl PhysicalExpr for TryCastExpr { let value = self.expr.evaluate(batch)?; match value { ColumnarValue::Array(array) => Ok(ColumnarValue::Array( - cast::cast( + cast_with_error( array.as_ref(), &self.cast_type, cast::CastOptions::default(), @@ -87,7 +88,7 @@ impl PhysicalExpr for TryCastExpr { )), ColumnarValue::Scalar(scalar) => { let scalar_array = scalar.to_array(); - let cast_array = cast::cast( + let cast_array = cast_with_error( scalar_array.as_ref(), &self.cast_type, cast::CastOptions::default(), @@ -126,11 +127,59 @@ pub fn try_cast( mod tests { use super::*; use crate::error::Result; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; use arrow::{array::*, datatypes::*}; type StringArray = Utf8Array; + // runs an end-to-end test of physical type cast + // 1. construct a record batch with a column "a" of type A + // 2. construct a physical expression of CAST(a AS B) + // 3. evaluate the expression + // 4. verify that the resulting expression is of type B + // 5. verify that the resulting values are downcastable and correct + macro_rules! generic_decimal_to_other_test_cast { + ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $A_TYPE, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new($DECIMAL_ARRAY)], + )?; + // verify that we can construct the expression + let expression = try_cast(col("a", &schema)?, &schema, $TYPE)?; + + // verify that its display is correct + assert_eq!( + format!("CAST(a@0 AS {:?})", $TYPE), + format!("{}", expression) + ); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema)?, $TYPE); + + // compute + let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + + // verify that the array's data_type is correct + assert_eq!(*result.data_type(), $TYPE); + + // verify that the data itself is downcastable + let result = result + .as_any() + .downcast_ref::<$TYPEARRAY>() + .expect("failed to downcast"); + + // verify that the result itself is correct + for (i, x) in $VEC.iter().enumerate() { + match x { + Some(x) => assert_eq!(result.value(i), *x), + None => assert!(!result.is_valid(i)), + } + } + }}; + } + // runs an end-to-end test of physical type cast // 1. construct a record batch with a column "a" of type A // 2. construct a physical expression of CAST(a AS B) @@ -181,6 +230,271 @@ mod tests { }}; } + #[test] + fn test_try_cast_decimal_to_decimal() -> Result<()> { + // try cast one decimal data type to another decimal data type + let array: Vec = vec![1234, 2222, 3, 4000, 5000]; + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 3), + Int128Array, + DataType::Decimal(20, 6), + vec![ + Some(1_234_000_i128), + Some(2_222_000_i128), + Some(3_000_i128), + Some(4_000_000_i128), + Some(5_000_000_i128), + None, + ] + ); + + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 3), + Int128Array, + DataType::Decimal(10, 2), + vec![ + Some(123_i128), + Some(222_i128), + Some(0_i128), + Some(400_i128), + Some(500_i128), + None, + ] + ); + + Ok(()) + } + + #[test] + fn test_try_cast_decimal_to_numeric() -> Result<()> { + // TODO we should add function to create Int128Array with value and metadata + // https://github.com/apache/arrow-rs/issues/1009 + let array: Vec = vec![1, 2, 3, 4, 5]; + let decimal_array = Int128Array::from_slice(&array); + // decimal to i8 + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int8Array, + DataType::Int8, + vec![ + Some(1_i8), + Some(2_i8), + Some(3_i8), + Some(4_i8), + Some(5_i8), + None, + ] + ); + + // decimal to i16 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int16Array, + DataType::Int16, + vec![ + Some(1_i16), + Some(2_i16), + Some(3_i16), + Some(4_i16), + Some(5_i16), + None, + ] + ); + + // decimal to i32 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int32Array, + DataType::Int32, + vec![ + Some(1_i32), + Some(2_i32), + Some(3_i32), + Some(4_i32), + Some(5_i32), + None, + ] + ); + + // decimal to i64 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 0), + Int64Array, + DataType::Int64, + vec![ + Some(1_i64), + Some(2_i64), + Some(3_i64), + Some(4_i64), + Some(5_i64), + None, + ] + ); + + // decimal to float32 + let array: Vec = vec![1234, 2222, 3, 4000, 5000]; + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(10, 3), + Float32Array, + DataType::Float32, + vec![ + Some(1.234_f32), + Some(2.222_f32), + Some(0.003_f32), + Some(4.0_f32), + Some(5.0_f32), + None, + ] + ); + // decimal to float64 + let decimal_array = Int128Array::from_slice(&array); + generic_decimal_to_other_test_cast!( + decimal_array, + DataType::Decimal(20, 6), + Float64Array, + DataType::Float64, + vec![ + Some(0.001234_f64), + Some(0.002222_f64), + Some(0.000003_f64), + Some(0.004_f64), + Some(0.005_f64), + None, + ] + ); + + Ok(()) + } + + #[test] + fn test_try_cast_numeric_to_decimal() -> Result<()> { + // int8 + generic_test_cast!( + Int8Array, + DataType::Int8, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(3, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ] + ); + + // int16 + generic_test_cast!( + Int16Array, + DataType::Int16, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(5, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ] + ); + + // int32 + generic_test_cast!( + Int32Array, + DataType::Int32, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(10, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ] + ); + + // int64 + generic_test_cast!( + Int64Array, + DataType::Int64, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(20, 0), + vec![ + Some(1_i128), + Some(2_i128), + Some(3_i128), + Some(4_i128), + Some(5_i128), + ] + ); + + // int64 to different scale + generic_test_cast!( + Int64Array, + DataType::Int64, + vec![1, 2, 3, 4, 5], + Int128Array, + DataType::Decimal(20, 2), + vec![ + Some(100_i128), + Some(200_i128), + Some(300_i128), + Some(400_i128), + Some(500_i128), + ] + ); + + // float32 + generic_test_cast!( + Float32Array, + DataType::Float32, + vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], + Int128Array, + DataType::Decimal(10, 2), + vec![ + Some(150_i128), + Some(250_i128), + Some(300_i128), + Some(112_i128), + Some(550_i128), + ] + ); + + // float64 + generic_test_cast!( + Float64Array, + DataType::Float64, + vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], + Int128Array, + DataType::Decimal(20, 4), + vec![ + Some(15000_i128), + Some(25000_i128), + Some(30000_i128), + Some(11234_i128), + Some(55000_i128), + ] + ); + Ok(()) + } + #[test] fn test_cast_i32_u32() -> Result<()> { generic_test_cast!( diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 1786c388e758..0ab9aa3482b4 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -21,10 +21,17 @@ use std::any::Any; use std::sync::Arc; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::expressions::cast::{ + cast_with_error, DEFAULT_DATAFUSION_CAST_OPTIONS, +}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; +use arrow::array::Float64Array; +use arrow::{ + array::{ArrayRef, UInt64Array}, + datatypes::DataType, + datatypes::Field, +}; use super::{format_state_name, StatsType}; @@ -56,7 +63,7 @@ pub(crate) fn variance_return_type(arg_type: &DataType) -> Result { | DataType::Float32 | DataType::Float64 => Ok(DataType::Float64), other => Err(DataFusionError::Plan(format!( - "VARIANCE does not support {:?}", + "VAR does not support {:?}", other ))), } @@ -209,8 +216,8 @@ impl AggregateExpr for VariancePop { #[derive(Debug)] pub struct VarianceAccumulator { - m2: ScalarValue, - mean: ScalarValue, + m2: f64, + mean: f64, count: u64, stats_type: StatsType, } @@ -219,9 +226,9 @@ impl VarianceAccumulator { /// Creates a new `VarianceAccumulator` pub fn try_new(s_type: StatsType) -> Result { Ok(Self { - m2: ScalarValue::from(0 as f64), - mean: ScalarValue::from(0 as f64), - count: 0, + m2: 0_f64, + mean: 0_f64, + count: 0_u64, stats_type: s_type, }) } @@ -230,12 +237,12 @@ impl VarianceAccumulator { self.count } - pub fn get_mean(&self) -> ScalarValue { - self.mean.clone() + pub fn get_mean(&self) -> f64 { + self.mean } - pub fn get_m2(&self) -> ScalarValue { - self.m2.clone() + pub fn get_m2(&self) -> f64 { + self.m2 } } @@ -243,26 +250,31 @@ impl Accumulator for VarianceAccumulator { fn state(&self) -> Result> { Ok(vec![ ScalarValue::from(self.count), - self.mean.clone(), - self.m2.clone(), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), ]) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let values = &values[0]; - let is_empty = values.is_null(); + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast_with_error( + values[0].as_ref(), + &DataType::Float64, + DEFAULT_DATAFUSION_CAST_OPTIONS, + )?; + let arr = values + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .flatten(); - if !is_empty { + for value in arr { let new_count = self.count + 1; - let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; - let new_mean = ScalarValue::add( - &ScalarValue::div(&delta1, &ScalarValue::from(new_count as f64))?, - &self.mean, - )?; - let delta2 = ScalarValue::add(values, &new_mean.arithmetic_negate())?; - let tmp = ScalarValue::mul(&delta1, &delta2)?; - - let new_m2 = ScalarValue::add(&self.m2, &tmp)?; + let delta1 = value - self.mean; + let new_mean = delta1 / new_count as f64 + self.mean; + let delta2 = value - new_mean; + let new_m2 = self.m2 + delta1 * delta2; + self.count += 1; self.mean = new_mean; self.m2 = new_m2; @@ -271,53 +283,28 @@ impl Accumulator for VarianceAccumulator { Ok(()) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - let count = &states[0]; - let mean = &states[1]; - let m2 = &states[2]; - let mut new_count: u64 = self.count; - - // counts are summed - if let ScalarValue::UInt64(Some(c)) = count { - if *c == 0_u64 { - return Ok(()); - } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = states[0].as_any().downcast_ref::().unwrap(); + let means = states[1].as_any().downcast_ref::().unwrap(); + let m2s = states[2].as_any().downcast_ref::().unwrap(); - if self.count == 0 { - self.count = *c; - self.mean = mean.clone(); - self.m2 = m2.clone(); - return Ok(()); + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; } - new_count += c - } else { - unreachable!() - }; - - let new_mean = ScalarValue::div( - &ScalarValue::add(&self.mean, mean)?, - &ScalarValue::from(2_f64), - )?; - let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?; - let delta_sqrt = ScalarValue::mul(&delta, &delta)?; - let new_m2 = ScalarValue::add( - &ScalarValue::add( - &ScalarValue::mul( - &delta_sqrt, - &ScalarValue::div( - &ScalarValue::mul(&ScalarValue::from(self.count), count)?, - &ScalarValue::from(new_count as f64), - )?, - )?, - &self.m2, - )?, - m2, - )?; - - self.count = new_count; - self.mean = new_mean; - self.m2 = new_m2; - + let new_count = self.count + c; + let new_mean = self.mean * self.count as f64 / new_count as f64 + + means.value(i) * c as f64 / new_count as f64; + let delta = self.mean - means.value(i); + let new_m2 = self.m2 + + m2s.value(i) + + delta * delta * self.count as f64 * c as f64 / new_count as f64; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + } Ok(()) } @@ -339,17 +326,10 @@ impl Accumulator for VarianceAccumulator { )); } - match self.m2 { - ScalarValue::Float64(e) => { - if self.count == 0 { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f / count as f64))) - } - } - _ => Err(DataFusionError::Internal( - "M2 should be f64 for variance".to_string(), - )), + if self.count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.m2 / count as f64))) } } } @@ -357,9 +337,10 @@ impl Accumulator for VarianceAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::expressions::col; + use crate::record_batch::RecordBatch; use crate::{error::Result, generic_test_op}; - use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; #[test] @@ -511,6 +492,64 @@ mod tests { Ok(()) } + #[test] + fn variance_f64_merge_1() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from_slice(&[4_f64, 5_f64])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = Arc::new(VariancePop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(VariancePop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(2_f64)); + + Ok(()) + } + + #[test] + fn variance_f64_merge_2() -> Result<()> { + let a = Arc::new(Float64Array::from_slice(&[ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, + ])); + let b = Arc::new(Float64Array::from(vec![None])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = Arc::new(VariancePop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let agg2 = Arc::new(VariancePop::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + + let actual = merge(&batch1, &batch2, agg1, agg2)?; + assert!(actual == ScalarValue::from(2_f64)); + + Ok(()) + } + fn aggregate( batch: &RecordBatch, agg: Arc, @@ -525,4 +564,37 @@ mod tests { accum.update_batch(&values)?; accum.evaluate() } + + fn merge( + batch1: &RecordBatch, + batch2: &RecordBatch, + agg1: Arc, + agg2: Arc, + ) -> Result { + let mut accum1 = agg1.create_accumulator()?; + let mut accum2 = agg2.create_accumulator()?; + let expr1 = agg1.expressions(); + let expr2 = agg2.expressions(); + + let values1 = expr1 + .iter() + .map(|e| e.evaluate(batch1)) + .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .collect::>>()?; + let values2 = expr2 + .iter() + .map(|e| e.evaluate(batch2)) + .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .collect::>>()?; + accum1.update_batch(&values1)?; + accum2.update_batch(&values2)?; + let state2 = accum2 + .state()? + .iter() + .map(|v| vec![v.clone()]) + .map(|x| ScalarValue::iter_to_array(x).unwrap()) + .collect::>(); + accum1.merge_batch(&state2)?; + accum1.evaluate() + } } diff --git a/datafusion/src/physical_plan/file_format/avro.rs b/datafusion/src/physical_plan/file_format/avro.rs index 38be1142c4b7..b5b5a829034f 100644 --- a/datafusion/src/physical_plan/file_format/avro.rs +++ b/datafusion/src/physical_plan/file_format/avro.rs @@ -25,25 +25,27 @@ use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use arrow::datatypes::SchemaRef; + +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use std::any::Any; use std::sync::Arc; #[cfg(feature = "avro")] use super::file_stream::{BatchIter, FileStream}; -use super::PhysicalPlanConfig; +use super::FileScanConfig; /// Execution plan for scanning Avro data source #[derive(Debug, Clone)] pub struct AvroExec { - base_config: PhysicalPlanConfig, + base_config: FileScanConfig, projected_statistics: Statistics, projected_schema: SchemaRef, } impl AvroExec { /// Create a new Avro reader execution plan provided base configurations - pub fn new(base_config: PhysicalPlanConfig) -> Self { + pub fn new(base_config: FileScanConfig) -> Self { let (projected_schema, projected_statistics) = base_config.project(); Self { @@ -53,7 +55,7 @@ impl AvroExec { } } /// Ref to the base configs - pub fn base_config(&self) -> &PhysicalPlanConfig { + pub fn base_config(&self) -> &FileScanConfig { &self.base_config } } @@ -91,17 +93,25 @@ impl ExecutionPlan for AvroExec { } #[cfg(not(feature = "avro"))] - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { Err(DataFusionError::NotImplemented( "Cannot execute avro plan without avro feature enabled".to_string(), )) } #[cfg(feature = "avro")] - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let proj = self.base_config.projected_file_column_names(); - let batch_size = self.base_config.batch_size; + let batch_size = runtime.batch_size(); let file_schema = Arc::clone(&self.base_config.file_schema); // The avro reader cannot limit the number of records, so `remaining` is ignored. @@ -136,9 +146,8 @@ impl ExecutionPlan for AvroExec { DisplayFormatType::Default => { write!( f, - "AvroExec: files={}, batch_size={}, limit={:?}", + "AvroExec: files={}, limit={:?}", super::FileGroupsDisplay(&self.base_config.file_groups), - self.base_config.batch_size, self.base_config.limit, ) } @@ -158,6 +167,7 @@ mod tests { use crate::datasource::object_store::local::{ local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, }; + use crate::field_util::SchemaExt; use crate::scalar::ScalarValue; use futures::StreamExt; @@ -165,9 +175,10 @@ mod tests { #[tokio::test] async fn avro_exec_without_partition() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::arrow_test_data(); let filename = format!("{}/avro/alltypes_plain.avro", testdata); - let avro_exec = AvroExec::new(PhysicalPlanConfig { + let avro_exec = AvroExec::new(FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_groups: vec![vec![local_unpartitioned_file(filename.clone())]], file_schema: AvroFormat {} @@ -175,13 +186,15 @@ mod tests { .await?, statistics: Statistics::default(), projection: Some(vec![0, 1, 2]), - batch_size: 1024, limit: None, table_partition_cols: vec![], }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); - let mut results = avro_exec.execute(0).await.expect("plan execution failed"); + let mut results = avro_exec + .execute(0, runtime) + .await + .expect("plan execution failed"); let batch = results .next() .await @@ -219,6 +232,7 @@ mod tests { #[tokio::test] async fn avro_exec_with_partition() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::arrow_test_data(); let filename = format!("{}/avro/alltypes_plain.avro", testdata); let mut partitioned_file = local_unpartitioned_file(filename.clone()); @@ -228,7 +242,7 @@ mod tests { .infer_schema(local_object_reader_stream(vec![filename])) .await?; - let avro_exec = AvroExec::new(PhysicalPlanConfig { + let avro_exec = AvroExec::new(FileScanConfig { // select specific columns of the files as well as the partitioning // column which is supposed to be the last column in the table schema. projection: Some(vec![0, 1, file_schema.fields().len(), 2]), @@ -236,13 +250,15 @@ mod tests { file_groups: vec![vec![partitioned_file]], file_schema, statistics: Statistics::default(), - batch_size: 1024, limit: None, table_partition_cols: vec!["date".to_owned()], }); assert_eq!(avro_exec.output_partitioning().partition_count(), 1); - let mut results = avro_exec.execute(0).await.expect("plan execution failed"); + let mut results = avro_exec + .execute(0, runtime) + .await + .expect("plan execution failed"); let batch = results .next() .await diff --git a/datafusion/src/physical_plan/file_format/csv.rs b/datafusion/src/physical_plan/file_format/csv.rs index 00b303575b5d..bf7c21fa567a 100644 --- a/datafusion/src/physical_plan/file_format/csv.rs +++ b/datafusion/src/physical_plan/file_format/csv.rs @@ -22,23 +22,25 @@ use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; +use crate::record_batch::RecordBatch; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::io::csv; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::io::Read; use std::sync::Arc; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; use async_trait::async_trait; use super::file_stream::{BatchIter, FileStream}; -use super::PhysicalPlanConfig; +use super::FileScanConfig; /// Execution plan for scanning a CSV file #[derive(Debug, Clone)] pub struct CsvExec { - base_config: PhysicalPlanConfig, + base_config: FileScanConfig, projected_statistics: Statistics, projected_schema: SchemaRef, has_header: bool, @@ -47,7 +49,7 @@ pub struct CsvExec { impl CsvExec { /// Create a new CSV reader execution plan provided base and specific configurations - pub fn new(base_config: PhysicalPlanConfig, has_header: bool, delimiter: u8) -> Self { + pub fn new(base_config: FileScanConfig, has_header: bool, delimiter: u8) -> Self { let (projected_schema, projected_statistics) = base_config.project(); Self { @@ -60,7 +62,7 @@ impl CsvExec { } /// Ref to the base configs - pub fn base_config(&self) -> &PhysicalPlanConfig { + pub fn base_config(&self) -> &FileScanConfig { &self.base_config } /// true if the first line of each file is a header @@ -86,6 +88,7 @@ fn deserialize( 0, csv::read::deserialize_column, ) + .map(|chunk| RecordBatch::new_with_chunk(schema, chunk)) } struct CsvBatchReader { @@ -191,23 +194,29 @@ impl ExecutionPlan for CsvExec { } } - async fn execute(&self, partition: usize) -> Result { - let batch_size = self.base_config.batch_size; - let file_schema = self.base_config.file_schema.clone(); + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { + let batch_size = runtime.batch_size(); + let file_schema = Arc::clone(&self.base_config.file_schema); let file_projection = self.base_config.file_column_projection_indices(); let has_header = self.has_header; let delimiter = self.delimiter; + let start_line = if has_header { 1 } else { 0 }; - let fun = move |freader, remaining: &Option| { + let fun = move |file, remaining: &Option| { + let bounds = remaining.map(|x| x + start_line); let reader = csv::read::ReaderBuilder::new() .delimiter(delimiter) .has_headers(has_header) - .from_reader(freader); + .from_reader(file); Box::new(CsvBatchReader::new( reader, file_schema.clone(), batch_size, - *remaining, + bounds, file_projection.clone(), )) as BatchIter }; @@ -231,10 +240,9 @@ impl ExecutionPlan for CsvExec { DisplayFormatType::Default => { write!( f, - "CsvExec: files={}, has_header={}, batch_size={}, limit={:?}", + "CsvExec: files={}, has_header={}, limit={:?}", super::FileGroupsDisplay(&self.base_config.file_groups), self.has_header, - self.base_config.batch_size, self.base_config.limit, ) } @@ -249,6 +257,7 @@ impl ExecutionPlan for CsvExec { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::{ assert_batches_eq, datasource::object_store::local::{local_unpartitioned_file, LocalFileSystem}, @@ -259,18 +268,18 @@ mod tests { #[tokio::test] async fn csv_exec_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv/{}", testdata, filename); let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema, file_groups: vec![vec![local_unpartitioned_file(path)]], statistics: Statistics::default(), projection: Some(vec![0, 2, 4]), - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -281,7 +290,7 @@ mod tests { assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); - let mut stream = csv.execute(0).await?; + let mut stream = csv.execute(0, runtime).await?; let batch = stream.next().await.unwrap()?; assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); @@ -305,18 +314,18 @@ mod tests { #[tokio::test] async fn csv_exec_with_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv/{}", testdata, filename); let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema, file_groups: vec![vec![local_unpartitioned_file(path)]], statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: Some(5), table_partition_cols: vec![], }, @@ -327,7 +336,7 @@ mod tests { assert_eq!(13, csv.projected_schema.fields().len()); assert_eq!(13, csv.schema().fields().len()); - let mut it = csv.execute(0).await?; + let mut it = csv.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); @@ -351,6 +360,7 @@ mod tests { #[tokio::test] async fn csv_exec_with_partition() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -361,7 +371,7 @@ mod tests { partitioned_file.partition_values = vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { // we should be able to project on the partition column // wich is supposed to be after the file fields projection: Some(vec![0, file_schema.fields().len()]), @@ -369,7 +379,6 @@ mod tests { file_schema, file_groups: vec![vec![partitioned_file]], statistics: Statistics::default(), - batch_size: 1024, limit: None, table_partition_cols: vec!["date".to_owned()], }, @@ -380,7 +389,7 @@ mod tests { assert_eq!(2, csv.projected_schema.fields().len()); assert_eq!(2, csv.schema().fields().len()); - let mut it = csv.execute(0).await?; + let mut it = csv.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(2, batch.num_columns()); assert_eq!(100, batch.num_rows()); diff --git a/datafusion/src/physical_plan/file_format/file_stream.rs b/datafusion/src/physical_plan/file_format/file_stream.rs index c90df7e0b009..aaa89e36f84f 100644 --- a/datafusion/src/physical_plan/file_format/file_stream.rs +++ b/datafusion/src/physical_plan/file_format/file_stream.rs @@ -22,6 +22,7 @@ //! compliant with the `SendableRecordBatchStream` trait. use crate::datasource::object_store::ReadSeek; +use crate::record_batch::RecordBatch; use crate::{ datasource::{object_store::ObjectStore, PartitionedFile}, physical_plan::RecordBatchStream, @@ -30,7 +31,6 @@ use crate::{ use arrow::{ datatypes::SchemaRef, error::{ArrowError, Result as ArrowResult}, - record_batch::RecordBatch, }; use futures::Stream; use std::{ diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index 693e02a18a5b..fca810bc198a 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -19,31 +19,32 @@ use async_trait::async_trait; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; +use crate::record_batch::RecordBatch; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::io::json; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::io::{BufRead, BufReader, Read}; use std::sync::Arc; use super::file_stream::{BatchIter, FileStream}; -use super::PhysicalPlanConfig; +use super::FileScanConfig; /// Execution plan for scanning NdJson data source #[derive(Debug, Clone)] pub struct NdJsonExec { - base_config: PhysicalPlanConfig, + base_config: FileScanConfig, projected_statistics: Statistics, projected_schema: SchemaRef, } impl NdJsonExec { /// Create a new JSON reader execution plan provided base configurations - pub fn new(base_config: PhysicalPlanConfig) -> Self { + pub fn new(base_config: FileScanConfig) -> Self { let (projected_schema, projected_statistics) = base_config.project(); Self { @@ -97,7 +98,8 @@ impl Iterator for JsonBatchReader { self.schema.fields.clone() }; self.rows.truncate(records_read); - json::read::deserialize(&self.rows, fields).map(Some) + json::read::deserialize(&self.rows, &fields) + .map(|chunk| Some(RecordBatch::new_with_chunk(&self.schema, chunk))) } else { Ok(None) } @@ -138,10 +140,14 @@ impl ExecutionPlan for NdJsonExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let proj = self.base_config.projected_file_column_names(); - let batch_size = self.base_config.batch_size; + let batch_size = runtime.batch_size(); let file_schema = Arc::clone(&self.base_config.file_schema); // The json reader cannot limit the number of records, so `remaining` is ignored. @@ -173,8 +179,7 @@ impl ExecutionPlan for NdJsonExec { DisplayFormatType::Default => { write!( f, - "JsonExec: batch_size={}, limit={:?}, files={}", - self.base_config.batch_size, + "JsonExec: limit={:?}, files={}", self.base_config.limit, super::FileGroupsDisplay(&self.base_config.file_groups), ) @@ -197,6 +202,7 @@ mod tests { local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, }, }; + use crate::field_util::SchemaExt; use super::*; @@ -210,15 +216,15 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_without_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); use arrow::datatypes::DataType; let path = format!("{}/1.json", TEST_DATA_BASE); - let exec = NdJsonExec::new(PhysicalPlanConfig { + let exec = NdJsonExec::new(FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_groups: vec![vec![local_unpartitioned_file(path.clone())]], file_schema: infer_schema(path).await?, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: Some(3), table_partition_cols: vec![], }); @@ -247,7 +253,7 @@ mod tests { &DataType::Utf8 ); - let mut it = exec.execute(0).await?; + let mut it = exec.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 3); @@ -265,14 +271,14 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let path = format!("{}/1.json", TEST_DATA_BASE); - let exec = NdJsonExec::new(PhysicalPlanConfig { + let exec = NdJsonExec::new(FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_groups: vec![vec![local_unpartitioned_file(path.clone())]], file_schema: infer_schema(path).await?, statistics: Statistics::default(), projection: Some(vec![0, 2]), - batch_size: 1024, limit: None, table_partition_cols: vec![], }); @@ -284,7 +290,7 @@ mod tests { inferred_schema.field_with_name("c").unwrap(); inferred_schema.field_with_name("d").unwrap_err(); - let mut it = exec.execute(0).await?; + let mut it = exec.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 4); diff --git a/datafusion/src/physical_plan/file_format/mod.rs b/datafusion/src/physical_plan/file_format/mod.rs index 036b605154af..9b34e9df723c 100644 --- a/datafusion/src/physical_plan/file_format/mod.rs +++ b/datafusion/src/physical_plan/file_format/mod.rs @@ -24,17 +24,18 @@ mod json; mod parquet; pub use self::parquet::ParquetExec; +use crate::record_batch::RecordBatch; use arrow::{ array::{ArrayRef, DictionaryArray}, datatypes::{DataType, Field, Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, - record_batch::RecordBatch, }; pub use avro::AvroExec; pub use csv::CsvExec; pub use json::NdJsonExec; use std::iter; +use crate::field_util::{FieldExt, SchemaExt}; use crate::{ datasource::{object_store::ObjectStore, PartitionedFile}, scalar::ScalarValue, @@ -60,7 +61,7 @@ lazy_static! { /// The base configurations to provide when creating a physical plan for /// any given file format. #[derive(Debug, Clone)] -pub struct PhysicalPlanConfig { +pub struct FileScanConfig { /// Store from which the `files` should be fetched pub object_store: Arc, /// Schema before projection. It contains the columns that are expected @@ -73,15 +74,13 @@ pub struct PhysicalPlanConfig { /// Columns on which to project the data. Indexes that are higher than the /// number of columns of `file_schema` refer to `table_partition_cols`. pub projection: Option>, - /// The maximum number of records per arrow column - pub batch_size: usize, /// The minimum number of records required from this source plan pub limit: Option, /// The partitioning column names pub table_partition_cols: Vec, } -impl PhysicalPlanConfig { +impl FileScanConfig { /// Project the schema and the statistics on the given column indices fn project(&self) -> (SchemaRef, Statistics) { if self.projection.is_none() && self.table_partition_cols.is_empty() { @@ -134,8 +133,7 @@ impl PhysicalPlanConfig { self.projection.as_ref().map(|p| { p.iter() .filter(|col_idx| **col_idx < self.file_schema.fields().len()) - .map(|col_idx| self.file_schema.field(*col_idx).name()) - .cloned() + .map(|col_idx| self.file_schema.field(*col_idx).name().to_string()) .collect() }) } @@ -467,9 +465,8 @@ mod tests { projection: Option>, statistics: Statistics, table_partition_cols: Vec, - ) -> PhysicalPlanConfig { - PhysicalPlanConfig { - batch_size: 1024, + ) -> FileScanConfig { + FileScanConfig { file_schema, file_groups: vec![vec![]], limit: None, diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 633343c5f76f..1903a8b7425d 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -25,12 +25,14 @@ use std::{any::Any, convert::TryInto}; use crate::datasource::object_store::ObjectStore; use crate::datasource::PartitionedFile; +use crate::field_util::{FieldExt, SchemaExt}; +use crate::record_batch::RecordBatch; use crate::{ error::{DataFusionError, Result}, logical_plan::{Column, Expr}, physical_optimizer::pruning::{PruningPredicate, PruningStatistics}, physical_plan::{ - file_format::PhysicalPlanConfig, + file_format::FileScanConfig, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, stream::RecordBatchReceiverStream, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, @@ -38,15 +40,13 @@ use crate::{ }, scalar::ScalarValue, }; - use arrow::{ array::ArrayRef, datatypes::*, error::Result as ArrowResult, io::parquet::read::{self, RowGroupMetaData}, - record_batch::RecordBatch, }; -use log::debug; +use log::{debug, info}; use parquet::statistics::{ BinaryStatistics as ParquetBinaryStatistics, @@ -59,6 +59,7 @@ use tokio::{ task, }; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use super::PartitionColumnProjector; @@ -66,7 +67,7 @@ use super::PartitionColumnProjector; /// Execution plan for scanning one or more Parquet partitions #[derive(Debug, Clone)] pub struct ParquetExec { - base_config: PhysicalPlanConfig, + base_config: FileScanConfig, projected_statistics: Statistics, projected_schema: SchemaRef, /// Execution metrics @@ -87,7 +88,7 @@ struct ParquetFileMetrics { impl ParquetExec { /// Create a new Parquet reader execution plan provided file list and schema. /// Even if `limit` is set, ParquetExec rounds up the number of records to the next `batch_size`. - pub fn new(base_config: PhysicalPlanConfig, predicate: Option) -> Self { + pub fn new(base_config: FileScanConfig, predicate: Option) -> Self { debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", base_config.file_groups, base_config.projection, predicate, base_config.limit); @@ -124,7 +125,7 @@ impl ParquetExec { } /// Ref to the base configs - pub fn base_config(&self) -> &PhysicalPlanConfig { + pub fn base_config(&self) -> &FileScanConfig { &self.base_config } } @@ -188,7 +189,11 @@ impl ExecutionPlan for ParquetExec { } } - async fn execute(&self, partition_index: usize) -> Result { + async fn execute( + &self, + partition_index: usize, + runtime: Arc, + ) -> Result { // because the parquet implementation is not thread-safe, it is necessary to execute // on a thread and communicate with channels let (response_tx, response_rx): (Sender, Receiver) = channel(2); @@ -200,7 +205,7 @@ impl ExecutionPlan for ParquetExec { None => (0..self.base_config.file_schema.fields().len()).collect(), }; let pruning_predicate = self.pruning_predicate.clone(); - let batch_size = self.base_config.batch_size; + let batch_size = runtime.batch_size(); let limit = self.base_config.limit; let object_store = Arc::clone(&self.base_config.object_store); let partition_col_proj = PartitionColumnProjector::new( @@ -208,9 +213,11 @@ impl ExecutionPlan for ParquetExec { &self.base_config.table_partition_cols, ); + let file_schema_ref = self.base_config().file_schema.clone(); let join_handle = task::spawn_blocking(move || { if let Err(e) = read_partition( object_store.as_ref(), + file_schema_ref, partition_index, partition, metrics, @@ -241,8 +248,7 @@ impl ExecutionPlan for ParquetExec { DisplayFormatType::Default => { write!( f, - "ParquetExec: batch_size={}, limit={:?}, partitions={}", - self.base_config.batch_size, + "ParquetExec: limit={:?}, partitions={}", self.base_config.limit, super::FileGroupsDisplay(&self.base_config.file_groups) ) @@ -341,6 +347,7 @@ macro_rules! get_min_max_values { }; let data_type = field.data_type(); + // The result may be None, because DataFusion doesn't have support for ScalarValues of the column type let null_scalar: ScalarValue = data_type.try_into().ok()?; let scalar_values : Vec = $self.row_group_metadata @@ -362,6 +369,30 @@ macro_rules! get_min_max_values { }} } +// Extract the null count value on the ParquetStatistics +macro_rules! get_null_count_values { + ($self:expr, $column:expr) => {{ + let column_index = + if let Some((v, _)) = $self.parquet_schema.column_with_name(&$column.name) { + v + } else { + // Named column was not present + return None; + }; + + let scalar_values: Vec = $self + .row_group_metadata + .iter() + .flat_map(|meta| meta.column(column_index).statistics()) + .flatten() + .map(|stats| ScalarValue::Int64(stats.null_count())) + .collect(); + + // ignore errors converting to arrays (e.g. different types) + ScalarValue::iter_to_array(scalar_values).ok() + }}; +} + impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn min_values(&self, column: &Column) -> Option { get_min_max_values!(self, column, min_value) @@ -374,6 +405,10 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn num_containers(&self) -> usize { self.row_group_metadata.len() } + + fn null_counts(&self, column: &Column) -> Option { + get_null_count_values!(self, column) + } } fn build_row_group_predicate( @@ -406,9 +441,33 @@ fn build_row_group_predicate( } } +// Map projections from the schema which merges all file schemas to projections on a particular +// file +fn map_projections( + merged_schema: &Schema, + file_schema: &Schema, + projections: &[usize], +) -> Result> { + let mut mapped: Vec = vec![]; + for idx in projections { + let field = merged_schema.field(*idx); + if let Ok(mapped_idx) = file_schema.index_of(field.name()) { + if file_schema.field(mapped_idx).data_type() == field.data_type() { + mapped.push(mapped_idx) + } else { + let msg = format!("Failed to map column projection for field {}. Incompatible data types {:?} and {:?}", field.name(), file_schema.field(mapped_idx).data_type(), field.data_type()); + info!("{}", msg); + return Err(DataFusionError::Execution(msg)); + } + } + } + Ok(mapped) +} + #[allow(clippy::too_many_arguments)] fn read_partition( object_store: &dyn ObjectStore, + file_schema: SchemaRef, partition_index: usize, partition: Vec, metrics: ExecutionPlanMetricsSet, @@ -420,6 +479,8 @@ fn read_partition( mut partition_column_projector: PartitionColumnProjector, ) -> Result<()> { for partitioned_file in partition { + debug!("Reading file {}", &partitioned_file.file_meta.path()); + let file_metrics = ParquetFileMetrics::new( partition_index, &*partitioned_file.file_meta.path(), @@ -428,6 +489,7 @@ fn read_partition( let object_reader = object_store.file_reader(partitioned_file.file_meta.sized_file.clone())?; let reader = object_reader.sync_reader()?; + let mut record_reader = read::RecordReader::try_new( reader, Some(projection.to_vec()), @@ -435,6 +497,9 @@ fn read_partition( None, None, )?; + // TODO : ??? + let _mapped_projections = + map_projections(&file_schema, record_reader.schema(), projection)?; if let Some(pruning_predicate) = pruning_predicate { record_reader.set_groups_filter(Arc::new(build_row_group_predicate( pruning_predicate, @@ -443,9 +508,11 @@ fn read_partition( ))); } - for batch in record_reader { + let schema = record_reader.schema().clone(); + for chunk in record_reader { + let batch = RecordBatch::new_with_chunk(&schema, chunk?); let proj_batch = partition_column_projector - .project(batch?, &partitioned_file.partition_values); + .project(batch, &partitioned_file.partition_values); response_tx .blocking_send(proj_batch) .map_err(|x| DataFusionError::Execution(format!("{}", x)))?; @@ -458,29 +525,338 @@ fn read_partition( #[cfg(test)] mod tests { - use crate::assert_batches_eq; use crate::datasource::{ file_format::{parquet::ParquetFormat, FileFormat}, object_store::local::{ local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, }, }; + use crate::{assert_batches_eq, assert_batches_sorted_eq}; + use arrow::array::*; use super::*; + use crate::field_util::FieldExt; + use crate::physical_plan::collect; use arrow::datatypes::{DataType, Field}; - use arrow::io::parquet::write::to_parquet_schema; + use arrow::io::parquet::write::{to_parquet_schema, write_file, RowGroupIterator}; use arrow::io::parquet::write::{ColumnDescriptor, SchemaDescriptor}; use futures::StreamExt; + use parquet::compression::Compression; + use parquet::encoding::Encoding; use parquet::metadata::ColumnChunkMetaData; use parquet::statistics::Statistics as ParquetStatistics; + use parquet::write::{Version, WriteOptions}; use parquet_format_async_temp::RowGroup; + /// writes each RecordBatch as an individual parquet file and then + /// reads it back in to the named location. + async fn round_trip_to_parquet( + batches: Vec, + projection: Option>, + schema: Option, + ) -> Vec { + let runtime = Arc::new(RuntimeEnv::default()); + + // When vec is dropped, temp files are deleted + let files: Vec<_> = batches + .into_iter() + .map(|batch| { + let output = tempfile::NamedTempFile::new().expect("creating temp file"); + + let mut file: std::fs::File = (*output.as_file()) + .try_clone() + .expect("cloning file descriptor"); + let options = WriteOptions { + write_statistics: true, + compression: Compression::Uncompressed, + version: Version::V2, + }; + let schema_ref = &batch.schema().clone(); + let parquet_schema = to_parquet_schema(schema_ref).unwrap(); + + let iter = vec![Ok(batch.into())]; + let row_groups = RowGroupIterator::try_new( + iter.into_iter(), + schema_ref, + options, + vec![Encoding::Plain, Encoding::Plain], + ) + .unwrap(); + + write_file( + &mut file, + row_groups, + schema_ref, + parquet_schema, + options, + None, + ) + .expect("Writing batch"); + output + }) + .collect(); + + let file_names: Vec<_> = files + .iter() + .map(|t| t.path().to_string_lossy().to_string()) + .collect(); + + // Now, read the files back in + let file_groups: Vec<_> = file_names + .iter() + .map(|name| local_unpartitioned_file(name.clone())) + .collect(); + + // Infer the schema (if not provided) + let file_schema = match schema { + Some(provided_schema) => provided_schema, + None => ParquetFormat::default() + .infer_schema(local_object_reader_stream(file_names)) + .await + .expect("inferring schema"), + }; + + // prepare the scan + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![file_groups], + file_schema, + statistics: Statistics::default(), + projection, + limit: None, + table_partition_cols: vec![], + }, + None, + ); + + collect(Arc::new(parquet_exec), runtime) + .await + .expect("reading parquet data") + } + + // Add a new column with the specified field name to the RecordBatch + fn add_to_batch( + batch: &RecordBatch, + field_name: &str, + array: ArrayRef, + ) -> RecordBatch { + let mut fields = batch.schema().fields().to_vec(); + fields.push(Field::new(field_name, array.data_type().clone(), true)); + let schema = Arc::new(Schema::new(fields)); + + let mut columns = batch.columns().to_vec(); + columns.push(array); + RecordBatch::try_new(schema, columns).expect("error; creating record batch") + } + + fn create_batch(columns: Vec<(&str, ArrayRef)>) -> RecordBatch { + columns.into_iter().fold( + RecordBatch::new_empty(Arc::new(Schema::new(vec![]))), + |batch, (field_name, arr)| add_to_batch(&batch, field_name, arr.clone()), + ) + } + + #[tokio::test] + async fn evolved_schema() { + let c1: ArrayRef = + Arc::new(Utf8Array::::from(vec![Some("Foo"), None, Some("bar")])); + // batch1: c1(string) + let batch1 = add_to_batch( + &RecordBatch::new_empty(Arc::new(Schema::new(vec![]))), + "c1", + c1, + ); + + // batch2: c1(string) and c2(int64) + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + let batch2 = add_to_batch(&batch1, "c2", c2); + + // batch3: c1(string) and c3(int8) + let c3: ArrayRef = Arc::new(Int8Array::from(vec![Some(10), Some(20), None])); + let batch3 = add_to_batch(&batch1, "c3", c3); + + // read/write them files: + let read = round_trip_to_parquet(vec![batch1, batch2, batch3], None, None).await; + let expected = vec![ + "+-----+----+----+", + "| c1 | c2 | c3 |", + "+-----+----+----+", + "| | | |", + "| | | 20 |", + "| | 2 | |", + "| Foo | | |", + "| Foo | | 10 |", + "| Foo | 1 | |", + "| bar | | |", + "| bar | | |", + "| bar | | |", + "+-----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &read); + } + + #[tokio::test] + async fn evolved_schema_inconsistent_order() { + let c1: ArrayRef = + Arc::new(Utf8Array::::from(vec![Some("Foo"), None, Some("bar")])); + + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + + let c3: ArrayRef = Arc::new(Int8Array::from(vec![Some(10), Some(20), None])); + + // batch1: c1(string), c2(int64), c3(int8) + let batch1 = create_batch(vec![ + ("c1", c1.clone()), + ("c2", c2.clone()), + ("c3", c3.clone()), + ]); + + // batch2: c3(int8), c2(int64), c1(string) + let batch2 = create_batch(vec![("c3", c3), ("c2", c2), ("c1", c1)]); + + // read/write them files: + let read = round_trip_to_parquet(vec![batch1, batch2], None, None).await; + let expected = vec![ + "+-----+----+----+", + "| c1 | c2 | c3 |", + "+-----+----+----+", + "| Foo | 1 | 10 |", + "| | 2 | 20 |", + "| bar | | |", + "| Foo | 1 | 10 |", + "| | 2 | 20 |", + "| bar | | |", + "+-----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &read); + } + + #[tokio::test] + async fn evolved_schema_intersection() { + let c1: ArrayRef = + Arc::new(Utf8Array::::from(vec![Some("Foo"), None, Some("bar")])); + + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + + let c3: ArrayRef = Arc::new(Int8Array::from(vec![Some(10), Some(20), None])); + + // batch1: c1(string), c2(int64), c3(int8) + let batch1 = create_batch(vec![("c1", c1), ("c3", c3.clone())]); + + // batch2: c3(int8), c2(int64), c1(string) + let batch2 = create_batch(vec![("c3", c3), ("c2", c2)]); + + // read/write them files: + let read = round_trip_to_parquet(vec![batch1, batch2], None, None).await; + let expected = vec![ + "+-----+----+----+", + "| c1 | c3 | c2 |", + "+-----+----+----+", + "| Foo | 10 | |", + "| | 20 | |", + "| bar | | |", + "| | 10 | 1 |", + "| | 20 | 2 |", + "| | | |", + "+-----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &read); + } + + #[tokio::test] + async fn evolved_schema_projection() { + let c1: ArrayRef = + Arc::new(Utf8Array::::from(vec![Some("Foo"), None, Some("bar")])); + + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + + let c3: ArrayRef = Arc::new(Int8Array::from(vec![Some(10), Some(20), None])); + + let c4: ArrayRef = + Arc::new(Utf8Array::::from(vec![Some("baz"), Some("boo"), None])); + + // batch1: c1(string), c2(int64), c3(int8) + let batch1 = create_batch(vec![ + ("c1", c1.clone()), + ("c2", c2.clone()), + ("c3", c3.clone()), + ]); + + // batch2: c3(int8), c2(int64), c1(string), c4(string) + let batch2 = create_batch(vec![("c3", c3), ("c2", c2), ("c1", c1), ("c4", c4)]); + + // read/write them files: + let read = + round_trip_to_parquet(vec![batch1, batch2], Some(vec![0, 3]), None).await; + let expected = vec![ + "+-----+-----+", + "| c1 | c4 |", + "+-----+-----+", + "| Foo | baz |", + "| | boo |", + "| bar | |", + "| Foo | |", + "| | |", + "| bar | |", + "+-----+-----+", + ]; + assert_batches_sorted_eq!(expected, &read); + } + + #[tokio::test] + async fn evolved_schema_incompatible_types() { + let c1: ArrayRef = + Arc::new(Utf8Array::::from(vec![Some("Foo"), None, Some("bar")])); + + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + + let c3: ArrayRef = Arc::new(Int8Array::from(vec![Some(10), Some(20), None])); + + let c4: ArrayRef = + Arc::new(Float32Array::from(vec![Some(1.0_f32), Some(2.0_f32), None])); + + // batch1: c1(string), c2(int64), c3(int8) + let batch1 = create_batch(vec![ + ("c1", c1.clone()), + ("c2", c2.clone()), + ("c3", c3.clone()), + ]); + + // batch2: c3(int8), c2(int64), c1(string), c4(string) + let batch2 = create_batch(vec![("c3", c4), ("c2", c2), ("c1", c1)]); + + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int64, true), + Field::new("c3", DataType::Int8, true), + ]); + + // read/write them files: + let read = + round_trip_to_parquet(vec![batch1, batch2], None, Some(Arc::new(schema))) + .await; + + // expect only the first batch to be read + let expected = vec![ + "+-----+----+----+", + "| c1 | c2 | c3 |", + "+-----+----+----+", + "| Foo | 1 | 10 |", + "| | 2 | 20 |", + "| bar | | |", + "+-----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &read); + } + #[tokio::test] async fn parquet_exec_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let parquet_exec = ParquetExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_groups: vec![vec![local_unpartitioned_file(filename.clone())]], file_schema: ParquetFormat::default() @@ -488,7 +864,6 @@ mod tests { .await?, statistics: Statistics::default(), projection: Some(vec![0, 1, 2]), - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -496,15 +871,14 @@ mod tests { ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0).await?; + let mut results = parquet_exec.execute(0, runtime).await?; let batch = results.next().await.unwrap()?; assert_eq!(8, batch.num_rows()); assert_eq!(3, batch.num_columns()); let schema = batch.schema(); - let field_names: Vec<&str> = - schema.fields().iter().map(|f| f.name().as_str()).collect(); + let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name()).collect(); assert_eq!(vec!["id", "bool_col", "tinyint_col"], field_names); let batch = results.next().await; @@ -521,6 +895,7 @@ mod tests { #[tokio::test] async fn parquet_exec_with_partition() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let mut partitioned_file = local_unpartitioned_file(filename.clone()); @@ -530,7 +905,7 @@ mod tests { ScalarValue::Utf8(Some("26".to_owned())), ]; let parquet_exec = ParquetExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_groups: vec![vec![partitioned_file]], file_schema: ParquetFormat::default() @@ -539,7 +914,6 @@ mod tests { statistics: Statistics::default(), // file has 10 cols so index 12 should be month projection: Some(vec![0, 1, 2, 12]), - batch_size: 1024, limit: None, table_partition_cols: vec![ "year".to_owned(), @@ -551,7 +925,7 @@ mod tests { ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0).await?; + let mut results = parquet_exec.execute(0, runtime).await?; let batch = results.next().await.unwrap()?; let expected = vec![ "+----+----------+-------------+-------+", @@ -775,22 +1149,12 @@ mod tests { Ok(()) } - #[test] - fn row_group_pruning_predicate_null_expr() -> Result<()> { - use crate::logical_plan::{col, lit}; - // test row group predicate with an unknown (Null) expr - // - // int > 15 and bool = NULL => c1_max > 15 and null - let expr = col("c1") - .gt(lit(15)) - .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); - let schema = Arc::new(Schema::new(vec![ + fn gen_row_group_meta_data_for_pruning_predicate() -> Vec { + let schema = Schema::new(vec![ Field::new("c1", DataType::Int32, false), - Field::new("c2", DataType::Boolean, false), - ])); - let pruning_predicate = PruningPredicate::try_new(&expr, schema.clone())?; - - let schema_descr = to_parquet_schema(&schema)?; + Field::new("c2", DataType::Boolean, true), + ]); + let schema_descr = to_parquet_schema(&schema).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, vec![ @@ -823,11 +1187,57 @@ mod tests { min_value: Some(false), max_value: Some(true), distinct_count: None, - null_count: Some(0), + null_count: Some(1), }, ], ); - let row_group_metadata = vec![rgm1, rgm2]; + vec![rgm1, rgm2] + } + + #[test] + fn row_group_pruning_predicate_null_expr() -> Result<()> { + use crate::logical_plan::{col, lit}; + // int > 1 and IsNull(bool) => c1_max > 1 and bool_null_count > 0 + let expr = col("c1").gt(lit::(15)).and(col("c2").is_null()); + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Boolean, false), + ])); + let pruning_predicate = PruningPredicate::try_new(&expr, schema)?; + let row_group_metadata = gen_row_group_meta_data_for_pruning_predicate(); + + let row_group_predicate = build_row_group_predicate( + &pruning_predicate, + parquet_file_metrics(), + &row_group_metadata, + ); + let row_group_filter = row_group_metadata + .iter() + .enumerate() + .map(|(i, g)| row_group_predicate(i, g)) + .collect::>(); + // First row group was filtered out because it contains no null value on "c2". + assert_eq!(row_group_filter, vec![false, true]); + + Ok(()) + } + + #[test] + fn row_group_pruning_predicate_eq_null_expr() -> Result<()> { + use crate::logical_plan::{col, lit}; + // test row group predicate with an unknown (Null) expr + // + // int > 15 and bool = NULL => c1_max > 15 and null + let expr = col("c1") + .gt(lit(15)) + .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Boolean, false), + ])); + let pruning_predicate = PruningPredicate::try_new(&expr, schema)?; + let row_group_metadata = gen_row_group_meta_data_for_pruning_predicate(); + let row_group_predicate = build_row_group_predicate( &pruning_predicate, parquet_file_metrics(), diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index cf3c28bf9051..27b1101fe2f1 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -30,15 +30,15 @@ use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, }; +use crate::record_batch::{filter_record_batch, RecordBatch}; use arrow::array::{Array, BooleanArray}; -use arrow::compute::filter::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; +use arrow::compute::boolean::{and, is_not_null}; use async_trait::async_trait; -use arrow::compute::boolean::{and, is_not_null}; +use crate::execution::runtime_env::RuntimeEnv; use futures::stream::{Stream, StreamExt}; /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to @@ -120,13 +120,17 @@ impl ExecutionPlan for FilterExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(FilterExecStream { schema: self.input.schema().clone(), predicate: self.predicate.clone(), - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, baseline_metrics, })) } @@ -173,7 +177,7 @@ fn batch_filter( predicate .evaluate(batch) .map(|v| v.into_array(batch.num_rows())) - .map_err(DataFusionError::into_arrow_external_error) + .map_err(DataFusionError::into) .and_then(|array| { array .as_any() @@ -182,7 +186,7 @@ fn batch_filter( DataFusionError::Internal( "Filter predicate evaluated to non-boolean value".to_string(), ) - .into_arrow_external_error() + .into() }) // apply filter array to record batch .and_then(|filter_array| { @@ -230,7 +234,7 @@ mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::expressions::*; - use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; + use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::scalar::ScalarValue; use crate::test; @@ -240,6 +244,7 @@ mod tests { #[tokio::test] async fn simple_predicate() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -247,13 +252,12 @@ mod tests { test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: Arc::clone(&schema), file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -281,7 +285,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?); - let results = collect(filter).await?; + let results = collect(filter, runtime).await?; results .iter() diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index a743359d83ae..ba5dc87f99b1 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -36,11 +36,13 @@ use super::{ use crate::execution::context::ExecutionContextState; use crate::physical_plan::array_expressions; use crate::physical_plan::datetime_expressions; +use crate::physical_plan::expressions::cast::DEFAULT_DATAFUSION_CAST_OPTIONS; use crate::physical_plan::expressions::{ cast_column, nullif_func, SUPPORTED_NULLIF_TYPES, }; use crate::physical_plan::math_expressions; use crate::physical_plan::string_expressions; +use crate::record_batch::RecordBatch; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, @@ -51,7 +53,6 @@ use arrow::{ datatypes::TimeUnit, datatypes::{DataType, Field, Schema}, error::{ArrowError, Result as ArrowResult}, - record_batch::RecordBatch, types::NativeType, }; use fmt::{Debug, Formatter}; @@ -172,7 +173,7 @@ pub type ReturnTypeFunction = Arc Result> + Send + Sync>; /// Enum of all built-in scalar functions -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum BuiltinScalarFunction { // math functions /// abs @@ -1224,6 +1225,7 @@ pub fn create_physical_expr( cast_column( &col_values[0], &DataType::Timestamp(TimeUnit::Nanosecond, None), + DEFAULT_DATAFUSION_CAST_OPTIONS, ) } } @@ -1243,6 +1245,7 @@ pub fn create_physical_expr( cast_column( &col_values[0], &DataType::Timestamp(TimeUnit::Millisecond, None), + DEFAULT_DATAFUSION_CAST_OPTIONS, ) } } @@ -1262,6 +1265,7 @@ pub fn create_physical_expr( cast_column( &col_values[0], &DataType::Timestamp(TimeUnit::Microsecond, None), + DEFAULT_DATAFUSION_CAST_OPTIONS, ) } } @@ -1281,6 +1285,7 @@ pub fn create_physical_expr( cast_column( &col_values[0], &DataType::Timestamp(TimeUnit::Second, None), + DEFAULT_DATAFUSION_CAST_OPTIONS, ) } } @@ -1587,7 +1592,7 @@ pub struct ScalarFunctionExpr { } impl Debug for ScalarFunctionExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("ScalarFunctionExpr") .field("fun", &"") .field("name", &self.name) @@ -1735,12 +1740,14 @@ where #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; + use crate::record_batch::RecordBatch; use crate::{ error::Result, physical_plan::expressions::{col, lit}, scalar::ScalarValue, }; - use arrow::{datatypes::Field, record_batch::RecordBatch}; + use arrow::datatypes::Field; type StringArray = Utf8Array; @@ -3749,7 +3756,7 @@ mod tests { lit(ScalarValue::Int64(Some(-1))), ], Err(DataFusionError::Execution( - "negative substring length not allowed".to_string(), + "negative substring length not allowed: substr(, 1, -1)".to_string(), )), &str, Utf8, diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 900a29c32de8..b4184ebe9d33 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -37,16 +37,19 @@ use crate::{ scalar::ScalarValue, }; +use crate::record_batch::RecordBatch; use arrow::{ array::*, compute::{cast, concatenate, take}, datatypes::{DataType, Field, Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, - record_batch::RecordBatch, }; use hashbrown::raw::RawTable; use pin_project_lite::pin_project; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::{FieldExt, SchemaExt}; +use crate::physical_plan::expressions::cast::cast_with_error; use async_trait::async_trait; use super::common::AbortOnDropSingle; @@ -206,8 +209,12 @@ impl ExecutionPlan for HashAggregateExec { self.input.output_partitioning() } - async fn execute(&self, partition: usize) -> Result { - let input = self.input.execute(partition).await?; + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { + let input = self.input.execute(partition, runtime).await?; let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect(); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -397,8 +404,7 @@ fn group_aggregate_batch( } // 1.2 Need to create new entry None => { - let accumulator_set = create_accumulators(aggr_expr) - .map_err(DataFusionError::into_arrow_external_error)?; + let accumulator_set = create_accumulators(aggr_expr)?; // Copy group values out of arrays into `ScalarValue`s let group_by_values = group_values @@ -504,8 +510,7 @@ async fn compute_grouped_hash_aggregate( // Assume create_schema() always put group columns in front of aggr columns, we set // col_idx_base to group expression count. let aggregate_expressions = - aggregate_expressions(&aggr_expr, &mode, group_expr.len()) - .map_err(DataFusionError::into_arrow_external_error)?; + aggregate_expressions(&aggr_expr, &mode, group_expr.len())?; let random_state = RandomState::new(); @@ -523,8 +528,7 @@ async fn compute_grouped_hash_aggregate( batch, accumulators, &aggregate_expressions, - ) - .map_err(DataFusionError::into_arrow_external_error)?; + )?; timer.done(); } @@ -607,7 +611,7 @@ struct Accumulators { } impl std::fmt::Debug for Accumulators { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { // hashes are not store inline, so could only get values let map_string = "RawTable"; f.debug_struct("Accumulators") @@ -741,10 +745,8 @@ async fn compute_hash_aggregate( elapsed_compute: metrics::Time, ) -> ArrowResult { let timer = elapsed_compute.timer(); - let mut accumulators = create_accumulators(&aggr_expr) - .map_err(DataFusionError::into_arrow_external_error)?; - let expressions = aggregate_expressions(&aggr_expr, &mode, 0) - .map_err(DataFusionError::into_arrow_external_error)?; + let mut accumulators = create_accumulators(&aggr_expr)?; + let expressions = aggregate_expressions(&aggr_expr, &mode, 0)?; let expressions = Arc::new(expressions); timer.done(); @@ -753,16 +755,14 @@ async fn compute_hash_aggregate( while let Some(batch) = input.next().await { let batch = batch?; let timer = elapsed_compute.timer(); - aggregate_batch(&mode, &batch, &mut accumulators, &expressions) - .map_err(DataFusionError::into_arrow_external_error)?; + aggregate_batch(&mode, &batch, &mut accumulators, &expressions)?; timer.done(); } // 2. convert values to a record batch let timer = elapsed_compute.timer(); let batch = finalize_aggregation(&accumulators, &mode) - .map(|columns| RecordBatch::try_new(schema.clone(), columns)) - .map_err(DataFusionError::into_arrow_external_error)?; + .map(|columns| RecordBatch::try_new(schema.clone(), columns))?; timer.done(); batch } @@ -906,9 +906,7 @@ fn create_batch_from_map( match mode { AggregateMode::Partial => { for acc in accs.iter() { - let state = acc - .state() - .map_err(DataFusionError::into_arrow_external_error)?; + let state = acc.state()?; acc_data_types.push(state.len()); } } @@ -926,8 +924,7 @@ fn create_batch_from_map( .map(|group_state| group_state.group_by_values[i].clone()), ) }) - .collect::>>() - .map_err(|x| x.into_arrow_external_error())?; + .collect::>>()?; // add state / evaluated arrays for (x, &state_len) in acc_data_types.iter().enumerate() { @@ -939,8 +936,7 @@ fn create_batch_from_map( let x = group_state.accumulator_set[x].state().unwrap(); x[y].clone() }), - ) - .map_err(DataFusionError::into_arrow_external_error)?; + )?; columns.push(res); } @@ -949,8 +945,7 @@ fn create_batch_from_map( accumulators.group_states.iter().map(|group_state| { group_state.accumulator_set[x].evaluate().unwrap() }), - ) - .map_err(DataFusionError::into_arrow_external_error)?; + )?; columns.push(res); } } @@ -964,11 +959,12 @@ fn create_batch_from_map( .iter() .zip(output_schema.fields().iter()) .map(|(col, desired_field)| { - cast::cast( + cast_with_error( col.as_ref(), desired_field.data_type(), cast::CastOptions::default(), ) + .map_err(|e| e.into()) .map(Arc::from) }) .collect::>>()?; @@ -1018,16 +1014,13 @@ fn finalize_aggregation( #[cfg(test)] mod tests { - use arrow::array::{Float64Array, UInt32Array}; - use arrow::datatypes::DataType; - use futures::FutureExt; - use super::*; use crate::assert_batches_sorted_eq; use crate::physical_plan::common; use crate::physical_plan::expressions::{col, Avg}; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use futures::FutureExt; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -1076,6 +1069,8 @@ mod tests { DataType::Float64, ))]; + let runtime = Arc::new(RuntimeEnv::default()); + let partial_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -1084,7 +1079,8 @@ mod tests { input_schema.clone(), )?); - let result = common::collect(partial_aggregate.execute(0).await?).await?; + let result = + common::collect(partial_aggregate.execute(0, runtime.clone()).await?).await?; let expected = vec![ "+---+---------------+-------------+", @@ -1115,7 +1111,8 @@ mod tests { input_schema, )?); - let result = common::collect(merged_aggregate.execute(0).await?).await?; + let result = + common::collect(merged_aggregate.execute(0, runtime.clone()).await?).await?; assert_eq!(result.len(), 1); let batch = &result[0]; @@ -1176,13 +1173,16 @@ mod tests { ))) } - async fn execute(&self, _partition: usize) -> Result { - let stream; - if self.yield_first { - stream = TestYieldingStream::New; + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { + let stream = if self.yield_first { + TestYieldingStream::New } else { - stream = TestYieldingStream::Yielded; - } + TestYieldingStream::Yielded + }; Ok(Box::pin(stream)) } @@ -1252,6 +1252,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1273,7 +1274,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec); + let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1285,6 +1286,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel_with_groups() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), @@ -1309,7 +1311,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec); + let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 07144d74a34d..983dbf919a5c 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -29,10 +29,10 @@ use async_trait::async_trait; use futures::{Stream, StreamExt, TryStreamExt}; use tokio::sync::Mutex; +use crate::record_batch::RecordBatch; use arrow::array::*; use arrow::datatypes::*; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; use arrow::compute::take; @@ -54,6 +54,8 @@ use super::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::PhysicalExpr; use arrow::bitmap::MutableBitmap; @@ -79,7 +81,7 @@ type LargeStringArray = Utf8Array; struct JoinHashMap(RawTable<(u64, SmallVec<[u64; 1]>)>); impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } } @@ -266,7 +268,11 @@ impl ExecutionPlan for HashJoinExec { self.right.output_partitioning() } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); // we only want to compute the build side once for PartitionMode::CollectLeft let left_data = { @@ -281,7 +287,7 @@ impl ExecutionPlan for HashJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0).await?; + let stream = merge.execute(0, runtime.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -334,7 +340,7 @@ impl ExecutionPlan for HashJoinExec { let start = Instant::now(); // Load 1 partition of left side in memory - let stream = self.left.execute(partition).await?; + let stream = self.left.execute(partition, runtime.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -385,7 +391,7 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let right_stream = self.right.execute(partition).await?; + let right_stream = self.right.execute(partition, runtime.clone()).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let num_rows = left_data.1.num_rows(); @@ -1041,11 +1047,12 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, + runtime: Arc, ) -> Result<(Vec, Vec)> { let join = join(left, right, on, join_type, null_equals_null)?; let columns = columns(&join.schema()); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; Ok((columns, batches)) @@ -1057,6 +1064,7 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, + runtime: Arc, ) -> Result<(Vec, Vec)> { let partition_count = 4; @@ -1089,7 +1097,7 @@ mod tests { let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i).await?; + let stream = join.execute(i, runtime.clone()).await?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1104,6 +1112,7 @@ mod tests { #[tokio::test] async fn join_inner_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1126,6 +1135,7 @@ mod tests { on.clone(), &JoinType::Inner, false, + runtime, ) .await?; @@ -1147,6 +1157,7 @@ mod tests { #[tokio::test] async fn partitioned_join_inner_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1168,6 +1179,7 @@ mod tests { on.clone(), &JoinType::Inner, false, + runtime, ) .await?; @@ -1189,6 +1201,7 @@ mod tests { #[tokio::test] async fn join_inner_one_no_shared_column_names() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1205,7 +1218,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false).await?; + join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1226,6 +1239,7 @@ mod tests { #[tokio::test] async fn join_inner_two() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -1248,7 +1262,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false).await?; + join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1272,6 +1286,7 @@ mod tests { /// Test where the left has 2 parts, the right with 1 part => 1 part #[tokio::test] async fn join_inner_one_two_parts_left() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1301,7 +1316,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false).await?; + join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1325,6 +1340,7 @@ mod tests { /// Test where the left has 1 part, the right has 2 parts => 2 parts #[tokio::test] async fn join_inner_one_two_parts_right() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1354,7 +1370,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime.clone()).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); @@ -1368,7 +1384,7 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); // second part - let stream = join.execute(1).await?; + let stream = join.execute(1, runtime.clone()).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1399,6 +1415,7 @@ mod tests { #[tokio::test] async fn join_left_multi_batch() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1419,7 +1436,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1439,6 +1456,7 @@ mod tests { #[tokio::test] async fn join_full_multi_batch() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1460,7 +1478,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1482,6 +1500,7 @@ mod tests { #[tokio::test] async fn join_left_empty_right() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1499,7 +1518,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1517,6 +1536,7 @@ mod tests { #[tokio::test] async fn join_full_empty_right() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1534,7 +1554,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1552,6 +1572,7 @@ mod tests { #[tokio::test] async fn join_left_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1573,6 +1594,7 @@ mod tests { on.clone(), &JoinType::Left, false, + runtime, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1593,6 +1615,7 @@ mod tests { #[tokio::test] async fn partitioned_join_left_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1614,6 +1637,7 @@ mod tests { on.clone(), &JoinType::Left, false, + runtime, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1634,6 +1658,7 @@ mod tests { #[tokio::test] async fn join_semi() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 2, 3]), ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right @@ -1654,7 +1679,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1673,6 +1698,7 @@ mod tests { #[tokio::test] async fn join_anti() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 2, 3, 5]), ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right @@ -1693,7 +1719,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1710,6 +1736,7 @@ mod tests { #[tokio::test] async fn join_right_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1726,7 +1753,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, false).await?; + join_collect(left, right, on, &JoinType::Right, false, runtime).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1747,6 +1774,7 @@ mod tests { #[tokio::test] async fn partitioned_join_right_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1763,7 +1791,8 @@ mod tests { )]; let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right, false).await?; + partitioned_join_collect(left, right, on, &JoinType::Right, false, runtime) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1784,6 +1813,7 @@ mod tests { #[tokio::test] async fn join_full_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1804,7 +1834,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1868,6 +1898,7 @@ mod tests { #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a", &vec![1, 2, 3]), ("b", &vec![4, 5, 7]), @@ -1889,7 +1920,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 4365c8af0a4c..2f063f3577f4 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -25,7 +25,7 @@ use arrow::array::ArrayRef; mod noforce_hash_collisions { use super::{ArrayRef, CallHasher, RandomState, Result}; use crate::error::DataFusionError; - use arrow::array::{Array, DictionaryArray, DictionaryKey}; + use arrow::array::{Array, DictionaryArray, DictionaryKey, Int128Array}; use arrow::array::{ BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, @@ -36,6 +36,44 @@ mod noforce_hash_collisions { type StringArray = Utf8Array; type LargeStringArray = Utf8Array; + fn hash_decimal128<'a>( + array: &ArrayRef, + random_state: &RandomState, + hashes_buffer: &'a mut Vec, + mul_col: bool, + ) { + let array = array.as_any().downcast_ref::().unwrap(); + if array.null_count() == 0 { + if mul_col { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + *hash = combine_hashes( + i128::get_hash(&array.value(i), random_state), + *hash, + ); + } + } else { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + *hash = i128::get_hash(&array.value(i), random_state); + } + } + } else if mul_col { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = combine_hashes( + i128::get_hash(&array.value(i), random_state), + *hash, + ); + } + } + } else { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = i128::get_hash(&array.value(i), random_state); + } + } + } + } + macro_rules! hash_array_float { ($array_type:ident, $column: ident, $ty: ident, $hashes: ident, $random_state: ident, $multi_col: ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); @@ -240,6 +278,9 @@ mod noforce_hash_collisions { for col in arrays { match col.data_type() { + DataType::Decimal(_, _) => { + hash_decimal128(col, random_state, hashes_buffer, multi_col); + } DataType::UInt8 => { hash_array_primitive!( UInt8Array, @@ -523,12 +564,28 @@ mod tests { use crate::error::Result; use std::sync::Arc; - use arrow::array::{Float32Array, Float64Array}; + use arrow::array::{Float32Array, Float64Array, Int128Vec, PrimitiveArray, TryPush}; #[cfg(not(feature = "force_hash_collisions"))] use arrow::array::{MutableDictionaryArray, MutableUtf8Array, TryExtend, Utf8Array}; use super::*; + #[test] + fn create_hashes_for_decimal_array() -> Result<()> { + let mut builder = Int128Vec::with_capacity(4); + let array: Vec = vec![1, 2, 3, 4]; + for value in &array { + builder.try_push(Some(*value))?; + } + let array: PrimitiveArray = builder.into(); + let array_ref = Arc::new(array); + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let hashes_buff = &mut vec![0; array_ref.len()]; + let hashes = create_hashes(&[array_ref], &random_state, hashes_buff)?; + assert_eq!(hashes.len(), 4); + Ok(()) + } + #[test] fn create_hashes_for_float_arrays() -> Result<()> { let f32_arr = Arc::new(Float32Array::from_slice(&[0.12, 0.5, 1f32, 444.7])); diff --git a/datafusion/src/physical_plan/join_utils.rs b/datafusion/src/physical_plan/join_utils.rs index 8359bbc4e9f7..8e903e6c008e 100644 --- a/datafusion/src/physical_plan/join_utils.rs +++ b/datafusion/src/physical_plan/join_utils.rs @@ -18,6 +18,7 @@ //! Join related functionality used both on logical and physical plans use crate::error::{DataFusionError, Result}; +use crate::field_util::{FieldExt, SchemaExt}; use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; use arrow::datatypes::{Field, Schema}; diff --git a/datafusion/src/physical_plan/limit.rs b/datafusion/src/physical_plan/limit.rs index 546f36cb60e0..762c598d46c7 100644 --- a/datafusion/src/physical_plan/limit.rs +++ b/datafusion/src/physical_plan/limit.rs @@ -30,17 +30,18 @@ use crate::physical_plan::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, }; +use crate::record_batch::RecordBatch; use arrow::array::ArrayRef; use arrow::compute::limit::limit; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; use super::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; /// Limit execution plan @@ -114,7 +115,11 @@ impl ExecutionPlan for GlobalLimitExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -131,7 +136,7 @@ impl ExecutionPlan for GlobalLimitExec { } let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(0).await?; + let stream = self.input.execute(0, runtime).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -243,9 +248,13 @@ impl ExecutionPlan for LocalLimitExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(partition).await?; + let stream = self.input.execute(partition, runtime).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -388,11 +397,12 @@ mod tests { use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; - use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; + use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::{test, test_util}; #[tokio::test] async fn limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let num_partitions = 4; @@ -400,13 +410,12 @@ mod tests { test::create_partitioned_csv("aggregate_test_100.csv", num_partitions)?; let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: schema, file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -421,7 +430,7 @@ mod tests { GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), 7); // the result should contain 4 batches (one per input partition) - let iter = limit.execute(0).await?; + let iter = limit.execute(0, runtime).await?; let batches = common::collect(iter).await?; // there should be a total of 100 rows diff --git a/datafusion/src/physical_plan/memory.rs b/datafusion/src/physical_plan/memory.rs index ecd7f254ff6f..7ce0c762bb1d 100644 --- a/datafusion/src/physical_plan/memory.rs +++ b/datafusion/src/physical_plan/memory.rs @@ -23,14 +23,15 @@ use std::sync::Arc; use std::task::{Context, Poll}; use super::{ - common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, + common, project_schema, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::error::{DataFusionError, Result}; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use crate::record_batch::RecordBatch; +use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use futures::Stream; @@ -47,7 +48,7 @@ pub struct MemoryExec { } impl fmt::Debug for MemoryExec { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection) @@ -86,7 +87,11 @@ impl ExecutionPlan for MemoryExec { ))) } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { Ok(Box::pin(MemoryStream::try_new( self.partitions[partition].clone(), self.projected_schema.clone(), @@ -131,24 +136,7 @@ impl MemoryExec { schema: SchemaRef, projection: Option>, ) -> Result { - let projected_schema = match &projection { - Some(columns) => { - let fields: Result> = columns - .iter() - .map(|i| { - if *i < schema.fields().len() { - Ok(schema.field(*i).clone()) - } else { - Err(DataFusionError::Internal( - "Projection index out of range".to_string(), - )) - } - }) - .collect(); - Arc::new(Schema::new(fields?)) - } - None => Arc::clone(&schema), - }; + let projected_schema = project_schema(&schema, projection.as_ref())?; Ok(Self { partitions: partitions.to_vec(), schema, @@ -159,7 +147,7 @@ impl MemoryExec { } /// Iterator over batches -pub(crate) struct MemoryStream { +pub struct MemoryStream { /// Vector of record batches data: Vec, /// Schema representing the data @@ -196,14 +184,14 @@ impl Stream for MemoryStream { Poll::Ready(if self.index < self.data.len() { self.index += 1; let batch = &self.data[self.index - 1]; - // apply projection - match &self.projection { - Some(columns) => Some(RecordBatch::try_new( - self.schema.clone(), - columns.iter().map(|i| batch.column(*i).clone()).collect(), - )), - None => Some(Ok(batch.clone())), - } + + // return just the columns requested + let batch = match self.projection.as_ref() { + Some(columns) => batch.project(columns)?, + None => batch.clone(), + }; + + Some(Ok(batch)) } else { None }) @@ -224,6 +212,7 @@ impl RecordBatchStream for MemoryStream { #[cfg(test)] mod tests { use super::*; + use crate::field_util::{FieldExt, SchemaExt}; use crate::physical_plan::ColumnStatistics; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; @@ -252,6 +241,7 @@ mod tests { #[tokio::test] async fn test_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let (schema, batch) = mock_data()?; let executor = MemoryExec::try_new(&[vec![batch]], schema, Some(vec![2, 1]))?; @@ -277,7 +267,7 @@ mod tests { ); // scan with projection - let mut it = executor.execute(0).await?; + let mut it = executor.execute(0, runtime).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -289,6 +279,7 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let (schema, batch) = mock_data()?; let executor = MemoryExec::try_new(&[vec![batch]], schema, None)?; @@ -325,7 +316,7 @@ mod tests { ]) ); - let mut it = executor.execute(0).await?; + let mut it = executor.execute(0, runtime).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(4, batch1.schema().fields().len()); assert_eq!(4, batch1.num_columns()); diff --git a/datafusion/src/physical_plan/metrics/aggregated.rs b/datafusion/src/physical_plan/metrics/aggregated.rs new file mode 100644 index 000000000000..c55cc1601768 --- /dev/null +++ b/datafusion/src/physical_plan/metrics/aggregated.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Metrics common for complex operators with multiple steps. + +use crate::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricsSet, Time, +}; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Debug, Clone)] +/// Aggregates all metrics during a complex operation, which is composed of multiple steps and +/// each stage reports its statistics separately. +/// Give sort as an example, when the dataset is more significant than available memory, it will report +/// multiple in-mem sort metrics and final merge-sort metrics from `SortPreservingMergeStream`. +/// Therefore, We need a separation of metrics for which are final metrics (for output_rows accumulation), +/// and which are intermediate metrics that we only account for elapsed_compute time. +pub struct AggregatedMetricsSet { + intermediate: Arc>>, + final_: Arc>>, +} + +impl Default for AggregatedMetricsSet { + fn default() -> Self { + Self::new() + } +} + +impl AggregatedMetricsSet { + /// Create a new aggregated set + pub fn new() -> Self { + Self { + intermediate: Arc::new(std::sync::Mutex::new(vec![])), + final_: Arc::new(std::sync::Mutex::new(vec![])), + } + } + + /// create a new intermediate baseline + pub fn new_intermediate_baseline(&self, partition: usize) -> BaselineMetrics { + let ms = ExecutionPlanMetricsSet::new(); + let result = BaselineMetrics::new(&ms, partition); + self.intermediate.lock().unwrap().push(ms); + result + } + + /// create a new final baseline + pub fn new_final_baseline(&self, partition: usize) -> BaselineMetrics { + let ms = ExecutionPlanMetricsSet::new(); + let result = BaselineMetrics::new(&ms, partition); + self.final_.lock().unwrap().push(ms); + result + } + + fn merge_compute_time(&self, dest: &Time) { + let time1 = self + .intermediate + .lock() + .unwrap() + .iter() + .map(|es| { + es.clone_inner() + .elapsed_compute() + .map_or(0u64, |v| v as u64) + }) + .sum(); + let time2 = self + .final_ + .lock() + .unwrap() + .iter() + .map(|es| { + es.clone_inner() + .elapsed_compute() + .map_or(0u64, |v| v as u64) + }) + .sum(); + dest.add_duration(Duration::from_nanos(time1)); + dest.add_duration(Duration::from_nanos(time2)); + } + + fn merge_spill_count(&self, dest: &Count) { + let count1 = self + .intermediate + .lock() + .unwrap() + .iter() + .map(|es| es.clone_inner().spill_count().map_or(0, |v| v)) + .sum(); + let count2 = self + .final_ + .lock() + .unwrap() + .iter() + .map(|es| es.clone_inner().spill_count().map_or(0, |v| v)) + .sum(); + dest.add(count1); + dest.add(count2); + } + + fn merge_spilled_bytes(&self, dest: &Count) { + let count1 = self + .intermediate + .lock() + .unwrap() + .iter() + .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v)) + .sum(); + let count2 = self + .final_ + .lock() + .unwrap() + .iter() + .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v)) + .sum(); + dest.add(count1); + dest.add(count2); + } + + fn merge_output_count(&self, dest: &Count) { + let count = self + .final_ + .lock() + .unwrap() + .iter() + .map(|es| es.clone_inner().output_rows().map_or(0, |v| v)) + .sum(); + dest.add(count); + } + + /// Aggregate all metrics into a one + pub fn aggregate_all(&self) -> MetricsSet { + let metrics = ExecutionPlanMetricsSet::new(); + let baseline = BaselineMetrics::new(&metrics, 0); + self.merge_compute_time(baseline.elapsed_compute()); + self.merge_spill_count(baseline.spill_count()); + self.merge_spilled_bytes(baseline.spilled_bytes()); + self.merge_output_count(baseline.output_rows()); + metrics.clone_inner() + } +} diff --git a/datafusion/src/physical_plan/metrics/baseline.rs b/datafusion/src/physical_plan/metrics/baseline.rs index b007d074f624..a095360ef54c 100644 --- a/datafusion/src/physical_plan/metrics/baseline.rs +++ b/datafusion/src/physical_plan/metrics/baseline.rs @@ -19,9 +19,9 @@ use std::task::Poll; -use arrow::{error::ArrowError, record_batch::RecordBatch}; - -use super::{Count, ExecutionPlanMetricsSet, MetricBuilder, Time, Timestamp}; +use super::{Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, Time, Timestamp}; +use crate::record_batch::RecordBatch; +use arrow::error::ArrowError; /// Helper for creating and tracking common "baseline" metrics for /// each operator @@ -50,6 +50,15 @@ pub struct BaselineMetrics { /// amount of time the operator was actively trying to use the CPU elapsed_compute: Time, + /// count of spills during the execution of the operator + spill_count: Count, + + /// total spilled bytes during the execution of the operator + spilled_bytes: Count, + + /// current memory usage for the operator + mem_used: Gauge, + /// output rows: the total output rows output_rows: Count, } @@ -63,6 +72,9 @@ impl BaselineMetrics { Self { end_time: MetricBuilder::new(metrics).end_timestamp(partition), elapsed_compute: MetricBuilder::new(metrics).elapsed_compute(partition), + spill_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + mem_used: MetricBuilder::new(metrics).mem_used(partition), output_rows: MetricBuilder::new(metrics).output_rows(partition), } } @@ -72,6 +84,27 @@ impl BaselineMetrics { &self.elapsed_compute } + /// return the metric for the total number of spills triggered during execution + pub fn spill_count(&self) -> &Count { + &self.spill_count + } + + /// return the metric for the total spilled bytes during execution + pub fn spilled_bytes(&self) -> &Count { + &self.spilled_bytes + } + + /// return the metric for current memory usage + pub fn mem_used(&self) -> &Gauge { + &self.mem_used + } + + /// Record a spill of `spilled_bytes` size. + pub fn record_spill(&self, spilled_bytes: usize) { + self.spill_count.add(1); + self.spilled_bytes.add(spilled_bytes); + } + /// return the metric for the total number of output rows produced pub fn output_rows(&self) -> &Count { &self.output_rows diff --git a/datafusion/src/physical_plan/metrics/builder.rs b/datafusion/src/physical_plan/metrics/builder.rs index 510366bb3565..30e9764c6446 100644 --- a/datafusion/src/physical_plan/metrics/builder.rs +++ b/datafusion/src/physical_plan/metrics/builder.rs @@ -20,7 +20,7 @@ use std::{borrow::Cow, sync::Arc}; use super::{ - Count, ExecutionPlanMetricsSet, Label, Metric, MetricValue, Time, Timestamp, + Count, ExecutionPlanMetricsSet, Gauge, Label, Metric, MetricValue, Time, Timestamp, }; /// Structure for constructing metrics, counters, timers, etc. @@ -105,6 +105,32 @@ impl<'a> MetricBuilder<'a> { count } + /// Consume self and create a new counter for recording the number of spills + /// triggered by an operator + pub fn spill_count(self, partition: usize) -> Count { + let count = Count::new(); + self.with_partition(partition) + .build(MetricValue::SpillCount(count.clone())); + count + } + + /// Consume self and create a new counter for recording the total spilled bytes + /// triggered by an operator + pub fn spilled_bytes(self, partition: usize) -> Count { + let count = Count::new(); + self.with_partition(partition) + .build(MetricValue::SpilledBytes(count.clone())); + count + } + + /// Consume self and create a new gauge for reporting current memory usage + pub fn mem_used(self, partition: usize) -> Gauge { + let gauge = Gauge::new(); + self.with_partition(partition) + .build(MetricValue::CurrentMemoryUsage(gauge.clone())); + gauge + } + /// Consumes self and creates a new [`Count`] for recording some /// arbitrary metric of an operator. pub fn counter( @@ -115,6 +141,16 @@ impl<'a> MetricBuilder<'a> { self.with_partition(partition).global_counter(counter_name) } + /// Consumes self and creates a new [`Gauge`] for reporting some + /// arbitrary metric of an operator. + pub fn gauge( + self, + gauge_name: impl Into>, + partition: usize, + ) -> Gauge { + self.with_partition(partition).global_gauge(gauge_name) + } + /// Consumes self and creates a new [`Count`] for recording a /// metric of an overall operator (not per partition) pub fn global_counter(self, counter_name: impl Into>) -> Count { @@ -126,6 +162,17 @@ impl<'a> MetricBuilder<'a> { count } + /// Consumes self and creates a new [`Gauge`] for reporting a + /// metric of an overall operator (not per partition) + pub fn global_gauge(self, gauge_name: impl Into>) -> Gauge { + let gauge = Gauge::new(); + self.build(MetricValue::Gauge { + name: gauge_name.into(), + gauge: gauge.clone(), + }); + gauge + } + /// Consume self and create a new Timer for recording the elapsed /// CPU time spent by an operator pub fn elapsed_compute(self, partition: usize) -> Time { diff --git a/datafusion/src/physical_plan/metrics/mod.rs b/datafusion/src/physical_plan/metrics/mod.rs index 7c59c8dddd76..d48959974e8d 100644 --- a/datafusion/src/physical_plan/metrics/mod.rs +++ b/datafusion/src/physical_plan/metrics/mod.rs @@ -17,6 +17,7 @@ //! Metrics for recording information about execution +mod aggregated; mod baseline; mod builder; mod value; @@ -30,9 +31,10 @@ use std::{ use hashbrown::HashMap; // public exports +pub use aggregated::AggregatedMetricsSet; pub use baseline::{BaselineMetrics, RecordOutput}; pub use builder::MetricBuilder; -pub use value::{Count, MetricValue, ScopedTimerGuard, Time, Timestamp}; +pub use value::{Count, Gauge, MetricValue, ScopedTimerGuard, Time, Timestamp}; /// Something that tracks a value of interest (metric) of a DataFusion /// [`ExecutionPlan`] execution. @@ -76,7 +78,7 @@ pub struct Metric { } impl Display for Metric { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", self.value.name())?; let mut iter = self @@ -191,6 +193,20 @@ impl MetricsSet { .map(|v| v.as_usize()) } + /// convenience: return the count of spills, aggregated + /// across partitions or None if no metric is present + pub fn spill_count(&self) -> Option { + self.sum(|metric| matches!(metric.value(), MetricValue::SpillCount(_))) + .map(|v| v.as_usize()) + } + + /// convenience: return the total byte size of spills, aggregated + /// across partitions or None if no metric is present + pub fn spilled_bytes(&self) -> Option { + self.sum(|metric| matches!(metric.value(), MetricValue::SpilledBytes(_))) + .map(|v| v.as_usize()) + } + /// convenience: return the amount of elapsed CPU time spent, /// aggregated across partitions or None if no metric is present pub fn elapsed_compute(&self) -> Option { @@ -282,7 +298,7 @@ impl MetricsSet { impl Display for MetricsSet { /// format the MetricsSet as a single string - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let mut is_first = true; for i in self.metrics.iter() { if !is_first { @@ -363,7 +379,7 @@ impl Label { } impl Display for Label { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}={}", self.name, self.value) } } diff --git a/datafusion/src/physical_plan/metrics/value.rs b/datafusion/src/physical_plan/metrics/value.rs index 6f6358339bdb..6ac282a496ee 100644 --- a/datafusion/src/physical_plan/metrics/value.rs +++ b/datafusion/src/physical_plan/metrics/value.rs @@ -45,11 +45,17 @@ impl PartialEq for Count { } impl Display for Count { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", self.value()) } } +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + impl Count { /// create a new counter pub fn new() -> Self { @@ -71,6 +77,62 @@ impl Count { } } +/// A gauge is the simplest metrics type. It just returns a value. +/// For example, you can easily expose current memory consumption with a gauge. +/// +/// Note `clone`ing gauge update the same underlying metrics +#[derive(Debug, Clone)] +pub struct Gauge { + /// value of the metric gauge + value: std::sync::Arc, +} + +impl PartialEq for Gauge { + fn eq(&self, other: &Self) -> bool { + self.value().eq(&other.value()) + } +} + +impl Display for Gauge { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.value()) + } +} + +impl Default for Gauge { + fn default() -> Self { + Self::new() + } +} + +impl Gauge { + /// create a new gauge + pub fn new() -> Self { + Self { + value: Arc::new(AtomicUsize::new(0)), + } + } + + /// Add `n` to the metric's value + pub fn add(&self, n: usize) { + // relaxed ordering for operations on `value` poses no issues + // we're purely using atomic ops with no associated memory ops + self.value.fetch_add(n, Ordering::Relaxed); + } + + /// Set the metric's value to `n` and return the previous value + pub fn set(&self, n: usize) -> usize { + // relaxed ordering for operations on `value` poses no issues + // we're purely using atomic ops with no associated memory ops + self.value.swap(n, Ordering::Relaxed) + } + + /// Get the current value + pub fn value(&self) -> usize { + self.value.load(Ordering::Relaxed) + } +} + /// Measure a potentially non contiguous duration of time #[derive(Debug, Clone)] pub struct Time { @@ -78,6 +140,12 @@ pub struct Time { nanos: Arc, } +impl Default for Time { + fn default() -> Self { + Self::new() + } +} + impl PartialEq for Time { fn eq(&self, other: &Self) -> bool { self.value().eq(&other.value()) @@ -85,7 +153,7 @@ impl PartialEq for Time { } impl Display for Time { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let duration = std::time::Duration::from_nanos(self.value() as u64); write!(f, "{:?}", duration) } @@ -140,6 +208,12 @@ pub struct Timestamp { timestamp: Arc>>>, } +impl Default for Timestamp { + fn default() -> Self { + Self::new() + } +} + impl Timestamp { /// Create a new timestamp and sets its value to 0 pub fn new() -> Self { @@ -198,7 +272,7 @@ impl PartialEq for Timestamp { } impl Display for Timestamp { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self.value() { None => write!(f, "NONE"), Some(v) => { @@ -265,6 +339,12 @@ pub enum MetricValue { /// classical defintion of "cpu_time", which is the time reported /// from `clock_gettime(CLOCK_THREAD_CPUTIME_ID, ..)`. ElapsedCompute(Time), + /// Number of spills produced: "spill_count" metric + SpillCount(Count), + /// Total size of spilled bytes produced: "spilled_bytes" metric + SpilledBytes(Count), + /// Current memory used + CurrentMemoryUsage(Gauge), /// Operator defined count. Count { /// The provided name of this metric @@ -272,6 +352,13 @@ pub enum MetricValue { /// The value of the metric count: Count, }, + /// Operator defined gauge. + Gauge { + /// The provided name of this metric + name: Cow<'static, str>, + /// The value of the metric + gauge: Gauge, + }, /// Operator defined time Time { /// The provided name of this metric @@ -290,8 +377,12 @@ impl MetricValue { pub fn name(&self) -> &str { match self { Self::OutputRows(_) => "output_rows", + Self::SpillCount(_) => "spill_count", + Self::SpilledBytes(_) => "spilled_bytes", + Self::CurrentMemoryUsage(_) => "mem_used", Self::ElapsedCompute(_) => "elapsed_compute", Self::Count { name, .. } => name.borrow(), + Self::Gauge { name, .. } => name.borrow(), Self::Time { name, .. } => name.borrow(), Self::StartTimestamp(_) => "start_timestamp", Self::EndTimestamp(_) => "end_timestamp", @@ -302,8 +393,12 @@ impl MetricValue { pub fn as_usize(&self) -> usize { match self { Self::OutputRows(count) => count.value(), + Self::SpillCount(count) => count.value(), + Self::SpilledBytes(bytes) => bytes.value(), + Self::CurrentMemoryUsage(used) => used.value(), Self::ElapsedCompute(time) => time.value(), Self::Count { count, .. } => count.value(), + Self::Gauge { gauge, .. } => gauge.value(), Self::Time { time, .. } => time.value(), Self::StartTimestamp(timestamp) => timestamp .value() @@ -321,11 +416,18 @@ impl MetricValue { pub fn new_empty(&self) -> Self { match self { Self::OutputRows(_) => Self::OutputRows(Count::new()), + Self::SpillCount(_) => Self::SpillCount(Count::new()), + Self::SpilledBytes(_) => Self::SpilledBytes(Count::new()), + Self::CurrentMemoryUsage(_) => Self::CurrentMemoryUsage(Gauge::new()), Self::ElapsedCompute(_) => Self::ElapsedCompute(Time::new()), Self::Count { name, .. } => Self::Count { name: name.clone(), count: Count::new(), }, + Self::Gauge { name, .. } => Self::Gauge { + name: name.clone(), + gauge: Gauge::new(), + }, Self::Time { name, .. } => Self::Time { name: name.clone(), time: Time::new(), @@ -347,12 +449,21 @@ impl MetricValue { pub fn aggregate(&mut self, other: &Self) { match (self, other) { (Self::OutputRows(count), Self::OutputRows(other_count)) + | (Self::SpillCount(count), Self::SpillCount(other_count)) + | (Self::SpilledBytes(count), Self::SpilledBytes(other_count)) | ( Self::Count { count, .. }, Self::Count { count: other_count, .. }, ) => count.add(other_count.value()), + (Self::CurrentMemoryUsage(gauge), Self::CurrentMemoryUsage(other_gauge)) + | ( + Self::Gauge { gauge, .. }, + Self::Gauge { + gauge: other_gauge, .. + }, + ) => gauge.add(other_gauge.value()), (Self::ElapsedCompute(time), Self::ElapsedCompute(other_time)) | ( Self::Time { time, .. }, @@ -383,10 +494,14 @@ impl MetricValue { match self { Self::OutputRows(_) => 0, // show first Self::ElapsedCompute(_) => 1, // show second - Self::Count { .. } => 2, - Self::Time { .. } => 3, - Self::StartTimestamp(_) => 4, // show timestamps last - Self::EndTimestamp(_) => 5, + Self::SpillCount(_) => 2, + Self::SpilledBytes(_) => 3, + Self::CurrentMemoryUsage(_) => 4, + Self::Count { .. } => 5, + Self::Gauge { .. } => 6, + Self::Time { .. } => 7, + Self::StartTimestamp(_) => 8, // show timestamps last + Self::EndTimestamp(_) => 9, } } @@ -398,11 +513,17 @@ impl MetricValue { impl std::fmt::Display for MetricValue { /// Prints the value of this metric - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - Self::OutputRows(count) | Self::Count { count, .. } => { + Self::OutputRows(count) + | Self::SpillCount(count) + | Self::SpilledBytes(count) + | Self::Count { count, .. } => { write!(f, "{}", count) } + Self::CurrentMemoryUsage(gauge) | Self::Gauge { gauge, .. } => { + write!(f, "{}", gauge) + } Self::ElapsedCompute(time) | Self::Time { time, .. } => { // distinguish between no time recorded and very small // amount of time recorded diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 769e88bad5a9..79f0aa499c33 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -22,20 +22,24 @@ use self::metrics::MetricsSet; use self::{ coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, }; -use crate::physical_plan::expressions::{PhysicalSortExpr, SortColumn}; +use crate::field_util::SchemaExt; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::record_batch::RecordBatch; use crate::{ error::{DataFusionError, Result}, + execution::runtime_env::RuntimeEnv, scalar::ScalarValue, }; use arrow::array::ArrayRef; use arrow::compute::merge_sort::SortOptions; use arrow::compute::partition::lexicographical_partition_ranges; +use arrow::compute::sort::SortColumn as ArrowSortColumn; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; use async_trait::async_trait; pub use display::DisplayFormatType; use futures::stream::Stream; +use sorts::SortColumn; use std::fmt; use std::fmt::{Debug, Display}; use std::ops::Range; @@ -154,7 +158,11 @@ pub trait ExecutionPlan: Debug + Send + Sync { ) -> Result>; /// creates an iterator - async fn execute(&self, partition: usize) -> Result; + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result; /// Return a snapshot of the set of [`Metric`]s for this /// [`ExecutionPlan`]. @@ -218,7 +226,7 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// \n CoalesceBatchesExec: target_batch_size=4096\ /// \n FilterExec: a@0 < 5\ /// \n RepartitionExec: partitioning=RoundRobinBatch(3)\ -/// \n CsvExec: files=[tests/example.csv], has_header=true, batch_size=8192, limit=None", +/// \n CsvExec: files=[tests/example.csv], has_header=true, limit=None", /// plan_string.trim()); /// } /// ``` @@ -310,24 +318,28 @@ pub fn visit_execution_plan( } /// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect(plan: Arc) -> Result> { - let stream = execute_stream(plan).await?; +pub async fn collect( + plan: Arc, + runtime: Arc, +) -> Result> { + let stream = execute_stream(plan, runtime).await?; common::collect(stream).await } /// Execute the [ExecutionPlan] and return a single stream of results pub async fn execute_stream( plan: Arc, + runtime: Arc, ) -> Result { match plan.output_partitioning().partition_count() { 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), - 1 => plan.execute(0).await, + 1 => plan.execute(0, runtime).await, _ => { // merge into a single partition let plan = CoalescePartitionsExec::new(plan.clone()); // CoalescePartitionsExec must produce a single partition assert_eq!(1, plan.output_partitioning().partition_count()); - plan.execute(0).await + plan.execute(0, runtime).await } } } @@ -335,8 +347,9 @@ pub async fn execute_stream( /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect_partitioned( plan: Arc, + runtime: Arc, ) -> Result>> { - let streams = execute_stream_partitioned(plan).await?; + let streams = execute_stream_partitioned(plan, runtime).await?; let mut batches = Vec::with_capacity(streams.len()); for stream in streams { batches.push(common::collect(stream).await?); @@ -347,11 +360,12 @@ pub async fn collect_partitioned( /// Execute the [ExecutionPlan] and return a vec with one stream per output partition pub async fn execute_stream_partitioned( plan: Arc, + runtime: Arc, ) -> Result> { let num_partitions = plan.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(num_partitions); for i in 0..num_partitions { - streams.push(plan.execute(i).await?); + streams.push(plan.execute(i, runtime.clone()).await?); } Ok(streams) } @@ -512,14 +526,13 @@ pub trait WindowExpr: Send + Sync + Debug { end: num_rows, }]) } else { - Ok(lexicographical_partition_ranges( - &partition_columns - .iter() - .map(|x| x.into()) - .collect::>(), - ) - .map_err(DataFusionError::ArrowError)? - .collect()) + let v = partition_columns + .iter() + .map(|sc| sc.into()) + .collect::>(); + Ok(lexicographical_partition_ranges(v.as_slice()) + .map_err(DataFusionError::ArrowError)? + .collect()) } } @@ -560,9 +573,9 @@ pub trait WindowExpr: Send + Sync + Debug { /// generically accumulates values. /// /// An accumulator knows how to: -/// * update its state from inputs via `update` +/// * update its state from inputs via `update_batch` /// * convert its internal state to a vector of scalar values -/// * update its state from multiple accumulators' states via `merge` +/// * update its state from multiple accumulators' states via `merge_batch` /// * compute the final value from its internal state via `evaluate` pub trait Accumulator: Send + Sync + Debug { /// Returns the state of the accumulator at the end of the accumulation. @@ -570,44 +583,57 @@ pub trait Accumulator: Send + Sync + Debug { // of two values, sum and n. fn state(&self) -> Result>; - /// updates the accumulator's state from a vector of scalars. - fn update(&mut self, values: &[ScalarValue]) -> Result<()>; - /// updates the accumulator's state from a vector of arrays. - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - }; - (0..values[0].len()).try_for_each(|index| { - let v = values - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - self.update(&v) - }) - } - - /// updates the accumulator's state from a vector of scalars. - fn merge(&mut self, states: &[ScalarValue]) -> Result<()>; + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; /// updates the accumulator's state from a vector of states. - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - }; - (0..states[0].len()).try_for_each(|index| { - let v = states - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - self.merge(&v) - }) - } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; /// returns its value based on its current state. fn evaluate(&self) -> Result; } +/// Applies an optional projection to a [`SchemaRef`], returning the +/// projected schema +/// +/// Example: +/// ``` +/// use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; +/// use datafusion::field_util::SchemaExt; +/// use datafusion::physical_plan::project_schema; +/// +/// // Schema with columns 'a', 'b', and 'c' +/// let schema = SchemaRef::new(Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Int64, true), +/// Field::new("c", DataType::Utf8, true), +/// ])); +/// +/// // Pick columns 'c' and 'b' +/// let projection = Some(vec![2,1]); +/// let projected_schema = project_schema( +/// &schema, +/// projection.as_ref() +/// ).unwrap(); +/// +/// let expected_schema = SchemaRef::new(Schema::new(vec![ +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Int64, true), +/// ])); +/// +/// assert_eq!(projected_schema, expected_schema); +/// ``` +pub fn project_schema( + schema: &SchemaRef, + projection: Option<&Vec>, +) -> Result { + let schema = match projection { + Some(columns) => Arc::new(schema.project(columns)?), + None => Arc::clone(schema), + }; + Ok(schema) +} + pub mod aggregates; pub mod analyze; pub mod array_expressions; @@ -620,7 +646,6 @@ pub mod cross_join; pub mod crypto_expressions; pub mod datetime_expressions; pub mod display; -pub mod distinct_expressions; pub mod empty; pub mod explain; pub mod expressions; @@ -641,8 +666,7 @@ pub mod projection; #[cfg(feature = "regex_expressions")] pub mod regex_expressions; pub mod repartition; -pub mod sort; -pub mod sort_preserving_merge; +pub mod sorts; pub mod stream; pub mod string_expressions; pub mod type_coercion; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index c25bdac868db..84821a067179 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -23,6 +23,7 @@ use super::{ hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, windows, }; use crate::execution::context::ExecutionContextState; +use crate::field_util::{FieldExt, SchemaExt}; use crate::logical_plan::plan::{ Aggregate, EmptyRelation, Filter, Join, Projection, Sort, TableScan, Window, }; @@ -45,7 +46,7 @@ use crate::physical_plan::hash_join::HashJoinExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sort::SortExec; +use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{join_utils, Partitioning}; @@ -327,8 +328,6 @@ impl DefaultPhysicalPlanner { ctx_state: &'a ExecutionContextState, ) -> BoxFuture<'a, Result>> { async move { - let batch_size = ctx_state.config.batch_size; - let exec_plan: Result> = match logical_plan { LogicalPlan::TableScan (TableScan { source, @@ -342,7 +341,7 @@ impl DefaultPhysicalPlanner { // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(projection, batch_size, &unaliased, *limit).await + source.scan(projection, &unaliased, *limit).await } LogicalPlan::Values(Values { values, @@ -1461,6 +1460,7 @@ mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; use crate::execution::options::CsvReadOptions; + use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::plan::Extension; use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{ @@ -1915,7 +1915,11 @@ mod tests { unimplemented!("NoOpExecutionPlan::with_new_children"); } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index 824b44cea8bd..a150341a6184 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -30,13 +30,15 @@ use crate::physical_plan::{ ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, }; +use crate::record_batch::RecordBatch; use arrow::datatypes::{Field, Metadata, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::{FieldExt, SchemaExt}; use async_trait::async_trait; use futures::stream::Stream; use futures::stream::StreamExt; @@ -78,7 +80,8 @@ impl ProjectionExec { }) .collect(); - let schema = Arc::new(Schema::new_from(fields?, input_schema.metadata().clone())); + let schema = + Arc::new(Schema::new(fields?).with_metadata(input_schema.metadata().clone())); Ok(Self { expr, @@ -135,11 +138,15 @@ impl ExecutionPlan for ProjectionExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { Ok(Box::pin(ProjectionStream { schema: self.schema.clone(), expr: self.expr.iter().map(|x| x.0.clone()).collect(), - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) } @@ -230,15 +237,14 @@ impl ProjectionStream { fn batch_project(&self, batch: &RecordBatch) -> ArrowResult { // records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); - self.expr + let arrays = self + .expr .iter() .map(|expr| expr.evaluate(batch)) .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>() - .map_or_else( - |e| Err(DataFusionError::into_arrow_external_error(e)), - |arrays| RecordBatch::try_new(self.schema.clone(), arrays), - ) + .collect::>>()?; + + RecordBatch::try_new(self.schema.clone(), arrays) } } @@ -284,7 +290,7 @@ mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::expressions::{self, col}; - use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; + use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::scalar::ScalarValue; use crate::test::{self}; use crate::test_util; @@ -292,6 +298,7 @@ mod tests { #[tokio::test] async fn project_first_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -299,13 +306,12 @@ mod tests { test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: Arc::clone(&schema), file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -328,7 +334,7 @@ mod tests { let mut row_count = 0; for partition in 0..projection.output_partitioning().partition_count() { partition_count += 1; - let stream = projection.execute(partition).await?; + let stream = projection.execute(partition, runtime.clone()).await?; row_count += stream .map(|batch| { diff --git a/datafusion/src/physical_plan/regex_expressions.rs b/datafusion/src/physical_plan/regex_expressions.rs index f06a62c62db0..71c0901a677e 100644 --- a/datafusion/src/physical_plan/regex_expressions.rs +++ b/datafusion/src/physical_plan/regex_expressions.rs @@ -139,7 +139,7 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result { let (pattern, replace_all) = if flags == "g" { (pattern.to_string(), true) } else if flags.contains('g') { - (format!("(?{}){}", flags.to_string().replace("g", ""), pattern), true) + (format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true) } else { (format!("(?{}){}", flags, pattern), false) }; diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 5bd2f82f07ce..75b857a78361 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -26,7 +26,7 @@ use std::{any::Any, vec}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::hash_utils::create_hashes; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; -use arrow::record_batch::RecordBatch; +use crate::record_batch::RecordBatch; use arrow::{ array::{Array, UInt64Array}, error::Result as ArrowResult, @@ -39,6 +39,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; +use crate::execution::runtime_env::RuntimeEnv; use futures::stream::Stream; use futures::StreamExt; use hashbrown::HashMap; @@ -167,7 +168,11 @@ impl ExecutionPlan for RepartitionExec { self.partitioning.clone() } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // lock mutexes let mut state = self.state.lock().await; @@ -210,6 +215,7 @@ impl ExecutionPlan for RepartitionExec { txs.clone(), self.partitioning.clone(), r_metrics, + runtime.clone(), )); // In a separate task, wait for each input to be done @@ -288,12 +294,13 @@ impl RepartitionExec { mut txs: HashMap>>>, partitioning: Partitioning, r_metrics: RepartitionMetrics, + runtime: Arc, ) -> Result<()> { let num_output_partitions = txs.len(); // execute the child operator let timer = r_metrics.fetch_time.timer(); - let mut stream = input.execute(i).await?; + let mut stream = input.execute(i, runtime).await?; timer.done(); let mut counter = 0; @@ -414,7 +421,7 @@ impl RepartitionExec { Err(e) => { for (_, tx) in txs { let err = DataFusionError::Execution(format!("Join Error: {}", e)); - let err = Err(err.into_arrow_external_error()); + let err = Err(err.into()); tx.send(Some(err)).ok(); } } @@ -423,7 +430,7 @@ impl RepartitionExec { for (_, tx) in txs { // wrap it because need to send error to all output partitions let err = DataFusionError::Execution(e.to_string()); - let err = Err(err.into_arrow_external_error()); + let err = Err(err.into()); tx.send(Some(err)).ok(); } } @@ -494,6 +501,8 @@ mod tests { type StringArray = Utf8Array; use super::*; + use crate::field_util::SchemaExt; + use crate::record_batch::RecordBatch; use crate::{ assert_batches_sorted_eq, physical_plan::{collect, expressions::col, memory::MemoryExec}, @@ -508,7 +517,6 @@ mod tests { use arrow::array::{ArrayRef, UInt32Array, Utf8Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; - use arrow::record_batch::RecordBatch; use futures::FutureExt; #[tokio::test] @@ -620,6 +628,7 @@ mod tests { input_partitions: Vec>, partitioning: Partitioning, ) -> Result>> { + let runtime = Arc::new(RuntimeEnv::default()); // create physical plan let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; @@ -628,7 +637,7 @@ mod tests { let mut output_partitions = vec![]; for i in 0..exec.partitioning.partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i).await?; + let mut stream = exec.execute(i, runtime.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -668,6 +677,7 @@ mod tests { #[tokio::test] async fn unsupported_partitioning() { + let runtime = Arc::new(RuntimeEnv::default()); // have to send at least one batch through to provoke error let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", @@ -682,7 +692,7 @@ mod tests { // returned and no results produced let partitioning = Partitioning::UnknownPartitioning(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -702,13 +712,14 @@ mod tests { // This generates an error on a call to execute. The error // should be returned and no results produced. + let runtime = Arc::new(RuntimeEnv::default()); let input = ErrorExec::new(); let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -724,6 +735,7 @@ mod tests { #[tokio::test] async fn repartition_with_error_in_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(Utf8Array::::from_slice(&["foo", "bar"])) as ArrayRef, @@ -743,7 +755,7 @@ mod tests { // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -759,6 +771,7 @@ mod tests { #[tokio::test] async fn repartition_with_delayed_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let batch1 = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(Utf8Array::::from_slice(&["foo", "bar"])) as ArrayRef, @@ -793,7 +806,7 @@ mod tests { assert_batches_sorted_eq!(&expected, &expected_batches); - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); let batches = crate::physical_plan::common::collect(output_stream) .await .unwrap(); @@ -803,6 +816,7 @@ mod tests { #[tokio::test] async fn robin_repartition_with_dropping_output_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let partitioning = Partitioning::RoundRobinBatch(2); // The barrier exec waits to be pinged // requires the input to wait at least once) @@ -811,8 +825,8 @@ mod tests { // partition into two output streams let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0).await.unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced @@ -845,6 +859,7 @@ mod tests { // wiht different compilers, we will compare the same execution with // and without droping the output stream. async fn hash_repartition_with_dropping_output_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let partitioning = Partitioning::Hash( vec![Arc::new(crate::physical_plan::expressions::Column::new( "my_awesome_field", @@ -856,7 +871,7 @@ mod tests { // We first collect the results without droping the output stream. let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); input.wait().await; let batches_without_drop = crate::physical_plan::common::collect(output_stream1) .await @@ -876,8 +891,8 @@ mod tests { // Now do the same but dropping the stream before waiting for the barrier let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0).await.unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); @@ -942,6 +957,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -952,7 +968,7 @@ mod tests { Partitioning::UnknownPartitioning(1), )?); - let fut = collect(repartition_exec); + let fut = collect(repartition_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -964,6 +980,7 @@ mod tests { #[tokio::test] async fn hash_repartition_avoid_empty_batch() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let batch = RecordBatch::try_from_iter(vec![( "a", Arc::new(StringArray::from_slice(vec!["foo"])) as ArrayRef, @@ -978,11 +995,11 @@ mod tests { let schema = batch.schema().clone(); let input = MockExec::new(vec![Ok(batch)], schema.clone()); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream0 = exec.execute(0).await.unwrap(); + let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); let batch0 = crate::physical_plan::common::collect(output_stream0) .await .unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); let batch1 = crate::physical_plan::common::collect(output_stream1) .await .unwrap(); diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 3700380fdb72..8b137891791f 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -1,566 +1 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -//! Defines the SORT plan - -use super::common::AbortOnDropSingle; -use super::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, -}; -use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::{ - common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, -}; -pub use arrow::compute::sort::SortOptions; -use arrow::compute::{sort::lexsort_to_indices, take}; -use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, error::ArrowError}; -use async_trait::async_trait; -use futures::stream::Stream; -use futures::Future; -use pin_project_lite::pin_project; -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -/// Sort execution plan -#[derive(Debug)] -pub struct SortExec { - /// Input schema - input: Arc, - /// Sort expressions - expr: Vec, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, - /// Preserve partitions of input plan - preserve_partitioning: bool, -} - -impl SortExec { - /// Create a new sort execution plan - pub fn try_new( - expr: Vec, - input: Arc, - ) -> Result { - Ok(Self::new_with_partitioning(expr, input, false)) - } - - /// Create a new sort execution plan with the option to preserve - /// the partitioning of the input plan - pub fn new_with_partitioning( - expr: Vec, - input: Arc, - preserve_partitioning: bool, - ) -> Self { - Self { - expr, - input, - metrics: ExecutionPlanMetricsSet::new(), - preserve_partitioning, - } - } - - /// Input schema - pub fn input(&self) -> &Arc { - &self.input - } - - /// Sort expressions - pub fn expr(&self) -> &[PhysicalSortExpr] { - &self.expr - } -} - -#[async_trait] -impl ExecutionPlan for SortExec { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.input.schema() - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - /// Get the output partitioning of this plan - fn output_partitioning(&self) -> Partitioning { - if self.preserve_partitioning { - self.input.output_partitioning() - } else { - Partitioning::UnknownPartitioning(1) - } - } - - fn required_child_distribution(&self) -> Distribution { - if self.preserve_partitioning { - Distribution::UnspecifiedDistribution - } else { - Distribution::SinglePartition - } - } - - fn with_new_children( - &self, - children: Vec>, - ) -> Result> { - match children.len() { - 1 => Ok(Arc::new(SortExec::try_new( - self.expr.clone(), - children[0].clone(), - )?)), - _ => Err(DataFusionError::Internal( - "SortExec wrong number of children".to_string(), - )), - } - } - - async fn execute(&self, partition: usize) -> Result { - if !self.preserve_partitioning { - if 0 != partition { - return Err(DataFusionError::Internal(format!( - "SortExec invalid partition {}", - partition - ))); - } - - // sort needs to operate on a single partition currently - if 1 != self.input.output_partitioning().partition_count() { - return Err(DataFusionError::Internal( - "SortExec requires a single input partition".to_owned(), - )); - } - } - - let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let input = self.input.execute(partition).await?; - - Ok(Box::pin(SortStream::new( - input, - self.expr.clone(), - baseline_metrics, - ))) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); - write!(f, "SortExec: [{}]", expr.join(",")) - } - } - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - fn statistics(&self) -> Statistics { - self.input.statistics() - } -} - -fn sort_batch( - batch: RecordBatch, - schema: SchemaRef, - expr: &[PhysicalSortExpr], -) -> ArrowResult { - let columns = expr - .iter() - .map(|e| e.evaluate_to_sort_column(&batch)) - .collect::>>() - .map_err(DataFusionError::into_arrow_external_error)?; - let columns = columns.iter().map(|x| x.into()).collect::>(); - - // sort combined record batch - // TODO: pushup the limit expression to sort - let indices = lexsort_to_indices::(&columns, None)?; - - // reorder all rows based on sorted indices - RecordBatch::try_new( - schema, - batch - .columns() - .iter() - .map(|column| take::take(column.as_ref(), &indices).map(|x| x.into())) - .collect::>>()?, - ) -} - -pin_project! { - /// stream for sort plan - struct SortStream { - #[pin] - output: futures::channel::oneshot::Receiver>>, - finished: bool, - schema: SchemaRef, - drop_helper: AbortOnDropSingle<()>, - } -} - -impl SortStream { - fn new( - input: SendableRecordBatchStream, - expr: Vec, - baseline_metrics: BaselineMetrics, - ) -> Self { - let (tx, rx) = futures::channel::oneshot::channel(); - let schema = input.schema(); - let join_handle = tokio::spawn(async move { - let schema = input.schema(); - let sorted_batch = common::collect(input) - .await - .map_err(DataFusionError::into_arrow_external_error) - .and_then(move |batches| { - let timer = baseline_metrics.elapsed_compute().timer(); - // combine all record batches into one for each column - let combined = common::combine_batches(&batches, schema.clone())?; - // sort combined record batch - let result = combined - .map(|batch| sort_batch(batch, schema, &expr)) - .transpose()? - .record_output(&baseline_metrics); - timer.done(); - Ok(result) - }); - - // failing here is OK, the receiver is gone and does not care about the result - tx.send(sorted_batch).ok(); - }); - - Self { - output: rx, - finished: false, - schema, - drop_helper: AbortOnDropSingle::new(join_handle), - } - } -} - -impl Stream for SortStream { - type Item = ArrowResult; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.finished { - return Poll::Ready(None); - } - - // is the output ready? - let this = self.project(); - let output_poll = this.output.poll(cx); - - match output_poll { - Poll::Ready(result) => { - *this.finished = true; - - // check for error in receiving channel and unwrap actual result - let result = match result { - Err(e) => { - Some(Err(ArrowError::External("".to_string(), Box::new(e)))) - } // error receiving - Ok(result) => result.transpose(), - }; - - Poll::Ready(result) - } - Poll::Pending => Poll::Pending, - } - } -} - -impl RecordBatchStream for SortStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -#[cfg(test)] -mod tests { - use std::collections::{BTreeMap, HashMap}; - - use super::*; - use crate::datasource::object_store::local::LocalFileSystem; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::expressions::col; - use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::{ - collect, - file_format::{CsvExec, PhysicalPlanConfig}, - }; - use crate::test::assert_is_pending; - use crate::test::exec::assert_strong_count_converges_to_zero; - use crate::test::{self, exec::BlockingExec}; - use crate::test_util; - use arrow::array::*; - use arrow::datatypes::*; - use futures::FutureExt; - - #[tokio::test] - async fn test_sort() -> Result<()> { - let schema = test_util::aggr_test_schema(); - let partitions = 4; - let (_, files) = - test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; - - let csv = CsvExec::new( - PhysicalPlanConfig { - object_store: Arc::new(LocalFileSystem {}), - file_schema: Arc::clone(&schema), - file_groups: files, - statistics: Statistics::default(), - projection: None, - batch_size: 1024, - limit: None, - table_partition_cols: vec![], - }, - true, - b',', - ); - - let sort_exec = Arc::new(SortExec::try_new( - vec![ - // c1 string column - PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }, - // c2 uin32 column - PhysicalSortExpr { - expr: col("c2", &schema)?, - options: SortOptions::default(), - }, - // c7 uin8 column - PhysicalSortExpr { - expr: col("c7", &schema)?, - options: SortOptions::default(), - }, - ], - Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), - )?); - - let result: Vec = collect(sort_exec).await?; - assert_eq!(result.len(), 1); - - let columns = result[0].columns(); - - let c1 = columns[0] - .as_any() - .downcast_ref::>() - .unwrap(); - assert_eq!(c1.value(0), "a"); - assert_eq!(c1.value(c1.len() - 1), "e"); - - let c2 = columns[1].as_any().downcast_ref::().unwrap(); - assert_eq!(c2.value(0), 1); - assert_eq!(c2.value(c2.len() - 1), 5,); - - let c7 = columns[6].as_any().downcast_ref::().unwrap(); - assert_eq!(c7.value(0), 15); - assert_eq!(c7.value(c7.len() - 1), 254,); - - Ok(()) - } - - #[tokio::test] - async fn test_sort_metadata() -> Result<()> { - let field_metadata: BTreeMap = - vec![("foo".to_string(), "bar".to_string())] - .into_iter() - .collect(); - let schema_metadata: HashMap = - vec![("baz".to_string(), "barf".to_string())] - .into_iter() - .collect(); - - let mut field = Field::new("field_name", DataType::UInt64, true); - field = field.with_metadata(field_metadata.clone()); - let schema = Schema::new_from(vec![field], schema_metadata.clone()); - let schema = Arc::new(schema); - - let data: ArrayRef = - Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); - - let batch = RecordBatch::try_new(schema.clone(), vec![data]).unwrap(); - let input = - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); - - let sort_exec = Arc::new(SortExec::try_new( - vec![PhysicalSortExpr { - expr: col("field_name", &schema)?, - options: SortOptions::default(), - }], - input, - )?); - - let result: Vec = collect(sort_exec).await?; - - let expected_data: ArrayRef = - Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); - let expected_batch = - RecordBatch::try_new(schema.clone(), vec![expected_data]).unwrap(); - - // Data is correct - assert_eq!(&vec![expected_batch], &result); - - // explicitlty ensure the metadata is present - assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata); - assert_eq!(result[0].schema().metadata(), &schema_metadata); - - Ok(()) - } - - #[tokio::test] - async fn test_lex_sort_by_float() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, true), - Field::new("b", DataType::Float64, true), - ])); - - // define data. - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![ - Some(f32::NAN), - None, - None, - Some(f32::NAN), - Some(1.0_f32), - Some(1.0_f32), - Some(2.0_f32), - Some(3.0_f32), - ])), - Arc::new(Float64Array::from(vec![ - Some(200.0_f64), - Some(20.0_f64), - Some(10.0_f64), - Some(100.0_f64), - Some(f64::NAN), - None, - None, - Some(f64::NAN), - ])), - ], - )?; - - let sort_exec = Arc::new(SortExec::try_new( - vec![ - PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: true, - nulls_first: true, - }, - }, - PhysicalSortExpr { - expr: col("b", &schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }, - ], - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?), - )?); - - assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); - assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); - - let result: Vec = collect(sort_exec.clone()).await?; - let metrics = sort_exec.metrics().unwrap(); - assert!(metrics.elapsed_compute().unwrap() > 0); - assert_eq!(metrics.output_rows().unwrap(), 8); - assert_eq!(result.len(), 1); - - let columns = result[0].columns(); - - assert_eq!(DataType::Float32, *columns[0].data_type()); - assert_eq!(DataType::Float64, *columns[1].data_type()); - - let a = columns[0].as_any().downcast_ref::().unwrap(); - let b = columns[1].as_any().downcast_ref::().unwrap(); - - // convert result to strings to allow comparing to expected result containing NaN - let result: Vec<(Option, Option)> = (0..result[0].num_rows()) - .map(|i| { - let aval = if a.is_valid(i) { - Some(a.value(i).to_string()) - } else { - None - }; - let bval = if b.is_valid(i) { - Some(b.value(i).to_string()) - } else { - None - }; - (aval, bval) - }) - .collect(); - - let expected: Vec<(Option, Option)> = vec![ - (None, Some("10".to_owned())), - (None, Some("20".to_owned())), - (Some("NaN".to_owned()), Some("100".to_owned())), - (Some("NaN".to_owned()), Some("200".to_owned())), - (Some("3".to_owned()), Some("NaN".to_owned())), - (Some("2".to_owned()), None), - (Some("1".to_owned()), Some("NaN".to_owned())), - (Some("1".to_owned()), None), - ]; - - assert_eq!(expected, result); - - Ok(()) - } - - #[tokio::test] - async fn test_drop_cancel() -> Result<()> { - let schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); - let refs = blocking_exec.refs(); - let sort_exec = Arc::new(SortExec::try_new( - vec![PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions::default(), - }], - blocking_exec, - )?); - - let fut = collect(sort_exec); - let mut fut = fut.boxed(); - - assert_is_pending(&mut fut); - drop(fut); - assert_strong_count_converges_to_zero(refs).await; - - Ok(()) - } -} diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs new file mode 100644 index 000000000000..fdde229f9ca9 --- /dev/null +++ b/datafusion/src/physical_plan/sorts/mod.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort functionalities + +use crate::error; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{PhysicalExpr, SendableRecordBatchStream}; +use crate::record_batch::RecordBatch; +use arrow::array::{ord::DynComparator, ArrayRef}; +use arrow::compute::sort::SortColumn as ArrowSortColumn; +use arrow::compute::sort::SortOptions; +use arrow::error::Result as ArrowResult; +use futures::channel::mpsc; +use futures::stream::FusedStream; +use futures::Stream; +use hashbrown::HashMap; +use std::borrow::BorrowMut; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; + +pub mod sort; +pub mod sort_preserving_merge; + +/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of +/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. +/// +/// Additionally it maintains a row cursor that can be advanced through the rows +/// of the provided `RecordBatch` +/// +/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to +/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores +/// a row comparator for each other cursor that it is compared to. +struct SortKeyCursor { + stream_idx: usize, + sort_columns: Vec, + cur_row: usize, + num_rows: usize, + + // An id uniquely identifying the record batch scanned by this cursor. + batch_id: usize, + + // A collection of comparators that compare rows in this cursor's batch to + // the cursors in other batches. Other batches are uniquely identified by + // their batch_idx. + batch_comparators: RwLock>>, + sort_options: Arc>, +} + +impl<'a> std::fmt::Debug for SortKeyCursor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("SortKeyCursor") + .field("sort_columns", &self.sort_columns) + .field("cur_row", &self.cur_row) + .field("num_rows", &self.num_rows) + .field("batch_id", &self.batch_id) + .field("batch_comparators", &"") + .finish() + } +} + +impl SortKeyCursor { + fn new( + stream_idx: usize, + batch_id: usize, + batch: &RecordBatch, + sort_key: &[Arc], + sort_options: Arc>, + ) -> error::Result { + let sort_columns = sort_key + .iter() + .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) + .collect::>()?; + Ok(Self { + stream_idx, + cur_row: 0, + num_rows: batch.num_rows(), + sort_columns, + batch_id, + batch_comparators: RwLock::new(HashMap::new()), + sort_options, + }) + } + + fn is_finished(&self) -> bool { + self.num_rows == self.cur_row + } + + fn advance(&mut self) -> usize { + assert!(!self.is_finished()); + let t = self.cur_row; + self.cur_row += 1; + t + } + + /// Compares the sort key pointed to by this instance's row cursor with that of another + fn compare(&self, other: &SortKeyCursor) -> error::Result { + if self.sort_columns.len() != other.sort_columns.len() { + return Err(DataFusionError::Internal(format!( + "SortKeyCursors had inconsistent column counts: {} vs {}", + self.sort_columns.len(), + other.sort_columns.len() + ))); + } + + if self.sort_columns.len() != self.sort_options.len() { + return Err(DataFusionError::Internal(format!( + "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", + self.sort_columns.len(), + self.sort_options.len() + ))); + } + + let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self + .sort_columns + .iter() + .zip(other.sort_columns.iter()) + .zip(self.sort_options.iter()) + .collect::>(); + + self.init_cmp_if_needed(other, &zipped)?; + let map = self.batch_comparators.read().unwrap(); + let cmp = map.get(&other.batch_id).ok_or_else(|| { + DataFusionError::Execution(format!( + "Failed to find comparator for {} cmp {}", + self.batch_id, other.batch_id + )) + })?; + + for (i, ((l, r), sort_options)) in zipped.iter().enumerate() { + match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { + (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), + (false, true) => return Ok(Ordering::Greater), + (true, false) if sort_options.nulls_first => { + return Ok(Ordering::Greater) + } + (true, false) => return Ok(Ordering::Less), + (false, false) => {} + (true, true) => match cmp[i](self.cur_row, other.cur_row) { + Ordering::Equal => {} + o if sort_options.descending => return Ok(o.reverse()), + o => return Ok(o), + }, + } + } + + // Break ties using stream_idx to ensure a predictable + // ordering of rows when comparing equal streams. + Ok(self.stream_idx.cmp(&other.stream_idx)) + } + + /// Initialize a collection of comparators for comparing + /// columnar arrays of this cursor and "other" if needed. + fn init_cmp_if_needed( + &self, + other: &SortKeyCursor, + zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)], + ) -> Result<()> { + let hm = self.batch_comparators.read().unwrap(); + if !hm.contains_key(&other.batch_id) { + drop(hm); + let mut map = self.batch_comparators.write().unwrap(); + let cmp = map + .borrow_mut() + .entry(other.batch_id) + .or_insert_with(|| Vec::with_capacity(other.sort_columns.len())); + + for (i, ((l, r), _)) in zipped.iter().enumerate() { + if i >= cmp.len() { + // initialise comparators + cmp.push(arrow::array::ord::build_compare(l.as_ref(), r.as_ref())?); + } + } + } + Ok(()) + } +} + +impl Ord for SortKeyCursor { + /// Needed by min-heap comparison and reverse the order at the same time. + fn cmp(&self, other: &Self) -> Ordering { + other.compare(self).unwrap() + } +} + +impl PartialEq for SortKeyCursor { + fn eq(&self, other: &Self) -> bool { + other.compare(self).unwrap() == Ordering::Equal + } +} + +impl Eq for SortKeyCursor {} + +impl PartialOrd for SortKeyCursor { + fn partial_cmp(&self, other: &Self) -> Option { + other.compare(self).ok() + } +} + +/// A `RowIndex` identifies a specific row from those buffered +/// by a `SortPreservingMergeStream` +#[derive(Debug, Clone)] +struct RowIndex { + /// The index of the stream + stream_idx: usize, + /// The index of the batch within the stream's VecDequeue. + batch_idx: usize, + /// The row index + row_idx: usize, +} + +pub(crate) struct SortedStream { + stream: SendableRecordBatchStream, + mem_used: usize, +} + +impl Debug for SortedStream { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "InMemSorterStream") + } +} + +impl SortedStream { + pub(crate) fn new(stream: SendableRecordBatchStream, mem_used: usize) -> Self { + Self { stream, mem_used } + } +} + +#[derive(Debug)] +enum StreamWrapper { + Receiver(mpsc::Receiver>), + Stream(Option), +} + +impl StreamWrapper { + fn mem_used(&self) -> usize { + match &self { + StreamWrapper::Stream(Some(s)) => s.mem_used, + _ => 0, + } + } +} + +impl Stream for StreamWrapper { + type Item = ArrowResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + StreamWrapper::Receiver(ref mut receiver) => Pin::new(receiver).poll_next(cx), + StreamWrapper::Stream(ref mut stream) => { + let inner = match stream { + None => return Poll::Ready(None), + Some(inner) => inner, + }; + + match Pin::new(&mut inner.stream).poll_next(cx) { + Poll::Ready(msg) => { + if msg.is_none() { + *stream = None + } + Poll::Ready(msg) + } + Poll::Pending => Poll::Pending, + } + } + } + } +} + +impl FusedStream for StreamWrapper { + fn is_terminated(&self) -> bool { + match self { + StreamWrapper::Receiver(receiver) => receiver.is_terminated(), + StreamWrapper::Stream(stream) => stream.is_none(), + } + } +} + +/// One column to be used in lexicographical sort +#[derive(Clone, Debug)] +pub struct SortColumn { + /// The array to be sorted + pub values: ArrayRef, + /// The options to sort the array + pub options: Option, +} + +impl<'a> From<&'a SortColumn> for ArrowSortColumn<'a> { + fn from(c: &'a SortColumn) -> Self { + Self { + values: c.values.as_ref(), + options: c.options, + } + } +} diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs new file mode 100644 index 000000000000..a39ddd3950ae --- /dev/null +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -0,0 +1,907 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort that deals with an arbitrary size of the input. +//! It will do in-memory sorting if it has enough memory budget +//! but spills to disk if needed. + +use crate::error::{DataFusionError, Result}; +use crate::execution::memory_manager::{ + ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, +}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::common::{batch_byte_size, IPCWriter, SizedRecordBatchStream}; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::metrics::{AggregatedMetricsSet, BaselineMetrics, MetricsSet}; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; +use crate::physical_plan::sorts::{SortColumn, SortedStream}; +use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::{ + common, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, + Partitioning, SendableRecordBatchStream, Statistics, +}; +use crate::record_batch::RecordBatch; +use arrow::array::ArrayRef; +pub use arrow::compute::sort::SortOptions; +use arrow::compute::sort::{lexsort_to_indices, SortColumn as ArrowSortColumn}; +use arrow::compute::take::take; +use arrow::datatypes::SchemaRef; +use arrow::error::Result as ArrowResult; +use arrow::io::ipc::read::{read_file_metadata, FileReader}; +use async_trait::async_trait; +use futures::lock::Mutex; +use futures::StreamExt; +use log::{error, info}; +use std::any::Any; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tempfile::NamedTempFile; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::task; + +/// Sort arbitrary size of data to get a total order (may spill several times during sorting based on free memory available). +/// +/// The basic architecture of the algorithm: +/// 1. get a non-empty new batch from input +/// 2. check with the memory manager if we could buffer the batch in memory +/// 2.1 if memory sufficient, then buffer batch in memory, go to 1. +/// 2.2 if the memory threshold is reached, sort all buffered batches and spill to file. +/// buffer the batch in memory, go to 1. +/// 3. when input is exhausted, merge all in memory batches and spills to get a total order. +struct ExternalSorter { + id: MemoryConsumerId, + schema: SchemaRef, + in_mem_batches: Mutex>, + spills: Mutex>, + /// Sort expressions + expr: Vec, + runtime: Arc, + metrics: AggregatedMetricsSet, + inner_metrics: BaselineMetrics, +} + +impl ExternalSorter { + pub fn new( + partition_id: usize, + schema: SchemaRef, + expr: Vec, + metrics: AggregatedMetricsSet, + runtime: Arc, + ) -> Self { + let inner_metrics = metrics.new_intermediate_baseline(partition_id); + Self { + id: MemoryConsumerId::new(partition_id), + schema, + in_mem_batches: Mutex::new(vec![]), + spills: Mutex::new(vec![]), + expr, + runtime, + metrics, + inner_metrics, + } + } + + async fn insert_batch(&self, input: RecordBatch) -> Result<()> { + if input.num_rows() > 0 { + let size = batch_byte_size(&input); + self.try_grow(size).await?; + self.inner_metrics.mem_used().add(size); + let mut in_mem_batches = self.in_mem_batches.lock().await; + in_mem_batches.push(input); + } + Ok(()) + } + + async fn spilled_before(&self) -> bool { + let spills = self.spills.lock().await; + !spills.is_empty() + } + + /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`. + async fn sort(&self) -> Result { + let partition = self.partition_id(); + let mut in_mem_batches = self.in_mem_batches.lock().await; + + if self.spilled_before().await { + let baseline_metrics = self.metrics.new_intermediate_baseline(partition); + let mut streams: Vec = vec![]; + if in_mem_batches.len() > 0 { + let in_mem_stream = in_mem_partial_sort( + &mut *in_mem_batches, + self.schema.clone(), + &self.expr, + baseline_metrics, + )?; + let prev_used = self.inner_metrics.mem_used().set(0); + streams.push(SortedStream::new(in_mem_stream, prev_used)); + } + + let mut spills = self.spills.lock().await; + + for spill in spills.drain(..) { + let stream = read_spill_as_stream(spill, self.schema.clone())?; + streams.push(SortedStream::new(stream, 0)); + } + let baseline_metrics = self.metrics.new_final_baseline(partition); + Ok(Box::pin(SortPreservingMergeStream::new_from_streams( + streams, + self.schema.clone(), + &self.expr, + baseline_metrics, + partition, + self.runtime.clone(), + ))) + } else if in_mem_batches.len() > 0 { + let baseline_metrics = self.metrics.new_final_baseline(partition); + let result = in_mem_partial_sort( + &mut *in_mem_batches, + self.schema.clone(), + &self.expr, + baseline_metrics, + ); + self.inner_metrics.mem_used().set(0); + // TODO: the result size is not tracked + result + } else { + Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) + } + } + + fn used(&self) -> usize { + self.inner_metrics.mem_used().value() + } + + fn spilled_bytes(&self) -> usize { + self.inner_metrics.spilled_bytes().value() + } + + fn spill_count(&self) -> usize { + self.inner_metrics.spill_count().value() + } +} + +impl Debug for ExternalSorter { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ExternalSorter") + .field("id", &self.id()) + .field("memory_used", &self.used()) + .field("spilled_bytes", &self.spilled_bytes()) + .field("spill_count", &self.spill_count()) + .finish() + } +} + +#[async_trait] +impl MemoryConsumer for ExternalSorter { + fn name(&self) -> String { + "ExternalSorter".to_owned() + } + + fn id(&self) -> &MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Requesting + } + + async fn spill(&self) -> Result { + info!( + "{}[{}] spilling sort data of {} to disk while inserting ({} time(s) so far)", + self.name(), + self.id(), + self.used(), + self.spill_count() + ); + + let partition = self.partition_id(); + let mut in_mem_batches = self.in_mem_batches.lock().await; + // we could always get a chance to free some memory as long as we are holding some + if in_mem_batches.len() == 0 { + return Ok(0); + } + + let baseline_metrics = self.metrics.new_intermediate_baseline(partition); + + let spillfile = self.runtime.disk_manager.create_tmp_file()?; + let stream = in_mem_partial_sort( + &mut *in_mem_batches, + self.schema.clone(), + &*self.expr, + baseline_metrics, + ); + + spill_partial_sorted_stream(&mut stream?, spillfile.path(), self.schema.clone()) + .await?; + let mut spills = self.spills.lock().await; + let used = self.inner_metrics.mem_used().set(0); + self.inner_metrics.record_spill(used); + spills.push(spillfile); + Ok(used) + } + + fn mem_used(&self) -> usize { + self.inner_metrics.mem_used().value() + } +} + +/// consume the non-empty `sorted_bathes` and do in_mem_sort +fn in_mem_partial_sort( + buffered_batches: &mut Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + baseline_metrics: BaselineMetrics, +) -> Result { + assert_ne!(buffered_batches.len(), 0); + + let result = { + // NB timer records time taken on drop, so there are no + // calls to `timer.done()` below. + let _timer = baseline_metrics.elapsed_compute().timer(); + + let pre_sort = if buffered_batches.len() == 1 { + buffered_batches.pop() + } else { + let batches = buffered_batches.drain(..).collect::>(); + // combine all record batches into one for each column + common::combine_batches(&batches, schema.clone())? + }; + + pre_sort + .map(|batch| sort_batch(batch, schema.clone(), expressions)) + .transpose()? + }; + + Ok(Box::pin(SizedRecordBatchStream::new( + schema, + vec![Arc::new(result.unwrap())], + baseline_metrics, + ))) +} + +async fn spill_partial_sorted_stream( + in_mem_stream: &mut SendableRecordBatchStream, + path: &Path, + schema: SchemaRef, +) -> Result<()> { + let (sender, receiver) = tokio::sync::mpsc::channel(2); + let path: PathBuf = path.into(); + let handle = task::spawn_blocking(move || write_sorted(receiver, path, schema)); + while let Some(item) = in_mem_stream.next().await { + sender.send(item).await.ok(); + } + drop(sender); + match handle.await { + Ok(r) => r, + Err(e) => Err(DataFusionError::Execution(format!( + "Error occurred while spilling {}", + e + ))), + } +} + +fn read_spill_as_stream( + path: NamedTempFile, + schema: SchemaRef, +) -> Result { + let (sender, receiver): ( + Sender>, + Receiver>, + ) = tokio::sync::mpsc::channel(2); + let schema_ref = schema.clone(); + let join_handle = task::spawn_blocking(move || { + if let Err(e) = read_spill(sender, path.path(), schema_ref) { + error!("Failure while reading spill file: {:?}. Error: {}", path, e); + } + }); + Ok(RecordBatchReceiverStream::create( + &schema, + receiver, + join_handle, + )) +} + +fn write_sorted( + mut receiver: Receiver>, + path: PathBuf, + schema: SchemaRef, +) -> Result<()> { + let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; + while let Some(batch) = receiver.blocking_recv() { + writer.write(&batch?)?; + } + writer.finish()?; + info!( + "Spilled {} batches of total {} rows to disk, memory released {}", + writer.num_batches, writer.num_rows, writer.num_bytes + ); + Ok(()) +} + +fn read_spill( + sender: Sender>, + path: &Path, + schena: SchemaRef, +) -> Result<()> { + let mut file = BufReader::new(File::open(&path)?); + let metadata = read_file_metadata(&mut file)?; + let reader = FileReader::new(file, metadata, None); + for chunk in reader { + let rb = RecordBatch::try_new(schena.clone(), chunk?.into_arrays()); + sender + .blocking_send(rb) + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; + } + Ok(()) +} + +/// External Sort execution plan +#[derive(Debug)] +pub struct SortExec { + /// Input schema + input: Arc, + /// Sort expressions + expr: Vec, + /// Containing all metrics set created during sort + all_metrics: AggregatedMetricsSet, + /// Preserve partitions of input plan + preserve_partitioning: bool, +} + +impl SortExec { + /// Create a new sort execution plan + pub fn try_new( + expr: Vec, + input: Arc, + ) -> Result { + Ok(Self::new_with_partitioning(expr, input, false)) + } + + /// Create a new sort execution plan with the option to preserve + /// the partitioning of the input plan + pub fn new_with_partitioning( + expr: Vec, + input: Arc, + preserve_partitioning: bool, + ) -> Self { + Self { + expr, + input, + all_metrics: AggregatedMetricsSet::new(), + preserve_partitioning, + } + } + + /// Input schema + pub fn input(&self) -> &Arc { + &self.input + } + + /// Sort expressions + pub fn expr(&self) -> &[PhysicalSortExpr] { + &self.expr + } +} + +#[async_trait] +impl ExecutionPlan for SortExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + if self.preserve_partitioning { + self.input.output_partitioning() + } else { + Partitioning::UnknownPartitioning(1) + } + } + + fn required_child_distribution(&self) -> Distribution { + if self.preserve_partitioning { + Distribution::UnspecifiedDistribution + } else { + Distribution::SinglePartition + } + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(SortExec::try_new( + self.expr.clone(), + children[0].clone(), + )?)), + _ => Err(DataFusionError::Internal( + "SortExec wrong number of children".to_string(), + )), + } + } + + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { + if !self.preserve_partitioning { + if 0 != partition { + return Err(DataFusionError::Internal(format!( + "SortExec invalid partition {}", + partition + ))); + } + + // sort needs to operate on a single partition currently + if 1 != self.input.output_partitioning().partition_count() { + return Err(DataFusionError::Internal( + "SortExec requires a single input partition".to_owned(), + )); + } + } + + let input = self.input.execute(partition, runtime.clone()).await?; + + do_sort( + input, + partition, + self.expr.clone(), + self.all_metrics.clone(), + runtime, + ) + .await + } + + fn metrics(&self) -> Option { + Some(self.all_metrics.aggregate_all()) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); + write!(f, "SortExec: [{}]", expr.join(",")) + } + } + } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } +} + +fn sort_batch( + batch: RecordBatch, + schema: SchemaRef, + expr: &[PhysicalSortExpr], +) -> ArrowResult { + // TODO: pushup the limit expression to sort + let vec = expr + .iter() + .map(|e| e.evaluate_to_sort_column(&batch)) + .collect::>>()?; + let indices = lexsort_to_indices::( + vec.iter() + .map(|sc| sc.into()) + .collect::>() + .as_slice(), + None, + )?; + + // reorder all rows based on sorted indices + RecordBatch::try_new( + schema, + batch + .columns() + .iter() + .map(|column| take(column.as_ref(), &indices).map(Arc::from)) + .collect::>>()?, + ) +} + +async fn do_sort( + mut input: SendableRecordBatchStream, + partition_id: usize, + expr: Vec, + metrics: AggregatedMetricsSet, + runtime: Arc, +) -> Result { + let schema = input.schema(); + let sorter = Arc::new(ExternalSorter::new( + partition_id, + schema.clone(), + expr, + metrics, + runtime.clone(), + )); + runtime.register_consumer(&(sorter.clone() as Arc)); + + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch).await?; + } + + let result = sorter.sort().await; + runtime.drop_consumer(sorter.id()); + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cast::{as_primitive_array, as_string_array}; + use crate::datasource::object_store::local::LocalFileSystem; + use crate::execution::context::ExecutionConfig; + use crate::field_util::{FieldExt, SchemaExt}; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::expressions::col; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::{ + collect, + file_format::{CsvExec, FileScanConfig}, + }; + use crate::test; + use crate::test::assert_is_pending; + use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test_util; + use arrow::array::*; + use arrow::compute::sort::SortOptions; + use arrow::datatypes::*; + use futures::FutureExt; + use std::collections::BTreeMap; + + #[tokio::test] + async fn test_in_mem_sort() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let schema = test_util::aggr_test_schema(); + let partitions = 4; + let (_, files) = + test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + + let csv = CsvExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_schema: Arc::clone(&schema), + file_groups: files, + statistics: Statistics::default(), + projection: None, + limit: None, + table_partition_cols: vec![], + }, + true, + b',', + ); + + let sort_exec = Arc::new(SortExec::try_new( + vec![ + // c1 string column + PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }, + // c2 uin32 column + PhysicalSortExpr { + expr: col("c2", &schema)?, + options: SortOptions::default(), + }, + // c7 uin8 column + PhysicalSortExpr { + expr: col("c7", &schema)?, + options: SortOptions::default(), + }, + ], + Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), + )?); + + let result = collect(sort_exec, runtime).await?; + + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + let c1 = as_string_array(columns[0].as_ref()); + assert_eq!(c1.value(0), "a"); + assert_eq!(c1.value(c1.len() - 1), "e"); + + let c2 = as_primitive_array::(columns[1].as_ref()); + assert_eq!(c2.value(0), 1); + assert_eq!(c2.value(c2.len() - 1), 5,); + + let c7 = as_primitive_array::(columns[6].as_ref()); + assert_eq!(c7.value(0), 15); + assert_eq!(c7.value(c7.len() - 1), 254,); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_spill() -> Result<()> { + // trigger spill there will be 4 batches with 5.5KB for each + let config = ExecutionConfig::new().with_memory_limit(12288, 1.0)?; + let runtime = Arc::new(RuntimeEnv::new(config.runtime)?); + + let schema = test_util::aggr_test_schema(); + let partitions = 4; + let (_, files) = + test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + + let csv = CsvExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_schema: Arc::clone(&schema), + file_groups: files, + statistics: Statistics::default(), + projection: None, + limit: None, + table_partition_cols: vec![], + }, + true, + b',', + ); + + let sort_exec = Arc::new(SortExec::try_new( + vec![ + // c1 string column + PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }, + // c2 uin32 column + PhysicalSortExpr { + expr: col("c2", &schema)?, + options: SortOptions::default(), + }, + // c7 uin8 column + PhysicalSortExpr { + expr: col("c7", &schema)?, + options: SortOptions::default(), + }, + ], + Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), + )?); + + let result = collect(sort_exec.clone(), runtime).await?; + + assert_eq!(result.len(), 1); + + // Now, validate metrics + let metrics = sort_exec.metrics().unwrap(); + + assert_eq!(metrics.output_rows().unwrap(), 100); + assert!(metrics.elapsed_compute().unwrap() > 0); + assert!(metrics.spill_count().unwrap() > 0); + assert!(metrics.spilled_bytes().unwrap() > 0); + + let columns = result[0].columns(); + + let c1 = as_string_array(columns[0].as_ref()); + assert_eq!(c1.value(0), "a"); + assert_eq!(c1.value(c1.len() - 1), "e"); + + let c2 = as_primitive_array::(columns[1].as_ref()); + assert_eq!(c2.value(0), 1); + assert_eq!(c2.value(c2.len() - 1), 5,); + + let c7 = as_primitive_array::(columns[6].as_ref()); + assert_eq!(c7.value(0), 15); + assert_eq!(c7.value(c7.len() - 1), 254,); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_metadata() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let field_metadata: BTreeMap = + vec![("foo".to_string(), "bar".to_string())] + .into_iter() + .collect(); + let schema_metadata: BTreeMap = + vec![("baz".to_string(), "barf".to_string())] + .into_iter() + .collect(); + + let mut field = Field::new("field_name", DataType::UInt64, true); + field.set_metadata(Some(field_metadata.clone())); + let schema = Schema::new_with_metadata(vec![field], schema_metadata.clone()); + let schema = Arc::new(schema); + + let data: ArrayRef = + Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); + + let batch = RecordBatch::try_new(schema.clone(), vec![data]).unwrap(); + let input = + Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + + let sort_exec = Arc::new(SortExec::try_new( + vec![PhysicalSortExpr { + expr: col("field_name", &schema)?, + options: SortOptions::default(), + }], + input, + )?); + + let result: Vec = collect(sort_exec, runtime).await?; + + let expected_data: ArrayRef = + Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); + let expected_batch = + RecordBatch::try_new(schema.clone(), vec![expected_data]).unwrap(); + + // Data is correct + assert_eq!(&vec![expected_batch], &result); + + // explicitlty ensure the metadata is present + assert_eq!(result[0].schema().fields()[0].metadata(), &field_metadata); + assert_eq!(result[0].schema().metadata(), &schema_metadata); + + Ok(()) + } + + #[tokio::test] + async fn test_lex_sort_by_float() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float64, true), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![ + Some(f32::NAN), + None, + None, + Some(f32::NAN), + Some(1.0_f32), + Some(1.0_f32), + Some(2.0_f32), + Some(3.0_f32), + ])), + Arc::new(Float64Array::from(vec![ + Some(200.0_f64), + Some(20.0_f64), + Some(10.0_f64), + Some(100.0_f64), + Some(f64::NAN), + None, + None, + Some(f64::NAN), + ])), + ], + )?; + + let sort_exec = Arc::new(SortExec::try_new( + vec![ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ], + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?), + )?); + + assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); + assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); + + let result: Vec = collect(sort_exec.clone(), runtime).await?; + let metrics = sort_exec.metrics().unwrap(); + assert!(metrics.elapsed_compute().unwrap() > 0); + assert_eq!(metrics.output_rows().unwrap(), 8); + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + assert_eq!(DataType::Float32, *columns[0].data_type()); + assert_eq!(DataType::Float64, *columns[1].data_type()); + + let a = as_primitive_array::(columns[0].as_ref()); + let b = as_primitive_array::(columns[1].as_ref()); + + // convert result to strings to allow comparing to expected result containing NaN + let result: Vec<(Option, Option)> = (0..result[0].num_rows()) + .map(|i| { + let aval = if a.is_valid(i) { + Some(a.value(i).to_string()) + } else { + None + }; + let bval = if b.is_valid(i) { + Some(b.value(i).to_string()) + } else { + None + }; + (aval, bval) + }) + .collect(); + + let expected: Vec<(Option, Option)> = vec![ + (None, Some("10".to_owned())), + (None, Some("20".to_owned())), + (Some("NaN".to_owned()), Some("100".to_owned())), + (Some("NaN".to_owned()), Some("200".to_owned())), + (Some("3".to_owned()), Some("NaN".to_owned())), + (Some("2".to_owned()), None), + (Some("1".to_owned()), Some("NaN".to_owned())), + (Some("1".to_owned()), None), + ]; + + assert_eq!(expected, result); + + Ok(()) + } + + #[tokio::test] + async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let refs = blocking_exec.refs(); + let sort_exec = Arc::new(SortExec::try_new( + vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions::default(), + }], + blocking_exec, + )?); + + let fut = collect(sort_exec, runtime); + let mut fut = fut.boxed(); + + assert_is_pending(&mut fut); + drop(fut); + assert_strong_count_converges_to_zero(refs).await; + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs similarity index 67% rename from datafusion/src/physical_plan/sort_preserving_merge.rs rename to datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index bc9aada8cee9..f641701bad3a 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -17,29 +17,34 @@ //! Defines the sort preserving merge plan -use super::common::AbortOnDropMany; -use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::physical_plan::common::AbortOnDropMany; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; use std::any::Any; -use std::cmp::Ordering; -use std::collections::VecDeque; +use std::collections::{BinaryHeap, VecDeque}; +use std::fmt::{Debug, Formatter}; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use arrow::array::ord::DynComparator; -use arrow::array::{growable::make_growable, ord::build_compare, ArrayRef}; +use crate::record_batch::RecordBatch; +use arrow::array::growable::make_growable; use arrow::compute::sort::SortOptions; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; use async_trait::async_trait; use futures::channel::mpsc; use futures::stream::FusedStream; use futures::{Stream, StreamExt}; -use hashbrown::HashMap; use crate::error::{DataFusionError, Result}; +use crate::execution::memory_manager::ConsumerType; +use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::{MemoryConsumer, MemoryConsumerId, MemoryManager}; +use crate::field_util::SchemaExt; +use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, StreamWrapper}; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, @@ -52,29 +57,43 @@ use crate::physical_plan::{ /// provided each partition of the input plan is sorted with respect to /// these sort expressions, this operator will yield a single partition /// that is also sorted with respect to them +/// +/// ```text +/// ┌─────────────────────────┐ +/// │ ┌───┬───┬───┬───┐ │ +/// │ │ A │ B │ C │ D │ ... │──┐ +/// │ └───┴───┴───┴───┘ │ │ +/// └─────────────────────────┘ │ ┌───────────────────┐ ┌───────────────────────────────┐ +/// Stream 1 │ │ │ │ ┌───┬───╦═══╦───┬───╦═══╗ │ +/// ├─▶│SortPreservingMerge│───▶│ │ A │ B ║ B ║ C │ D ║ E ║ ... │ +/// │ │ │ │ └───┴─▲─╩═══╩───┴───╩═══╝ │ +/// ┌─────────────────────────┐ │ └───────────────────┘ └─┬─────┴───────────────────────┘ +/// │ ╔═══╦═══╗ │ │ +/// │ ║ B ║ E ║ ... │──┘ │ +/// │ ╚═══╩═══╝ │ Note Stable Sort: the merged stream +/// └─────────────────────────┘ places equal rows from stream 1 +/// Stream 2 +/// +/// +/// Input Streams Output stream +/// (sorted) (sorted) +/// ``` #[derive(Debug)] pub struct SortPreservingMergeExec { /// Input plan input: Arc, /// Sort expressions expr: Vec, - /// The target size of yielded batches - target_batch_size: usize, /// Execution metrics metrics: ExecutionPlanMetricsSet, } impl SortPreservingMergeExec { /// Create a new sort execution plan - pub fn new( - expr: Vec, - input: Arc, - target_batch_size: usize, - ) -> Self { + pub fn new(expr: Vec, input: Arc) -> Self { Self { input, expr, - target_batch_size, metrics: ExecutionPlanMetricsSet::new(), } } @@ -122,7 +141,6 @@ impl ExecutionPlan for SortPreservingMergeExec { 1 => Ok(Arc::new(SortPreservingMergeExec::new( self.expr.clone(), children[0].clone(), - self.target_batch_size, ))), _ => Err(DataFusionError::Internal( "SortPreservingMergeExec wrong number of children".to_string(), @@ -130,7 +148,11 @@ impl ExecutionPlan for SortPreservingMergeExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "SortPreservingMergeExec invalid partition {}", @@ -148,26 +170,31 @@ impl ExecutionPlan for SortPreservingMergeExec { )), 1 => { // bypass if there is only one partition to merge (no metrics in this case either) - self.input.execute(0).await + self.input.execute(0, runtime).await } _ => { let (receivers, join_handles) = (0..input_partitions) .into_iter() .map(|part_i| { let (sender, receiver) = mpsc::channel(1); - let join_handle = - spawn_execution(self.input.clone(), sender, part_i); + let join_handle = spawn_execution( + self.input.clone(), + sender, + part_i, + runtime.clone(), + ); (receiver, join_handle) }) .unzip(); - Ok(Box::pin(SortPreservingMergeStream::new( + Ok(Box::pin(SortPreservingMergeStream::new_from_receivers( receivers, AbortOnDropMany(join_handles), self.schema(), &self.expr, - self.target_batch_size, baseline_metrics, + partition, + runtime, ))) } } @@ -195,163 +222,96 @@ impl ExecutionPlan for SortPreservingMergeExec { } } -/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of -/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. -/// -/// Additionally it maintains a row cursor that can be advanced through the rows -/// of the provided `RecordBatch` -/// -/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to -/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores -/// a row comparator for each other cursor that it is compared to. -struct SortKeyCursor { - columns: Vec, - cur_row: usize, - num_rows: usize, - - // An index uniquely identifying the record batch scanned by this cursor. - batch_idx: usize, - batch: RecordBatch, - - // A collection of comparators that compare rows in this cursor's batch to - // the cursors in other batches. Other batches are uniquely identified by - // their batch_idx. - batch_comparators: HashMap>, +struct MergingStreams { + /// ConsumerId + id: MemoryConsumerId, + /// The sorted input streams to merge together + streams: Mutex>, + /// number of streams + num_streams: usize, + /// Runtime + runtime: Arc, } -impl<'a> std::fmt::Debug for SortKeyCursor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SortKeyCursor") - .field("columns", &self.columns) - .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) - .field("batch_idx", &self.batch_idx) - .field("batch", &self.batch) - .field("batch_comparators", &"") +impl Debug for MergingStreams { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("MergingStreams") + .field("id", &self.id()) .finish() } } -impl SortKeyCursor { +impl MergingStreams { fn new( - batch_idx: usize, - batch: RecordBatch, - sort_key: &[Arc], - ) -> Result { - let columns = sort_key - .iter() - .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) - .collect::>()?; - Ok(Self { - cur_row: 0, - num_rows: batch.num_rows(), - columns, - batch, - batch_idx, - batch_comparators: HashMap::new(), - }) + partition: usize, + input_streams: Vec, + runtime: Arc, + ) -> Self { + Self { + id: MemoryConsumerId::new(partition), + num_streams: input_streams.len(), + streams: Mutex::new(input_streams), + runtime, + } } - fn is_finished(&self) -> bool { - self.num_rows == self.cur_row + fn num_streams(&self) -> usize { + self.num_streams } +} - fn advance(&mut self) -> usize { - assert!(!self.is_finished()); - let t = self.cur_row; - self.cur_row += 1; - t +#[async_trait] +impl MemoryConsumer for MergingStreams { + fn name(&self) -> String { + "MergingStreams".to_owned() } - /// Compares the sort key pointed to by this instance's row cursor with that of another - fn compare( - &mut self, - other: &SortKeyCursor, - options: &[SortOptions], - ) -> Result { - if self.columns.len() != other.columns.len() { - return Err(DataFusionError::Internal(format!( - "SortKeyCursors had inconsistent column counts: {} vs {}", - self.columns.len(), - other.columns.len() - ))); - } - - if self.columns.len() != options.len() { - return Err(DataFusionError::Internal(format!( - "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", - self.columns.len(), - options.len() - ))); - } + fn id(&self) -> &MemoryConsumerId { + &self.id + } - let zipped = self - .columns - .iter() - .zip(other.columns.iter()) - .zip(options.iter()); - - // Recall or initialise a collection of comparators for comparing - // columnar arrays of this cursor and "other". - let cmp = self - .batch_comparators - .entry(other.batch_idx) - .or_insert_with(|| Vec::with_capacity(other.columns.len())); - - for (i, ((l, r), sort_options)) in zipped.enumerate() { - if i >= cmp.len() { - // initialise comparators as potentially needed - cmp.push(build_compare(l.as_ref(), r.as_ref())?); - } + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } - match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { - (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), - (false, true) => return Ok(Ordering::Greater), - (true, false) if sort_options.nulls_first => { - return Ok(Ordering::Greater) - } - (true, false) => return Ok(Ordering::Less), - (false, false) => {} - (true, true) => match cmp[i](self.cur_row, other.cur_row) { - Ordering::Equal => {} - o if sort_options.descending => return Ok(o.reverse()), - o => return Ok(o), - }, - } - } + fn type_(&self) -> &ConsumerType { + &ConsumerType::Tracking + } - Ok(Ordering::Equal) + async fn spill(&self) -> Result { + return Err(DataFusionError::Internal(format!( + "Calling spill on a tracking only consumer {}, {}", + self.name(), + self.id, + ))); } -} -/// A `RowIndex` identifies a specific row from those buffered -/// by a `SortPreservingMergeStream` -#[derive(Debug, Clone)] -struct RowIndex { - /// The index of the stream - stream_idx: usize, - /// The index of the cursor within the stream's VecDequeue - cursor_idx: usize, - /// The row index - row_idx: usize, + fn mem_used(&self) -> usize { + let streams = self.streams.lock().unwrap(); + streams.iter().map(StreamWrapper::mem_used).sum::() + } } #[derive(Debug)] -struct SortPreservingMergeStream { +pub(crate) struct SortPreservingMergeStream { /// The schema of the RecordBatches yielded by this stream schema: SchemaRef, /// The sorted input streams to merge together - receivers: Vec>>, + streams: Arc, /// Drop helper for tasks feeding the [`receivers`](Self::receivers) _drop_helper: AbortOnDropMany<()>, - /// For each input stream maintain a dequeue of SortKeyCursor + /// For each input stream maintain a dequeue of RecordBatches /// - /// Exhausted cursors will be popped off the front once all + /// Exhausted batches will be popped off the front once all /// their rows have been yielded to the output - cursors: Vec>, + batches: Vec>, + + /// Maintain a flag for each stream denoting if the current cursor + /// has finished and needs to poll from the stream + cursor_finished: Vec, /// The accumulated row indexes for the next record batch in_progress: Vec, @@ -360,10 +320,7 @@ struct SortPreservingMergeStream { column_expressions: Vec>, /// The sort options for each expression - sort_options: Vec, - - /// The desired RecordBatch size to yield - target_batch_size: usize, + sort_options: Arc>, /// used to record execution metrics baseline_metrics: BaselineMetrics, @@ -371,36 +328,93 @@ struct SortPreservingMergeStream { /// If the stream has encountered an error aborted: bool, - /// An index to uniquely identify the input stream batch - next_batch_index: usize, + /// An id to uniquely identify the input stream batch + next_batch_id: usize, + + /// min heap for record comparison + min_heap: BinaryHeap, + + /// runtime + runtime: Arc, +} + +impl Drop for SortPreservingMergeStream { + fn drop(&mut self) { + self.runtime.drop_consumer(self.streams.id()) + } } impl SortPreservingMergeStream { - fn new( + #[allow(clippy::too_many_arguments)] + pub(crate) fn new_from_receivers( receivers: Vec>>, _drop_helper: AbortOnDropMany<()>, schema: SchemaRef, expressions: &[PhysicalSortExpr], - target_batch_size: usize, baseline_metrics: BaselineMetrics, + partition: usize, + runtime: Arc, ) -> Self { - let cursors = (0..receivers.len()) + let stream_count = receivers.len(); + let batches = (0..stream_count) .into_iter() .map(|_| VecDeque::new()) .collect(); + let wrappers = receivers.into_iter().map(StreamWrapper::Receiver).collect(); + let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); + runtime.register_consumer(&(streams.clone() as Arc)); - Self { + SortPreservingMergeStream { schema, - cursors, - receivers, + batches, + cursor_finished: vec![true; stream_count], + streams, _drop_helper, column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), - sort_options: expressions.iter().map(|x| x.options).collect(), - target_batch_size, + sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), + baseline_metrics, + aborted: false, + in_progress: vec![], + next_batch_id: 0, + min_heap: BinaryHeap::with_capacity(stream_count), + runtime, + } + } + + pub(crate) fn new_from_streams( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + baseline_metrics: BaselineMetrics, + partition: usize, + runtime: Arc, + ) -> Self { + let stream_count = streams.len(); + let batches = (0..stream_count) + .into_iter() + .map(|_| VecDeque::new()) + .collect(); + let wrappers = streams + .into_iter() + .map(|s| StreamWrapper::Stream(Some(s))) + .collect(); + let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); + runtime.register_consumer(&(streams.clone() as Arc)); + + Self { + schema, + batches, + cursor_finished: vec![true; stream_count], + streams, + _drop_helper: AbortOnDropMany(vec![]), + column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), + sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), baseline_metrics, aborted: false, in_progress: vec![], - next_batch_index: 0, + next_batch_id: 0, + min_heap: BinaryHeap::with_capacity(stream_count), + runtime, } } @@ -412,83 +426,71 @@ impl SortPreservingMergeStream { cx: &mut Context<'_>, idx: usize, ) -> Poll> { - if let Some(cursor) = &self.cursors[idx].back() { - if !cursor.is_finished() { - // Cursor is not finished - don't need a new RecordBatch yet - return Poll::Ready(Ok(())); - } - } - - let stream = &mut self.receivers[idx]; - if stream.is_terminated() { + if !self.cursor_finished[idx] { + // Cursor is not finished - don't need a new RecordBatch yet return Poll::Ready(Ok(())); } + let mut empty_batch = false; + { + let mut streams = self.streams.streams.lock().unwrap(); - // Fetch a new input record and create a cursor from it - match futures::ready!(stream.poll_next_unpin(cx)) { - None => return Poll::Ready(Ok(())), - Some(Err(e)) => { - return Poll::Ready(Err(e)); - } - Some(Ok(batch)) => { - let cursor = match SortKeyCursor::new( - self.next_batch_index, // assign this batch an ID - batch, - &self.column_expressions, - ) { - Ok(cursor) => cursor, - Err(e) => { - return Poll::Ready(Err(ArrowError::External( - "".to_string(), - Box::new(e), - ))); - } - }; - self.next_batch_index += 1; - self.cursors[idx].push_back(cursor) + let stream = &mut streams[idx]; + if stream.is_terminated() { + return Poll::Ready(Ok(())); } - } - Poll::Ready(Ok(())) - } - - /// Returns the index of the next stream to pull a row from, or None - /// if all cursors for all streams are exhausted - fn next_stream_idx(&mut self) -> Result> { - let mut min_cursor: Option<(usize, &mut SortKeyCursor)> = None; - for (idx, candidate) in self.cursors.iter_mut().enumerate() { - if let Some(candidate) = candidate.back_mut() { - if candidate.is_finished() { - continue; + // Fetch a new input record and create a cursor from it + match futures::ready!(stream.poll_next_unpin(cx)) { + None => return Poll::Ready(Ok(())), + Some(Err(e)) => { + return Poll::Ready(Err(e)); } - - match min_cursor { - None => min_cursor = Some((idx, candidate)), - Some((_, ref mut min)) => { - if min.compare(candidate, &self.sort_options)? - == Ordering::Greater - { - min_cursor = Some((idx, candidate)) - } + Some(Ok(batch)) => { + if batch.num_rows() > 0 { + let cursor = match SortKeyCursor::new( + idx, + self.next_batch_id, // assign this batch an ID + &batch, + &self.column_expressions, + self.sort_options.clone(), + ) { + Ok(cursor) => cursor, + Err(e) => { + return Poll::Ready(Err(ArrowError::External( + "datafusion".to_string(), + Box::new(e), + ))); + } + }; + self.next_batch_id += 1; + self.min_heap.push(cursor); + self.cursor_finished[idx] = false; + self.batches[idx].push_back(batch) + } else { + empty_batch = true; } } } } - Ok(min_cursor.map(|(idx, _)| idx)) + if empty_batch { + self.maybe_poll_stream(cx, idx) + } else { + Poll::Ready(Ok(())) + } } /// Drains the in_progress row indexes, and builds a new RecordBatch from them /// - /// Will then drop any cursors for which all rows have been yielded to the output + /// Will then drop any batches for which all rows have been yielded to the output fn build_record_batch(&mut self) -> ArrowResult { // Mapping from stream index to the index of the first buffer from that stream let mut buffer_idx = 0; - let mut stream_to_buffer_idx = Vec::with_capacity(self.cursors.len()); + let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len()); - for cursors in &self.cursors { + for batches in &self.batches { stream_to_buffer_idx.push(buffer_idx); - buffer_idx += cursors.len(); + buffer_idx += batches.len(); } let columns = self @@ -498,12 +500,10 @@ impl SortPreservingMergeStream { .enumerate() .map(|(column_idx, _)| { let arrays = self - .cursors + .batches .iter() - .flat_map(|cursor| { - cursor - .iter() - .map(|cursor| cursor.batch.column(column_idx).as_ref()) + .flat_map(|batch| { + batch.iter().map(|batch| batch.column(column_idx).as_ref()) }) .collect::>(); @@ -516,13 +516,13 @@ impl SortPreservingMergeStream { let first = &self.in_progress[0]; let mut buffer_idx = - stream_to_buffer_idx[first.stream_idx] + first.cursor_idx; + stream_to_buffer_idx[first.stream_idx] + first.batch_idx; let mut start_row_idx = first.row_idx; let mut end_row_idx = start_row_idx + 1; for row_index in self.in_progress.iter().skip(1) { let next_buffer_idx = - stream_to_buffer_idx[row_index.stream_idx] + row_index.cursor_idx; + stream_to_buffer_idx[row_index.stream_idx] + row_index.batch_idx; if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { // subsequent row in same batch @@ -552,17 +552,17 @@ impl SortPreservingMergeStream { self.in_progress.clear(); // New cursors are only created once the previous cursor for the stream - // is finished. This means all remaining rows from all but the last cursor + // is finished. This means all remaining rows from all but the last batch // for each stream have been yielded to the newly created record batch // // Additionally as `in_progress` has been drained, there are no longer - // any RowIndex's reliant on the cursor indexes + // any RowIndex's reliant on the batch indexes // - // We can therefore drop all but the last cursor for each stream - for cursors in &mut self.cursors { - if cursors.len() > 1 { - // Drain all but the last cursor - cursors.drain(0..(cursors.len() - 1)); + // We can therefore drop all but the last batch for each stream + for batches in &mut self.batches { + if batches.len() > 1 { + // Drain all but the last batch + batches.drain(0..(batches.len() - 1)); } } @@ -594,7 +594,7 @@ impl SortPreservingMergeStream { // Ensure all non-exhausted streams have a cursor from which // rows can be pulled - for i in 0..self.cursors.len() { + for i in 0..self.streams.num_streams() { match futures::ready!(self.maybe_poll_stream(cx, i)) { Ok(_) => {} Err(e) => { @@ -610,45 +610,45 @@ impl SortPreservingMergeStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let _timer = elapsed_compute.timer(); - let stream_idx = match self.next_stream_idx() { - Ok(Some(idx)) => idx, - Ok(None) if self.in_progress.is_empty() => return Poll::Ready(None), - Ok(None) => return Poll::Ready(Some(self.build_record_batch())), - Err(e) => { - self.aborted = true; - return Poll::Ready(Some(Err(ArrowError::External( - "".to_string(), - Box::new(e), - )))); - } - }; - - let cursors = &mut self.cursors[stream_idx]; - let cursor_idx = cursors.len() - 1; - let cursor = cursors.back_mut().unwrap(); - let row_idx = cursor.advance(); - let cursor_finished = cursor.is_finished(); - - self.in_progress.push(RowIndex { - stream_idx, - cursor_idx, - row_idx, - }); + match self.min_heap.pop() { + Some(mut cursor) => { + let stream_idx = cursor.stream_idx; + let batch_idx = self.batches[stream_idx].len() - 1; + let row_idx = cursor.advance(); + + let mut cursor_finished = false; + // insert the cursor back to min_heap if the record batch is not exhausted + if !cursor.is_finished() { + self.min_heap.push(cursor); + } else { + cursor_finished = true; + self.cursor_finished[stream_idx] = true; + } - if self.in_progress.len() == self.target_batch_size { - return Poll::Ready(Some(self.build_record_batch())); - } + self.in_progress.push(RowIndex { + stream_idx, + batch_idx, + row_idx, + }); - // If removed the last row from the cursor, need to fetch a new record - // batch if possible, before looping round again - if cursor_finished { - match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) { - Ok(_) => {} - Err(e) => { - self.aborted = true; - return Poll::Ready(Some(Err(e))); + if self.in_progress.len() == self.runtime.batch_size() { + return Poll::Ready(Some(self.build_record_batch())); + } + + // If removed the last row from the cursor, need to fetch a new record + // batch if possible, before looping round again + if cursor_finished { + match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) { + Ok(_) => {} + Err(e) => { + self.aborted = true; + return Poll::Ready(Some(Err(e))); + } + } } } + None if self.in_progress.is_empty() => return Poll::Ready(None), + None => return Poll::Ready(Some(self.build_record_batch())), } } } @@ -663,8 +663,10 @@ impl RecordBatchStream for SortPreservingMergeStream { #[cfg(test)] mod tests { use crate::datasource::object_store::local::LocalFileSystem; + use crate::physical_plan::metrics::MetricValue; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use arrow::array::ArrayRef; use std::iter::FromIterator; use crate::arrow::array::*; @@ -673,20 +675,22 @@ mod tests { use crate::assert_batches_eq; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::col; - use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; + use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::sort::SortExec; + use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{collect, common}; use crate::test::{self, assert_is_pending}; use crate::test_util; use super::*; + use crate::execution::runtime_env::RuntimeConfig; use arrow::datatypes::{DataType, Field, Schema}; use futures::{FutureExt, SinkExt}; use tokio_stream::StreamExt; #[tokio::test] async fn test_merge_interleave() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(Utf8Array::::from(&[ Some("a"), @@ -734,12 +738,14 @@ mod tests { "| 3 | j | 1970-01-01 00:00:00.000000008 |", "+----+---+-------------------------------+", ], + runtime, ) .await; } #[tokio::test] async fn test_merge_some_overlap() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(Utf8Array::::from_iter(vec![ Some("a"), @@ -786,12 +792,14 @@ mod tests { "| 110 | g | 1970-01-01 00:00:00.000000006 |", "+-----+---+-------------------------------+", ], + runtime, ) .await; } #[tokio::test] async fn test_merge_no_overlap() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(Utf8Array::::from(&[ Some("a"), @@ -838,12 +846,14 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000006 |", "+----+---+-------------------------------+", ], + runtime, ) .await; } #[tokio::test] async fn test_merge_three_partitions() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(Utf8Array::::from(&[ Some("a"), @@ -909,11 +919,16 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000060 |", "+-----+---+-------------------------------+", ], + runtime, ) .await; } - async fn _test_merge(partitions: &[Vec], exp: &[&str]) { + async fn _test_merge( + partitions: &[Vec], + exp: &[&str], + runtime: Arc, + ) { let schema = partitions[0][0].schema(); let sort = vec![ PhysicalSortExpr { @@ -926,18 +941,19 @@ mod tests { }, ]; let exec = MemoryExec::try_new(partitions, schema.clone(), None).unwrap(); - let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge).await.unwrap(); + let collected = collect(merge, runtime).await.unwrap(); assert_batches_eq!(exp, collected.as_slice()); } async fn sorted_merge( input: Arc, sort: Vec, + runtime: Arc, ) -> RecordBatch { - let merge = Arc::new(SortPreservingMergeExec::new(sort, input, 1024)); - let mut result = collect(merge).await.unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); + let mut result = collect(merge, runtime).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } @@ -945,38 +961,40 @@ mod tests { async fn partition_sort( input: Arc, sort: Vec, + runtime: Arc, ) -> RecordBatch { let sort_exec = Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true)); - sorted_merge(sort_exec, sort).await + sorted_merge(sort_exec, sort, runtime).await } async fn basic_sort( src: Arc, sort: Vec, + runtime: Arc, ) -> RecordBatch { let merge = Arc::new(CoalescePartitionsExec::new(src)); let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap()); - let mut result = collect(sort_exec).await.unwrap(); + let mut result = collect(sort_exec, runtime).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } #[tokio::test] async fn test_partition_sort() { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions).unwrap(); let csv = Arc::new(CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: Arc::clone(&schema), file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -1006,8 +1024,8 @@ mod tests { }, ]; - let basic = basic_sort(csv.clone(), sort.clone()).await; - let partition = partition_sort(csv, sort).await; + let basic = basic_sort(csv.clone(), sort.clone(), runtime.clone()).await; + let partition = partition_sort(csv, sort, runtime.clone()).await; let basic = arrow_print::write(&[basic]); let partition = arrow_print::write(&[partition]); @@ -1047,6 +1065,7 @@ mod tests { async fn sorted_partitioned_input( sort: Vec, sizes: &[usize], + runtime: Arc, ) -> Arc { let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -1054,13 +1073,12 @@ mod tests { test::create_partitioned_csv("aggregate_test_100.csv", partitions).unwrap(); let csv = Arc::new(CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: schema, file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -1068,7 +1086,7 @@ mod tests { b',', )); - let sorted = basic_sort(csv, sort).await; + let sorted = basic_sort(csv, sort, runtime).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); Arc::new(MemoryExec::try_new(&split, sorted.schema().clone(), None).unwrap()) @@ -1076,6 +1094,7 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input() { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let sort = vec![ // uint8 @@ -1100,9 +1119,10 @@ mod tests { }, ]; - let input = sorted_partitioned_input(sort.clone(), &[10, 3, 11]).await; - let basic = basic_sort(input.clone(), sort.clone()).await; - let partition = sorted_merge(input, sort).await; + let input = + sorted_partitioned_input(sort.clone(), &[10, 3, 11], runtime.clone()).await; + let basic = basic_sort(input.clone(), sort.clone(), runtime.clone()).await; + let partition = sorted_merge(input, sort, runtime.clone()).await; assert_eq!(basic.num_rows(), 300); assert_eq!(partition.num_rows(), 300); @@ -1130,11 +1150,15 @@ mod tests { }, ]; - let input = sorted_partitioned_input(sort.clone(), &[10, 5, 13]).await; - let basic = basic_sort(input.clone(), sort.clone()).await; + let runtime = Arc::new(RuntimeEnv::default()); + let input = + sorted_partitioned_input(sort.clone(), &[10, 5, 13], runtime.clone()).await; + let basic = basic_sort(input.clone(), sort.clone(), runtime).await; - let merge = Arc::new(SortPreservingMergeExec::new(sort, input, 23)); - let merged = collect(merge).await.unwrap(); + let runtime_bs_23 = + Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(23)).unwrap()); + let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); + let merged = collect(merge, runtime_bs_23).await.unwrap(); assert_eq!(merged.len(), 14); @@ -1149,6 +1173,7 @@ mod tests { #[tokio::test] async fn test_nulls() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(Utf8Array::::from(&[ None, @@ -1195,9 +1220,9 @@ mod tests { }, ]; let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); - let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge).await.unwrap(); + let collected = collect(merge, runtime).await.unwrap(); assert_eq!(collected.len(), 1); assert_batches_eq!( @@ -1223,13 +1248,15 @@ mod tests { #[tokio::test] async fn test_async() { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let sort = vec![PhysicalSortExpr { expr: col("c12", &schema).unwrap(), options: SortOptions::default(), }]; - let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3]).await; + let batches = + sorted_partitioned_input(sort.clone(), &[5, 7, 3], runtime.clone()).await; let partition_count = batches.output_partitioning().partition_count(); let mut join_handles = Vec::with_capacity(partition_count); @@ -1237,7 +1264,7 @@ mod tests { for partition in 0..partition_count { let (mut sender, receiver) = mpsc::channel(1); - let mut stream = batches.execute(partition).await.unwrap(); + let mut stream = batches.execute(partition, runtime.clone()).await.unwrap(); let join_handle = tokio::spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); @@ -1252,14 +1279,15 @@ mod tests { let metrics = ExecutionPlanMetricsSet::new(); let baseline_metrics = BaselineMetrics::new(&metrics, 0); - let merge_stream = SortPreservingMergeStream::new( + let merge_stream = SortPreservingMergeStream::new_from_receivers( receivers, // Use empty vector since we want to use the join handles ourselves AbortOnDropMany(vec![]), batches.schema(), sort.as_slice(), - 1024, baseline_metrics, + 0, + runtime.clone(), ); let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); @@ -1271,7 +1299,7 @@ mod tests { assert_eq!(merged.len(), 1); let merged = merged.remove(0); - let basic = basic_sort(batches, sort.clone()).await; + let basic = basic_sort(batches, sort.clone(), runtime.clone()).await; let basic = arrow_print::write(&[basic]); let partition = arrow_print::write(&[merged]); @@ -1285,6 +1313,7 @@ mod tests { #[tokio::test] async fn test_merge_metrics() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2])); let b: ArrayRef = Arc::new(Utf8Array::::from_iter(vec![Some("a"), Some("c")])); @@ -1302,9 +1331,9 @@ mod tests { }]; let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema.clone(), None).unwrap(); - let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge.clone()).await.unwrap(); + let collected = collect(merge.clone(), runtime).await.unwrap(); let expected = vec![ "+----+---+", "| a | b |", @@ -1343,6 +1372,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1354,10 +1384,9 @@ mod tests { options: SortOptions::default(), }], blocking_exec, - 1, )); - let fut = collect(sort_preserving_merge_exec); + let fut = collect(sort_preserving_merge_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1366,4 +1395,84 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_stable_sort() { + let runtime = Arc::new(RuntimeEnv::default()); + + // Create record batches like: + // batch_number |value + // -------------+------ + // 1 | A + // 1 | B + // + // Ensure that the output is in the same order the batches were fed + let partitions: Vec> = (0..10) + .map(|batch_number| { + let batch_number: Int32Array = + vec![Some(batch_number), Some(batch_number)] + .into_iter() + .collect(); + let value: Utf8Array = + vec![Some("A"), Some("B")].into_iter().collect(); + + let batch = RecordBatch::try_from_iter(vec![ + ("batch_number", Arc::new(batch_number) as ArrayRef), + ("value", Arc::new(value) as ArrayRef), + ]) + .unwrap(); + + vec![batch] + }) + .collect(); + + let schema = partitions[0][0].schema(); + + let sort = vec![PhysicalSortExpr { + expr: col("value", schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + + let exec = MemoryExec::try_new(&partitions, schema.clone(), None).unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); + + let collected = collect(merge, runtime).await.unwrap(); + assert_eq!(collected.len(), 1); + + // Expect the data to be sorted first by "batch_number" (because + // that was the order it was fed in, even though only "value" + // is in the sort key) + assert_batches_eq!( + &[ + "+--------------+-------+", + "| batch_number | value |", + "+--------------+-------+", + "| 0 | A |", + "| 1 | A |", + "| 2 | A |", + "| 3 | A |", + "| 4 | A |", + "| 5 | A |", + "| 6 | A |", + "| 7 | A |", + "| 8 | A |", + "| 9 | A |", + "| 0 | B |", + "| 1 | B |", + "| 2 | B |", + "| 3 | B |", + "| 4 | B |", + "| 5 | B |", + "| 6 | B |", + "| 7 | B |", + "| 8 | B |", + "| 9 | B |", + "+--------------+-------+", + ], + collected.as_slice() + ); + } } diff --git a/datafusion/src/physical_plan/stream.rs b/datafusion/src/physical_plan/stream.rs index 67b709040690..cf590c99bd7f 100644 --- a/datafusion/src/physical_plan/stream.rs +++ b/datafusion/src/physical_plan/stream.rs @@ -17,9 +17,8 @@ //! Stream wrappers for physical operators -use arrow::{ - datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, -}; +use crate::record_batch::RecordBatch; +use arrow::{datatypes::SchemaRef, error::Result as ArrowResult}; use futures::{Stream, StreamExt}; use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index b4133565aebf..bc98663b157f 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -210,6 +210,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { #[cfg(test)] mod tests { use super::*; + use crate::field_util::SchemaExt; use crate::physical_plan::{ expressions::col, functions::{TypeSignature, Volatility}, diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 33bc5b939b81..0de696d61172 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -56,7 +56,7 @@ pub struct AggregateUDF { } impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("AggregateUDF") .field("name", &self.name) .field("signature", &self.signature) diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index ae85a7feae4c..7355746a368b 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -54,7 +54,7 @@ pub struct ScalarUDF { } impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("ScalarUDF") .field("name", &self.name) .field("signature", &self.signature) diff --git a/datafusion/src/physical_plan/unicode_expressions.rs b/datafusion/src/physical_plan/unicode_expressions.rs index ae7dfab990af..c55eb7e0e4df 100644 --- a/datafusion/src/physical_plan/unicode_expressions.rs +++ b/datafusion/src/physical_plan/unicode_expressions.rs @@ -442,9 +442,11 @@ pub fn substr(args: &[ArrayRef]) -> Result { .map(|((string, start), count)| match (string, start, count) { (Some(string), Some(&start), Some(&count)) => { if count < 0 { - Err(DataFusionError::Execution( - "negative substring length not allowed".to_string(), - )) + Err(DataFusionError::Execution(format!( + "negative substring length not allowed: substr(, {}, {})", + start, + count + ))) } else if start <= 0 { Ok(Some(string.to_string())) } else { diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index 79c50720496d..96dbc2eb448c 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -23,7 +23,8 @@ use std::{any::Any, sync::Arc}; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use crate::record_batch::RecordBatch; +use arrow::datatypes::SchemaRef; use futures::StreamExt; use super::{ @@ -31,6 +32,7 @@ use super::{ ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::execution::runtime_env::RuntimeEnv; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -91,7 +93,11 @@ impl ExecutionPlan for UnionExec { Ok(Arc::new(UnionExec::new(children))) } - async fn execute(&self, mut partition: usize) -> Result { + async fn execute( + &self, + mut partition: usize, + runtime: Arc, + ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); // record the tiny amount of work done in this function so // elapsed_compute is reported as non zero @@ -102,7 +108,7 @@ impl ExecutionPlan for UnionExec { for input in self.inputs.iter() { // Calculate whether partition belongs to the current partition if partition < input.output_partitioning().partition_count() { - let stream = input.execute(partition).await?; + let stream = input.execute(partition, runtime.clone()).await?; return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))); } else { partition -= input.output_partitioning().partition_count(); @@ -183,14 +189,12 @@ fn col_stats_union( .min_value .zip(right.min_value) .map(|(a, b)| expressions::helpers::min(&a, &b)) - .map(Result::ok) - .flatten(); + .and_then(Result::ok); left.max_value = left .max_value .zip(right.max_value) .map(|(a, b)| expressions::helpers::max(&a, &b)) - .map(Result::ok) - .flatten(); + .and_then(Result::ok); left.null_count = left.null_count.zip(right.null_count).map(|(a, b)| a + b); left @@ -221,17 +225,18 @@ mod tests { use crate::datasource::object_store::{local::LocalFileSystem, ObjectStore}; use crate::{test, test_util}; + use crate::record_batch::RecordBatch; use crate::{ physical_plan::{ collect, - file_format::{CsvExec, PhysicalPlanConfig}, + file_format::{CsvExec, FileScanConfig}, }, scalar::ScalarValue, }; - use arrow::record_batch::RecordBatch; #[tokio::test] async fn test_union_partitions() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let fs: Arc = Arc::new(LocalFileSystem {}); @@ -240,13 +245,12 @@ mod tests { let (_, files2) = test::create_partitioned_csv("aggregate_test_100.csv", 5)?; let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::clone(&fs), file_schema: Arc::clone(&schema), file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -255,13 +259,12 @@ mod tests { ); let csv2 = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::clone(&fs), file_schema: Arc::clone(&schema), file_groups: files2, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -274,7 +277,7 @@ mod tests { // Should have 9 partitions and 9 output batches assert_eq!(union_exec.output_partitioning().partition_count(), 9); - let result: Vec = collect(union_exec).await?; + let result: Vec = collect(union_exec, runtime).await?; assert_eq!(result.len(), 9); Ok(()) diff --git a/datafusion/src/physical_plan/values.rs b/datafusion/src/physical_plan/values.rs index fe66125c077f..8672b5f93c8e 100644 --- a/datafusion/src/physical_plan/values.rs +++ b/datafusion/src/physical_plan/values.rs @@ -19,14 +19,16 @@ use super::{common, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; use crate::physical_plan::{ memory::MemoryStream, ColumnarValue, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, }; +use crate::record_batch::RecordBatch; use crate::scalar::ScalarValue; use arrow::array::new_null_array; use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; use async_trait::async_trait; use std::any::Any; use std::sync::Arc; @@ -134,7 +136,11 @@ impl ExecutionPlan for ValuesExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -177,7 +183,7 @@ mod tests { async fn values_empty_case() -> Result<()> { let schema = test_util::aggr_test_schema(); let empty = ValuesExec::try_new(schema, vec![]); - assert!(!empty.is_ok()); + assert!(empty.is_err()); Ok(()) } } diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 5b34f672cbac..f281d7d01837 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -26,9 +26,9 @@ use crate::physical_plan::{ aggregates, aggregates::AggregateFunction, functions::Signature, type_coercion::data_types, windows::find_ranges_in_range, PhysicalExpr, }; +use crate::record_batch::RecordBatch; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow::record_batch::RecordBatch; use std::any::Any; use std::ops::Range; use std::sync::Arc; diff --git a/datafusion/src/physical_plan/windows/aggregate.rs b/datafusion/src/physical_plan/windows/aggregate.rs index fda1290016dc..e582b4929c3c 100644 --- a/datafusion/src/physical_plan/windows/aggregate.rs +++ b/datafusion/src/physical_plan/windows/aggregate.rs @@ -23,8 +23,8 @@ use crate::physical_plan::windows::find_ranges_in_range; use crate::physical_plan::{ expressions::PhysicalSortExpr, Accumulator, AggregateExpr, PhysicalExpr, WindowExpr, }; +use crate::record_batch::RecordBatch; use arrow::compute::concatenate; -use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use std::any::Any; use std::iter::IntoIterator; diff --git a/datafusion/src/physical_plan/windows/built_in.rs b/datafusion/src/physical_plan/windows/built_in.rs index a3197994be55..0c62dc8172c0 100644 --- a/datafusion/src/physical_plan/windows/built_in.rs +++ b/datafusion/src/physical_plan/windows/built_in.rs @@ -22,8 +22,8 @@ use crate::physical_plan::{ expressions::PhysicalSortExpr, window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowExpr, }; +use crate::record_batch::RecordBatch; use arrow::compute::concatenate; -use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/src/physical_plan/windows/mod.rs b/datafusion/src/physical_plan/windows/mod.rs index d2ab49cf4676..6ee052082ac8 100644 --- a/datafusion/src/physical_plan/windows/mod.rs +++ b/datafusion/src/physical_plan/windows/mod.rs @@ -174,16 +174,18 @@ pub(crate) fn find_ranges_in_range<'a>( mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; + use crate::execution::runtime_env::RuntimeEnv; + use crate::field_util::SchemaExt; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::col; - use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; + use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::{collect, Statistics}; + use crate::record_batch::RecordBatch; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending}; use crate::test_util::{self, aggr_test_schema}; use arrow::array::*; use arrow::datatypes::{DataType, Field, SchemaRef}; - use arrow::record_batch::RecordBatch; use futures::FutureExt; fn create_test_schema(partitions: usize) -> Result<(Arc, SchemaRef)> { @@ -191,13 +193,12 @@ mod tests { let (_, files) = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; let csv = CsvExec::new( - PhysicalPlanConfig { + FileScanConfig { object_store: Arc::new(LocalFileSystem {}), file_schema: aggr_test_schema(), file_groups: files, statistics: Statistics::default(), projection: None, - batch_size: 1024, limit: None, table_partition_cols: vec![], }, @@ -211,6 +212,7 @@ mod tests { #[tokio::test] async fn window_function() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let (input, schema) = create_test_schema(1)?; let window_exec = Arc::new(WindowAggExec::try_new( @@ -247,7 +249,7 @@ mod tests { schema.clone(), )?); - let result: Vec = collect(window_exec).await?; + let result: Vec = collect(window_exec, runtime).await?; assert_eq!(result.len(), 1); let columns = result[0].columns(); @@ -271,6 +273,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -290,7 +293,7 @@ mod tests { schema, )?); - let fut = collect(window_agg_exec); + let fut = collect(window_agg_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/windows/window_agg_exec.rs b/datafusion/src/physical_plan/windows/window_agg_exec.rs index 9c1a83abc98e..379fe2307ed0 100644 --- a/datafusion/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/src/physical_plan/windows/window_agg_exec.rs @@ -18,6 +18,8 @@ //! Stream and channel implementations for window function expressions. use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, @@ -26,11 +28,11 @@ use crate::physical_plan::{ common, ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; +use crate::record_batch::RecordBatch; use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, - record_batch::RecordBatch, }; use async_trait::async_trait; use futures::stream::Stream; @@ -141,8 +143,12 @@ impl ExecutionPlan for WindowAggExec { } } - async fn execute(&self, partition: usize) -> Result { - let input = self.input.execute(partition).await?; + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { + let input = self.input.execute(partition, runtime).await?; let stream = Box::pin(WindowAggStream::new( self.schema.clone(), self.window_expr.clone(), @@ -268,9 +274,7 @@ impl WindowAggStream { elapsed_compute: crate::physical_plan::metrics::Time, ) -> ArrowResult { let input_schema = input.schema(); - let batches = common::collect(input) - .await - .map_err(DataFusionError::into_arrow_external_error)?; + let batches = common::collect(input).await?; // record compute time on drop let _timer = elapsed_compute.timer(); @@ -278,8 +282,7 @@ impl WindowAggStream { let batch = common::combine_batches(&batches, input_schema.clone())?; if let Some(batch) = batch { // calculate window cols - let mut columns = compute_window_aggregates(window_expr, &batch) - .map_err(DataFusionError::into_arrow_external_error)?; + let mut columns = compute_window_aggregates(window_expr, &batch)?; // combine with the original cols // note the setup of window aggregates is that they newly calculated window // expressions are always prepended to the columns diff --git a/datafusion/src/record_batch.rs b/datafusion/src/record_batch.rs new file mode 100644 index 000000000000..8fba09e73e33 --- /dev/null +++ b/datafusion/src/record_batch.rs @@ -0,0 +1,432 @@ +//! Contains [`RecordBatch`]. +use std::sync::Arc; + +use crate::field_util::SchemaExt; +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::compute::filter::{build_filter, filter}; +use arrow::datatypes::*; +use arrow::error::{ArrowError, Result}; + +/// A two-dimensional dataset with a number of +/// columns ([`Array`]) and rows and defined [`Schema`](crate::datatypes::Schema). +/// # Implementation +/// Cloning is `O(C)` where `C` is the number of columns. +#[derive(Clone, Debug, PartialEq)] +pub struct RecordBatch { + schema: Arc, + columns: Vec>, +} + +impl RecordBatch { + /// Creates a [`RecordBatch`] from a schema and columns. + /// # Errors + /// This function errors iff + /// * `columns` is empty + /// * the schema and column data types do not match + /// * `columns` have a different length + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow2::array::PrimitiveArray; + /// # use arrow2::datatypes::{Schema, Field, DataType}; + /// # use arrow2::record_batch::RecordBatch; + /// # fn main() -> arrow2::error::Result<()> { + /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ])); + /// + /// let batch = RecordBatch::try_new( + /// schema, + /// vec![Arc::new(id_array)] + /// )?; + /// # Ok(()) + /// # } + /// ``` + pub fn try_new(schema: Arc, columns: Vec>) -> Result { + let options = RecordBatchOptions::default(); + Self::validate_new_batch(&schema, columns.as_slice(), &options)?; + Ok(RecordBatch { schema, columns }) + } + + /// Creates a [`RecordBatch`] from a schema and columns, with additional options, + /// such as whether to strictly validate field names. + /// + /// See [`Self::try_new()`] for the expected conditions. + pub fn try_new_with_options( + schema: Arc, + columns: Vec>, + options: &RecordBatchOptions, + ) -> Result { + Self::validate_new_batch(&schema, &columns, options)?; + Ok(RecordBatch { schema, columns }) + } + + /// Creates a new empty [`RecordBatch`]. + pub fn new_empty(schema: Arc) -> Self { + let columns = schema + .fields() + .iter() + .map(|field| new_empty_array(field.data_type().clone()).into()) + .collect(); + RecordBatch { schema, columns } + } + + /// Creates a new [`RecordBatch`] from a [`arrow::chunk::Chunk`] + pub fn new_with_chunk(schema: &Arc, chunk: Chunk) -> Self { + Self { + schema: schema.clone(), + columns: chunk.into_arrays(), + } + } + + /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error + /// if any validation check fails. + fn validate_new_batch( + schema: &Schema, + columns: &[Arc], + options: &RecordBatchOptions, + ) -> Result<()> { + // check that there are some columns + if columns.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "at least one column must be defined to create a record batch" + .to_string(), + )); + } + // check that number of fields in schema match column length + if schema.fields().len() != columns.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "number of columns({}) must match number of fields({}) in schema", + columns.len(), + schema.fields().len(), + ))); + } + // check that all columns have the same row count, and match the schema + let len = columns[0].len(); + + // This is a bit repetitive, but it is better to check the condition outside the loop + if options.match_field_names { + for (i, column) in columns.iter().enumerate() { + if column.len() != len { + return Err(ArrowError::InvalidArgumentError( + "all columns in a record batch must have the same length" + .to_string(), + )); + } + if column.data_type() != schema.field(i).data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + schema.field(i).data_type(), + column.data_type(), + i))); + } + } + } else { + for (i, column) in columns.iter().enumerate() { + if column.len() != len { + return Err(ArrowError::InvalidArgumentError( + "all columns in a record batch must have the same length" + .to_string(), + )); + } + if !column.data_type().eq(schema.field(i).data_type()) { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + schema.field(i).data_type(), + column.data_type(), + i))); + } + } + } + + Ok(()) + } + + /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. + pub fn schema(&self) -> &Arc { + &self.schema + } + + /// Returns the number of columns in the record batch. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow2::array::PrimitiveArray; + /// # use arrow2::datatypes::{Schema, Field, DataType}; + /// # use arrow2::record_batch::RecordBatch; + /// # fn main() -> arrow2::error::Result<()> { + /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ])); + /// + /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?; + /// + /// assert_eq!(batch.num_columns(), 1); + /// # Ok(()) + /// # } + /// ``` + pub fn num_columns(&self) -> usize { + self.columns.len() + } + + /// Returns the number of rows in each column. + /// + /// # Panics + /// + /// Panics if the `RecordBatch` contains no columns. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow2::array::PrimitiveArray; + /// # use arrow2::datatypes::{Schema, Field, DataType}; + /// # use arrow2::record_batch::RecordBatch; + /// # fn main() -> arrow2::error::Result<()> { + /// let id_array = PrimitiveArray::from_slice([1i32, 2, 3, 4, 5]); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int32, false) + /// ])); + /// + /// let batch = RecordBatch::try_new(schema, vec![Arc::new(id_array)])?; + /// + /// assert_eq!(batch.num_rows(), 5); + /// # Ok(()) + /// # } + /// ``` + pub fn num_rows(&self) -> usize { + self.columns[0].len() + } + + /// Get a reference to a column's array by index. + /// + /// # Panics + /// + /// Panics if `index` is outside of `0..num_columns`. + pub fn column(&self, index: usize) -> &Arc { + &self.columns[index] + } + + /// Get a reference to all columns in the record batch. + pub fn columns(&self) -> &[Arc] { + &self.columns[..] + } + + /// Create a `RecordBatch` from an iterable list of pairs of the + /// form `(field_name, array)`, with the same requirements on + /// fields and arrays as [`RecordBatch::try_new`]. This method is + /// often used to create a single `RecordBatch` from arrays, + /// e.g. for testing. + /// + /// The resulting schema is marked as nullable for each column if + /// the array for that column is has any nulls. To explicitly + /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`] + /// + /// Example: + /// ``` + /// use std::sync::Arc; + /// use arrow::array::*; + /// use arrow::datatypes::DataType; + /// use datafusion::record_batch::RecordBatch; + /// + /// let a: Arc = Arc::new(Int32Array::from_slice(&[1, 2])); + /// let b: Arc = Arc::new(Utf8Array::::from_slice(&["a", "b"])); + /// + /// let record_batch = RecordBatch::try_from_iter(vec![ + /// ("a", a), + /// ("b", b), + /// ]); + /// ``` + pub fn try_from_iter(value: I) -> Result + where + I: IntoIterator)>, + F: AsRef, + { + // TODO: implement `TryFrom` trait, once + // https://github.com/rust-lang/rust/issues/50133 is no longer an + // issue + let iter = value.into_iter().map(|(field_name, array)| { + let nullable = array.null_count() > 0; + (field_name, array, nullable) + }); + + Self::try_from_iter_with_nullable(iter) + } + + /// Create a `RecordBatch` from an iterable list of tuples of the + /// form `(field_name, array, nullable)`, with the same requirements on + /// fields and arrays as [`RecordBatch::try_new`]. This method is often + /// used to create a single `RecordBatch` from arrays, e.g. for + /// testing. + /// + /// Example: + /// ``` + /// use std::sync::Arc; + /// use arrow::array::*; + /// use arrow::datatypes::DataType; + /// use datafusion::record_batch::RecordBatch; + /// + /// let a: Arc = Arc::new(Int32Array::from_slice(&[1, 2])); + /// let b: Arc = Arc::new(Utf8Array::::from_slice(&["a", "b"])); + /// + /// // Note neither `a` nor `b` has any actual nulls, but we mark + /// // b an nullable + /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ + /// ("a", a, false), + /// ("b", b, true), + /// ]); + /// ``` + pub fn try_from_iter_with_nullable(value: I) -> Result + where + I: IntoIterator, bool)>, + F: AsRef, + { + // TODO: implement `TryFrom` trait, once + // https://github.com/rust-lang/rust/issues/50133 is no longer an + // issue + let (fields, columns) = value + .into_iter() + .map(|(field_name, array, nullable)| { + let field_name = field_name.as_ref(); + let field = Field::new(field_name, array.data_type().clone(), nullable); + (field, array) + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, columns) + } + + /// Deconstructs itself into its internal components + pub fn into_inner(self) -> (Vec>, Arc) { + let Self { columns, schema } = self; + (columns, schema) + } + + /// Projects the schema onto the specified columns + pub fn project(&self, indices: &[usize]) -> Result { + let projected_schema = self.schema.project(indices)?; + let batch_fields = indices + .iter() + .map(|f| { + self.columns.get(*f).cloned().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "project index {} out of bounds, max field {}", + f, + self.columns.len() + )) + }) + }) + .collect::>>()?; + + RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields) + } + + /// Return a new RecordBatch where each column is sliced + /// according to `offset` and `length` + /// + /// # Panics + /// + /// Panics if `offset` with `length` is greater than column length. + pub fn slice(&self, offset: usize, length: usize) -> RecordBatch { + if self.schema.fields().is_empty() { + assert!((offset + length) == 0); + return RecordBatch::new_empty(self.schema.clone()); + } + assert!((offset + length) <= self.num_rows()); + + let columns = self + .columns() + .iter() + .map(|column| Arc::from(column.slice(offset, length))) + .collect(); + + Self { + schema: self.schema.clone(), + columns, + } + } +} + +/// Options that control the behaviour used when creating a [`RecordBatch`]. +#[derive(Debug)] +pub struct RecordBatchOptions { + /// Match field names of structs and lists. If set to `true`, the names must match. + pub match_field_names: bool, +} + +impl Default for RecordBatchOptions { + fn default() -> Self { + Self { + match_field_names: true, + } + } +} + +impl From for RecordBatch { + /// # Panics iff the null count of the array is not null. + fn from(array: StructArray) -> Self { + assert!(array.null_count() == 0); + let (fields, values, _) = array.into_data(); + RecordBatch { + schema: Arc::new(Schema::new(fields)), + columns: values, + } + } +} + +impl From for StructArray { + fn from(batch: RecordBatch) -> Self { + let (fields, values) = batch + .schema + .fields + .iter() + .zip(batch.columns.iter()) + .map(|t| (t.0.clone(), t.1.clone())) + .unzip(); + StructArray::from_data(DataType::Struct(fields), values, None) + } +} + +impl From for Chunk { + fn from(rb: RecordBatch) -> Self { + Chunk::new(rb.columns) + } +} + +impl From<&RecordBatch> for Chunk { + fn from(rb: &RecordBatch) -> Self { + Chunk::new(rb.columns.clone()) + } +} + +/// Returns a new [RecordBatch] with arrays containing only values matching the filter. +/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. +/// Therefore, it is considered undefined behavior to pass `filter` with null values. +pub fn filter_record_batch( + record_batch: &RecordBatch, + filter_values: &BooleanArray, +) -> Result { + let num_colums = record_batch.columns().len(); + + let filtered_arrays = match num_colums { + 1 => { + vec![filter(record_batch.columns()[0].as_ref(), filter_values)?.into()] + } + _ => { + let filter = build_filter(filter_values)?; + record_batch + .columns() + .iter() + .map(|a| filter(a.as_ref()).into()) + .collect() + } + }; + RecordBatch::try_new(record_batch.schema().clone(), filtered_arrays) +} diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index ea447a746cc7..0c0afc2f5b3e 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -20,7 +20,7 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; use crate::error::{DataFusionError, Result}; -use crate::field_util::StructArrayExt; +use crate::field_util::{FieldExt, StructArrayExt}; use arrow::bitmap::Bitmap; use arrow::buffer::Buffer; use arrow::compute::concatenate; @@ -103,6 +103,8 @@ pub enum ScalarValue { IntervalYearMonth(Option), /// Interval with DayTime unit IntervalDayTime(Option), + /// Interval with MonthDayNano unit + IntervalMonthDayNano(Option), /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) #[allow(clippy::box_collection)] Struct(Option>>, Box>), @@ -177,6 +179,8 @@ impl PartialEq for ScalarValue { (IntervalYearMonth(_), _) => false, (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), (IntervalDayTime(_), _) => false, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), + (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, } @@ -269,6 +273,8 @@ impl PartialOrd for ScalarValue { (IntervalYearMonth(_), _) => None, (_, IntervalDayTime(_)) => None, (IntervalDayTime(_), _) => None, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), + (IntervalMonthDayNano(_), _) => None, (Struct(v1, t1), Struct(v2, t2)) => { if t1.eq(t2) { v1.partial_cmp(v2) @@ -327,6 +333,7 @@ impl std::hash::Hash for ScalarValue { TimestampNanosecond(v, _) => v.hash(state), IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), + IntervalMonthDayNano(v) => v.hash(state), Struct(v, t) => { v.hash(state); t.hash(state); @@ -506,320 +513,6 @@ macro_rules! eq_array_primitive { } impl ScalarValue { - /// Return true if the value is numeric - pub fn is_numeric(&self) -> bool { - matches!( - self, - ScalarValue::Float32(_) - | ScalarValue::Float64(_) - | ScalarValue::Decimal128(_, _, _) - | ScalarValue::Int8(_) - | ScalarValue::Int16(_) - | ScalarValue::Int32(_) - | ScalarValue::Int64(_) - | ScalarValue::UInt8(_) - | ScalarValue::UInt16(_) - | ScalarValue::UInt32(_) - | ScalarValue::UInt64(_) - ) - } - - /// Add two numeric ScalarValues - pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(DataFusionError::Internal(format!( - "Addition only supports numeric types, \ - here has {:?} and {:?}", - lhs.get_datatype(), - rhs.get_datatype() - ))); - } - - if lhs.is_null() || rhs.is_null() { - return Err(DataFusionError::Internal( - "Addition does not support empty values".to_string(), - )); - } - - // TODO: Finding a good way to support operation between different types without - // writing a hige match block. - // TODO: Add support for decimal types - match (lhs, rhs) { - (ScalarValue::Decimal128(_, _, _), _) | - (_, ScalarValue::Decimal128(_, _, _)) => { - Err(DataFusionError::Internal( - "Addition with Decimals are not supported for now".to_string() - )) - }, - // f64 / _ - (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() + f2.unwrap()))) - }, - // f32 / _ - (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap() as f64))) - }, - // i64 / _ - (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { - Ok(ScalarValue::Int64(Some(f1.unwrap() + f2.unwrap()))) - }, - // i32 / _ - (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { - Ok(ScalarValue::Int64(Some(f1.unwrap() as i64 + f2.unwrap() as i64))) - }, - // i16 / _ - (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { - Ok(ScalarValue::Int32(Some(f1.unwrap() as i32 + f2.unwrap() as i32))) - }, - // i8 / _ - (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { - Ok(ScalarValue::Int16(Some(f1.unwrap() as i16 + f2.unwrap() as i16))) - }, - // u64 / _ - (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { - Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) - }, - // u32 / _ - (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { - Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) - }, - // u16 / _ - (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { - Ok(ScalarValue::UInt32(Some(f1.unwrap() as u32 + f2.unwrap() as u32))) - }, - // u8 / _ - (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) - }, - (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { - Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 + f2.unwrap() as u16))) - }, - _ => Err(DataFusionError::Internal( - format!( - "Addition only support calculation with the same type or f64 as one of the numbers for now, here has {:?} and {:?}", - lhs.get_datatype(), rhs.get_datatype() - ))), - } - } - - /// Multiply two numeric ScalarValues - pub fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(DataFusionError::Internal(format!( - "Multiplication is only supported on numeric types, \ - here has {:?} and {:?}", - lhs.get_datatype(), - rhs.get_datatype() - ))); - } - - if lhs.is_null() || rhs.is_null() { - return Err(DataFusionError::Internal( - "Multiplication does not support empty values".to_string(), - )); - } - - // TODO: Finding a good way to support operation between different types without - // writing a hige match block. - // TODO: Add support for decimal type - match (lhs, rhs) { - (ScalarValue::Decimal128(_, _, _), _) - | (_, ScalarValue::Decimal128(_, _, _)) => Err(DataFusionError::Internal( - "Multiplication with Decimals are not supported for now".to_string(), - )), - // f64 / _ - (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) - } - // f32 / _ - (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => Ok( - ScalarValue::Float64(Some(f1.unwrap() as f64 * f2.unwrap() as f64)), - ), - // i64 / _ - (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { - Ok(ScalarValue::Int64(Some(f1.unwrap() * f2.unwrap()))) - } - // i32 / _ - (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => Ok(ScalarValue::Int64( - Some(f1.unwrap() as i64 * f2.unwrap() as i64), - )), - // i16 / _ - (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => Ok(ScalarValue::Int32( - Some(f1.unwrap() as i32 * f2.unwrap() as i32), - )), - // i8 / _ - (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => Ok(ScalarValue::Int16( - Some(f1.unwrap() as i16 * f2.unwrap() as i16), - )), - // u64 / _ - (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => Ok( - ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), - ), - // u32 / _ - (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => Ok( - ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), - ), - // u16 / _ - (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => Ok( - ScalarValue::UInt32(Some(f1.unwrap() as u32 * f2.unwrap() as u32)), - ), - // u8 / _ - (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => Ok(ScalarValue::UInt16( - Some(f1.unwrap() as u16 * f2.unwrap() as u16), - )), - _ => Err(DataFusionError::Internal(format!( - "Multiplication only support f64 for now, here has {:?} and {:?}", - lhs.get_datatype(), - rhs.get_datatype() - ))), - } - } - - /// Division between two numeric ScalarValues - pub fn div(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(DataFusionError::Internal(format!( - "Division is only supported on numeric types, \ - here has {:?} and {:?}", - lhs.get_datatype(), - rhs.get_datatype() - ))); - } - - if lhs.is_null() || rhs.is_null() { - return Err(DataFusionError::Internal( - "Division does not support empty values".to_string(), - )); - } - - // TODO: Finding a good way to support operation between different types without - // writing a hige match block. - // TODO: Add support for decimal types - match (lhs, rhs) { - (ScalarValue::Decimal128(_, _, _), _) | - (_, ScalarValue::Decimal128(_, _, _)) => { - Err(DataFusionError::Internal( - "Division with Decimals are not supported for now".to_string() - )) - }, - // f64 / _ - (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() / f2.unwrap()))) - }, - // f32 / _ - (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap()))) - }, - (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap() as f64))) - }, - // i64 / _ - (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - // i32 / _ - (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - // i16 / _ - (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - // i8 / _ - (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - // u64 / _ - (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - // u32 / _ - (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - // u16 / _ - (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - // u8 / _ - (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) - }, - (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) - }, - _ => Err(DataFusionError::Internal( - format!( - "Division only support calculation with the same type or f64 as denominator for now, here has {:?} and {:?}", - lhs.get_datatype(), rhs.get_datatype() - ))), - } - } - - /// Create null scalar value for specific data type. - pub fn new_null(dt: DataType) -> Self { - match dt { - DataType::Timestamp(TimeUnit::Second, _) => { - ScalarValue::TimestampSecond(None, None) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - ScalarValue::TimestampMillisecond(None, None) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - ScalarValue::TimestampMicrosecond(None, None) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - ScalarValue::TimestampNanosecond(None, None) - } - _ => todo!("Create null scalar value for datatype: {:?}", dt), - } - } - /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128( value: i128, @@ -880,6 +573,9 @@ impl ScalarValue { DataType::Interval(IntervalUnit::YearMonth) } ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), + ScalarValue::IntervalMonthDayNano(_) => { + DataType::Interval(IntervalUnit::MonthDayNano) + } ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), } } @@ -973,7 +669,7 @@ impl ScalarValue { /// ``` pub fn iter_to_array( scalars: impl IntoIterator, - ) -> Result> { + ) -> Result { let mut scalars = scalars.into_iter().peekable(); // figure out the type based on the first element @@ -991,7 +687,7 @@ impl ScalarValue { macro_rules! build_array_primitive { ($TY:ty, $SCALAR_TY:ident, $DT:ident) => {{ { - Box::new(scalars + Arc::new(scalars .map(|sv| { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) @@ -1003,7 +699,7 @@ impl ScalarValue { ))) } }).collect::>>()?.to($DT) - ) as Box + ) as Arc } }}; } @@ -1025,7 +721,7 @@ impl ScalarValue { }) .collect::>()?; - Box::new(array) + Arc::new(array) } }}; } @@ -1048,7 +744,7 @@ impl ScalarValue { } }) .collect::>()?; - Box::new(array) + Arc::new(array) } }}; } @@ -1087,18 +783,18 @@ impl ScalarValue { } let array: ListArray = array.into(); - Box::new(array) + Arc::new(array) }} } use DataType::*; - let array: Box = match &data_type { + let array: Arc = match &data_type { DataType::Decimal(precision, scale) => { let decimal_array = ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; - Box::new(decimal_array) + Arc::new(decimal_array) } - DataType::Boolean => Box::new( + DataType::Boolean => Arc::new( scalars .map(|sv| { if let ScalarValue::Boolean(v) = sv { @@ -1190,7 +886,7 @@ impl ScalarValue { DataType::List(_) => { // Fallback case handling homogeneous lists with any ScalarValue element type let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; - Box::new(list_array) + Arc::new(list_array) } DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column @@ -1231,7 +927,7 @@ impl ScalarValue { .map(|c| Self::iter_to_array(c.clone()).map(Arc::from)) .collect::>>()?; - Box::new(StructArray::from_data(data_type, field_values, None)) + Arc::new(StructArray::from_data(data_type, field_values, None)) } _ => { return Err(DataFusionError::Internal(format!( @@ -1285,7 +981,7 @@ impl ScalarValue { flat_len += element_array.len() as i32; offsets.push(flat_len); - elements.push(element_array.into()); + elements.push(element_array); // Element is valid valid.push(true); @@ -1359,6 +1055,10 @@ impl ScalarValue { Some(value) => dyn_to_array!(self, value, size, i32), None => new_null_array(self.get_datatype(), size).into(), }, + ScalarValue::IntervalMonthDayNano(e) => match e { + Some(value) => dyn_to_array!(self, value, size, i128), + None => new_null_array(self.get_datatype(), size).into(), + }, ScalarValue::Int64(e) | ScalarValue::Date64(e) => match e { Some(value) => dyn_to_array!(self, value, size, i64), None => new_null_array(self.get_datatype(), size).into(), @@ -1704,6 +1404,9 @@ impl ScalarValue { ScalarValue::IntervalDayTime(val) => { eq_array_primitive!(array, index, DaysMsArray, val) } + ScalarValue::IntervalMonthDayNano(val) => { + eq_array_primitive!(array, index, Int128Array, val) + } ScalarValue::Struct(_, _) => unimplemented!(), } } @@ -1856,6 +1559,22 @@ impl TryFrom for i64 { } } +// special implementation for i128 because of Decimal128 +impl TryFrom for i128 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + impl_try_from!(UInt8, u8); impl_try_from!(UInt16, u16); impl_try_from!(UInt32, u32); @@ -2105,6 +1824,7 @@ impl fmt::Display for ScalarValue { ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, + ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, ScalarValue::Struct(e, fields) => match e { Some(l) => write!( f, @@ -2123,7 +1843,7 @@ impl fmt::Display for ScalarValue { } impl fmt::Debug for ScalarValue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), @@ -2166,6 +1886,9 @@ impl fmt::Debug for ScalarValue { ScalarValue::IntervalYearMonth(_) => { write!(f, "IntervalYearMonth(\"{}\")", self) } + ScalarValue::IntervalMonthDayNano(_) => { + write!(f, "IntervalMonthDayNano(\"{}\")", self) + } ScalarValue::Struct(e, fields) => { // Use Debug representation of field values match e { @@ -2194,6 +1917,8 @@ mod tests { fn scalar_decimal_test() { let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); assert_eq!(DataType::Decimal(10, 1), decimal_value.get_datatype()); + let try_into_value: i128 = decimal_value.clone().try_into().unwrap(); + assert_eq!(123_i128, try_into_value); assert!(!decimal_value.is_null()); let neg_decimal_value = decimal_value.arithmetic_negate(); match neg_decimal_value { @@ -2262,9 +1987,8 @@ mod tests { ScalarValue::Decimal128(Some(3), 10, 2), ScalarValue::Decimal128(None, 10, 2), ]; - let array: ArrayRef = ScalarValue::iter_to_array(decimal_vec.into_iter()) - .unwrap() - .into(); + let array: ArrayRef = + ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); assert_eq!(4, array.len()); assert_eq!(DataType::Decimal(10, 2), array.data_type().clone()); @@ -2371,7 +2095,7 @@ mod tests { let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); - let expected = $ARRAYTYPE::from($INPUT).as_box(); + let expected = $ARRAYTYPE::from($INPUT).as_arc(); assert_eq!(&array, &expected); }}; @@ -2388,7 +2112,7 @@ mod tests { let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); - let expected: Box = Box::new(Int64Array::from($INPUT)); + let expected: Arc = Arc::new(Int64Array::from($INPUT)); assert_eq!(&array, &expected); }}; @@ -2405,7 +2129,7 @@ mod tests { let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); - let expected: Box = Box::new($ARRAYTYPE::from($INPUT)); + let expected: Arc = Arc::new($ARRAYTYPE::from($INPUT)); assert_eq!(&array, &expected); }}; @@ -2425,7 +2149,7 @@ mod tests { let expected: $ARRAYTYPE = $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); - let expected: Box = Box::new(expected); + let expected: Arc = Arc::new(expected); assert_eq!(&array, &expected); }}; @@ -2995,7 +2719,7 @@ mod tests { ), ]), ]; - let array: ArrayRef = ScalarValue::iter_to_array(scalars).unwrap().into(); + let array: ArrayRef = ScalarValue::iter_to_array(scalars).unwrap(); let expected = Arc::new(struct_array_from(vec![ (field_a, Int32Vec::from_slice(vec![23, 7, -1000]).as_arc()), @@ -3235,7 +2959,7 @@ mod tests { .try_push(Some(vec![Some(vec![Some(9)])])) .unwrap(); - let expected = outer_builder.as_box(); + let expected = outer_builder.as_arc(); assert_eq!(&array, &expected); } @@ -3265,245 +2989,4 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) ); } - - macro_rules! test_scalar_op { - ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident, $RESULT:expr, $RESULT_TYPE:ident) => {{ - let v1 = &ScalarValue::from($LHS as $LHS_TYPE); - let v2 = &ScalarValue::from($RHS as $RHS_TYPE); - assert_eq!( - ScalarValue::$OP(v1, v2).unwrap(), - ScalarValue::from($RESULT as $RESULT_TYPE) - ); - }}; - } - - macro_rules! test_scalar_op_err { - ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident) => {{ - let v1 = &ScalarValue::from($LHS as $LHS_TYPE); - let v2 = &ScalarValue::from($RHS as $RHS_TYPE); - let actual = ScalarValue::$OP(v1, v2).is_err(); - assert!(actual); - }}; - } - - #[test] - fn scalar_addition() { - test_scalar_op!(add, 1, f64, 2, f64, 3, f64); - test_scalar_op!(add, 1, f32, 2, f32, 3, f64); - test_scalar_op!(add, 1, i64, 2, i64, 3, i64); - test_scalar_op!(add, 100, i64, -32, i64, 68, i64); - test_scalar_op!(add, -102, i64, 32, i64, -70, i64); - test_scalar_op!(add, 1, i32, 2, i32, 3, i64); - test_scalar_op!( - add, - std::i32::MAX, - i32, - std::i32::MAX, - i32, - std::i32::MAX as i64 * 2, - i64 - ); - test_scalar_op!(add, 1, i16, 2, i16, 3, i32); - test_scalar_op!( - add, - std::i16::MAX, - i16, - std::i16::MAX, - i16, - std::i16::MAX as i32 * 2, - i32 - ); - test_scalar_op!(add, 1, i8, 2, i8, 3, i16); - test_scalar_op!( - add, - std::i8::MAX, - i8, - std::i8::MAX, - i8, - std::i8::MAX as i16 * 2, - i16 - ); - test_scalar_op!(add, 1, u64, 2, u64, 3, u64); - test_scalar_op!(add, 1, u32, 2, u32, 3, u64); - test_scalar_op!( - add, - std::u32::MAX, - u32, - std::u32::MAX, - u32, - std::u32::MAX as u64 * 2, - u64 - ); - test_scalar_op!(add, 1, u16, 2, u16, 3, u32); - test_scalar_op!( - add, - std::u16::MAX, - u16, - std::u16::MAX, - u16, - std::u16::MAX as u32 * 2, - u32 - ); - test_scalar_op!(add, 1, u8, 2, u8, 3, u16); - test_scalar_op!( - add, - std::u8::MAX, - u8, - std::u8::MAX, - u8, - std::u8::MAX as u16 * 2, - u16 - ); - test_scalar_op_err!(add, 1, i32, 2, u16); - test_scalar_op_err!(add, 1, i32, 2, u16); - - let v1 = &ScalarValue::from(1); - let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - assert!(ScalarValue::add(v1, v2).is_err()); - - let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); - let v2 = &ScalarValue::from(2); - assert!(ScalarValue::add(v1, v2).is_err()); - - let v1 = &ScalarValue::Float32(None); - let v2 = &ScalarValue::from(2); - assert!(ScalarValue::add(v1, v2).is_err()); - - let v2 = &ScalarValue::Float32(None); - let v1 = &ScalarValue::from(2); - assert!(ScalarValue::add(v1, v2).is_err()); - - let v1 = &ScalarValue::Float32(None); - let v2 = &ScalarValue::Float32(None); - assert!(ScalarValue::add(v1, v2).is_err()); - } - - #[test] - fn scalar_multiplication() { - test_scalar_op!(mul, 1, f64, 2, f64, 2, f64); - test_scalar_op!(mul, 1, f32, 2, f32, 2, f64); - test_scalar_op!(mul, 15, i64, 2, i64, 30, i64); - test_scalar_op!(mul, 100, i64, -32, i64, -3200, i64); - test_scalar_op!(mul, -1.1, f64, 2, f64, -2.2, f64); - test_scalar_op!(mul, 1, i32, 2, i32, 2, i64); - test_scalar_op!( - mul, - std::i32::MAX, - i32, - std::i32::MAX, - i32, - std::i32::MAX as i64 * std::i32::MAX as i64, - i64 - ); - test_scalar_op!(mul, 1, i16, 2, i16, 2, i32); - test_scalar_op!( - mul, - std::i16::MAX, - i16, - std::i16::MAX, - i16, - std::i16::MAX as i32 * std::i16::MAX as i32, - i32 - ); - test_scalar_op!(mul, 1, i8, 2, i8, 2, i16); - test_scalar_op!( - mul, - std::i8::MAX, - i8, - std::i8::MAX, - i8, - std::i8::MAX as i16 * std::i8::MAX as i16, - i16 - ); - test_scalar_op!(mul, 1, u64, 2, u64, 2, u64); - test_scalar_op!(mul, 1, u32, 2, u32, 2, u64); - test_scalar_op!( - mul, - std::u32::MAX, - u32, - std::u32::MAX, - u32, - std::u32::MAX as u64 * std::u32::MAX as u64, - u64 - ); - test_scalar_op!(mul, 1, u16, 2, u16, 2, u32); - test_scalar_op!( - mul, - std::u16::MAX, - u16, - std::u16::MAX, - u16, - std::u16::MAX as u32 * std::u16::MAX as u32, - u32 - ); - test_scalar_op!(mul, 1, u8, 2, u8, 2, u16); - test_scalar_op!( - mul, - std::u8::MAX, - u8, - std::u8::MAX, - u8, - std::u8::MAX as u16 * std::u8::MAX as u16, - u16 - ); - test_scalar_op_err!(mul, 1, i32, 2, u16); - test_scalar_op_err!(mul, 1, i32, 2, u16); - - let v1 = &ScalarValue::from(1); - let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - assert!(ScalarValue::mul(v1, v2).is_err()); - - let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); - let v2 = &ScalarValue::from(2); - assert!(ScalarValue::mul(v1, v2).is_err()); - - let v1 = &ScalarValue::Float32(None); - let v2 = &ScalarValue::from(2); - assert!(ScalarValue::mul(v1, v2).is_err()); - - let v2 = &ScalarValue::Float32(None); - let v1 = &ScalarValue::from(2); - assert!(ScalarValue::mul(v1, v2).is_err()); - - let v1 = &ScalarValue::Float32(None); - let v2 = &ScalarValue::Float32(None); - assert!(ScalarValue::mul(v1, v2).is_err()); - } - - #[test] - fn scalar_division() { - test_scalar_op!(div, 1, f64, 2, f64, 0.5, f64); - test_scalar_op!(div, 1, f32, 2, f32, 0.5, f64); - test_scalar_op!(div, 15, i64, 2, i64, 7.5, f64); - test_scalar_op!(div, 100, i64, -2, i64, -50, f64); - test_scalar_op!(div, 1, i32, 2, i32, 0.5, f64); - test_scalar_op!(div, 1, i16, 2, i16, 0.5, f64); - test_scalar_op!(div, 1, i8, 2, i8, 0.5, f64); - test_scalar_op!(div, 1, u64, 2, u64, 0.5, f64); - test_scalar_op!(div, 1, u32, 2, u32, 0.5, f64); - test_scalar_op!(div, 1, u16, 2, u16, 0.5, f64); - test_scalar_op!(div, 1, u8, 2, u8, 0.5, f64); - test_scalar_op_err!(div, 1, i32, 2, u16); - test_scalar_op_err!(div, 1, i32, 2, u16); - - let v1 = &ScalarValue::from(1); - let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - assert!(ScalarValue::div(v1, v2).is_err()); - - let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); - let v2 = &ScalarValue::from(2); - assert!(ScalarValue::div(v1, v2).is_err()); - - let v1 = &ScalarValue::Float32(None); - let v2 = &ScalarValue::from(2); - assert!(ScalarValue::div(v1, v2).is_err()); - - let v2 = &ScalarValue::Float32(None); - let v1 = &ScalarValue::from(2); - assert!(ScalarValue::div(v1, v2).is_err()); - - let v1 = &ScalarValue::Float32(None); - let v2 = &ScalarValue::Float32(None); - assert!(ScalarValue::div(v1, v2).is_err()); - } } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 8a01287294ba..4d99a056cc10 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -49,6 +49,7 @@ use crate::{ use arrow::datatypes::*; use arrow::types::days_ms; +use crate::field_util::SchemaExt; use hashbrown::HashMap; use sqlparser::ast::{ BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg, @@ -218,6 +219,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Process CTEs from top to bottom // do not allow self-references for cte in &with.cte_tables { + // A `WITH` block can't use the same name for many times + let cte_name: &str = cte.alias.name.value.as_ref(); + if ctes.contains_key(cte_name) { + return Err(DataFusionError::SQL(ParserError(format!( + "WITH query name {:?} specified more than once", + cte_name + )))); + } // create logical plan & pass backreferencing CTEs let logical_plan = self.query_to_plan_with_alias( &cte.query, @@ -699,7 +708,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { alias: Option, ) -> Result { let plans = self.plan_from_tables(&select.from, ctes)?; - let plan = match &select.selection { Some(predicate_expr) => { // build join schema @@ -716,33 +724,80 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?; let mut all_join_keys = HashSet::new(); - let mut left = plans[0].clone(); - for right in plans.iter().skip(1) { - let left_schema = left.schema(); - let right_schema = right.schema(); + + let mut plans = plans.into_iter(); + let mut left = plans.next().unwrap(); // have at least one plan + + // List of the plans that have not yet been joined + let mut remaining_plans: Vec> = + plans.into_iter().map(Some).collect(); + + // Take from the list of remaining plans, + loop { let mut join_keys = vec![]; - for (l, r) in &possible_join_keys { - if left_schema.field_from_column(l).is_ok() - && right_schema.field_from_column(r).is_ok() - { - join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_from_column(r).is_ok() - && right_schema.field_from_column(l).is_ok() - { - join_keys.push((r.clone(), l.clone())); - } - } + + // Search all remaining plans for the next to + // join. Prefer the first one that has a join + // predicate in the predicate lists + let plan_with_idx = + remaining_plans.iter().enumerate().find(|(_idx, plan)| { + // skip plans that have been joined already + let plan = if let Some(plan) = plan { + plan + } else { + return false; + }; + + // can we find a match? + let left_schema = left.schema(); + let right_schema = plan.schema(); + for (l, r) in &possible_join_keys { + if left_schema.field_from_column(l).is_ok() + && right_schema.field_from_column(r).is_ok() + { + join_keys.push((l.clone(), r.clone())); + } else if left_schema.field_from_column(r).is_ok() + && right_schema.field_from_column(l).is_ok() + { + join_keys.push((r.clone(), l.clone())); + } + } + // stop if we found join keys + !join_keys.is_empty() + }); + + // If we did not find join keys, either there are + // no more plans, or we can't find any plans that + // can be joined with predicates if join_keys.is_empty() { - left = - LogicalPlanBuilder::from(left).cross_join(right)?.build()?; + assert!(plan_with_idx.is_none()); + + // pick the first non null plan to join + let plan_with_idx = remaining_plans + .iter() + .enumerate() + .find(|(_idx, plan)| plan.is_some()); + if let Some((idx, _)) = plan_with_idx { + let plan = std::mem::take(&mut remaining_plans[idx]).unwrap(); + left = LogicalPlanBuilder::from(left) + .cross_join(&plan)? + .build()?; + } else { + // no more plans to join + break; + } } else { + // have a plan + let (idx, _) = plan_with_idx.expect("found plan node"); + let plan = std::mem::take(&mut remaining_plans[idx]).unwrap(); + let left_keys: Vec = join_keys.iter().map(|(l, _)| l.clone()).collect(); let right_keys: Vec = join_keys.iter().map(|(_, r)| r.clone()).collect(); let builder = LogicalPlanBuilder::from(left); left = builder - .join(right, JoinType::Inner, (left_keys, right_keys))? + .join(&plan, JoinType::Inner, (left_keys, right_keys))? .build()?; } @@ -1148,14 +1203,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .iter() .find(|field| match field.qualifier() { Some(field_q) => { - field.name() == &col.name + field.name() == col.name && field_q.ends_with(&format!(".{}", q)) } _ => false, }) { Some(df_field) => Expr::Column(Column { relation: df_field.qualifier().cloned(), - name: df_field.name().clone(), + name: df_field.name().to_string(), }), None => Expr::Column(col), } @@ -1490,6 +1545,54 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ref right, } => self.parse_sql_binary_op(left, op, right, schema), + SQLExpr::Substring { + expr, + substring_from, + substring_for, + } => { + #[cfg(feature = "unicode_expressions")] + { + let arg = self.sql_expr_to_logical_expr(expr, schema)?; + let args = match (substring_from, substring_for) { + (Some(from_expr), Some(for_expr)) => { + let from_logic = + self.sql_expr_to_logical_expr(from_expr, schema)?; + let for_logic = + self.sql_expr_to_logical_expr(for_expr, schema)?; + vec![arg, from_logic, for_logic] + } + (Some(from_expr), None) => { + let from_logic = + self.sql_expr_to_logical_expr(from_expr, schema)?; + vec![arg, from_logic] + } + (None, Some(for_expr)) => { + let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); + let for_logic = + self.sql_expr_to_logical_expr(for_expr, schema)?; + vec![arg, from_logic, for_logic] + } + _ => { + return Err(DataFusionError::Plan(format!( + "Substring without for/from is not valid {:?}", + sql + ))) + } + }; + Ok(Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::Substr, + args, + }) + } + + #[cfg(not(feature = "unicode_expressions"))] + { + Err(DataFusionError::Internal( + "statement substring requires compilation with feature flag: unicode_expressions.".to_string() + )) + } + } + SQLExpr::Trim { expr, trim_where } => { let (fun, where_expr) = match trim_where { Some((TrimWhereField::Leading, expr)) => { @@ -1820,16 +1923,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Interval is tricky thing // 1 day is not 24 hours because timezones, 1 year != 365/364! 30 days != 1 month // The true way to store and calculate intervals is to store it as it defined - // Due the fact that Arrow supports only two types YearMonth (month) and DayTime (day, time) - // It's not possible to store complex intervals - // It's possible to do select (NOW() + INTERVAL '1 year') + INTERVAL '1 day'; as workaround + // It's why we there are 3 different interval types in Arrow if result_month != 0 && (result_days != 0 || result_millis != 0) { - return Err(DataFusionError::NotImplemented(format!( - "DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: {:?}. Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'", - value - ))); + let result: i128 = ((result_month as i128) << 96) + | ((result_days as i128) << 64) + // IntervalMonthDayNano uses nanos, but IntervalDayTime uses milles + | ((result_millis * 1_000_000_i64) as i128); + + return Ok(Expr::Literal(ScalarValue::IntervalMonthDayNano(Some( + result, + )))); } + // Month interval if result_month != 0 { return Ok(Expr::Literal(ScalarValue::IntervalYearMonth(Some( result_month as i32, @@ -2766,16 +2872,6 @@ mod tests { ); } - #[test] - fn select_unsupported_complex_interval() { - let sql = "SELECT INTERVAL '1 year 1 day'"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert!(matches!( - err, - DataFusionError::NotImplemented(msg) if msg == "DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \"1 year 1 day\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'", - )); - } - #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( @@ -3820,6 +3916,39 @@ mod tests { \n TableScan: public.person projection=None"; quick_test(sql, expected); } + + #[test] + fn cross_join_to_inner_join() { + let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;"; + let expected = "Projection: #person.id\ + \n Join: #lineitem.l_description = #orders.o_item_id\ + \n Join: #person.id = #lineitem.l_item_id\ + \n TableScan: person projection=None\ + \n TableScan: lineitem projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn cross_join_not_to_inner_join() { + let sql = "select person.id from person, orders, lineitem where person.id = person.age;"; + let expected = "Projection: #person.id\ + \n Filter: #person.id = #person.age\ + \n CrossJoin:\ + \n CrossJoin:\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None\ + \n TableScan: lineitem projection=None"; + quick_test(sql, expected); + } + + #[test] + fn cte_use_same_name_multiple_times() { + let sql = "with a as (select * from person), a as (select * from orders) select * from a;"; + let expected = "SQL error: ParserError(\"WITH query name \\\"a\\\" specified more than once\")"; + let result = logical_plan(sql).err().unwrap(); + assert_eq!(expected, format!("{}", result)); + } } fn parse_sql_number(n: &str) -> Result { diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs index 363ab5a7d366..2a74b214ffcd 100644 --- a/datafusion/src/test/exec.rs +++ b/datafusion/src/test/exec.rs @@ -26,13 +26,15 @@ use std::{ }; use tokio::sync::Barrier; +use crate::record_batch::RecordBatch; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, - record_batch::RecordBatch, }; use futures::Stream; +use crate::execution::runtime_env::RuntimeEnv; +use crate::field_util::SchemaExt; use crate::physical_plan::{ common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -162,7 +164,11 @@ impl ExecutionPlan for MockExec { } /// Returns a stream which yields data - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { assert_eq!(partition, 0); // Result doesn't implement clone, so do it ourself @@ -293,7 +299,11 @@ impl ExecutionPlan for BarrierExec { } /// Returns a stream which yields data - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { assert!(partition < self.data.len()); let (tx, rx) = tokio::sync::mpsc::channel(2); @@ -342,6 +352,13 @@ impl ExecutionPlan for BarrierExec { pub struct ErrorExec { schema: SchemaRef, } + +impl Default for ErrorExec { + fn default() -> Self { + Self::new() + } +} + impl ErrorExec { pub fn new() -> Self { let schema = Arc::new(Schema::new(vec![Field::new( @@ -379,7 +396,11 @@ impl ExecutionPlan for ErrorExec { } /// Returns a stream which yields data - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { Err(DataFusionError::Internal(format!( "ErrorExec, unsurprisingly, errored in partition {}", partition @@ -456,7 +477,11 @@ impl ExecutionPlan for StatisticsExec { } } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { unimplemented!("This plan only serves for testing statistics") } @@ -546,7 +571,11 @@ impl ExecutionPlan for BlockingExec { ))) } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { Ok(Box::pin(BlockingStream { schema: Arc::clone(&self.schema), _refs: Arc::clone(&self.refs), diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index dce8d9b6d48d..6e0e44e5e147 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -20,10 +20,11 @@ use crate::datasource::object_store::local::local_unpartitioned_file; use crate::datasource::{MemTable, PartitionedFile, TableProvider}; use crate::error::Result; +use crate::field_util::{FieldExt, SchemaExt}; use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; +use crate::record_batch::RecordBatch; use arrow::array::*; use arrow::datatypes::*; -use arrow::record_batch::RecordBatch; use futures::{Future, FutureExt}; use std::fs::File; use std::io::prelude::*; @@ -120,7 +121,7 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { .schema() .fields() .iter() - .map(|f| f.name().clone()) + .map(|f| f.name().to_string()) .collect(); assert_eq!(actual, expected); } @@ -150,7 +151,11 @@ pub fn build_table_i32( /// Returns the column names on the schema pub fn columns(schema: &Schema) -> Vec { - schema.fields().iter().map(|f| f.name().clone()).collect() + schema + .fields() + .iter() + .map(|f| f.name().to_string()) + .collect() } /// Return a new table provider that has a single Int32 column with diff --git a/datafusion/src/test/variable.rs b/datafusion/src/test/variable.rs index 12597b832df6..4a8bf94aa1db 100644 --- a/datafusion/src/test/variable.rs +++ b/datafusion/src/test/variable.rs @@ -22,6 +22,7 @@ use crate::scalar::ScalarValue; use crate::variable::VarProvider; /// System variable +#[derive(Default)] pub struct SystemVar {} impl SystemVar { @@ -40,6 +41,7 @@ impl VarProvider for SystemVar { } /// user defined variable +#[derive(Default)] pub struct UserDefinedVar {} impl UserDefinedVar { diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs index 06850f6bdc20..429539ec1f53 100644 --- a/datafusion/src/test_util.rs +++ b/datafusion/src/test_util.rs @@ -20,6 +20,7 @@ use std::collections::BTreeMap; use std::{env, error::Error, path::PathBuf, sync::Arc}; +use crate::field_util::SchemaExt; use arrow::datatypes::{DataType, Field, Schema}; /// Compares formatted output of a record batch with an expected diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index e0c75a32f306..d9a73c98a035 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -18,9 +18,9 @@ use arrow::array::{Int32Array, PrimitiveArray, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; - +use datafusion::field_util::{FieldExt, SchemaExt}; use datafusion::physical_plan::empty::EmptyExec; +use datafusion::record_batch::RecordBatch; use datafusion::scalar::ScalarValue; use datafusion::{datasource::TableProvider, physical_plan::collect}; use datafusion::{ @@ -33,7 +33,7 @@ use datafusion::logical_plan::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; use datafusion::physical_plan::{ - ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream, + project_schema, ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -45,6 +45,7 @@ use std::task::{Context, Poll}; use arrow::compute::aggregate; use async_trait::async_trait; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::Projection; //// Custom source dataframe tests //// @@ -107,12 +108,7 @@ impl ExecutionPlan for CustomExecutionPlan { } fn schema(&self) -> SchemaRef { let schema = TEST_CUSTOM_SCHEMA_REF!(); - match &self.projection { - None => schema, - Some(p) => Arc::new(Schema::new( - p.iter().map(|i| schema.field(*i).clone()).collect(), - )), - } + project_schema(&schema, self.projection.as_ref()).expect("projected schema") } fn output_partitioning(&self) -> Partitioning { Partitioning::UnknownPartitioning(1) @@ -132,7 +128,11 @@ impl ExecutionPlan for CustomExecutionPlan { )) } } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } @@ -196,7 +196,6 @@ impl TableProvider for CustomTableProvider { async fn scan( &self, projection: &Option>, - _batch_size: usize, _filters: &[Expr], _limit: Option, ) -> Result> { @@ -241,9 +240,10 @@ async fn custom_source_dataframe() -> Result<()> { let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); + assert_eq!("c2", physical_plan.schema().field(0).name()); - let batches = collect(physical_plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let batches = collect(physical_plan, runtime).await?; let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -289,7 +289,8 @@ async fn optimizers_catch_all_statistics() { ) .unwrap(); - let actual = collect(physical_plan).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let actual = collect(physical_plan, runtime).await.unwrap(); assert_eq!(actual.len(), 1); assert_eq!(format!("{:?}", actual[0]), format!("{:?}", expected)); diff --git a/datafusion/tests/dataframe.rs b/datafusion/tests/dataframe.rs index 99de1800df59..d4850a917d65 100644 --- a/datafusion/tests/dataframe.rs +++ b/datafusion/tests/dataframe.rs @@ -17,15 +17,14 @@ use std::sync::Arc; +use arrow::array::{Int32Array, Utf8Array}; use arrow::datatypes::{DataType, Field, Schema}; -use arrow::{ - array::{Int32Array, Utf8Array}, - record_batch::RecordBatch, -}; +use datafusion::record_batch::RecordBatch; use datafusion::assert_batches_eq; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; +use datafusion::field_util::SchemaExt; use datafusion::logical_plan::{col, Expr}; use datafusion::{datasource::MemTable, prelude::JoinType}; diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index b9277f4f5969..2437140197a1 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -17,12 +17,12 @@ use std::sync::Arc; +use arrow::array::Int32Array; use arrow::array::Utf8Array; use arrow::datatypes::{DataType, Field, Schema}; -use arrow::{array::Int32Array, record_batch::RecordBatch}; - use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; +use datafusion::record_batch::RecordBatch; use datafusion::error::Result; @@ -32,6 +32,7 @@ use datafusion::prelude::*; use datafusion::execution::context::ExecutionContext; use datafusion::assert_batches_eq; +use datafusion::field_util::SchemaExt; fn create_test_table() -> Result> { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/tests/merge_fuzz.rs b/datafusion/tests/merge_fuzz.rs new file mode 100644 index 000000000000..cf8e66dbb116 --- /dev/null +++ b/datafusion/tests/merge_fuzz.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for various corner cases merging streams of RecordBatchs +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array}; +use arrow::compute::sort::SortOptions; +use datafusion::record_batch::RecordBatch; +use datafusion::{ + execution::runtime_env::{RuntimeConfig, RuntimeEnv}, + physical_plan::{ + collect, + expressions::{col, PhysicalSortExpr}, + memory::MemoryExec, + sorts::sort_preserving_merge::SortPreservingMergeExec, + }, +}; +use rand::{prelude::StdRng, Rng, SeedableRng}; + +#[tokio::test] +async fn test_merge_2() { + run_merge_test(vec![ + // (0..100) + // (0..100) + make_staggered_batches(0, 100, 2), + make_staggered_batches(0, 100, 3), + ]) + .await +} + +#[tokio::test] +async fn test_merge_2_no_overlap() { + run_merge_test(vec![ + // (0..20) + // (20..40) + make_staggered_batches(0, 20, 2), + make_staggered_batches(20, 40, 3), + ]) + .await +} + +#[tokio::test] +async fn test_merge_3() { + run_merge_test(vec![ + // (0 .. 100) + // (0 .. 100) + // (0 .. 51) + make_staggered_batches(0, 100, 2), + make_staggered_batches(0, 100, 3), + make_staggered_batches(0, 51, 4), + ]) + .await +} + +#[tokio::test] +async fn test_merge_3_gaps() { + run_merge_test(vec![ + // (0 .. 50)(50 .. 100) + // (0 ..33) (50 .. 100) + // (0 .. 51) + concat( + make_staggered_batches(0, 50, 2), + make_staggered_batches(50, 100, 7), + ), + concat( + make_staggered_batches(0, 33, 21), + make_staggered_batches(50, 123, 31), + ), + make_staggered_batches(0, 51, 11), + ]) + .await +} + +/// Merge a set of input streams using SortPreservingMergeExec and +/// `Vec::sort` and ensure the results are the same. +/// +/// For each case, the `input` streams are turned into a set of of +/// streams which are then merged together by [SortPreservingMerge] +/// +/// Each `Vec` in `input` must be sorted and have a +/// single Int32 field named 'x'. +async fn run_merge_test(input: Vec>) { + // Produce output with the specified output batch sizes + let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; + + for batch_size in batch_sizes { + let first_batch = input + .iter() + .map(|p| p.iter()) + .flatten() + .next() + .expect("at least one batch"); + let schema = first_batch.schema(); + + let sort = vec![PhysicalSortExpr { + expr: col("x", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + + let exec = MemoryExec::try_new(&input, schema.clone(), None).unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); + + let runtime_config = RuntimeConfig::new().with_batch_size(batch_size); + + let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + let collected = collect(merge, runtime).await.unwrap(); + + // verify the output batch size: all batches except the last + // should contain `batch_size` rows + for (i, batch) in collected.iter().enumerate() { + if i < collected.len() - 1 { + assert_eq!( + batch.num_rows(), + batch_size, + "Expected batch {} to have {} rows, got {}", + i, + batch_size, + batch.num_rows() + ); + } + } + + let expected = partitions_to_sorted_vec(&input); + let actual = batches_to_vec(&collected); + + assert_eq!(expected, actual, "failure in @ batch_size {}", batch_size); + } +} + +/// Extracts the i32 values from the set of batches and returns them as a single Vec +fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { + batches + .iter() + .map(|batch| { + assert_eq!(batch.num_columns(), 1); + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + }) + .flatten() + .map(|v| v.copied()) + .collect() +} + +// extract values from batches and sort them +fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { + let mut values: Vec<_> = partitions + .iter() + .map(|batches| batches_to_vec(batches).into_iter()) + .flatten() + .collect(); + + values.sort_unstable(); + values +} + +/// Return the values `low..high` in order, in randomly sized +/// record batches in a field named 'x' of type `Int32` +fn make_staggered_batches(low: i32, high: i32, seed: u64) -> Vec { + let input: Int32Array = (low..high).map(Some).collect(); + + // split into several record batches + let mut remainder = + RecordBatch::try_from_iter(vec![("x", Arc::new(input) as ArrayRef)]).unwrap(); + + let mut batches = vec![]; + + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(seed); + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + + add_empty_batches(batches, &mut rng) +} + +/// Adds a random number of empty record batches into the stream +fn add_empty_batches(batches: Vec, rng: &mut StdRng) -> Vec { + let schema = batches[0].schema().clone(); + + batches + .into_iter() + .map(|batch| { + // insert 0, or 1 empty batches before and after the current batch + let empty_batch = RecordBatch::new_empty(schema.clone()); + std::iter::repeat(empty_batch.clone()) + .take(rng.gen_range(0..2)) + .chain(std::iter::once(batch)) + .chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2))) + }) + .flatten() + .collect() +} + +fn concat(mut v1: Vec, v2: Vec) -> Vec { + v1.extend(v2); + v1 +} diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 3c27b82a3b0b..abba09671cc9 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -29,9 +29,10 @@ use arrow::{ array_to_pages, to_parquet_schema, write_file, Compression, Compressor, DynIter, DynStreamingIterator, Encoding, FallibleStreamingIterator, Version, WriteOptions, }, - record_batch::RecordBatch, }; use chrono::{Datelike, Duration}; +use datafusion::field_util::SchemaExt; +use datafusion::record_batch::RecordBatch; use datafusion::{ arrow_print, datasource::TableProvider, @@ -539,7 +540,8 @@ impl ContextWithParquet { .await .expect("creating physical plan"); - let results = datafusion::physical_plan::collect(physical_plan.clone()) + let runtime = self.ctx.state.lock().unwrap().runtime_env.clone(); + let results = datafusion::physical_plan::collect(physical_plan.clone(), runtime) .await .expect("Running"); diff --git a/datafusion/tests/path_partition.rs b/datafusion/tests/path_partition.rs index e68ef32fa3ee..7697480c2db1 100644 --- a/datafusion/tests/path_partition.rs +++ b/datafusion/tests/path_partition.rs @@ -20,6 +20,7 @@ use std::{fs, io, sync::Arc}; use async_trait::async_trait; +use datafusion::field_util::SchemaExt; use datafusion::{ assert_batches_sorted_eq, datasource::{ diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index 45397267bb11..c3da1f3544ea 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -17,17 +17,20 @@ use arrow::array::*; use arrow::datatypes::*; -use arrow::record_batch::RecordBatch; use async_trait::async_trait; use datafusion::datasource::datasource::{TableProvider, TableProviderFilterPushDown}; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::field_util::SchemaExt; use datafusion::logical_plan::Expr; use datafusion::physical_plan::common::SizedRecordBatchStream; +use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use datafusion::prelude::*; +use datafusion::record_batch::RecordBatch; use datafusion::scalar::ScalarValue; use std::sync::Arc; @@ -76,10 +79,17 @@ impl ExecutionPlan for CustomPlan { unreachable!() } - async fn execute(&self, _: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { + let metrics = ExecutionPlanMetricsSet::new(); + let baseline_metrics = BaselineMetrics::new(&metrics, partition); Ok(Box::pin(SizedRecordBatchStream::new( self.schema(), self.batches.clone(), + baseline_metrics, ))) } @@ -121,7 +131,6 @@ impl TableProvider for CustomProvider { async fn scan( &self, _: &Option>, - _: usize, filters: &[Expr], _: Option, ) -> Result> { diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index edf530be8b7d..9d72752b091d 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -16,6 +16,7 @@ // under the License. use super::*; +use datafusion::scalar::ScalarValue; #[tokio::test] async fn csv_query_avg_multi_batch() -> Result<()> { @@ -25,7 +26,8 @@ async fn csv_query_avg_multi_batch() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.unwrap(); let batch = &results[0]; let column = batch.column(0); let array = column.as_any().downcast_ref::().unwrap(); @@ -49,6 +51,42 @@ async fn csv_query_avg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_covariance_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT covar_pop(c2, c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["-0.07916932235380847"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_covariance_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT covar(c2, c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["-0.07996901247859442"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_correlation() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT corr(c2, c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["-0.19064544190576607"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_variance_1() -> Result<()> { let mut ctx = ExecutionContext::new(); @@ -385,3 +423,53 @@ async fn csv_query_array_agg_one() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn csv_query_array_agg_distinct() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT array_agg(distinct c2) FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + + // The results for this query should be something like the following: + // +------------------------------------------+ + // | ARRAYAGG(DISTINCT aggregate_test_100.c2) | + // +------------------------------------------+ + // | [4, 2, 3, 5, 1] | + // +------------------------------------------+ + // Since ARRAY_AGG(DISTINCT) ordering is nondeterministic, check the schema and contents. + assert_eq!( + *actual[0].schema(), + Schema::new(vec![Field::new( + "ARRAYAGG(DISTINCT aggregate_test_100.c2)", + DataType::List(Box::new(Field::new("item", DataType::UInt32, true))), + false + ),]) + ); + + // We should have 1 row containing a list + let column = actual[0].column(0); + assert_eq!(column.len(), 1); + + if let ScalarValue::List(Some(mut v), _) = ScalarValue::try_from_array(column, 0)? { + // workaround lack of Ord of ScalarValue + let cmp = |a: &ScalarValue, b: &ScalarValue| { + a.partial_cmp(b).expect("Can compare ScalarValues") + }; + v.sort_by(cmp); + assert_eq!( + *v, + vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::UInt32(Some(2)), + ScalarValue::UInt32(Some(3)), + ScalarValue::UInt32(Some(4)), + ScalarValue::UInt32(Some(5)) + ] + ); + } else { + unreachable!(); + } + + Ok(()) +} diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs index 3983389dae34..d0cdf71b0868 100644 --- a/datafusion/tests/sql/avro.rs +++ b/datafusion/tests/sql/avro.rs @@ -124,7 +124,8 @@ async fn avro_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); assert_eq!(1, batch.num_columns()); @@ -153,7 +154,7 @@ async fn avro_explain() { \n CoalescePartitionsExec\ \n HashAggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]\ \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ - \n AvroExec: files=[ARROW_TEST_DATA/avro/alltypes_plain.avro], batch_size=8192, limit=None\ + \n AvroExec: files=[ARROW_TEST_DATA/avro/alltypes_plain.avro], limit=None\ \n", ], ]; diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs index 9cd7bc96ff89..05ca0642bae0 100644 --- a/datafusion/tests/sql/errors.rs +++ b/datafusion/tests/sql/errors.rs @@ -37,7 +37,8 @@ async fn test_cast_expressions_error() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let result = collect(plan).await; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let result = collect(plan, runtime).await; match result { Ok(_) => panic!("expected error"), diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index d524eb29343f..25e0cd6b0bda 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -41,7 +41,8 @@ async fn explain_analyze_baseline_metrics() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(physical_plan.clone()).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(physical_plan.clone(), runtime).await.unwrap(); let formatted = print::write(&results); println!("Query Output:\n\n{}", formatted); @@ -103,8 +104,9 @@ async fn explain_analyze_baseline_metrics() { fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { use datafusion::physical_plan; + use datafusion::physical_plan::sorts; - plan.as_any().downcast_ref::().is_some() + plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() // CoalescePartitionsExec doesn't do any work so is not included || plan.as_any().downcast_ref::().is_some() @@ -325,7 +327,8 @@ async fn csv_explain_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let results = collect(plan).await.expect(&msg); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -522,7 +525,8 @@ async fn csv_explain_verbose_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let results = collect(plan).await.expect(&msg); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -648,7 +652,7 @@ async fn test_physical_plan_display_indent() { " CoalesceBatchesExec: target_batch_size=4096", " FilterExec: c12@1 < CAST(10 AS Float64)", " RepartitionExec: partitioning=RoundRobinBatch(3)", - " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", + " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None", ]; let data_path = datafusion::test_util::arrow_test_data(); @@ -693,13 +697,13 @@ async fn test_physical_plan_display_indent_multi_children() { " ProjectionExec: expr=[c1@0 as c1]", " ProjectionExec: expr=[c1@0 as c1]", " RepartitionExec: partitioning=RoundRobinBatch(3)", - " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", + " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"c2\", index: 0 }], 3)", " ProjectionExec: expr=[c2@0 as c2]", " ProjectionExec: expr=[c1@0 as c2]", " RepartitionExec: partitioning=RoundRobinBatch(3)", - " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None", + " CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None", ]; let data_path = datafusion::test_util::arrow_test_data(); @@ -741,7 +745,7 @@ async fn csv_explain() { \n CoalesceBatchesExec: target_batch_size=4096\ \n FilterExec: CAST(c2@1 AS Int64) > 10\ \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ - \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, batch_size=8192, limit=None\ + \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None\ \n" ]]; assert_eq!(expected, actual); diff --git a/datafusion/tests/sql/expr.rs b/datafusion/tests/sql/expr.rs index 8c2f6b970165..d6ab1a5f6838 100644 --- a/datafusion/tests/sql/expr.rs +++ b/datafusion/tests/sql/expr.rs @@ -367,6 +367,7 @@ async fn test_crypto_expressions() -> Result<()> { #[tokio::test] async fn test_interval_expressions() -> Result<()> { + // day nano intervals test_expression!( "interval '1'", "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" @@ -456,6 +457,7 @@ async fn test_interval_expressions() -> Result<()> { "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs" ); + // month intervals test_expression!( "interval '0.5 month'", "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs" @@ -496,6 +498,34 @@ async fn test_interval_expressions() -> Result<()> { "interval '2' year", "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs" ); + // complex + test_expression!( + "interval '1 year 1 day'", + "0 years 12 mons 1 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 year 1 day 1 hour'", + "0 years 12 mons 1 days 1 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 year 1 day 1 hour 1 minute'", + "0 years 12 mons 1 days 1 hours 1 mins 0.00 secs" + ); + test_expression!( + "interval '1 year 1 day 1 hour 1 minute 1 second'", + "0 years 12 mons 1 days 1 hours 1 mins 1.00 secs" + ); + + Ok(()) +} + +#[cfg(feature = "unicode_expressions")] +#[tokio::test] +async fn test_substring_expr() -> Result<()> { + test_expression!("substring('alphabet' from 2 for 1)", "l"); + test_expression!("substring('alphabet' from 8)", "t"); + test_expression!("substring('alphabet' for 1)", "a"); + Ok(()) } diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs index 4934eeff88c5..1f7599a80f16 100644 --- a/datafusion/tests/sql/joins.rs +++ b/datafusion/tests/sql/joins.rs @@ -16,6 +16,7 @@ // under the License. use super::*; +use datafusion::from_slice::FromSlice; #[tokio::test] async fn equijoin() -> Result<()> { @@ -210,6 +211,202 @@ async fn left_join_unbalanced() -> Result<()> { Ok(()) } +#[tokio::test] +async fn left_join_null_filter() -> Result<()> { + // Since t2 is the non-preserved side of the join, we cannot push down a NULL filter. + // Note that this is only true because IS NULL does not remove nulls. For filters that + // remove nulls, we can rewrite the join as an inner join and then push down the filter. + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NULL ORDER BY t1_id"; + let expected = vec![ + "+-------+-------+---------+", + "| t1_id | t2_id | t2_name |", + "+-------+-------+---------+", + "| 22 | 22 | |", + "| 33 | | |", + "| 77 | | |", + "| 88 | | |", + "+-------+-------+---------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn left_join_null_filter_on_join_column() -> Result<()> { + // Again, since t2 is the non-preserved side of the join, we cannot push down a NULL filter. + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NULL ORDER BY t1_id"; + let expected = vec![ + "+-------+-------+---------+", + "| t1_id | t2_id | t2_name |", + "+-------+-------+---------+", + "| 33 | | |", + "| 77 | | |", + "| 88 | | |", + "+-------+-------+---------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn left_join_not_null_filter() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NOT NULL ORDER BY t1_id"; + let expected = vec![ + "+-------+-------+---------+", + "| t1_id | t2_id | t2_name |", + "+-------+-------+---------+", + "| 11 | 11 | z |", + "| 44 | 44 | x |", + "| 99 | 99 | u |", + "+-------+-------+---------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn left_join_not_null_filter_on_join_column() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NOT NULL ORDER BY t1_id"; + let expected = vec![ + "+-------+-------+---------+", + "| t1_id | t2_id | t2_name |", + "+-------+-------+---------+", + "| 11 | 11 | z |", + "| 22 | 22 | |", + "| 44 | 44 | x |", + "| 99 | 99 | u |", + "+-------+-------+---------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn right_join_null_filter() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t2_id"; + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| | | 55 |", + "| 99 | | 99 |", + "+-------+---------+-------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn right_join_null_filter_on_join_column() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NULL ORDER BY t2_id"; + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| | | 55 |", + "+-------+---------+-------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn right_join_not_null_filter() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t2_id"; + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| 11 | a | 11 |", + "| 22 | b | 22 |", + "| 44 | d | 44 |", + "+-------+---------+-------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn right_join_not_null_filter_on_join_column() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NOT NULL ORDER BY t2_id"; + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| 11 | a | 11 |", + "| 22 | b | 22 |", + "| 44 | d | 44 |", + "| 99 | | 99 |", + "+-------+---------+-------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn full_join_null_filter() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t1_id"; + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| 88 | | |", + "| 99 | | 99 |", + "| | | 55 |", + "+-------+---------+-------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn full_join_not_null_filter() -> Result<()> { + let mut ctx = create_join_context_with_nulls()?; + let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t1_id"; + let expected = vec![ + "+-------+---------+-------+", + "| t1_id | t1_name | t2_id |", + "+-------+---------+-------+", + "| 11 | a | 11 |", + "| 22 | b | 22 |", + "| 33 | c | |", + "| 44 | d | 44 |", + "| 77 | e | |", + "+-------+---------+-------+", + ]; + + let actual = execute_to_batches(&mut ctx, sql).await; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn right_join() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; @@ -418,32 +615,32 @@ async fn cross_join_unbalanced() { // the order of the values is not determinisitic, so we need to sort to check the values let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name"; + "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name, t2_name"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", "+-------+---------+---------+", - "| 11 | a | z |", - "| 11 | a | y |", - "| 11 | a | x |", "| 11 | a | w |", - "| 22 | b | z |", - "| 22 | b | y |", - "| 22 | b | x |", + "| 11 | a | x |", + "| 11 | a | y |", + "| 11 | a | z |", "| 22 | b | w |", - "| 33 | c | z |", - "| 33 | c | y |", - "| 33 | c | x |", + "| 22 | b | x |", + "| 22 | b | y |", + "| 22 | b | z |", "| 33 | c | w |", - "| 44 | d | z |", - "| 44 | d | y |", - "| 44 | d | x |", + "| 33 | c | x |", + "| 33 | c | y |", + "| 33 | c | z |", "| 44 | d | w |", - "| 77 | e | z |", - "| 77 | e | y |", - "| 77 | e | x |", + "| 44 | d | x |", + "| 44 | d | y |", + "| 44 | d | z |", "| 77 | e | w |", + "| 77 | e | x |", + "| 77 | e | y |", + "| 77 | e | z |", "+-------+---------+---------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 3a08ee031f12..4685447258e7 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -25,6 +25,7 @@ use datafusion::assert_batches_eq; use datafusion::assert_batches_sorted_eq; use datafusion::assert_contains; use datafusion::assert_not_contains; +use datafusion::from_slice::FromSlice; use datafusion::logical_plan::plan::{Aggregate, Projection}; use datafusion::logical_plan::LogicalPlan; use datafusion::logical_plan::TableScan; @@ -294,6 +295,55 @@ fn create_join_context_unbalanced( Ok(ctx) } +// Create memory tables with nulls +fn create_join_context_with_nulls() -> Result { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("t1_id", DataType::UInt32, true), + Field::new("t1_name", DataType::Utf8, true), + ])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 33, 44, 77, 88, 99])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + None, + None, + ])), + ], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("t2_id", DataType::UInt32, true), + Field::new("t2_name", DataType::Utf8, true), + ])); + let t2_data = RecordBatch::try_new( + t2_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 44, 55, 99])), + Arc::new(StringArray::from(vec![ + Some("z"), + None, + Some("x"), + Some("w"), + Some("u"), + ])), + ], + )?; + let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(t2_table))?; + + Ok(ctx) +} + fn get_tpch_table_schema(table: &str) -> Schema { match table { "customer" => Schema::new(vec![ @@ -482,7 +532,8 @@ async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec Result<()> { @@ -651,20 +652,28 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { async fn query_on_string_dictionary() -> Result<()> { // Test to ensure DataFusion can operate on dictionary types // Use StringDictionary (32 bit indexes = keys) - let original_data = vec![Some("one"), None, Some("three")]; - let mut array = MutableDictionaryArray::>::new(); - array.try_extend(original_data)?; - let array: DictionaryArray = array.into(); + let d1: DictionaryArray = + vec![Some("one"), None, Some("three")].into_iter().collect(); - let batch = - RecordBatch::try_from_iter(vec![("d1", Arc::new(array) as ArrayRef)]).unwrap(); + let d2: DictionaryArray = vec![Some("blarg"), None, Some("three")] + .into_iter() + .collect(); + + let d3: StringArray = vec![Some("XYZ"), None, Some("three")].into_iter().collect(); + + let batch = RecordBatch::try_from_iter(vec![ + ("d1", Arc::new(d1) as ArrayRef), + ("d2", Arc::new(d2) as ArrayRef), + ("d3", Arc::new(d3) as ArrayRef), + ]) + .unwrap(); let table = MemTable::try_new(batch.schema().clone(), vec![vec![batch]])?; let mut ctx = ExecutionContext::new(); ctx.register_table("test", Arc::new(table))?; // Basic SELECT - let sql = "SELECT * FROM test"; + let sql = "SELECT d1 FROM test"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ "+-------+", @@ -678,7 +687,7 @@ async fn query_on_string_dictionary() -> Result<()> { assert_batches_eq!(expected, &actual); // basic filtering - let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; + let sql = "SELECT d1 FROM test WHERE d1 IS NOT NULL"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ "+-------+", @@ -690,8 +699,56 @@ async fn query_on_string_dictionary() -> Result<()> { ]; assert_batches_eq!(expected, &actual); + // comparison with constant + let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // comparison with another dictionary column + let sql = "SELECT d1 FROM test WHERE d1 = d2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // order comparison with another dictionary column + let sql = "SELECT d1 FROM test WHERE d1 <= d2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // comparison with a non dictionary column + let sql = "SELECT d1 FROM test WHERE d1 = d3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + // filtering with constant - let sql = "SELECT * FROM test WHERE d1 = 'three'"; + let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ "+-------+", @@ -716,6 +773,20 @@ async fn query_on_string_dictionary() -> Result<()> { ]; assert_batches_eq!(expected, &actual); + // Expression evaluation with two dictionaries + let sql = "SELECT concat(d1, d2) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+", + "| concat(test.d1,test.d2) |", + "+-------------------------+", + "| oneblarg |", + "| |", + "| threethree |", + "+-------------------------+", + ]; + assert_batches_eq!(expected, &actual); + // aggregation let sql = "SELECT COUNT(d1) FROM test"; let actual = execute_to_batches(&mut ctx, sql).await; diff --git a/datafusion/tests/sql/timestamp.rs b/datafusion/tests/sql/timestamp.rs index ce4cc4a97338..28a5c5d09a2b 100644 --- a/datafusion/tests/sql/timestamp.rs +++ b/datafusion/tests/sql/timestamp.rs @@ -16,6 +16,7 @@ // under the License. use super::*; +use datafusion::from_slice::FromSlice; #[tokio::test] async fn query_cast_timestamp_millis() -> Result<()> { @@ -387,7 +388,8 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let plan = ctx.create_physical_plan(&plan).await.expect(&msg); let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let res = collect(plan).await.expect(&msg); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let res = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&res); let res1 = actual[0][0].as_str(); diff --git a/datafusion/tests/sql_integration.rs b/datafusion/tests/sql_integration.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/datafusion/tests/sql_integration.rs @@ -0,0 +1 @@ + diff --git a/datafusion/tests/statistics.rs b/datafusion/tests/statistics.rs index 2934d7889215..e3b94cc2b604 100644 --- a/datafusion/tests/statistics.rs +++ b/datafusion/tests/statistics.rs @@ -25,7 +25,7 @@ use datafusion::{ error::{DataFusionError, Result}, logical_plan::Expr, physical_plan::{ - ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, + project_schema, ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, prelude::ExecutionContext, @@ -33,6 +33,8 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::field_util::SchemaExt; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -43,7 +45,7 @@ struct StatisticsValidation { } impl StatisticsValidation { - fn new(stats: Statistics, schema: Schema) -> Self { + fn new(stats: Statistics, schema: SchemaRef) -> Self { assert!( stats .column_statistics @@ -52,10 +54,7 @@ impl StatisticsValidation { .unwrap_or(true), "if defined, the column statistics vector length should be the number of fields" ); - Self { - stats, - schema: Arc::new(schema), - } + Self { stats, schema } } } @@ -72,7 +71,6 @@ impl TableProvider for StatisticsValidation { async fn scan( &self, projection: &Option>, - _batch_size: usize, filters: &[Expr], // limit is ignored because it is not mandatory for a `TableProvider` to honor it _limit: Option, @@ -87,12 +85,7 @@ impl TableProvider for StatisticsValidation { Some(p) => p, None => (0..self.schema.fields().len()).collect(), }; - let projected_schema = Schema::new( - projection - .iter() - .map(|i| self.schema.field(*i).clone()) - .collect(), - ); + let projected_schema = project_schema(&self.schema, Some(&projection))?; let current_stat = self.stats.clone(); @@ -144,7 +137,11 @@ impl ExecutionPlan for StatisticsValidation { } } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { unimplemented!("This plan only serves for testing statistics") } @@ -173,7 +170,7 @@ impl ExecutionPlan for StatisticsValidation { fn init_ctx(stats: Statistics, schema: Schema) -> Result { let mut ctx = ExecutionContext::new(); let provider: Arc = - Arc::new(StatisticsValidation::new(stats, schema)); + Arc::new(StatisticsValidation::new(stats, Arc::new(schema))); ctx.register_table("stats_table", provider)?; Ok(ctx) } diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index 72ab6f9499c9..976db6cbd5e2 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -64,8 +64,8 @@ use arrow::{ array::{Int64Array, Utf8Array}, datatypes::SchemaRef, error::ArrowError, - record_batch::RecordBatch, }; +use datafusion::record_batch::RecordBatch; use datafusion::{ arrow_print::write, error::{DataFusionError, Result}, @@ -86,6 +86,7 @@ use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use async_trait::async_trait; use datafusion::execution::context::ExecutionProps; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::{Extension, Sort}; use datafusion::logical_plan::{DFSchemaRef, Limit}; @@ -217,9 +218,9 @@ async fn topk_plan() -> Result<()> { let mut ctx = setup_table(make_topk_context()).await?; let expected = vec![ - "| logical_plan after topk | TopK: k=3 |", - "| | Projection: #sales.customer_id, #sales.revenue |", - "| | TableScan: sales projection=Some([0, 1]) |", + "| logical_plan after topk | TopK: k=3 |", + "| | Projection: #sales.customer_id, #sales.revenue |", + "| | TableScan: sales projection=Some([0, 1]) |", ].join("\n"); let explain_query = format!("EXPLAIN VERBOSE {}", QUERY); @@ -453,7 +454,11 @@ impl ExecutionPlan for TopKExec { } /// Execute one partition and return an iterator over RecordBatch - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "TopKExec invalid partition {}", @@ -462,7 +467,7 @@ impl ExecutionPlan for TopKExec { } Ok(Box::pin(TopKReader { - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, k: self.k, done: false, state: BTreeMap::new(), diff --git a/dev/docker/ballista-base.dockerfile b/dev/docker/ballista-base.dockerfile index 5bc3488a2185..79b58cdbfd6e 100644 --- a/dev/docker/ballista-base.dockerfile +++ b/dev/docker/ballista-base.dockerfile @@ -23,7 +23,7 @@ # Base image extends debian:buster-slim -FROM rust:1.57.0-buster AS builder +FROM rust:1.58.0-buster AS builder RUN apt update && apt -y install musl musl-dev musl-tools libssl-dev openssl