diff --git a/ballista-examples/src/bin/ballista-dataframe.rs b/ballista-examples/src/bin/ballista-dataframe.rs index 345b6982dd85..cab8f5645450 100644 --- a/ballista-examples/src/bin/ballista-dataframe.rs +++ b/ballista-examples/src/bin/ballista-dataframe.rs @@ -25,7 +25,7 @@ async fn main() -> Result<()> { let config = BallistaConfig::builder() .set("ballista.shuffle.partitions", "4") .build()?; - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; let testdata = datafusion::test_util::parquet_test_data(); diff --git a/ballista-examples/src/bin/ballista-sql.rs b/ballista-examples/src/bin/ballista-sql.rs index 25fc333ed247..b4a1260f6f8a 100644 --- a/ballista-examples/src/bin/ballista-sql.rs +++ b/ballista-examples/src/bin/ballista-sql.rs @@ -25,7 +25,7 @@ async fn main() -> Result<()> { let config = BallistaConfig::builder() .set("ballista.shuffle.partitions", "4") .build()?; - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; let testdata = datafusion::test_util::arrow_test_data(); diff --git a/ballista/rust/client/README.md b/ballista/rust/client/README.md index 08ca67b94a24..ecf364f4f749 100644 --- a/ballista/rust/client/README.md +++ b/ballista/rust/client/README.md @@ -106,7 +106,7 @@ async fn main() -> Result<()> { .build()?; // connect to Ballista scheduler - let ctx = BallistaContext::remote("localhost", 50050, &config); + let ctx = BallistaContext::remote("localhost", 50050, &config).await?; // register csv file with the execution context ctx.register_csv( diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 8db9a0c05821..e4233d4f4025 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -17,6 +17,7 @@ //! Distributed execution context. +use log::info; use parking_lot::Mutex; use sqlparser::ast::Statement; use std::collections::HashMap; @@ -25,7 +26,8 @@ use std::path::PathBuf; use std::sync::Arc; use ballista_core::config::BallistaConfig; -use ballista_core::serde::protobuf::LogicalPlanNode; +use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; +use ballista_core::serde::protobuf::{ExecuteQueryParams, KeyValuePair, LogicalPlanNode}; use ballista_core::utils::create_df_ctx_with_ballista_query_planner; use datafusion::catalog::TableReference; @@ -63,26 +65,86 @@ impl BallistaContextState { } } + pub fn config(&self) -> &BallistaConfig { + &self.config + } +} + +pub struct BallistaContext { + state: Arc>, + context: Arc, +} + +impl BallistaContext { + /// Create a context for executing queries against a remote Ballista scheduler instance + pub async fn remote( + host: &str, + port: u16, + config: &BallistaConfig, + ) -> ballista_core::error::Result { + let state = BallistaContextState::new(host.to_owned(), port, config); + + let scheduler_url = + format!("http://{}:{}", &state.scheduler_host, state.scheduler_port); + info!( + "Connecting to Ballista scheduler at {}", + scheduler_url.clone() + ); + let mut scheduler = SchedulerGrpcClient::connect(scheduler_url.clone()) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + + let remote_session_id = scheduler + .execute_query(ExecuteQueryParams { + query: None, + settings: config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(), + optional_session_id: None, + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? + .into_inner() + .session_id; + + info!( + "Server side SessionContext created with session id: {}", + remote_session_id + ); + + let ctx = { + create_df_ctx_with_ballista_query_planner::( + scheduler_url, + remote_session_id, + state.config(), + ) + }; + + Ok(Self { + state: Arc::new(Mutex::new(state)), + context: Arc::new(ctx), + }) + } + #[cfg(feature = "standalone")] - pub async fn new_standalone( + pub async fn standalone( config: &BallistaConfig, concurrent_tasks: usize, ) -> ballista_core::error::Result { - use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; use ballista_core::serde::protobuf::PhysicalPlanNode; use ballista_core::serde::BallistaCodec; log::info!("Running in local mode. Scheduler will be run in-proc"); let addr = ballista_scheduler::standalone::new_standalone_scheduler().await?; - - let scheduler = loop { - match SchedulerGrpcClient::connect(format!( - "http://localhost:{}", - addr.port() - )) - .await - { + let scheduler_url = format!("http://localhost:{}", addr.port()); + let mut scheduler = loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { Err(_) => { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; log::info!("Attempting to connect to in-proc scheduler..."); @@ -91,6 +153,37 @@ impl BallistaContextState { } }; + let remote_session_id = scheduler + .execute_query(ExecuteQueryParams { + query: None, + settings: config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(), + optional_session_id: None, + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? + .into_inner() + .session_id; + + info!( + "Server side SessionContext created with session id: {}", + remote_session_id + ); + + let ctx = { + create_df_ctx_with_ballista_query_planner::( + scheduler_url, + remote_session_id, + config, + ) + }; + let default_codec: BallistaCodec = BallistaCodec::default(); @@ -101,43 +194,12 @@ impl BallistaContextState { ) .await?; - Ok(Self { - config: config.clone(), - scheduler_host: "localhost".to_string(), - scheduler_port: addr.port(), - tables: HashMap::new(), - }) - } - - pub fn config(&self) -> &BallistaConfig { - &self.config - } -} - -pub struct BallistaContext { - state: Arc>, -} - -impl BallistaContext { - /// Create a context for executing queries against a remote Ballista scheduler instance - pub fn remote(host: &str, port: u16, config: &BallistaConfig) -> Self { - let state = BallistaContextState::new(host.to_owned(), port, config); - - Self { - state: Arc::new(Mutex::new(state)), - } - } - - #[cfg(feature = "standalone")] - pub async fn standalone( - config: &BallistaConfig, - concurrent_tasks: usize, - ) -> ballista_core::error::Result { let state = - BallistaContextState::new_standalone(config, concurrent_tasks).await?; + BallistaContextState::new("localhost".to_string(), addr.port(), config); Ok(Self { state: Arc::new(Mutex::new(state)), + context: Arc::new(ctx), }) } @@ -152,15 +214,7 @@ impl BallistaContext { let path = PathBuf::from(path); let path = fs::canonicalize(&path)?; - // use local DataFusion context for now but later this might call the scheduler - let mut ctx = { - let guard = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &guard.scheduler_host, - guard.scheduler_port, - guard.config(), - ) - }; + let ctx = self.context.clone(); let df = ctx.read_avro(path.to_str().unwrap(), options).await?; Ok(df) } @@ -172,15 +226,7 @@ impl BallistaContext { let path = PathBuf::from(path); let path = fs::canonicalize(&path)?; - // use local DataFusion context for now but later this might call the scheduler - let mut ctx = { - let guard = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &guard.scheduler_host, - guard.scheduler_port, - guard.config(), - ) - }; + let ctx = self.context.clone(); let df = ctx.read_parquet(path.to_str().unwrap()).await?; Ok(df) } @@ -196,15 +242,7 @@ impl BallistaContext { let path = PathBuf::from(path); let path = fs::canonicalize(&path)?; - // use local DataFusion context for now but later this might call the scheduler - let mut ctx = { - let guard = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &guard.scheduler_host, - guard.scheduler_port, - guard.config(), - ) - }; + let ctx = self.context.clone(); let df = ctx.read_csv(path.to_str().unwrap(), options).await?; Ok(df) } @@ -291,34 +329,31 @@ impl BallistaContext { /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` /// might require the schema to be inferred. pub async fn sql(&self, sql: &str) -> Result> { - let mut ctx = { - let state = self.state.lock(); - create_df_ctx_with_ballista_query_planner::( - &state.scheduler_host, - state.scheduler_port, - state.config(), - ) - }; + let mut ctx = self.context.clone(); let is_show = self.is_show_statement(sql).await?; // the show tables、 show columns sql can not run at scheduler because the tables is store at client if is_show { let state = self.state.lock(); - ctx = SessionContext::with_config( + ctx = Arc::new(SessionContext::with_config( SessionConfig::new().with_information_schema( state.config.default_with_information_schema(), ), - ); + )); } // register tables with DataFusion context { let state = self.state.lock(); for (name, prov) in &state.tables { - ctx.register_table( - TableReference::Bare { table: name }, - Arc::clone(prov), - )?; + // ctx is shared between queries, check table exists or not before register + let table_ref = TableReference::Bare { table: name }; + if !ctx.table_exist(table_ref)? { + ctx.register_table( + TableReference::Bare { table: name }, + Arc::clone(prov), + )?; + } } } @@ -341,16 +376,16 @@ impl BallistaContext { .has_header(*has_header), ) .await?; - Ok(Arc::new(DataFrame::new(ctx.state, &plan))) + Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan))) } FileType::Parquet => { self.register_parquet(name, location).await?; - Ok(Arc::new(DataFrame::new(ctx.state, &plan))) + Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan))) } FileType::Avro => { self.register_avro(name, location, AvroReadOptions::default()) .await?; - Ok(Arc::new(DataFrame::new(ctx.state, &plan))) + Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan))) } _ => Err(DataFusionError::NotImplemented(format!( "Unsupported file type {:?}.", @@ -475,7 +510,6 @@ mod tests { use datafusion::arrow::datatypes::Schema; use datafusion::arrow::util::pretty; use datafusion::datasource::file_format::csv::CsvFormat; - use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, }; @@ -483,9 +517,6 @@ mod tests { use ballista_core::config::{ BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA, }; - use std::fs::File; - use std::io::Write; - use tempfile::TempDir; let config = BallistaConfigBuilder::default() .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") .build() diff --git a/ballista/rust/client/src/prelude.rs b/ballista/rust/client/src/prelude.rs index d162d0c017bd..15558a357daa 100644 --- a/ballista/rust/client/src/prelude.rs +++ b/ballista/rust/client/src/prelude.rs @@ -19,6 +19,7 @@ pub use crate::context::BallistaContext; pub use ballista_core::config::BallistaConfig; +pub use ballista_core::config::BALLISTA_DEFAULT_BATCH_SIZE; pub use ballista_core::config::BALLISTA_DEFAULT_SHUFFLE_PARTITIONS; pub use ballista_core::error::{BallistaError, Result}; diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 4b493105d933..015618fe876a 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -722,6 +722,8 @@ message TaskDefinition { bytes plan = 2; // Output partition for shuffle writer PhysicalHashRepartition output_partitioning = 3; + string session_id = 4; + repeated KeyValuePair props = 5; } message PollWorkResult { @@ -767,7 +769,10 @@ message ExecuteQueryParams { bytes logical_plan = 1; string sql = 2; } - repeated KeyValuePair settings = 3; + oneof optional_session_id { + string session_id = 3; + } + repeated KeyValuePair settings = 4; } message ExecuteSqlParams { @@ -776,6 +781,7 @@ message ExecuteSqlParams { message ExecuteQueryResult { string job_id = 1; + string session_id = 2; } message GetJobStatusParams { diff --git a/ballista/rust/core/src/config.rs b/ballista/rust/core/src/config.rs index 8cdaf1f82952..fffe0ead3d75 100644 --- a/ballista/rust/core/src/config.rs +++ b/ballista/rust/core/src/config.rs @@ -28,6 +28,11 @@ use crate::error::{BallistaError, Result}; use datafusion::arrow::datatypes::DataType; pub const BALLISTA_DEFAULT_SHUFFLE_PARTITIONS: &str = "ballista.shuffle.partitions"; +pub const BALLISTA_DEFAULT_BATCH_SIZE: &str = "ballista.batch.size"; +pub const BALLISTA_REPARTITION_JOINS: &str = "ballista.repartition.joins"; +pub const BALLISTA_REPARTITION_AGGREGATIONS: &str = "ballista.repartition.aggregations"; +pub const BALLISTA_REPARTITION_WINDOWS: &str = "ballista.repartition.windows"; +pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning"; pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = "ballista.with_information_schema"; pub type ParseResult = result::Result; @@ -148,6 +153,21 @@ impl BallistaConfig { ConfigEntry::new(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS.to_string(), "Sets the default number of partitions to create when repartitioning query stages".to_string(), DataType::UInt16, Some("2".to_string())), + ConfigEntry::new(BALLISTA_DEFAULT_BATCH_SIZE.to_string(), + "Sets the default batch size".to_string(), + DataType::UInt16, Some("8192".to_string())), + ConfigEntry::new(BALLISTA_REPARTITION_JOINS.to_string(), + "Configuration for repartition joins".to_string(), + DataType::Boolean, Some("true".to_string())), + ConfigEntry::new(BALLISTA_REPARTITION_AGGREGATIONS.to_string(), + "Configuration for repartition aggregations".to_string(), + DataType::Boolean,Some("true".to_string())), + ConfigEntry::new(BALLISTA_REPARTITION_WINDOWS.to_string(), + "Configuration for repartition windows".to_string(), + DataType::Boolean,Some("true".to_string())), + ConfigEntry::new(BALLISTA_PARQUET_PRUNING.to_string(), + "Configuration for parquet prune".to_string(), + DataType::Boolean,Some("true".to_string())), ConfigEntry::new(BALLISTA_WITH_INFORMATION_SCHEMA.to_string(), "Sets whether enable information_schema".to_string(), DataType::Boolean,Some("false".to_string())), @@ -166,6 +186,26 @@ impl BallistaConfig { self.get_usize_setting(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS) } + pub fn default_batch_size(&self) -> usize { + self.get_usize_setting(BALLISTA_DEFAULT_BATCH_SIZE) + } + + pub fn repartition_joins(&self) -> bool { + self.get_bool_setting(BALLISTA_REPARTITION_JOINS) + } + + pub fn repartition_aggregations(&self) -> bool { + self.get_bool_setting(BALLISTA_REPARTITION_AGGREGATIONS) + } + + pub fn repartition_windows(&self) -> bool { + self.get_bool_setting(BALLISTA_REPARTITION_WINDOWS) + } + + pub fn parquet_pruning(&self) -> bool { + self.get_bool_setting(BALLISTA_PARQUET_PRUNING) + } + pub fn default_with_information_schema(&self) -> bool { self.get_bool_setting(BALLISTA_WITH_INFORMATION_SCHEMA) } diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index d4daeb26c42e..46a168f17273 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -40,6 +40,7 @@ use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; +use crate::serde::protobuf::execute_query_params::OptionalSessionId; use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use async_trait::async_trait; use datafusion::execution::context::TaskContext; @@ -63,16 +64,24 @@ pub struct DistributedQueryExec { extension_codec: Arc, /// Phantom data for serializable plan message plan_repr: PhantomData, + /// Session id + session_id: String, } impl DistributedQueryExec { - pub fn new(scheduler_url: String, config: BallistaConfig, plan: LogicalPlan) -> Self { + pub fn new( + scheduler_url: String, + config: BallistaConfig, + plan: LogicalPlan, + session_id: String, + ) -> Self { Self { scheduler_url, config, plan, extension_codec: Arc::new(DefaultLogicalExtensionCodec {}), plan_repr: PhantomData, + session_id, } } @@ -81,6 +90,7 @@ impl DistributedQueryExec { config: BallistaConfig, plan: LogicalPlan, extension_codec: Arc, + session_id: String, ) -> Self { Self { scheduler_url, @@ -88,6 +98,7 @@ impl DistributedQueryExec { plan, extension_codec, plan_repr: PhantomData, + session_id, } } @@ -97,6 +108,7 @@ impl DistributedQueryExec { plan: LogicalPlan, extension_codec: Arc, plan_repr: PhantomData, + session_id: String, ) -> Self { Self { scheduler_url, @@ -104,6 +116,7 @@ impl DistributedQueryExec { plan, extension_codec, plan_repr, + session_id, } } } @@ -144,6 +157,7 @@ impl ExecutionPlan for DistributedQueryExec { plan: self.plan.clone(), extension_codec: self.extension_codec.clone(), plan_repr: self.plan_repr, + session_id: self.session_id.clone(), })) } @@ -155,6 +169,7 @@ impl ExecutionPlan for DistributedQueryExec { assert_eq!(0, partition); info!("Connecting to Ballista scheduler at {}", self.scheduler_url); + // TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again let mut scheduler = SchedulerGrpcClient::connect(self.scheduler_url.clone()) .await @@ -176,7 +191,7 @@ impl ExecutionPlan for DistributedQueryExec { DataFusionError::Execution(format!("failed to encode logical plan: {:?}", e)) })?; - let job_id = scheduler + let query_result = scheduler .execute_query(ExecuteQueryParams { query: Some(Query::LogicalPlan(buf)), settings: self @@ -188,12 +203,22 @@ impl ExecutionPlan for DistributedQueryExec { value: v.to_owned(), }) .collect::>(), + optional_session_id: Some(OptionalSessionId::SessionId( + self.session_id.clone(), + )), }) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? - .into_inner() - .job_id; + .into_inner(); + + let response_session_id = query_result.session_id; + assert_eq!( + self.session_id.clone(), + response_session_id, + "Session id inconsistent between Client and Server side in DistributedQueryExec." + ); + let job_id = query_result.job_id; let mut prev_status: Option = None; loop { diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 3f622e821690..3e36c00eedf8 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -111,7 +111,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Projection(projection) => { let input: LogicalPlan = - into_logical_plan!(projection.input, &ctx, extension_codec)?; + into_logical_plan!(projection.input, ctx, extension_codec)?; let x: Vec = projection .expr .iter() @@ -131,7 +131,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Selection(selection) => { let input: LogicalPlan = - into_logical_plan!(selection.input, &ctx, extension_codec)?; + into_logical_plan!(selection.input, ctx, extension_codec)?; let expr: Expr = selection .expr .as_ref() @@ -148,7 +148,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Window(window) => { let input: LogicalPlan = - into_logical_plan!(window.input, &ctx, extension_codec)?; + into_logical_plan!(window.input, ctx, extension_codec)?; let window_expr = window .window_expr .iter() @@ -161,7 +161,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Aggregate(aggregate) => { let input: LogicalPlan = - into_logical_plan!(aggregate.input, &ctx, extension_codec)?; + into_logical_plan!(aggregate.input, ctx, extension_codec)?; let group_expr = aggregate .group_expr .iter() @@ -261,7 +261,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Sort(sort) => { let input: LogicalPlan = - into_logical_plan!(sort.input, &ctx, extension_codec)?; + into_logical_plan!(sort.input, ctx, extension_codec)?; let sort_expr: Vec = sort .expr .iter() @@ -275,7 +275,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Repartition(repartition) => { use datafusion::logical_plan::Partitioning; let input: LogicalPlan = - into_logical_plan!(repartition.input, &ctx, extension_codec)?; + into_logical_plan!(repartition.input, ctx, extension_codec)?; use protobuf::repartition_node::PartitionMethod; let pb_partition_method = repartition.partition_method.clone().ok_or_else(|| { BallistaError::General(String::from( @@ -342,7 +342,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Analyze(analyze) => { let input: LogicalPlan = - into_logical_plan!(analyze.input, &ctx, extension_codec)?; + into_logical_plan!(analyze.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input) .explain(analyze.verbose, true)? .build() @@ -350,7 +350,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Explain(explain) => { let input: LogicalPlan = - into_logical_plan!(explain.input, &ctx, extension_codec)?; + into_logical_plan!(explain.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input) .explain(explain.verbose, false)? .build() @@ -358,7 +358,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Limit(limit) => { let input: LogicalPlan = - into_logical_plan!(limit.input, &ctx, extension_codec)?; + into_logical_plan!(limit.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input) .limit(limit.limit as usize)? .build() @@ -388,17 +388,17 @@ impl AsLogicalPlan for LogicalPlanNode { let builder = LogicalPlanBuilder::from(into_logical_plan!( join.left, - &ctx, + ctx, extension_codec )?); let builder = match join_constraint.into() { JoinConstraint::On => builder.join( - &into_logical_plan!(join.right, &ctx, extension_codec)?, + &into_logical_plan!(join.right, ctx, extension_codec)?, join_type.into(), (left_keys, right_keys), )?, JoinConstraint::Using => builder.join_using( - &into_logical_plan!(join.right, &ctx, extension_codec)?, + &into_logical_plan!(join.right, ctx, extension_codec)?, join_type.into(), left_keys, )?, @@ -426,8 +426,8 @@ impl AsLogicalPlan for LogicalPlanNode { builder.build().map_err(|e| e.into()) } LogicalPlanType::CrossJoin(crossjoin) => { - let left = into_logical_plan!(crossjoin.left, &ctx, extension_codec)?; - let right = into_logical_plan!(crossjoin.right, &ctx, extension_codec)?; + let left = into_logical_plan!(crossjoin.left, ctx, extension_codec)?; + let right = into_logical_plan!(crossjoin.right, ctx, extension_codec)?; LogicalPlanBuilder::from(left) .cross_join(&right)? @@ -905,7 +905,7 @@ impl AsLogicalPlan for LogicalPlanNode { macro_rules! into_logical_plan { ($PB:expr, $CTX:expr, $CODEC:expr) => {{ if let Some(field) = $PB.as_ref() { - field.as_ref().try_into_logical_plan(&$CTX, $CODEC) + field.as_ref().try_into_logical_plan($CTX, $CODEC) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 1068d2cd90b1..5f3d054e2034 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -24,10 +24,13 @@ use std::marker::PhantomData; use std::sync::Arc; use std::{convert::TryInto, io::Cursor}; -use datafusion::logical_plan::{JoinConstraint, JoinType, LogicalPlan, Operator}; +use datafusion::logical_plan::{ + FunctionRegistry, JoinConstraint, JoinType, LogicalPlan, Operator, +}; use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::Extension; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; @@ -132,7 +135,8 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { fn try_into_physical_plan( &self, - ctx: &SessionContext, + registry: &dyn FunctionRegistry, + runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result, BallistaError>; @@ -149,7 +153,7 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { &self, buf: &[u8], inputs: &[Arc], - ctx: &SessionContext, + registry: &dyn FunctionRegistry, ) -> Result, BallistaError>; fn try_encode( @@ -167,7 +171,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { &self, _buf: &[u8], _inputs: &[Arc], - _ctx: &SessionContext, + _registry: &dyn FunctionRegistry, ) -> Result, BallistaError> { Err(BallistaError::NotImplemented( "PhysicalExtensionCodec is not provided".to_string(), @@ -353,7 +357,8 @@ mod tests { use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::logical_plan::plan::Extension; use datafusion::logical_plan::{ - col, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, + col, DFSchemaRef, Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder, + UserDefinedLogicalNode, }; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::planner::{DefaultPhysicalPlanner, ExtensionPlanner}; @@ -369,6 +374,7 @@ mod tests { use std::convert::TryInto; use std::fmt; use std::fmt::{Debug, Formatter}; + use std::ops::Deref; use std::sync::Arc; pub mod proto { @@ -660,7 +666,7 @@ mod tests { &self, buf: &[u8], inputs: &[Arc], - _ctx: &SessionContext, + _registry: &dyn FunctionRegistry, ) -> Result, BallistaError> { if let Some((input, _)) = inputs.split_first() { let proto = TopKExecProto::decode(buf).map_err(|e| { @@ -701,8 +707,9 @@ mod tests { async fn test_extension_plan() -> crate::error::Result<()> { let store = Arc::new(LocalFileSystem {}); let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); - let session_state = SessionState::with_config(SessionConfig::new(), runtime) - .with_query_planner(Arc::new(TopKQueryPlanner {})); + let session_state = + SessionState::with_config_rt(SessionConfig::new(), runtime.clone()) + .with_query_planner(Arc::new(TopKQueryPlanner {})); let ctx = SessionContext::with_state(session_state); @@ -736,7 +743,8 @@ mod tests { topk_exec.clone(), &extension_codec, )?; - let physical_round_trip = proto.try_into_physical_plan(&ctx, &extension_codec)?; + let physical_round_trip = + proto.try_into_physical_plan(&ctx, runtime.deref(), &extension_codec)?; assert_eq!( format!("{:?}", topk_exec), diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index d7a84959c87c..39463c791be0 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -35,7 +35,9 @@ use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; use datafusion::datafusion_storage::object_store::local::LocalFileSystem; use datafusion::datasource::listing::PartitionedFile; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::window_frames::WindowFrame; +use datafusion::logical_plan::FunctionRegistry; use datafusion::physical_plan::aggregates::create_aggregate_expr; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -57,7 +59,6 @@ use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, }; -use datafusion::prelude::SessionContext; use datafusion_proto::from_proto::parse_expr; use prost::bytes::BufMut; use prost::Message; @@ -89,7 +90,8 @@ impl AsExecutionPlan for PhysicalPlanNode { fn try_into_physical_plan( &self, - ctx: &SessionContext, + registry: &dyn FunctionRegistry, + runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result, BallistaError> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { @@ -100,8 +102,12 @@ impl AsExecutionPlan for PhysicalPlanNode { })?; match plan { PhysicalPlanType::Projection(projection) => { - let input: Arc = - into_physical_plan!(projection.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + projection.input, + registry, + runtime, + extension_codec + )?; let exprs = projection .expr .iter() @@ -112,8 +118,12 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(ProjectionExec::try_new(exprs, input)?)) } PhysicalPlanType::Filter(filter) => { - let input: Arc = - into_physical_plan!(filter.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + filter.input, + registry, + runtime, + extension_codec + )?; let predicate = filter .expr .as_ref() @@ -127,7 +137,7 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(FilterExec::try_new(predicate, input)?)) } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?, + decode_scan_config(scan.base_conf.as_ref().unwrap(), runtime)?, scan.has_header, str_to_byte(&scan.delimiter)?, ))), @@ -135,19 +145,23 @@ impl AsExecutionPlan for PhysicalPlanNode { let predicate = scan .pruning_predicate .as_ref() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .transpose()?; Ok(Arc::new(ParquetExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?, + decode_scan_config(scan.base_conf.as_ref().unwrap(), runtime)?, predicate, ))) } PhysicalPlanType::AvroScan(scan) => Ok(Arc::new(AvroExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?, + decode_scan_config(scan.base_conf.as_ref().unwrap(), runtime)?, ))), PhysicalPlanType::CoalesceBatches(coalesce_batches) => { - let input: Arc = - into_physical_plan!(coalesce_batches.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + coalesce_batches.input, + registry, + runtime, + extension_codec + )?; Ok(Arc::new(CoalesceBatchesExec::new( input, coalesce_batches.target_batch_size as usize, @@ -155,12 +169,16 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::Merge(merge) => { let input: Arc = - into_physical_plan!(merge.input, ctx, extension_codec)?; + into_physical_plan!(merge.input, registry, runtime, extension_codec)?; Ok(Arc::new(CoalescePartitionsExec::new(input))) } PhysicalPlanType::Repartition(repart) => { - let input: Arc = - into_physical_plan!(repart.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + repart.input, + registry, + runtime, + extension_codec + )?; match repart.partition_method { Some(PartitionMethod::Hash(ref hash_part)) => { let expr = hash_part @@ -200,17 +218,21 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::GlobalLimit(limit) => { let input: Arc = - into_physical_plan!(limit.input, ctx, extension_codec)?; + into_physical_plan!(limit.input, registry, runtime, extension_codec)?; Ok(Arc::new(GlobalLimitExec::new(input, limit.limit as usize))) } PhysicalPlanType::LocalLimit(limit) => { let input: Arc = - into_physical_plan!(limit.input, ctx, extension_codec)?; + into_physical_plan!(limit.input, registry, runtime, extension_codec)?; Ok(Arc::new(LocalLimitExec::new(input, limit.limit as usize))) } PhysicalPlanType::Window(window_agg) => { - let input: Arc = - into_physical_plan!(window_agg.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + window_agg.input, + registry, + runtime, + extension_codec + )?; let input_schema = window_agg .input_schema .as_ref() @@ -256,8 +278,12 @@ impl AsExecutionPlan for PhysicalPlanNode { )?)) } PhysicalPlanType::HashAggregate(hash_agg) => { - let input: Arc = - into_physical_plan!(hash_agg.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + hash_agg.input, + registry, + runtime, + extension_codec + )?; let mode = protobuf::AggregateMode::from_i32(hash_agg.mode).ok_or_else(|| { proto_error(format!( "Received a HashAggregateNode message with unknown AggregateMode {}", @@ -341,10 +367,18 @@ impl AsExecutionPlan for PhysicalPlanNode { )?)) } PhysicalPlanType::HashJoin(hashjoin) => { - let left: Arc = - into_physical_plan!(hashjoin.left, ctx, extension_codec)?; - let right: Arc = - into_physical_plan!(hashjoin.right, ctx, extension_codec)?; + let left: Arc = into_physical_plan!( + hashjoin.left, + registry, + runtime, + extension_codec + )?; + let right: Arc = into_physical_plan!( + hashjoin.right, + registry, + runtime, + extension_codec + )?; let on: Vec<(Column, Column)> = hashjoin .on .iter() @@ -386,20 +420,36 @@ impl AsExecutionPlan for PhysicalPlanNode { PhysicalPlanType::Union(union) => { let mut inputs: Vec> = vec![]; for input in &union.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); } Ok(Arc::new(UnionExec::new(inputs))) } PhysicalPlanType::CrossJoin(crossjoin) => { - let left: Arc = - into_physical_plan!(crossjoin.left, ctx, extension_codec)?; - let right: Arc = - into_physical_plan!(crossjoin.right, ctx, extension_codec)?; + let left: Arc = into_physical_plan!( + crossjoin.left, + registry, + runtime, + extension_codec + )?; + let right: Arc = into_physical_plan!( + crossjoin.right, + registry, + runtime, + extension_codec + )?; Ok(Arc::new(CrossJoinExec::try_new(left, right)?)) } PhysicalPlanType::ShuffleWriter(shuffle_writer) => { - let input: Arc = - into_physical_plan!(shuffle_writer.input, ctx, extension_codec)?; + let input: Arc = into_physical_plan!( + shuffle_writer.input, + registry, + runtime, + extension_codec + )?; let output_partitioning = parse_protobuf_hash_partitioning( shuffle_writer.output_partitioning.as_ref(), @@ -435,7 +485,7 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::Sort(sort) => { let input: Arc = - into_physical_plan!(sort.input, ctx, extension_codec)?; + into_physical_plan!(sort.input, registry, runtime, extension_codec)?; let exprs = sort .expr .iter() @@ -489,13 +539,13 @@ impl AsExecutionPlan for PhysicalPlanNode { let inputs: Vec> = extension .inputs .iter() - .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .map(|i| i.try_into_physical_plan(registry, runtime, extension_codec)) .collect::>()?; let extension_node = extension_codec.try_decode( extension.node.as_slice(), &inputs, - ctx, + registry, )?; Ok(extension_node) @@ -908,7 +958,7 @@ impl AsExecutionPlan for PhysicalPlanNode { fn decode_scan_config( proto: &protobuf::FileScanExecConf, - ctx: &SessionContext, + runtime: &RuntimeEnv, ) -> Result { let schema = Arc::new(convert_required!(proto.schema)?); let projection = proto @@ -930,7 +980,7 @@ fn decode_scan_config( .collect::, _>>()?; let object_store = if let Some(file) = file_groups.get(0).and_then(|h| h.get(0)) { - ctx.runtime_env().object_store(file.file_meta.path())?.0 + runtime.object_store(file.file_meta.path())?.0 } else { Arc::new(LocalFileSystem {}) }; @@ -948,9 +998,11 @@ fn decode_scan_config( #[macro_export] macro_rules! into_physical_plan { - ($PB:expr, $CTX:expr, $CODEC:expr) => {{ + ($PB:expr, $REG:expr, $RUNTIME:expr, $CODEC:expr) => {{ if let Some(field) = $PB.as_ref() { - field.as_ref().try_into_physical_plan(&$CTX, $CODEC) + field + .as_ref() + .try_into_physical_plan($REG, $RUNTIME, $CODEC) } else { Err(proto_error("Missing required field in protobuf")) } @@ -959,6 +1011,7 @@ macro_rules! into_physical_plan { #[cfg(test)] mod roundtrip_tests { + use std::ops::Deref; use std::sync::Arc; use crate::serde::{AsExecutionPlan, BallistaCodec}; @@ -1001,8 +1054,13 @@ mod roundtrip_tests { codec.physical_extension_codec(), ) .expect("to proto"); + let runtime = ctx.runtime_env(); let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, codec.physical_extension_codec()) + .try_into_physical_plan( + &ctx, + runtime.deref(), + codec.physical_extension_codec(), + ) .expect("from proto"); assert_eq!( format!("{:?}", exec_plan), diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index b769cd60b10b..6670ab5cedd8 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -225,25 +225,26 @@ fn build_exec_plan_diagram( Ok(node_id) } -/// Create a DataFusion context that uses the BallistaQueryPlanner to send logical plans +/// Create a client DataFusion context that uses the BallistaQueryPlanner to send logical plans /// to a Ballista scheduler pub fn create_df_ctx_with_ballista_query_planner( - scheduler_host: &str, - scheduler_port: u16, + scheduler_url: String, + session_id: String, config: &BallistaConfig, ) -> SessionContext { - let scheduler_url = format!("http://{}:{}", scheduler_host, scheduler_port); let planner: Arc> = Arc::new(BallistaQueryPlanner::new(scheduler_url, config.clone())); let session_config = SessionConfig::new() .with_target_partitions(config.default_shuffle_partitions()) .with_information_schema(true); - let session_state = SessionState::with_config( + let mut session_state = SessionState::with_config_rt( session_config, Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()), ) .with_query_planner(planner); + session_state.session_id = session_id; + // the SessionContext created here is the client side context, but the session_id is from server side. SessionContext::with_state(session_state) } @@ -297,7 +298,7 @@ impl QueryPlanner for BallistaQueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - _session_state: &SessionState, + session_state: &SessionState, ) -> std::result::Result, DataFusionError> { match logical_plan { LogicalPlan::CreateExternalTable(_) => { @@ -310,6 +311,7 @@ impl QueryPlanner for BallistaQueryPlanner { logical_plan.clone(), self.extension_codec.clone(), self.plan_repr, + session_state.session_id.clone(), ))), } } diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs index 833c06f4617a..bb4edb6b969d 100644 --- a/ballista/rust/executor/src/execution_loop.rs +++ b/ballista/rust/executor/src/execution_loop.rs @@ -16,6 +16,7 @@ // under the License. use std::collections::HashMap; +use std::ops::Deref; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{Receiver, Sender, TryRecvError}; use std::{sync::Arc, time::Duration}; @@ -126,24 +127,36 @@ async fn run_received_tasks = U::try_decode(task.plan.as_slice()).and_then(|proto| { proto.try_into_physical_plan( - executor.ctx.as_ref(), + task_context.deref(), + runtime.deref(), codec.physical_extension_codec(), ) })?; diff --git a/ballista/rust/executor/src/executor.rs b/ballista/rust/executor/src/executor.rs index d6579e2f9ba9..e64e30861d98 100644 --- a/ballista/rust/executor/src/executor.rs +++ b/ballista/rust/executor/src/executor.rs @@ -17,6 +17,7 @@ //! Ballista executor logic +use std::collections::HashMap; use std::sync::Arc; use ballista_core::error::BallistaError; @@ -25,9 +26,11 @@ use ballista_core::serde::protobuf; use ballista_core::serde::protobuf::ExecutorRegistration; use datafusion::error::DataFusionError; use datafusion::execution::context::TaskContext; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::udaf::AggregateUDF; +use datafusion::physical_plan::udf::ScalarUDF; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; -use datafusion::prelude::SessionContext; /// Ballista executor pub struct Executor { @@ -37,8 +40,14 @@ pub struct Executor { /// Directory for storing partial results pub work_dir: String, - /// DataFusion session context - pub ctx: Arc, + /// Scalar functions that are registered in the Executor + pub scalar_functions: HashMap>, + + /// Aggregate functions registered in the Executor + pub aggregate_functions: HashMap>, + + /// Runtime environment for Executor + pub runtime: Arc, } impl Executor { @@ -46,12 +55,15 @@ impl Executor { pub fn new( metadata: ExecutorRegistration, work_dir: &str, - ctx: Arc, + runtime: Arc, ) -> Self { Self { metadata, work_dir: work_dir.to_owned(), - ctx, + // TODO add logic to dynamically load UDF/UDAFs libs from files + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + runtime, } } } diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs index 1ad5f05e7ff3..81910f264bc4 100644 --- a/ballista/rust/executor/src/executor_server.rs +++ b/ballista/rust/executor/src/executor_server.rs @@ -16,6 +16,7 @@ // under the License. use std::collections::HashMap; +use std::ops::Deref; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc; @@ -179,18 +180,29 @@ impl ExecutorServer ExecutorServer = U::try_decode(encoded_plan).and_then(|proto| { proto.try_into_physical_plan( - self.executor.ctx.as_ref(), + task_context.deref(), + runtime.deref(), self.codec.physical_extension_codec(), ) })?; diff --git a/ballista/rust/executor/src/main.rs b/ballista/rust/executor/src/main.rs index c88dbf1fc041..66d9bd71af1c 100644 --- a/ballista/rust/executor/src/main.rs +++ b/ballista/rust/executor/src/main.rs @@ -32,6 +32,7 @@ use tonic::transport::Server; use uuid::Uuid; use ballista_core::config::TaskSchedulingPolicy; +use ballista_core::error::BallistaError; use ballista_core::serde::protobuf::{ executor_registration, scheduler_grpc_client::SchedulerGrpcClient, ExecutorRegistration, LogicalPlanNode, PhysicalPlanNode, @@ -42,7 +43,7 @@ use ballista_core::{print_version, BALLISTA_VERSION}; use ballista_executor::executor::Executor; use ballista_executor::flight_service::BallistaFlightService; use config::prelude::*; -use datafusion::prelude::SessionContext; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; #[macro_use] extern crate configure_me; @@ -111,11 +112,13 @@ async fn main() -> Result<()> { .into(), ), }; - let executor = Arc::new(Executor::new( - executor_meta, - &work_dir, - Arc::new(SessionContext::new()), - )); + + let config = RuntimeConfig::new().with_temp_file_path(work_dir.clone()); + let runtime = Arc::new(RuntimeEnv::new(config).map_err(|_| { + BallistaError::Internal("Failed to init Executor RuntimeEnv".to_owned()) + })?); + + let executor = Arc::new(Executor::new(executor_meta, &work_dir, runtime)); let scheduler = SchedulerGrpcClient::connect(scheduler_url) .await diff --git a/ballista/rust/executor/src/standalone.rs b/ballista/rust/executor/src/standalone.rs index dc55e0180fcf..edbb857cca3a 100644 --- a/ballista/rust/executor/src/standalone.rs +++ b/ballista/rust/executor/src/standalone.rs @@ -27,7 +27,7 @@ use ballista_core::{ serde::protobuf::{scheduler_grpc_client::SchedulerGrpcClient, ExecutorRegistration}, BALLISTA_VERSION, }; -use datafusion::prelude::SessionContext; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use log::info; use tempfile::TempDir; use tokio::net::TcpListener; @@ -71,8 +71,14 @@ pub async fn new_standalone_executor< .into_string() .unwrap(); info!("work_dir: {}", work_dir); - let ctx = Arc::new(SessionContext::new()); - let executor = Arc::new(Executor::new(executor_meta, &work_dir, ctx)); + + let config = RuntimeConfig::new().with_temp_file_path(work_dir.clone()); + + let executor = Arc::new(Executor::new( + executor_meta, + &work_dir, + Arc::new(RuntimeEnv::new(config).unwrap()), + )); let service = BallistaFlightService::new(executor.clone()); let server = FlightServiceServer::new(service); diff --git a/ballista/rust/scheduler/src/main.rs b/ballista/rust/scheduler/src/main.rs index e37f419148ee..4304821823e0 100644 --- a/ballista/rust/scheduler/src/main.rs +++ b/ballista/rust/scheduler/src/main.rs @@ -46,7 +46,6 @@ use ballista_scheduler::state::backend::{StateBackend, StateBackendClient}; use ballista_core::config::TaskSchedulingPolicy; use ballista_core::serde::BallistaCodec; use log::info; -use tokio::sync::RwLock; #[macro_use] extern crate configure_me; @@ -62,7 +61,7 @@ mod config { } use config::prelude::*; -use datafusion::prelude::SessionContext; +use datafusion::execution::context::default_session_builder; async fn start_server( config_backend: Arc, @@ -85,13 +84,12 @@ async fn start_server( config_backend.clone(), namespace.clone(), policy, - Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), + default_session_builder, ), _ => SchedulerServer::new( config_backend.clone(), namespace.clone(), - Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), ), }; diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 3580e81db8a2..d7ee22bbdf35 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -280,6 +280,7 @@ mod test { }; use datafusion::physical_plan::{displayable, ExecutionPlan}; use datafusion::prelude::SessionContext; + use std::ops::Deref; use ballista_core::serde::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use std::sync::Arc; @@ -293,7 +294,7 @@ mod test { #[tokio::test] async fn distributed_hash_aggregate_plan() -> Result<(), BallistaError> { - let mut ctx = datafusion_test_context("testdata").await?; + let ctx = datafusion_test_context("testdata").await?; // simplified form of TPC-H query 1 let df = ctx @@ -380,7 +381,7 @@ mod test { #[tokio::test] async fn distributed_join_plan() -> Result<(), BallistaError> { - let mut ctx = datafusion_test_context("testdata").await?; + let ctx = datafusion_test_context("testdata").await?; // simplified form of TPC-H query 12 let df = ctx @@ -555,7 +556,7 @@ order by #[tokio::test] async fn roundtrip_serde_hash_aggregate() -> Result<(), BallistaError> { - let mut ctx = datafusion_test_context("testdata").await?; + let ctx = datafusion_test_context("testdata").await?; // simplified form of TPC-H query 1 let df = ctx @@ -602,8 +603,12 @@ order by plan.clone(), codec.physical_extension_codec(), )?; - let result_exec_plan: Arc = - (&proto).try_into_physical_plan(&ctx, codec.physical_extension_codec())?; + let runtime = ctx.runtime_env(); + let result_exec_plan: Arc = (&proto).try_into_physical_plan( + &ctx, + runtime.deref(), + codec.physical_extension_codec(), + )?; Ok(result_exec_plan) } } diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs b/ballista/rust/scheduler/src/scheduler_server/grpc.rs index a421bbdbb7ba..eabcb503e9f0 100644 --- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs @@ -16,9 +16,9 @@ // under the License. use anyhow::Context; -use ballista_core::config::TaskSchedulingPolicy; +use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy}; use ballista_core::error::BallistaError; -use ballista_core::serde::protobuf::execute_query_params::Query; +use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Query}; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::protobuf::executor_registration::OptionalHost; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc; @@ -40,13 +40,16 @@ use futures::StreamExt; use log::{debug, error, info, trace, warn}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::convert::TryInto; +use std::ops::Deref; use std::sync::Arc; use std::time::Instant; use std::time::{SystemTime, UNIX_EPOCH}; use tonic::{Request, Response, Status}; use crate::scheduler_server::event::QueryStageSchedulerEvent; -use crate::scheduler_server::SchedulerServer; +use crate::scheduler_server::{ + create_datafusion_context, update_datafusion_context, SchedulerServer, +}; use crate::state::task_scheduler::TaskScheduler; #[tonic::async_trait] @@ -321,31 +324,62 @@ impl SchedulerGrpc &self, request: Request, ) -> std::result::Result, tonic::Status> { + let query_params = request.into_inner(); if let ExecuteQueryParams { query: Some(query), - settings: _, - } = request.into_inner() + settings, + optional_session_id, + } = query_params { - let plan = match query { - Query::LogicalPlan(message) => { - let ctx = self.ctx.read().await; - T::try_decode(message.as_slice()) - .and_then(|m| { - m.try_into_logical_plan( - &ctx, - self.codec.logical_extension_codec(), - ) - }) - .map_err(|e| { - let msg = - format!("Could not parse logical plan protobuf: {}", e); - error!("{}", msg); - tonic::Status::internal(msg) - })? + // parse config + let mut config_builder = BallistaConfig::builder(); + for kv_pair in &settings { + config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); + } + let config = config_builder.build().map_err(|e| { + let msg = format!("Could not parse configs: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + + let df_session = match optional_session_id { + Some(OptionalSessionId::SessionId(session_id)) => { + let session_ctx = self + .state + .session_registry() + .lookup_session(session_id.as_str()) + .await + .expect( + "SessionContext does not exist in SessionContextRegistry.", + ); + update_datafusion_context(session_ctx, &config) } + _ => { + let df_session = + create_datafusion_context(&config, self.session_builder); + self.state + .session_registry() + .register_session(df_session.clone()) + .await; + df_session + } + }; + + let plan = match query { + Query::LogicalPlan(message) => T::try_decode(message.as_slice()) + .and_then(|m| { + m.try_into_logical_plan( + df_session.deref(), + self.codec.logical_extension_codec(), + ) + }) + .map_err(|e| { + let msg = format!("Could not parse logical plan protobuf: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?, Query::Sql(sql) => { - let mut ctx = self.ctx.write().await; - let df = ctx.sql(&sql).await.map_err(|e| { + let df = df_session.sql(&sql).await.map_err(|e| { let msg = format!("Error parsing SQL: {}", e); error!("{}", msg); tonic::Status::internal(msg) @@ -358,6 +392,7 @@ impl SchedulerGrpc // Generate job id. // TODO Maybe the format will be changed in the future let job_id = generate_job_id(); + let session_id = df_session.session_id(); let state = self.state.clone(); let query_stage_event_sender = self.query_stage_event_loop.get_sender().map_err(|e| { @@ -380,8 +415,18 @@ impl SchedulerGrpc tonic::Status::internal(format!("Could not save job metadata: {}", e)) })?; + state + .save_job_session(&job_id, &session_id) + .await + .map_err(|e| { + tonic::Status::internal(format!( + "Could not save job session mapping: {}", + e + )) + })?; + let job_id_spawn = job_id.clone(); - let ctx = self.ctx.read().await.clone(); + let ctx = df_session.clone(); tokio::spawn(async move { if let Err(e) = async { // create physical plan @@ -445,7 +490,32 @@ impl SchedulerGrpc } }); - Ok(Response::new(ExecuteQueryResult { job_id })) + Ok(Response::new(ExecuteQueryResult { job_id, session_id })) + } else if let ExecuteQueryParams { + query: None, + settings, + optional_session_id: None, + } = query_params + { + // parse config for new session + let mut config_builder = BallistaConfig::builder(); + for kv_pair in &settings { + config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); + } + let config = config_builder.build().map_err(|e| { + let msg = format!("Could not parse configs: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; + let df_session = create_datafusion_context(&config, self.session_builder); + self.state + .session_registry() + .register_session(df_session.clone()) + .await; + Ok(Response::new(ExecuteQueryResult { + job_id: "NA".to_owned(), + session_id: df_session.session_id(), + })) } else { Err(tonic::Status::internal("Error parsing request")) } @@ -476,7 +546,6 @@ fn generate_job_id() -> String { #[cfg(all(test, feature = "sled"))] mod test { use std::sync::Arc; - use tokio::sync::RwLock; use tonic::Request; @@ -488,7 +557,6 @@ mod test { }; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::BallistaCodec; - use datafusion::prelude::SessionContext; use super::{SchedulerGrpc, SchedulerServer}; @@ -500,7 +568,6 @@ mod test { SchedulerServer::new( state_storage.clone(), namespace.to_owned(), - Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), ); let exec_meta = ExecutorRegistration { @@ -528,8 +595,7 @@ mod test { namespace.to_string(), BallistaCodec::default(), ); - let ctx = scheduler.ctx.read().await; - state.init(&ctx).await?; + state.init().await?; // executor should be registered assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1); @@ -551,8 +617,7 @@ mod test { namespace.to_string(), BallistaCodec::default(), ); - let ctx = scheduler.ctx.read().await; - state.init(&ctx).await?; + state.init().await?; // executor should be registered assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1); Ok(()) diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs b/ballista/rust/scheduler/src/scheduler_server/mod.rs index 37f95b3ad0cb..2f6220f33c7f 100644 --- a/ballista/rust/scheduler/src/scheduler_server/mod.rs +++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs @@ -28,6 +28,7 @@ use ballista_core::event_loop::EventLoop; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::protobuf::TaskStatus; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; +use datafusion::execution::context::{default_session_builder, SessionState}; use datafusion::prelude::{SessionConfig, SessionContext}; use crate::scheduler_server::event::{QueryStageSchedulerEvent, SchedulerServerEvent}; @@ -49,6 +50,7 @@ mod grpc; mod query_stage_scheduler; type ExecutorsClient = Arc>>>; +type SessionBuilder = fn(SessionConfig) -> SessionState; #[derive(Clone)] pub struct SchedulerServer { @@ -58,23 +60,38 @@ pub struct SchedulerServer, event_loop: Option>, query_stage_event_loop: EventLoop, - ctx: Arc>, codec: BallistaCodec, + /// SessionState Builder + session_builder: SessionBuilder, } impl SchedulerServer { pub fn new( config: Arc, namespace: String, - ctx: Arc>, codec: BallistaCodec, ) -> Self { SchedulerServer::new_with_policy( config, namespace, TaskSchedulingPolicy::PullStaged, - ctx, codec, + default_session_builder, + ) + } + + pub fn new_with_builder( + config: Arc, + namespace: String, + codec: BallistaCodec, + session_builder: SessionBuilder, + ) -> Self { + SchedulerServer::new_with_policy( + config, + namespace, + TaskSchedulingPolicy::PullStaged, + codec, + session_builder, ) } @@ -82,8 +99,8 @@ impl SchedulerServer, namespace: String, policy: TaskSchedulingPolicy, - ctx: Arc>, codec: BallistaCodec, + session_builder: SessionBuilder, ) -> Self { let state = Arc::new(SchedulerState::new(config, namespace, codec.clone())); @@ -115,16 +132,15 @@ impl SchedulerServer Result<()> { { // initialize state - let ctx = self.ctx.read().await; - self.state.init(&ctx).await?; + self.state.init().await?; } { @@ -180,10 +196,19 @@ impl SchedulerServer SessionContext { - let config = - SessionConfig::new().with_target_partitions(config.default_shuffle_partitions()); - SessionContext::with_config(config) +pub fn create_datafusion_context( + config: &BallistaConfig, + session_builder: SessionBuilder, +) -> Arc { + let config = SessionConfig::new() + .with_target_partitions(config.default_shuffle_partitions()) + .with_batch_size(config.default_batch_size()) + .with_repartition_joins(config.repartition_joins()) + .with_repartition_aggregations(config.repartition_aggregations()) + .with_repartition_windows(config.repartition_windows()) + .with_parquet_pruning(config.parquet_pruning()); + let session_state = session_builder(config); + Arc::new(SessionContext::with_state(session_state)) } /// Update the existing DataFusion session context with Ballista Configuration @@ -191,18 +216,69 @@ pub fn update_datafusion_context( session_ctx: Arc, config: &BallistaConfig, ) -> Arc { - session_ctx.state.write().config.target_partitions = - config.default_shuffle_partitions(); + { + let mut mut_state = session_ctx.state.write(); + mut_state.config.target_partitions = config.default_shuffle_partitions(); + mut_state.config.batch_size = config.default_batch_size(); + mut_state.config.repartition_joins = config.repartition_joins(); + mut_state.config.repartition_aggregations = config.repartition_aggregations(); + mut_state.config.repartition_windows = config.repartition_windows(); + mut_state.config.parquet_pruning = config.parquet_pruning(); + } session_ctx } +/// A Registry holds all the datafusion session contexts +pub struct SessionContextRegistry { + /// A map from session_id to SessionContext + pub running_sessions: RwLock>>, +} + +impl Default for SessionContextRegistry { + fn default() -> Self { + Self::new() + } +} + +impl SessionContextRegistry { + /// Create the registry that object stores can registered into. + /// ['LocalFileSystem'] store is registered in by default to support read local files natively. + pub fn new() -> Self { + Self { + running_sessions: RwLock::new(HashMap::new()), + } + } + + /// Adds a new session to this registry. + pub async fn register_session( + &self, + session_ctx: Arc, + ) -> Option> { + let session_id = session_ctx.session_id(); + let mut sessions = self.running_sessions.write().await; + sessions.insert(session_id, session_ctx) + } + + /// Lookup the session context registered + pub async fn lookup_session(&self, session_id: &str) -> Option> { + let sessions = self.running_sessions.read().await; + sessions.get(session_id).cloned() + } + + /// Remove a session from this registry. + pub async fn unregister_session( + &self, + session_id: &str, + ) -> Option> { + let mut sessions = self.running_sessions.write().await; + sessions.remove(session_id) + } +} #[cfg(all(test, feature = "sled"))] mod test { use std::sync::Arc; use std::time::{Duration, Instant}; - use tokio::sync::RwLock; - use ballista_core::config::TaskSchedulingPolicy; use ballista_core::error::{BallistaError, Result}; use ballista_core::execution_plans::ShuffleWriterExec; @@ -213,6 +289,7 @@ mod test { use ballista_core::serde::scheduler::ExecutorData; use ballista_core::serde::BallistaCodec; use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::execution::context::default_session_builder; use datafusion::logical_plan::{col, sum, LogicalPlan, LogicalPlanBuilder}; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -250,9 +327,7 @@ mod test { plan_of_linear_stages: LogicalPlan, total_available_task_slots: usize, ) -> Result<()> { - let config = - SessionConfig::new().with_target_partitions(total_available_task_slots); - let scheduler = test_scheduler(policy, config).await?; + let scheduler = test_scheduler(policy).await?; if matches!(policy, TaskSchedulingPolicy::PushStaged) { let executors = test_executors(total_available_task_slots); for executor_data in executors { @@ -262,9 +337,10 @@ mod test { .save_executor_data(executor_data); } } - + let config = + SessionConfig::new().with_target_partitions(total_available_task_slots); + let ctx = Arc::new(SessionContext::with_config(config)); let plan = async { - let ctx = scheduler.ctx.read().await.clone(); let optimized_plan = ctx.optimize(&plan_of_linear_stages).map_err(|e| { BallistaError::General(format!( "Could not create optimized logical plan: {}", @@ -284,7 +360,15 @@ mod test { .await?; let job_id = "job"; - + scheduler + .state + .session_registry() + .register_session(ctx.clone()) + .await; + scheduler + .state + .save_job_session(job_id, ctx.session_id().as_str()) + .await?; { // verify job submit scheduler @@ -459,7 +543,6 @@ mod test { async fn test_scheduler( policy: TaskSchedulingPolicy, - config: SessionConfig, ) -> Result> { let state_storage = Arc::new(StandaloneClient::try_new_temporary()?); let mut scheduler: SchedulerServer = @@ -467,8 +550,8 @@ mod test { state_storage.clone(), "default".to_owned(), policy, - Arc::new(RwLock::new(SessionContext::with_config(config))), BallistaCodec::default(), + default_session_builder, ); scheduler.init().await?; diff --git a/ballista/rust/scheduler/src/standalone.rs b/ballista/rust/scheduler/src/standalone.rs index b5404e713525..f333587e495d 100644 --- a/ballista/rust/scheduler/src/standalone.rs +++ b/ballista/rust/scheduler/src/standalone.rs @@ -21,11 +21,9 @@ use ballista_core::{ error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer, BALLISTA_VERSION, }; -use datafusion::prelude::SessionContext; use log::info; use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpListener; -use tokio::sync::RwLock; use tonic::transport::Server; use crate::{ @@ -39,7 +37,6 @@ pub async fn new_standalone_scheduler() -> Result { SchedulerServer::new( Arc::new(client), "ballista".to_string(), - Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), ); scheduler_server.init().await?; diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index 4d8d4a04f7b4..711d5b0acb07 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -23,10 +23,10 @@ use datafusion::physical_plan::ExecutionPlan; use ballista_core::error::Result; +use crate::scheduler_server::SessionContextRegistry; use ballista_core::serde::protobuf::{ExecutorHeartbeat, JobStatus}; use ballista_core::serde::scheduler::ExecutorMetadata; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; -use datafusion::prelude::SessionContext; use crate::state::backend::StateBackendClient; use crate::state::executor_manager::ExecutorManager; @@ -64,8 +64,8 @@ impl SchedulerState Result<()> { - self.persistent_state.init(ctx).await?; + pub async fn init(&self) -> Result<()> { + self.persistent_state.init().await?; Ok(()) } @@ -120,6 +120,16 @@ impl SchedulerState Result<()> { + self.persistent_state + .save_job_session(job_id, session_id) + .await + } + + pub fn get_session_from_job(&self, job_id: &str) -> Option { + self.persistent_state.get_session_from_job(job_id) + } + pub async fn save_job_metadata( &self, job_id: &str, @@ -152,6 +162,10 @@ impl SchedulerState Option> { self.persistent_state.get_stage_plan(job_id, stage_id) } + + pub fn session_registry(&self) -> Arc { + self.persistent_state.session_registry() + } } #[cfg(all(test, feature = "sled"))] diff --git a/ballista/rust/scheduler/src/state/persistent_state.rs b/ballista/rust/scheduler/src/state/persistent_state.rs index 3ef07e964327..18abe007431e 100644 --- a/ballista/rust/scheduler/src/state/persistent_state.rs +++ b/ballista/rust/scheduler/src/state/persistent_state.rs @@ -20,18 +20,19 @@ use parking_lot::RwLock; use prost::Message; use std::any::type_name; use std::collections::HashMap; +use std::ops::Deref; use std::sync::Arc; use ballista_core::error::{BallistaError, Result}; use ballista_core::serde::protobuf::JobStatus; +use crate::scheduler_server::SessionContextRegistry; use crate::state::backend::StateBackendClient; use crate::state::stage_manager::StageKey; use ballista_core::serde::scheduler::ExecutorMetadata; use ballista_core::serde::{protobuf, AsExecutionPlan, AsLogicalPlan, BallistaCodec}; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; #[derive(Clone)] pub(crate) struct PersistentSchedulerState< @@ -46,8 +47,13 @@ pub(crate) struct PersistentSchedulerState< // for in-memory cache executors_metadata: Arc>>, + // TODO add remove logic jobs: Arc>>, stages: Arc>>>, + job2session: Arc>>, + + /// DataFusion session contexts that are registered within the Scheduler + session_context_registry: Arc, } impl @@ -65,14 +71,16 @@ impl executors_metadata: Arc::new(RwLock::new(HashMap::new())), jobs: Arc::new(RwLock::new(HashMap::new())), stages: Arc::new(RwLock::new(HashMap::new())), + job2session: Arc::new(RwLock::new(HashMap::new())), + session_context_registry: Arc::new(SessionContextRegistry::default()), } } /// Load the state stored in storage into memory - pub(crate) async fn init(&self, ctx: &SessionContext) -> Result<()> { + pub(crate) async fn init(&self) -> Result<()> { self.init_executors_metadata_from_storage().await?; self.init_jobs_from_storage().await?; - self.init_stages_from_storage(ctx).await?; + self.init_stages_from_storage().await?; Ok(()) } @@ -110,7 +118,7 @@ impl Ok(()) } - async fn init_stages_from_storage(&self, ctx: &SessionContext) -> Result<()> { + async fn init_stages_from_storage(&self) -> Result<()> { let entries = self .config_client .get_from_prefix(&get_stage_prefix(&self.namespace)) @@ -119,9 +127,21 @@ impl let mut stages = self.stages.write(); for (key, entry) in entries { let (job_id, stage_id) = extract_stage_id_from_stage_key(&key).unwrap(); + let session_id = self + .get_session_from_job(&job_id) + .expect("session id does not exist for job"); + let session_ctx = self + .session_context_registry + .lookup_session(&session_id) + .await + .expect("SessionContext does not exist in SessionContextRegistry."); let value = U::try_decode(&entry)?; - let plan = value - .try_into_physical_plan(ctx, self.codec.physical_extension_codec())?; + let runtime = session_ctx.runtime_env(); + let plan = value.try_into_physical_plan( + session_ctx.deref(), + runtime.deref(), + self.codec.physical_extension_codec(), + )?; stages.insert((job_id, stage_id), plan); } @@ -166,6 +186,21 @@ impl executors_metadata.values().cloned().collect() } + pub(crate) async fn save_job_session( + &self, + job_id: &str, + session_id: &str, + ) -> Result<()> { + let mut job2_sess = self.job2session.write(); + job2_sess.insert(job_id.to_string(), session_id.to_string()); + Ok(()) + } + + pub(crate) fn get_session_from_job(&self, job_id: &str) -> Option { + let job_session = self.job2session.read(); + job_session.get(job_id).cloned() + } + pub(crate) async fn save_job_metadata( &self, job_id: &str, @@ -241,6 +276,10 @@ impl Ok(()) } + + pub fn session_registry(&self) -> Arc { + self.session_context_registry.clone() + } } fn get_executors_metadata_prefix(namespace: &str) -> String { @@ -266,7 +305,6 @@ fn get_stage_prefix(namespace: &str) -> String { fn get_stage_plan_key(namespace: &str, job_id: &str, stage_id: u32) -> String { format!("{}/{}/{}", get_stage_prefix(namespace), job_id, stage_id,) } - fn extract_job_id_from_job_key(job_key: &str) -> Result<&str> { job_key.split('/').nth(2).ok_or_else(|| { BallistaError::Internal(format!("Unexpected task key: {}", job_key)) diff --git a/ballista/rust/scheduler/src/state/stage_manager.rs b/ballista/rust/scheduler/src/state/stage_manager.rs index b15978436768..e926c1db444c 100644 --- a/ballista/rust/scheduler/src/state/stage_manager.rs +++ b/ballista/rust/scheduler/src/state/stage_manager.rs @@ -28,6 +28,7 @@ use ballista_core::error::{BallistaError, Result}; use ballista_core::serde::protobuf; use ballista_core::serde::protobuf::{task_status, FailedTask, TaskStatus}; +/// job_id + stage_id pub type StageKey = (String, u32); #[derive(Clone)] diff --git a/ballista/rust/scheduler/src/state/task_scheduler.rs b/ballista/rust/scheduler/src/state/task_scheduler.rs index 1ae66b1b248b..132fbcf1c695 100644 --- a/ballista/rust/scheduler/src/state/task_scheduler.rs +++ b/ballista/rust/scheduler/src/state/task_scheduler.rs @@ -21,7 +21,8 @@ use async_trait::async_trait; use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::serde::protobuf::{ - job_status, task_status, FailedJob, RunningTask, TaskDefinition, TaskStatus, + job_status, task_status, FailedJob, KeyValuePair, RunningTask, TaskDefinition, + TaskStatus, }; use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto; use ballista_core::serde::scheduler::{ExecutorData, PartitionId}; @@ -153,6 +154,22 @@ impl TaskScheduler )) })?; + let session_id = self.get_session_from_job(&job_id).expect("session id does not exist for job"); + let session_props = self + .session_registry() + .lookup_session(&session_id) + .await + .expect("SessionContext does not exist in SessionContextRegistry.") + .copied_config() + .to_props(); + let task_props = session_props + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(); + ret[idx].push(TaskDefinition { plan: buf, task_id, @@ -160,6 +177,8 @@ impl TaskScheduler output_partitioning, ) .map_err(|_| tonic::Status::internal("TBD".to_string()))?, + session_id, + props: task_props, }); executor.available_task_slots -= 1; num_tasks += 1; diff --git a/ballista/rust/scheduler/src/test_utils.rs b/ballista/rust/scheduler/src/test_utils.rs index 9d6e83fec89e..3f8b57670834 100644 --- a/ballista/rust/scheduler/src/test_utils.rs +++ b/ballista/rust/scheduler/src/test_utils.rs @@ -28,7 +28,7 @@ pub const TPCH_TABLES: &[&str] = &[ pub async fn datafusion_test_context(path: &str) -> Result { let default_shuffle_partitions = 2; let config = SessionConfig::new().with_target_partitions(default_shuffle_partitions); - let mut ctx = SessionContext::with_config(config); + let ctx = SessionContext::with_config(config); for table in TPCH_TABLES { let schema = get_tpch_schema(table); let options = CsvReadOptions::new() diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index aa70500d19e3..a2932ebdc271 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -30,7 +30,9 @@ use std::{ }; use ballista::context::BallistaContext; -use ballista::prelude::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS}; +use ballista::prelude::{ + BallistaConfig, BALLISTA_DEFAULT_BATCH_SIZE, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, +}; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::{DataFusionError, Result}; @@ -81,9 +83,10 @@ struct BallistaBenchmarkOpt { #[structopt(short = "i", long = "iterations", default_value = "3")] iterations: usize, - // /// Batch size when reading CSV or Parquet files - // #[structopt(short = "s", long = "batch-size", default_value = "8192")] - // batch_size: usize, + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + batch_size: usize, + /// Path to data files #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] path: PathBuf, @@ -172,6 +175,10 @@ struct BallistaLoadtestOpt { #[structopt(short = "n", long = "partitions", default_value = "2")] partitions: usize, + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + batch_size: usize, + /// Path to data files #[structopt(parse(from_os_str), required = true, short = "p", long = "data-path")] path: PathBuf, @@ -274,7 +281,7 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result Result = Vec::with_capacity(1); for i in 0..opt.iterations { let start = Instant::now(); - let plan = create_logical_plan(&mut ctx, opt.query)?; + let plan = create_logical_plan(&ctx, opt.query)?; result = execute_query(&ctx, &plan, opt.debug).await?; let elapsed = start.elapsed().as_secs_f64() * 1000.0; millis.push(elapsed as f64); @@ -337,11 +344,14 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> { BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, &format!("{}", opt.partitions), ) + .set(BALLISTA_DEFAULT_BATCH_SIZE, &format!("{}", opt.batch_size)) .build() .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; let ctx = - BallistaContext::remote(opt.host.unwrap().as_str(), opt.port.unwrap(), &config); + BallistaContext::remote(opt.host.unwrap().as_str(), opt.port.unwrap(), &config) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; // register tables with Ballista context let path = opt.path.to_str().unwrap(); @@ -417,6 +427,7 @@ async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> { BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, &format!("{}", opt.partitions), ) + .set(BALLISTA_DEFAULT_BATCH_SIZE, &format!("{}", opt.batch_size)) .build() .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -425,11 +436,15 @@ async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> { let mut clients = vec![]; for _num in 0..concurrency { - clients.push(BallistaContext::remote( - opt.host.clone().unwrap().as_str(), - opt.port.unwrap(), - &config, - )); + clients.push( + BallistaContext::remote( + opt.host.clone().unwrap().as_str(), + opt.port.unwrap(), + &config, + ) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?, + ); } // register tables with Ballista context @@ -580,7 +595,7 @@ fn get_query_sql(query: usize) -> Result { } } -fn create_logical_plan(ctx: &mut SessionContext, query: usize) -> Result { +fn create_logical_plan(ctx: &SessionContext, query: usize) -> Result { let sql = get_query_sql(query)?; ctx.create_logical_plan(&sql) } @@ -629,7 +644,7 @@ async fn convert_tbl(opt: ConvertOpt) -> Result<()> { .file_extension(".tbl"); let config = SessionConfig::new().with_batch_size(opt.batch_size); - let mut ctx = SessionContext::with_config(config); + let ctx = SessionContext::with_config(config); // build plan to read the TBL file let mut csv = ctx.read_csv(&input_path, options).await?; @@ -1281,7 +1296,7 @@ mod tests { let config = SessionConfig::new() .with_target_partitions(1) .with_batch_size(10); - let mut ctx = SessionContext::with_config(config); + let ctx = SessionContext::with_config(config); for &table in TABLES { let schema = get_schema(table); @@ -1292,7 +1307,7 @@ mod tests { ctx.register_table(table, Arc::new(provider))?; } - let plan = create_logical_plan(&mut ctx, n)?; + let plan = create_logical_plan(&ctx, n)?; execute_query(&ctx, &plan, false).await?; Ok(()) @@ -1303,7 +1318,7 @@ mod tests { // load expected answers from tpch-dbgen // read csv as all strings, trim and cast to expected type as the csv string // to value parser does not handle data with leading/trailing spaces - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = string_schema(get_answer_schema(n)); let options = CsvReadOptions::new() .schema(&schema) @@ -1373,12 +1388,13 @@ mod tests { protobuf, AsExecutionPlan, AsLogicalPlan, BallistaCodec, }; use datafusion::physical_plan::ExecutionPlan; + use std::ops::Deref; async fn round_trip_query(n: usize) -> Result<()> { let config = SessionConfig::new() .with_target_partitions(1) .with_batch_size(10); - let mut ctx = SessionContext::with_config(config); + let ctx = SessionContext::with_config(config); let codec: BallistaCodec< protobuf::LogicalPlanNode, protobuf::PhysicalPlanNode, @@ -1408,7 +1424,7 @@ mod tests { } // test logical plan round trip - let plan = create_logical_plan(&mut ctx, n)?; + let plan = create_logical_plan(&ctx, n)?; let proto: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( &plan, @@ -1450,8 +1466,13 @@ mod tests { codec.physical_extension_codec(), ) .unwrap(); + let runtime = ctx.runtime_env(); let round_trip: Arc = (&proto) - .try_into_physical_plan(&ctx, codec.physical_extension_codec()) + .try_into_physical_plan( + &ctx, + runtime.deref(), + codec.physical_extension_codec(), + ) .unwrap(); assert_eq!( format!("{:?}", physical_plan), diff --git a/datafusion-cli/src/context.rs b/datafusion-cli/src/context.rs index dc609bc68ea8..6b76baffa04f 100644 --- a/datafusion-cli/src/context.rs +++ b/datafusion-cli/src/context.rs @@ -32,8 +32,8 @@ pub enum Context { impl Context { /// create a new remote context with given host and port - pub fn new_remote(host: &str, port: u16) -> Result { - Ok(Context::Remote(BallistaContext::try_new(host, port)?)) + pub async fn new_remote(host: &str, port: u16) -> Result { + Ok(Context::Remote(BallistaContext::try_new(host, port).await?)) } /// create a local context using the given config @@ -56,12 +56,15 @@ impl Context { pub struct BallistaContext(ballista::context::BallistaContext); #[cfg(feature = "ballista")] impl BallistaContext { - pub fn try_new(host: &str, port: u16) -> Result { + pub async fn try_new(host: &str, port: u16) -> Result { use ballista::context::BallistaContext; use ballista::prelude::BallistaConfig; let config: BallistaConfig = BallistaConfig::new() .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - Ok(Self(BallistaContext::remote(host, port, &config))) + let remote_ctx = BallistaContext::remote(host, port, &config) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + Ok(Self(remote_ctx)) } pub async fn sql(&mut self, sql: &str) -> Result> { self.0.sql(sql).await @@ -72,7 +75,7 @@ impl BallistaContext { pub struct BallistaContext(); #[cfg(not(feature = "ballista"))] impl BallistaContext { - pub fn try_new(_host: &str, _port: u16) -> Result { + pub async fn try_new(_host: &str, _port: u16) -> Result { Err(DataFusionError::NotImplemented( "Remote execution not supported. Compile with feature 'ballista' to enable" .to_string(), diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 92b997d41e0d..d76fe38e5fb6 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -105,7 +105,7 @@ pub async fn main() -> Result<()> { }; let mut ctx: Context = match (args.host, args.port) { - (Some(ref h), Some(p)) => Context::new_remote(h, p)?, + (Some(ref h), Some(p)) => Context::new_remote(h, p).await?, _ => Context::new_local(&session_config), }; diff --git a/datafusion-examples/examples/avro_sql.rs b/datafusion-examples/examples/avro_sql.rs index 9a2189f03ec3..93804c910370 100644 --- a/datafusion-examples/examples/avro_sql.rs +++ b/datafusion-examples/examples/avro_sql.rs @@ -25,7 +25,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::arrow_test_data(); diff --git a/datafusion-examples/examples/csv_sql.rs b/datafusion-examples/examples/csv_sql.rs index 59c6fa072e4e..0dfc288c7510 100644 --- a/datafusion-examples/examples/csv_sql.rs +++ b/datafusion-examples/examples/csv_sql.rs @@ -23,7 +23,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::arrow_test_data(); diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 655c568804f9..982f1b4a9cb9 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -23,7 +23,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/dataframe_in_memory.rs b/datafusion-examples/examples/dataframe_in_memory.rs index 67504e9b5baa..0cbca7ba5416 100644 --- a/datafusion-examples/examples/dataframe_in_memory.rs +++ b/datafusion-examples/examples/dataframe_in_memory.rs @@ -44,7 +44,7 @@ async fn main() -> Result<()> { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight_server.rs index d8b0405f19b9..fcdec561ae38 100644 --- a/datafusion-examples/examples/flight_server.rs +++ b/datafusion-examples/examples/flight_server.rs @@ -90,7 +90,7 @@ impl FlightService for FlightServiceImpl { println!("do_get: {}", sql); // create local execution context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index ccae8c500542..11793a837f3a 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -31,7 +31,7 @@ async fn main() -> Result<()> { let mem_table = create_memtable()?; // create local execution context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // Register the in-memory table containing the data ctx.register_table("users", Arc::new(mem_table))?; diff --git a/datafusion-examples/examples/parquet_sql.rs b/datafusion-examples/examples/parquet_sql.rs index 6ab24ff41ab6..9d3dd4967507 100644 --- a/datafusion-examples/examples/parquet_sql.rs +++ b/datafusion-examples/examples/parquet_sql.rs @@ -23,7 +23,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local session context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs index ee6b9e9c054a..b948f16835ac 100644 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -28,7 +28,7 @@ use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 2b69e8d1e651..f3829c9ffa11 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -46,7 +46,7 @@ fn create_context() -> Result { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 70ad04e7e695..321adcf1fc85 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -49,7 +49,7 @@ fn create_context() -> Result { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion-proto/src/from_proto.rs b/datafusion-proto/src/from_proto.rs index 17f03b923694..dd8b13547da5 100644 --- a/datafusion-proto/src/from_proto.rs +++ b/datafusion-proto/src/from_proto.rs @@ -16,8 +16,8 @@ // under the License. use crate::protobuf; -use datafusion::prelude::{bit_length, SessionContext}; -use datafusion::sql::planner::ContextProvider; +use datafusion::logical_plan::FunctionRegistry; +use datafusion::prelude::bit_length; use datafusion::{ arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}, error::DataFusionError, @@ -839,7 +839,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { pub fn parse_expr( proto: &protobuf::LogicalExprNode, - ctx: &SessionContext, + registry: &dyn FunctionRegistry, ) -> Result { use datafusion::physical_plan::window_functions; use protobuf::{logical_expr_node::ExprType, window_expr_node, ScalarFunction}; @@ -851,9 +851,9 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => Ok(Expr::BinaryExpr { - left: Box::new(parse_required_expr(&binary_expr.l, ctx, "l")?), + left: Box::new(parse_required_expr(&binary_expr.l, registry, "l")?), op: from_proto_binary_op(&binary_expr.op)?, - right: Box::new(parse_required_expr(&binary_expr.r, ctx, "r")?), + right: Box::new(parse_required_expr(&binary_expr.r, registry, "r")?), }), ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { @@ -868,12 +868,12 @@ pub fn parse_expr( let partition_by = expr .partition_by .iter() - .map(|e| parse_expr(e, ctx)) + .map(|e| parse_expr(e, registry)) .collect::, _>>()?; let order_by = expr .order_by .iter() - .map(|e| parse_expr(e, ctx)) + .map(|e| parse_expr(e, registry)) .collect::, _>>()?; let window_frame = expr .window_frame @@ -900,7 +900,7 @@ pub fn parse_expr( fun: window_functions::WindowFunction::AggregateFunction( aggr_function, ), - args: vec![parse_required_expr(&expr.expr, ctx, "expr")?], + args: vec![parse_required_expr(&expr.expr, registry, "expr")?], partition_by, order_by, window_frame, @@ -915,7 +915,7 @@ pub fn parse_expr( fun: window_functions::WindowFunction::BuiltInWindowFunction( built_in_function, ), - args: vec![parse_required_expr(&expr.expr, ctx, "expr")?], + args: vec![parse_required_expr(&expr.expr, registry, "expr")?], partition_by, order_by, window_frame, @@ -931,31 +931,31 @@ pub fn parse_expr( args: expr .expr .iter() - .map(|e| parse_expr(e, ctx)) + .map(|e| parse_expr(e, registry)) .collect::, _>>()?, distinct: false, // TODO }) } ExprType::Alias(alias) => Ok(Expr::Alias( - Box::new(parse_required_expr(&alias.expr, ctx, "expr")?), + Box::new(parse_required_expr(&alias.expr, registry, "expr")?), alias.alias.clone(), )), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( &is_null.expr, - ctx, + registry, "expr", )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(&is_not_null.expr, ctx, "expr")?, + parse_required_expr(&is_not_null.expr, registry, "expr")?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( - ¬.expr, ctx, "expr", + ¬.expr, registry, "expr", )?))), ExprType::Between(between) => Ok(Expr::Between { - expr: Box::new(parse_required_expr(&between.expr, ctx, "expr")?), + expr: Box::new(parse_required_expr(&between.expr, registry, "expr")?), negated: between.negated, - low: Box::new(parse_required_expr(&between.low, ctx, "expr")?), - high: Box::new(parse_required_expr(&between.high, ctx, "expr")?), + low: Box::new(parse_required_expr(&between.low, registry, "expr")?), + high: Box::new(parse_required_expr(&between.high, registry, "expr")?), }), ExprType::Case(case) => { let when_then_expr = case @@ -963,42 +963,42 @@ pub fn parse_expr( .iter() .map(|e| { let when_expr = - parse_required_expr_inner(&e.when_expr, ctx, "when_expr")?; + parse_required_expr_inner(&e.when_expr, registry, "when_expr")?; let then_expr = - parse_required_expr_inner(&e.then_expr, ctx, "then_expr")?; + parse_required_expr_inner(&e.then_expr, registry, "then_expr")?; Ok((Box::new(when_expr), Box::new(then_expr))) }) .collect::, Box)>, Error>>()?; Ok(Expr::Case { - expr: parse_optional_expr(&case.expr, ctx)?.map(Box::new), + expr: parse_optional_expr(&case.expr, registry)?.map(Box::new), when_then_expr, - else_expr: parse_optional_expr(&case.else_expr, ctx)?.map(Box::new), + else_expr: parse_optional_expr(&case.else_expr, registry)?.map(Box::new), }) } ExprType::Cast(cast) => { - let expr = Box::new(parse_required_expr(&cast.expr, ctx, "expr")?); + let expr = Box::new(parse_required_expr(&cast.expr, registry, "expr")?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::Cast { expr, data_type }) } ExprType::TryCast(cast) => { - let expr = Box::new(parse_required_expr(&cast.expr, ctx, "expr")?); + let expr = Box::new(parse_required_expr(&cast.expr, registry, "expr")?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::TryCast { expr, data_type }) } ExprType::Sort(sort) => Ok(Expr::Sort { - expr: Box::new(parse_required_expr(&sort.expr, ctx, "expr")?), + expr: Box::new(parse_required_expr(&sort.expr, registry, "expr")?), asc: sort.asc, nulls_first: sort.nulls_first, }), ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(&negative.expr, ctx, "expr")?, + parse_required_expr(&negative.expr, registry, "expr")?, ))), ExprType::InList(in_list) => Ok(Expr::InList { - expr: Box::new(parse_required_expr(&in_list.expr, ctx, "expr")?), + expr: Box::new(parse_required_expr(&in_list.expr, registry, "expr")?), list: in_list .list .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, negated: in_list.negated, }), @@ -1009,159 +1009,162 @@ pub fn parse_expr( let args = &expr.args; match scalar_function { - ScalarFunction::Asin => Ok(asin(parse_expr(&args[0], ctx)?)), - ScalarFunction::Acos => Ok(acos(parse_expr(&args[0], ctx)?)), + ScalarFunction::Asin => Ok(asin(parse_expr(&args[0], registry)?)), + ScalarFunction::Acos => Ok(acos(parse_expr(&args[0], registry)?)), ScalarFunction::Array => Ok(array( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), - ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], ctx)?)), - ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], ctx)?)), - ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], ctx)?)), - ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], ctx)?)), - ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], ctx)?)), - ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], ctx)?)), - ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], ctx)?)), - ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], ctx)?)), - ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], ctx)?)), - ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], ctx)?)), - ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], ctx)?)), - ScalarFunction::Round => Ok(round(parse_expr(&args[0], ctx)?)), - ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], ctx)?)), - ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], ctx)?)), - ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], ctx)?)), + ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), + ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), + ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry)?)), + ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry)?)), + ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry)?)), + ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry)?)), + ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry)?)), + ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry)?)), + ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)), + ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)), + ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)), + ScalarFunction::Round => Ok(round(parse_expr(&args[0], registry)?)), + ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], registry)?)), + ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], registry)?)), + ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)), ScalarFunction::OctetLength => { - Ok(octet_length(parse_expr(&args[0], ctx)?)) + Ok(octet_length(parse_expr(&args[0], registry)?)) } - ScalarFunction::Lower => Ok(lower(parse_expr(&args[0], ctx)?)), - ScalarFunction::Upper => Ok(upper(parse_expr(&args[0], ctx)?)), - ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], ctx)?)), - ScalarFunction::Ltrim => Ok(ltrim(parse_expr(&args[0], ctx)?)), - ScalarFunction::Rtrim => Ok(rtrim(parse_expr(&args[0], ctx)?)), + ScalarFunction::Lower => Ok(lower(parse_expr(&args[0], registry)?)), + ScalarFunction::Upper => Ok(upper(parse_expr(&args[0], registry)?)), + ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry)?)), + ScalarFunction::Ltrim => Ok(ltrim(parse_expr(&args[0], registry)?)), + ScalarFunction::Rtrim => Ok(rtrim(parse_expr(&args[0], registry)?)), ScalarFunction::DatePart => Ok(date_part( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), ScalarFunction::DateTrunc => Ok(date_trunc( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), - ScalarFunction::Sha224 => Ok(sha224(parse_expr(&args[0], ctx)?)), - ScalarFunction::Sha256 => Ok(sha256(parse_expr(&args[0], ctx)?)), - ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], ctx)?)), - ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], ctx)?)), - ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], ctx)?)), - ScalarFunction::NullIf => Ok(nullif(parse_expr(&args[0], ctx)?)), + ScalarFunction::Sha224 => Ok(sha224(parse_expr(&args[0], registry)?)), + ScalarFunction::Sha256 => Ok(sha256(parse_expr(&args[0], registry)?)), + ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), + ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), + ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::NullIf => Ok(nullif(parse_expr(&args[0], registry)?)), ScalarFunction::Digest => Ok(digest( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), - ScalarFunction::Ascii => Ok(ascii(parse_expr(&args[0], ctx)?)), - ScalarFunction::BitLength => Ok(bit_length(parse_expr(&args[0], ctx)?)), - ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], ctx)?)) + ScalarFunction::Ascii => Ok(ascii(parse_expr(&args[0], registry)?)), + ScalarFunction::BitLength => { + Ok(bit_length(parse_expr(&args[0], registry)?)) } - ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], ctx)?)), - ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0], ctx)?)), - ScalarFunction::Left => { - Ok(left(parse_expr(&args[0], ctx)?, parse_expr(&args[1], ctx)?)) + ScalarFunction::CharacterLength => { + Ok(character_length(parse_expr(&args[0], registry)?)) } + ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry)?)), + ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0], registry)?)), + ScalarFunction::Left => Ok(left( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Random => Ok(random()), ScalarFunction::Repeat => Ok(repeat( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), ScalarFunction::Replace => Ok(replace( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, - parse_expr(&args[2], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, )), - ScalarFunction::Reverse => Ok(reverse(parse_expr(&args[0], ctx)?)), + ScalarFunction::Reverse => Ok(reverse(parse_expr(&args[0], registry)?)), ScalarFunction::Right => Ok(right( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), ScalarFunction::Concat => Ok(concat_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::Lpad => Ok(lpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::Rpad => Ok(rpad( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::RegexpReplace => Ok(regexp_replace( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::RegexpMatch => Ok(regexp_match( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::Btrim => Ok(btrim( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::SplitPart => Ok(split_part( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, - parse_expr(&args[2], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, )), ScalarFunction::StartsWith => Ok(starts_with( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), ScalarFunction::Substr => Ok(substr( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), - ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], ctx)?)), + ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), ScalarFunction::ToTimestampMillis => { - Ok(to_timestamp_millis(parse_expr(&args[0], ctx)?)) + Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) } ScalarFunction::ToTimestampMicros => { - Ok(to_timestamp_micros(parse_expr(&args[0], ctx)?)) + Ok(to_timestamp_micros(parse_expr(&args[0], registry)?)) } ScalarFunction::ToTimestampSeconds => { - Ok(to_timestamp_seconds(parse_expr(&args[0], ctx)?)) + Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) } ScalarFunction::Now => Ok(now_expr( args.to_owned() .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), ScalarFunction::Translate => Ok(translate( - parse_expr(&args[0], ctx)?, - parse_expr(&args[1], ctx)?, - parse_expr(&args[2], ctx)?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, )), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", @@ -1169,30 +1172,23 @@ pub fn parse_expr( } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { - let scalar_fn = ctx - .state - .read() - .get_function_meta(fun_name.as_str()).ok_or_else(|| Error::General(format!("invalid aggregate function message, function {} is not registered in the ExecutionContext", fun_name)))?; - + let scalar_fn = registry.udf(fun_name.as_str())?; Ok(Expr::ScalarUDF { fun: scalar_fn, args: args .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, }) } ExprType::AggregateUdfExpr(protobuf::AggregateUdfExprNode { fun_name, args }) => { - let agg_fn = ctx - .state - .read() - .get_aggregate_meta(fun_name.as_str()).ok_or_else(|| Error::General(format!("invalid aggregate function message, function {} is not registered in the ExecutionContext", fun_name)))?; + let agg_fn = registry.udaf(fun_name.as_str())?; Ok(Expr::AggregateUDF { fun: agg_fn, args: args .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, }) } @@ -1464,32 +1460,32 @@ fn from_proto_binary_op(op: &str) -> Result { fn parse_optional_expr( p: &Option>, - ctx: &SessionContext, + registry: &dyn FunctionRegistry, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr.as_ref(), ctx).map(Some), + Some(expr) => parse_expr(expr.as_ref(), registry).map(Some), None => Ok(None), } } fn parse_required_expr( p: &Option>, - ctx: &SessionContext, + registry: &dyn FunctionRegistry, field: impl Into, ) -> Result { match p { - Some(expr) => parse_expr(expr.as_ref(), ctx), + Some(expr) => parse_expr(expr.as_ref(), registry), None => Err(Error::required(field)), } } fn parse_required_expr_inner( p: &Option, - ctx: &SessionContext, + registry: &dyn FunctionRegistry, field: impl Into, ) -> Result { match p { - Some(expr) => parse_expr(expr, ctx), + Some(expr) => parse_expr(expr, registry), None => Err(Error::required(field)), } } diff --git a/datafusion/benches/aggregate_query_sql.rs b/datafusion/benches/aggregate_query_sql.rs index 78e914cb3984..807e64ff5e27 100644 --- a/datafusion/benches/aggregate_query_sql.rs +++ b/datafusion/benches/aggregate_query_sql.rs @@ -40,7 +40,7 @@ fn create_context( array_len: usize, batch_size: usize, ) -> Result>> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; ctx.register_table("t", provider)?; Ok(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/benches/filter_query_sql.rs b/datafusion/benches/filter_query_sql.rs index e3401b11a903..8dc981700598 100644 --- a/datafusion/benches/filter_query_sql.rs +++ b/datafusion/benches/filter_query_sql.rs @@ -57,7 +57,7 @@ fn create_context(array_len: usize, batch_size: usize) -> Result }) .collect::>(); - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; diff --git a/datafusion/benches/math_query_sql.rs b/datafusion/benches/math_query_sql.rs index 88d486a5dcc5..11b107f02f4b 100644 --- a/datafusion/benches/math_query_sql.rs +++ b/datafusion/benches/math_query_sql.rs @@ -69,7 +69,7 @@ fn create_context( }) .collect::>(); - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; diff --git a/datafusion/benches/parquet_query_sql.rs b/datafusion/benches/parquet_query_sql.rs index 183f64739069..5416e2c8edb9 100644 --- a/datafusion/benches/parquet_query_sql.rs +++ b/datafusion/benches/parquet_query_sql.rs @@ -193,7 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { assert!(Path::new(&file_path).exists(), "path not found"); println!("Using parquet file {}", file_path); - let mut context = SessionContext::new(); + let context = SessionContext::new(); let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap(); rt.block_on(context.register_parquet("t", file_path.as_str())) @@ -219,7 +219,7 @@ fn criterion_benchmark(c: &mut Criterion) { let query = query.as_str(); c.bench_function(query, |b| { b.iter(|| { - let mut context = context.clone(); + let context = context.clone(); rt.block_on(async move { let query = context.sql(query).await.unwrap(); let mut stream = query.execute_stream().await.unwrap(); diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index e3806f22d0d5..bad723a1af44 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -83,7 +83,7 @@ fn create_context() -> Arc> { rt.block_on(async { // create local session context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.state.write().config.target_partitions = 1; let task_ctx = ctx.task_ctx(); diff --git a/datafusion/benches/window_query_sql.rs b/datafusion/benches/window_query_sql.rs index f76958bbbd32..42a1e51be361 100644 --- a/datafusion/benches/window_query_sql.rs +++ b/datafusion/benches/window_query_sql.rs @@ -40,7 +40,7 @@ fn create_context( array_len: usize, batch_size: usize, ) -> Result>> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; ctx.register_table("t", provider)?; Ok(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/src/catalog/schema.rs b/datafusion/src/catalog/schema.rs index 73e23bc5984a..62f37729c31b 100644 --- a/datafusion/src/catalog/schema.rs +++ b/datafusion/src/catalog/schema.rs @@ -292,7 +292,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); catalog.register_schema("active", Arc::new(schema)).unwrap(); - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_catalog("cat", Arc::new(catalog)); diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index 28f845dc899a..c4c4b838acc6 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -59,7 +59,7 @@ use std::any::Any; /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = SessionContext::new(); +/// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))? /// .aggregate(vec![col("a")], vec![min(col("b"))])? @@ -97,7 +97,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.select_columns(&["a", "b"])?; /// # Ok(()) @@ -119,7 +119,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.select(vec![col("a") * col("b"), col("c")])?; /// # Ok(()) @@ -147,7 +147,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))?; /// # Ok(()) @@ -167,7 +167,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// /// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a" @@ -196,7 +196,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.limit(100)?; /// # Ok(()) @@ -216,7 +216,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.union(df.clone())?; /// # Ok(()) @@ -236,7 +236,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.union(df.clone())?; /// let df = df.distinct()?; @@ -260,7 +260,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; /// # Ok(()) @@ -280,7 +280,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await? /// .select(vec![ @@ -318,7 +318,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df1 = df.repartition(Partitioning::RoundRobinBatch(4))?; /// # Ok(()) @@ -342,7 +342,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.collect().await?; /// # Ok(()) @@ -361,7 +361,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// df.show().await?; /// # Ok(()) @@ -379,7 +379,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// df.show_limit(10).await?; /// # Ok(()) @@ -397,7 +397,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let stream = df.execute_stream().await?; /// # Ok(()) @@ -417,7 +417,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.collect_partitioned().await?; /// # Ok(()) @@ -436,7 +436,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.execute_stream_partitioned().await?; /// # Ok(()) @@ -458,7 +458,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let schema = df.schema(); /// # Ok(()) @@ -482,7 +482,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.limit(100)?.explain(false, false)?.collect().await?; /// # Ok(()) @@ -502,7 +502,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let f = df.registry(); /// // use f.udf("name", vec![...]) to use the udf @@ -521,7 +521,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.intersect(df.clone())?; /// # Ok(()) @@ -543,7 +543,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.except(df.clone())?; /// # Ok(()) @@ -878,7 +878,7 @@ mod tests { #[tokio::test] async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.to_logical_plan())); // register a dataframe as a table diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs index 10b3b488b0d0..d4ae381bede1 100644 --- a/datafusion/src/datasource/listing/helpers.rs +++ b/datafusion/src/datasource/listing/helpers.rs @@ -247,7 +247,7 @@ pub async fn pruned_partition_list( // Filter the partitions using a local datafusion context // TODO having the external context would allow us to resolve `Volatility::Stable` // scalar functions (`ScalarFunction` & `ScalarUDF`) and `ScalarVariable`s - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let mut df = ctx.read_table(Arc::new(mem_table))?; for filter in applicable_filters { df = df.filter(filter.clone())?; diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 5aebe99070f5..aa17d57309b8 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -116,7 +116,7 @@ const DEFAULT_SCHEMA: &str = "public"; /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = SessionContext::new(); +/// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))? /// .aggregate(vec![col("a")], vec![min(col("b"))])? @@ -170,7 +170,7 @@ impl SessionContext { /// Creates a new session context using the provided configuration and RuntimeEnv. pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let state = SessionState::with_config(config, runtime); + let state = SessionState::with_config_rt(config, runtime); Self { session_id: state.session_id.clone(), session_start_time: chrono::Utc::now(), @@ -192,6 +192,11 @@ impl SessionContext { self.state.read().runtime_env.clone() } + /// Return the session_id of this Session + pub fn session_id(&self) -> String { + self.session_id.clone() + } + /// Return a copied version of config for this Session pub fn copied_config(&self) -> SessionConfig { self.state.read().config.clone() @@ -201,7 +206,7 @@ impl SessionContext { /// /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` /// might require the schema to be inferred. - pub async fn sql(&mut self, sql: &str) -> Result> { + pub async fn sql(&self, sql: &str) -> Result> { let plan = self.create_logical_plan(sql)?; match plan { LogicalPlan::CreateExternalTable(CreateExternalTable { @@ -381,7 +386,7 @@ impl SessionContext { /// Creates a DataFrame for reading an Avro data source. pub async fn read_avro( - &mut self, + &self, uri: impl Into, options: AvroReadOptions<'_>, ) -> Result> { @@ -435,7 +440,7 @@ impl SessionContext { /// Creates a DataFrame for reading a CSV data source. pub async fn read_csv( - &mut self, + &self, uri: impl Into, options: CsvReadOptions<'_>, ) -> Result> { @@ -457,10 +462,7 @@ impl SessionContext { } /// Creates a DataFrame for reading a Parquet data source. - pub async fn read_parquet( - &mut self, - uri: impl Into, - ) -> Result> { + pub async fn read_parquet(&self, uri: impl Into) -> Result> { let uri: String = uri.into(); let (object_store, path) = self.runtime_env().object_store(&uri)?; let target_partitions = self.copied_config().target_partitions; @@ -472,10 +474,7 @@ impl SessionContext { } /// Creates a DataFrame for reading a custom TableProvider. - pub fn read_table( - &mut self, - provider: Arc, - ) -> Result> { + pub fn read_table(&self, provider: Arc) -> Result> { Ok(Arc::new(DataFrame::new( self.state.clone(), &LogicalPlanBuilder::scan(UNNAMED_TABLE, provider, None)?.build()?, @@ -486,7 +485,7 @@ impl SessionContext { /// find the files to be processed /// This is async because it might need to resolve the schema. pub async fn register_listing_table<'a>( - &'a mut self, + &'a self, name: &'a str, uri: &'a str, options: ListingOptions, @@ -512,7 +511,7 @@ impl SessionContext { /// Registers a CSV data source so that it can be referenced from SQL statements /// executed against this context. pub async fn register_csv( - &mut self, + &self, name: &str, uri: &str, options: CsvReadOptions<'_>, @@ -534,7 +533,7 @@ impl SessionContext { // Registers a Json data source so that it can be referenced from SQL statements /// executed against this context. pub async fn register_json( - &mut self, + &self, name: &str, uri: &str, options: NdJsonReadOptions<'_>, @@ -549,7 +548,7 @@ impl SessionContext { /// Registers a Parquet data source so that it can be referenced from SQL statements /// executed against this context. - pub async fn register_parquet(&mut self, name: &str, uri: &str) -> Result<()> { + pub async fn register_parquet(&self, name: &str, uri: &str) -> Result<()> { let (target_partitions, enable_pruning) = { let conf = self.copied_config(); (conf.target_partitions, conf.parquet_pruning) @@ -572,7 +571,7 @@ impl SessionContext { /// Registers an Avro data source so that it can be referenced from SQL statements /// executed against this context. pub async fn register_avro( - &mut self, + &self, name: &str, uri: &str, options: AvroReadOptions<'_>, @@ -623,7 +622,7 @@ impl SessionContext { /// Returns the `TableProvider` previously registered for this /// reference, if any pub fn register_table<'a>( - &'a mut self, + &'a self, table_ref: impl Into>, provider: Arc, ) -> Result>> { @@ -638,7 +637,7 @@ impl SessionContext { /// /// Returns the registered provider, if any pub fn deregister_table<'a>( - &'a mut self, + &'a self, table_ref: impl Into>, ) -> Result>> { let table_ref = table_ref.into(); @@ -947,6 +946,33 @@ impl SessionConfig { self.parquet_pruning = enabled; self } + + /// Convert configuration to name-value pairs + pub fn to_props(&self) -> HashMap { + let mut map = HashMap::new(); + map.insert(BATCH_SIZE.to_owned(), format!("{}", self.batch_size)); + map.insert( + TARGET_PARTITIONS.to_owned(), + format!("{}", self.target_partitions), + ); + map.insert( + REPARTITION_JOINS.to_owned(), + format!("{}", self.repartition_joins), + ); + map.insert( + REPARTITION_AGGREGATIONS.to_owned(), + format!("{}", self.repartition_aggregations), + ); + map.insert( + REPARTITION_WINDOWS.to_owned(), + format!("{}", self.repartition_windows), + ); + map.insert( + PARQUET_PRUNING.to_owned(), + format!("{}", self.parquet_pruning), + ); + map + } } /// Holds per-execution properties and data (such as starting timestamps, etc). @@ -1015,7 +1041,7 @@ impl ExecutionProps { #[derive(Clone)] pub struct SessionState { /// Uuid for the session - session_id: String, + pub session_id: String, /// Responsible for optimizing a logical plan pub optimizers: Vec>, /// Responsible for optimizing a physical execution plan @@ -1036,9 +1062,17 @@ pub struct SessionState { pub runtime_env: Arc, } +/// Default session builder using the provided configuration +pub fn default_session_builder(config: SessionConfig) -> SessionState { + SessionState::with_config_rt( + config, + Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()), + ) +} + impl SessionState { /// Returns new SessionState using the provided configuration and runtime - pub fn with_config(config: SessionConfig, runtime: Arc) -> Self { + pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { let session_id = Uuid::new_v4().to_string(); let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; @@ -1309,13 +1343,17 @@ pub enum TaskProperties { /// Task Execution Context pub struct TaskContext { /// Session Id - pub session_id: String, + session_id: String, /// Optional Task Identify - pub task_id: Option, + task_id: Option, /// Task properties - pub properties: TaskProperties, + properties: TaskProperties, + /// Scalar functions associated with this task context + scalar_functions: HashMap>, + /// Aggregate functions associated with this task context + aggregate_functions: HashMap>, /// Runtime environment associated with this task context - pub runtime: Arc, + runtime: Arc, } impl TaskContext { @@ -1323,13 +1361,17 @@ impl TaskContext { pub fn new( task_id: String, session_id: String, - task_settings: HashMap, + task_props: HashMap, + scalar_functions: HashMap>, + aggregate_functions: HashMap>, runtime: Arc, ) -> Self { Self { task_id: Some(task_id), session_id, - properties: TaskProperties::KVPairs(task_settings), + properties: TaskProperties::KVPairs(task_props), + scalar_functions, + aggregate_functions, runtime, } } @@ -1369,18 +1411,42 @@ impl TaskContext { TaskProperties::SessionConfig(session_config) => session_config.clone(), } } + + /// Return the session_id of this [TaskContext] + pub fn session_id(&self) -> String { + self.session_id.clone() + } + + /// Return the task_id of this [TaskContext] + pub fn task_id(&self) -> Option { + self.task_id.clone() + } + + /// Return the [RuntimeEnv] associated with this [TaskContext] + pub fn runtime_env(&self) -> Arc { + self.runtime.clone() + } } /// Create a new task context instance from SessionContext impl From<&SessionContext> for TaskContext { fn from(session: &SessionContext) -> Self { let session_id = session.session_id.clone(); - let config = session.state.read().config.clone(); + let (config, scalar_functions, aggregate_functions) = { + let session_state = session.state.read(); + ( + session_state.config.clone(), + session_state.scalar_functions.clone(), + session_state.aggregate_functions.clone(), + ) + }; let runtime = session.runtime_env(); Self { task_id: None, session_id, properties: TaskProperties::SessionConfig(config), + scalar_functions, + aggregate_functions, runtime, } } @@ -1391,16 +1457,48 @@ impl From<&SessionState> for TaskContext { fn from(state: &SessionState) -> Self { let session_id = state.session_id.clone(); let config = state.config.clone(); + let scalar_functions = state.scalar_functions.clone(); + let aggregate_functions = state.aggregate_functions.clone(); let runtime = state.runtime_env.clone(); Self { task_id: None, session_id, properties: TaskProperties::SessionConfig(config), + scalar_functions, + aggregate_functions, runtime, } } } +impl FunctionRegistry for TaskContext { + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } + + fn udf(&self, name: &str) -> Result> { + let result = self.scalar_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDF named \"{}\" in the TaskContext", + name + )) + }) + } + + fn udaf(&self, name: &str) -> Result> { + let result = self.aggregate_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDAF named \"{}\" in the TaskContext", + name + )) + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1479,7 +1577,7 @@ mod tests { ctx.register_table("dual", provider)?; let results = - plan_and_collect(&mut ctx, "SELECT @@version, @name, @integer + 1 FROM dual") + plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual") .await?; let expected = vec![ @@ -1498,7 +1596,7 @@ mod tests { async fn register_deregister() -> Result<()> { let tmp_dir = TempDir::new()?; let partition_count = 4; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; + let ctx = create_ctx(&tmp_dir, partition_count).await?; let provider = test::create_table_dual(); ctx.register_table("dual", provider)?; @@ -1757,11 +1855,11 @@ mod tests { #[tokio::test] async fn aggregate_decimal_min() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select min(c1) from d_table") + let result = plan_and_collect(&ctx, "select min(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -1781,12 +1879,12 @@ mod tests { #[tokio::test] async fn aggregate_decimal_max() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select max(c1) from d_table") + let result = plan_and_collect(&ctx, "select max(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -1806,11 +1904,11 @@ mod tests { #[tokio::test] async fn aggregate_decimal_sum() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table") + let result = plan_and_collect(&ctx, "select sum(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -1830,11 +1928,11 @@ mod tests { #[tokio::test] async fn aggregate_decimal_avg() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table") + let result = plan_and_collect(&ctx, "select avg(c1) from d_table") .await .unwrap(); let expected = vec![ @@ -2145,7 +2243,7 @@ mod tests { #[tokio::test] async fn group_by_date_trunc() -> Result<()> { let tmp_dir = TempDir::new()?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c2", DataType::UInt64, false), Field::new( @@ -2176,7 +2274,7 @@ mod tests { .await?; let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT date_trunc('week', t1) as week, SUM(c2) FROM test GROUP BY date_trunc('week', t1)", ).await?; @@ -2196,7 +2294,7 @@ mod tests { #[tokio::test] async fn group_by_largeutf8() { { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // input data looks like: // A, 1 @@ -2227,7 +2325,7 @@ mod tests { ctx.register_table("t", Arc::new(provider)).unwrap(); let results = - plan_and_collect(&mut ctx, "SELECT str, count(val) FROM t GROUP BY str") + plan_and_collect(&ctx, "SELECT str, count(val) FROM t GROUP BY str") .await .expect("ran plan correctly"); @@ -2246,7 +2344,7 @@ mod tests { #[tokio::test] async fn unprojected_filter() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let df = ctx .read_table(test::table_with_sequence(1, 3).unwrap()) .unwrap(); @@ -2271,7 +2369,7 @@ mod tests { #[tokio::test] async fn group_by_dictionary() { async fn run_test_case() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // input data looks like: // A, 1 @@ -2299,12 +2397,10 @@ mod tests { let provider = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); - let results = plan_and_collect( - &mut ctx, - "SELECT dict, count(val) FROM t GROUP BY dict", - ) - .await - .expect("ran plan correctly"); + let results = + plan_and_collect(&ctx, "SELECT dict, count(val) FROM t GROUP BY dict") + .await + .expect("ran plan correctly"); let expected = vec![ "+------+--------------+", @@ -2319,7 +2415,7 @@ mod tests { // Now, use dict as an aggregate let results = - plan_and_collect(&mut ctx, "SELECT val, count(dict) FROM t GROUP BY val") + plan_and_collect(&ctx, "SELECT val, count(dict) FROM t GROUP BY val") .await .expect("ran plan correctly"); @@ -2336,7 +2432,7 @@ mod tests { // Now, use dict as an aggregate let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT val, count(distinct dict) FROM t GROUP BY val", ) .await @@ -2368,7 +2464,7 @@ mod tests { partitions: Vec>, ) -> Result> { let tmp_dir = TempDir::new()?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c_group", DataType::Utf8, false), Field::new("c_int8", DataType::Int8, false), @@ -2407,7 +2503,7 @@ mod tests { .await?; let results = plan_and_collect( - &mut ctx, + &ctx, " SELECT c_group, @@ -2517,14 +2613,13 @@ mod tests { #[tokio::test] async fn limit() -> Result<()> { let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; + let ctx = create_ctx(&tmp_dir, 1).await?; ctx.register_table("t", test::table_with_sequence(1, 1000).unwrap()) .unwrap(); - let results = - plan_and_collect(&mut ctx, "SELECT i FROM t ORDER BY i DESC limit 3") - .await - .unwrap(); + let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i DESC limit 3") + .await + .unwrap(); let expected = vec![ "+------+", "| i |", "+------+", "| 1000 |", "| 999 |", "| 998 |", @@ -2533,7 +2628,7 @@ mod tests { assert_batches_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT i FROM t ORDER BY i limit 3") + let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i limit 3") .await .unwrap(); @@ -2543,7 +2638,7 @@ mod tests { assert_batches_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT i FROM t limit 3") + let results = plan_and_collect(&ctx, "SELECT i FROM t limit 3") .await .unwrap(); @@ -2557,7 +2652,7 @@ mod tests { #[tokio::test] async fn limit_multi_partitions() -> Result<()> { let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; + let ctx = create_ctx(&tmp_dir, 1).await?; let partitions = vec![ vec![test::make_partition(0)], @@ -2573,14 +2668,14 @@ mod tests { ctx.register_table("t", provider).unwrap(); // select all rows - let results = plan_and_collect(&mut ctx, "SELECT i FROM t").await.unwrap(); + let results = plan_and_collect(&ctx, "SELECT i FROM t").await.unwrap(); let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); assert_eq!(num_rows, 15); for limit in 1..10 { let query = format!("SELECT i FROM t limit {}", limit); - let results = plan_and_collect(&mut ctx, &query).await.unwrap(); + let results = plan_and_collect(&ctx, &query).await.unwrap(); let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); assert_eq!(num_rows, limit, "mismatch with query {}", query); @@ -2591,7 +2686,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_functions() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2603,19 +2698,19 @@ mod tests { "+-----------+", ]; - let results = plan_and_collect(&mut ctx, "SELECT sqrt(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT sqrt(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT SQRT(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT SQRT(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&mut ctx, "SELECT \"SQRT\"(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT \"SQRT\"(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2623,7 +2718,7 @@ mod tests { "Error during planning: Invalid function 'SQRT'" ); - let results = plan_and_collect(&mut ctx, "SELECT \"sqrt\"(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT \"sqrt\"(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); @@ -2631,7 +2726,7 @@ mod tests { #[tokio::test] async fn case_builtin_math_expression() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let type_values = vec![ ( @@ -2691,7 +2786,7 @@ mod tests { "| 1 |", "+-----------+", ]; - let results = plan_and_collect(&mut ctx, "SELECT sqrt(v) FROM t") + let results = plan_and_collect(&ctx, "SELECT sqrt(v) FROM t") .await .unwrap(); @@ -2717,7 +2812,7 @@ mod tests { )); // doesn't work as it was registered with non lowercase - let err = plan_and_collect(&mut ctx, "SELECT MY_FUNC(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2726,7 +2821,7 @@ mod tests { ); // Can call it if you put quotes - let result = plan_and_collect(&mut ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; let expected = vec![ "+--------------+", @@ -2742,7 +2837,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_aggregates() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2754,19 +2849,19 @@ mod tests { "+----------+", ]; - let results = plan_and_collect(&mut ctx, "SELECT max(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT max(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); - let results = plan_and_collect(&mut ctx, "SELECT MAX(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT MAX(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&mut ctx, "SELECT \"MAX\"(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2774,7 +2869,7 @@ mod tests { "Error during planning: Invalid function 'MAX'" ); - let results = plan_and_collect(&mut ctx, "SELECT \"max\"(i) FROM t") + let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &results); @@ -2799,7 +2894,7 @@ mod tests { ctx.register_udaf(my_avg); // doesn't work as it was registered as non lowercase - let err = plan_and_collect(&mut ctx, "SELECT MY_AVG(i) FROM t") + let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t") .await .unwrap_err(); assert_eq!( @@ -2808,7 +2903,7 @@ mod tests { ); // Can call it if you put quotes - let result = plan_and_collect(&mut ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; let expected = vec![ "+-------------+", @@ -2829,7 +2924,7 @@ mod tests { // The main stipulation of this test: use a file extension that isn't .csv. let file_extension = ".tst"; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?; ctx.register_csv( "test", @@ -2840,8 +2935,7 @@ mod tests { ) .await?; let results = - plan_and_collect(&mut ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test") - .await?; + plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?; assert_eq!(results.len(), 1); let expected = vec![ @@ -2885,7 +2979,7 @@ mod tests { #[tokio::test] async fn ctx_sql_should_optimize_plan() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let plan1 = ctx .create_logical_plan("SELECT * FROM (SELECT 1) AS one WHERE TRUE AND TRUE")?; @@ -2916,13 +3010,13 @@ mod tests { vec![Arc::new(Int32Array::from_slice(&[4, 5]))], )?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; - let result = plan_and_collect(&mut ctx, "SELECT AVG(a) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT AVG(a) FROM t").await?; let batch = &result[0]; assert_eq!(1, batch.num_columns()); @@ -2942,9 +3036,9 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); - let session_state = SessionState::with_config(SessionConfig::new(), runtime) + let session_state = SessionState::with_config_rt(SessionConfig::new(), runtime) .with_query_planner(Arc::new(MyQueryPlanner {})); - let mut ctx = SessionContext::with_state(session_state); + let ctx = SessionContext::with_state(session_state); let df = ctx.sql("SELECT 1").await?; df.collect().await.expect_err("query not supported"); @@ -2953,7 +3047,7 @@ mod tests { #[tokio::test] async fn disabled_default_catalog_and_schema() -> Result<()> { - let mut ctx = SessionContext::with_config( + let ctx = SessionContext::with_config( SessionConfig::new().create_default_catalog_and_schema(false), ); @@ -2996,8 +3090,7 @@ mod tests { } async fn catalog_and_schema_test(config: SessionConfig) { - let mut ctx = SessionContext::with_config(config); - + let ctx = SessionContext::with_config(config); let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema @@ -3010,7 +3103,7 @@ mod tests { for table_ref in &["my_catalog.my_schema.test", "my_schema.test", "test"] { let result = plan_and_collect( - &mut ctx, + &ctx, &format!("SELECT COUNT(*) AS count FROM {}", table_ref), ) .await @@ -3029,7 +3122,7 @@ mod tests { #[tokio::test] async fn cross_catalog_access() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let catalog_a = MemoryCatalogProvider::new(); let schema_a = MemorySchemaProvider::new(); @@ -3046,7 +3139,7 @@ mod tests { ctx.register_catalog("catalog_b", Arc::new(catalog_b)); let result = plan_and_collect( - &mut ctx, + &ctx, "SELECT cat, SUM(i) AS total FROM ( SELECT i, 'a' AS cat FROM catalog_a.schema_a.table_a UNION ALL @@ -3097,7 +3190,7 @@ mod tests { #[tokio::test] async fn sql_create_schema() -> Result<()> { // the information schema used to introduce cyclic Arcs - let mut ctx = SessionContext::with_config( + let ctx = SessionContext::with_config( SessionConfig::new().with_information_schema(true), ); @@ -3120,7 +3213,7 @@ mod tests { #[tokio::test] async fn normalized_column_identifiers() { // create local execution context - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // register csv file with the execution context ctx.register_csv( @@ -3132,7 +3225,7 @@ mod tests { .unwrap(); let sql = "SELECT A, b FROM case_insensitive_test"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3145,7 +3238,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = "SELECT t.A, b FROM case_insensitive_test AS t"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3160,7 +3253,7 @@ mod tests { // Aliases let sql = "SELECT t.A as x, b FROM case_insensitive_test AS t"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3173,7 +3266,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = "SELECT t.A AS X, b FROM case_insensitive_test AS t"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3186,7 +3279,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t"#; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3201,7 +3294,7 @@ mod tests { // Order by let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY x"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3214,7 +3307,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY X"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3227,7 +3320,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t ORDER BY "X""#; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3242,7 +3335,7 @@ mod tests { // Where let sql = "SELECT a, b FROM case_insensitive_test where A IS NOT null"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3257,7 +3350,7 @@ mod tests { // Group by let sql = "SELECT a as x, count(*) as c FROM case_insensitive_test GROUP BY X"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3271,7 +3364,7 @@ mod tests { let sql = r#"SELECT a as "X", count(*) as c FROM case_insensitive_test GROUP BY "X""#; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("ran plan correctly"); let expected = vec![ @@ -3327,7 +3420,7 @@ mod tests { /// Execute SQL and return results async fn plan_and_collect( - ctx: &mut SessionContext, + ctx: &SessionContext, sql: &str, ) -> Result> { ctx.sql(sql).await?.collect().await @@ -3336,8 +3429,8 @@ mod tests { /// Execute SQL and return results async fn execute(sql: &str, partition_count: usize) -> Result> { let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; - plan_and_collect(&mut ctx, sql).await + let ctx = create_ctx(&tmp_dir, partition_count).await?; + plan_and_collect(&ctx, sql).await } /// Generate CSV partitions within the supplied directory @@ -3374,7 +3467,7 @@ mod tests { tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; @@ -3404,19 +3497,19 @@ mod tests { #[async_trait] impl CallReadTrait for CallRead { async fn call_read_csv(&self) -> Arc { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() } async fn call_read_avro(&self) -> Arc { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.read_avro("dummy", AvroReadOptions::default()) .await .unwrap() } async fn call_read_parquet(&self) -> Arc { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.read_parquet("dummy").await.unwrap() } } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 9f3915c63185..2c5f9c570f2f 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -34,7 +34,7 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<()> { -//! let mut ctx = SessionContext::new(); +//! let ctx = SessionContext::new(); //! //! // create the dataframe //! let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; @@ -73,7 +73,7 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<()> { -//! let mut ctx = SessionContext::new(); +//! let ctx = SessionContext::new(); //! //! ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; //! diff --git a/datafusion/src/physical_plan/file_format/csv.rs b/datafusion/src/physical_plan/file_format/csv.rs index 28466b8aace2..8a5356b80b67 100644 --- a/datafusion/src/physical_plan/file_format/csv.rs +++ b/datafusion/src/physical_plan/file_format/csv.rs @@ -459,7 +459,7 @@ mod tests { async fn write_csv_results() -> Result<()> { // create partitioned input file and context let tmp_dir = TempDir::new()?; - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?; @@ -478,7 +478,7 @@ mod tests { df.write_csv(&out_dir).await?; // create a new context and verify that the results were saved to a partitioned csv file - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::UInt32, false), diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index 4eb2fda57821..afcdde4cfb84 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -353,7 +353,7 @@ mod tests { async fn write_json_results() -> Result<()> { // create partitioned input file and context let tmp_dir = TempDir::new()?; - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let path = format!("{}/1.json", TEST_DATA_BASE); @@ -368,7 +368,7 @@ mod tests { df.write_json(&out_dir).await?; // create a new context and verify that the results were saved to a partitioned csv file - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // register each partition as well as the top level dir let json_read_option = NdJsonReadOptions::default(); diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index c77e7d8a27c0..7d797477df08 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -1308,7 +1308,7 @@ mod tests { // create partitioned input file and context let tmp_dir = TempDir::new()?; // let mut ctx = create_ctx(&tmp_dir, 4).await?; - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; // register csv file with the execution context @@ -1326,7 +1326,7 @@ mod tests { // write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?; // create a new context and verify that the results were saved to a partitioned csv file - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // register each partition as well as the top level dir ctx.register_parquet("part0", &format!("{}/part-0.parquet", out_dir)) diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 1189f23d921c..36b984043b7e 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -1469,7 +1469,7 @@ mod tests { fn make_session_state() -> SessionState { let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); - SessionState::with_config(SessionConfig::new(), runtime) + SessionState::with_config_rt(SessionConfig::new(), runtime) } async fn plan(logical_plan: &LogicalPlan) -> Result> { diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index 6ac4a47c86ae..67b9085bb4a7 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -582,9 +582,9 @@ async fn do_sort( expr, metrics_set, Arc::new(context.session_config()), - context.runtime.clone(), + context.runtime_env(), ); - context.runtime.register_requester(sorter.id()); + context.runtime_env().register_requester(sorter.id()); while let Some(batch) = input.next().await { let batch = batch?; sorter.insert_batch(batch).await?; diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index c2ff27d4f048..957bcf50433e 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -210,7 +210,7 @@ impl TableProvider for CustomTableProvider { #[tokio::test] async fn custom_source_dataframe() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table = ctx.read_table(Arc::new(CustomTableProvider))?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -257,7 +257,7 @@ async fn custom_source_dataframe() -> Result<()> { #[tokio::test] async fn optimizers_catch_all_statistics() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(CustomTableProvider)) .unwrap(); diff --git a/datafusion/tests/dataframe.rs b/datafusion/tests/dataframe.rs index 3f106be73b4b..7b89909d400d 100644 --- a/datafusion/tests/dataframe.rs +++ b/datafusion/tests/dataframe.rs @@ -58,7 +58,7 @@ async fn join() -> Result<()> { ], )?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table1 = MemTable::try_new(schema1, vec![vec![batch1]])?; let table2 = MemTable::try_new(schema2, vec![vec![batch2]])?; @@ -96,7 +96,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { ) .unwrap(); - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); @@ -132,7 +132,7 @@ async fn filter_with_alias_overwrite() -> Result<()> { ) .unwrap(); - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index 8b7c5d894416..10cbe98ead25 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -55,7 +55,7 @@ fn create_test_table() -> Result> { ], )?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 83d51e071aaa..4301e6556b57 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -480,7 +480,7 @@ impl ContextWithParquet { let parquet_path = file.path().to_string_lossy(); // now, setup a the file as a data source and run a query against it - let mut ctx = SessionContext::with_config(config); + let ctx = SessionContext::with_config(config); ctx.register_parquet("t", &parquet_path).await.unwrap(); let provider = ctx.deregister_table("t").unwrap().unwrap(); diff --git a/datafusion/tests/path_partition.rs b/datafusion/tests/path_partition.rs index 8475ea5b43a3..351b85d3d50e 100644 --- a/datafusion/tests/path_partition.rs +++ b/datafusion/tests/path_partition.rs @@ -42,10 +42,10 @@ use futures::{stream, StreamExt}; #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); register_partitioned_aggregate_csv( - &mut ctx, + &ctx, &[ "mytable/date=2021-10-27/file.csv", "mytable/date=2021-10-28/file.csv", @@ -78,10 +78,10 @@ async fn csv_filter_with_file_col() -> Result<()> { #[tokio::test] async fn csv_projection_on_partition() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); register_partitioned_aggregate_csv( - &mut ctx, + &ctx, &[ "mytable/date=2021-10-27/file.csv", "mytable/date=2021-10-28/file.csv", @@ -114,10 +114,10 @@ async fn csv_projection_on_partition() -> Result<()> { #[tokio::test] async fn csv_grouping_by_partition() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); register_partitioned_aggregate_csv( - &mut ctx, + &ctx, &[ "mytable/date=2021-10-26/file.csv", "mytable/date=2021-10-27/file.csv", @@ -148,10 +148,10 @@ async fn csv_grouping_by_partition() -> Result<()> { #[tokio::test] async fn parquet_multiple_partitions() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); register_partitioned_alltypes_parquet( - &mut ctx, + &ctx, &[ "year=2021/month=09/day=09/file.parquet", "year=2021/month=10/day=09/file.parquet", @@ -190,10 +190,10 @@ async fn parquet_multiple_partitions() -> Result<()> { #[tokio::test] async fn parquet_statistics() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); register_partitioned_alltypes_parquet( - &mut ctx, + &ctx, &[ "year=2021/month=09/day=09/file.parquet", "year=2021/month=10/day=09/file.parquet", @@ -249,11 +249,11 @@ async fn parquet_statistics() -> Result<()> { #[tokio::test] async fn parquet_overlapping_columns() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // `id` is both a column of the file and a partitioning col register_partitioned_alltypes_parquet( - &mut ctx, + &ctx, &[ "id=1/file.parquet", "id=2/file.parquet", @@ -275,7 +275,7 @@ async fn parquet_overlapping_columns() -> Result<()> { } fn register_partitioned_aggregate_csv( - ctx: &mut SessionContext, + ctx: &SessionContext, store_paths: &[&str], partition_cols: &[&str], table_path: &str, @@ -298,7 +298,7 @@ fn register_partitioned_aggregate_csv( } async fn register_partitioned_alltypes_parquet( - ctx: &mut SessionContext, + ctx: &SessionContext, store_paths: &[&str], partition_cols: &[&str], table_path: &str, diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index 62d1cac51c09..e722d4ceeeb6 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -173,7 +173,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() one_batch: create_batch(1, 5)?, }; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let df = ctx .read_table(Arc::new(provider.clone()))? .filter(col("flag").eq(lit(value)))? diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index c3553a9cae61..272bb82d6269 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -20,8 +20,8 @@ use datafusion::scalar::ScalarValue; #[tokio::test] async fn csv_query_avg_multi_batch() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(c12) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); @@ -41,10 +41,10 @@ async fn csv_query_avg_multi_batch() -> Result<()> { #[tokio::test] async fn csv_query_avg() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.5089725099127211"]]; assert_float_eq(&expected, &actual); @@ -53,10 +53,10 @@ async fn csv_query_avg() -> Result<()> { #[tokio::test] async fn csv_query_covariance_1() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT covar_pop(c2, c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["-0.07916932235380847"]]; assert_float_eq(&expected, &actual); @@ -65,10 +65,10 @@ async fn csv_query_covariance_1() -> Result<()> { #[tokio::test] async fn csv_query_covariance_2() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT covar(c2, c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["-0.07996901247859442"]]; assert_float_eq(&expected, &actual); @@ -77,10 +77,10 @@ async fn csv_query_covariance_2() -> Result<()> { #[tokio::test] async fn csv_query_correlation() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT corr(c2, c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["-0.19064544190576607"]]; assert_float_eq(&expected, &actual); @@ -89,10 +89,10 @@ async fn csv_query_correlation() -> Result<()> { #[tokio::test] async fn csv_query_variance_1() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_pop(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.8675"]]; assert_float_eq(&expected, &actual); @@ -101,10 +101,10 @@ async fn csv_query_variance_1() -> Result<()> { #[tokio::test] async fn csv_query_variance_2() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_pop(c6) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["26156334342021890000000000000000000000"]]; assert_float_eq(&expected, &actual); @@ -113,10 +113,10 @@ async fn csv_query_variance_2() -> Result<()> { #[tokio::test] async fn csv_query_variance_3() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_pop(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.09234223721582163"]]; assert_float_eq(&expected, &actual); @@ -125,10 +125,10 @@ async fn csv_query_variance_3() -> Result<()> { #[tokio::test] async fn csv_query_variance_4() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.8863636363636365"]]; assert_float_eq(&expected, &actual); @@ -137,10 +137,10 @@ async fn csv_query_variance_4() -> Result<()> { #[tokio::test] async fn csv_query_variance_5() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT var_samp(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.8863636363636365"]]; assert_float_eq(&expected, &actual); @@ -149,10 +149,10 @@ async fn csv_query_variance_5() -> Result<()> { #[tokio::test] async fn csv_query_stddev_1() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_pop(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["1.3665650368716449"]]; assert_float_eq(&expected, &actual); @@ -161,10 +161,10 @@ async fn csv_query_stddev_1() -> Result<()> { #[tokio::test] async fn csv_query_stddev_2() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_pop(c6) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["5114326382039172000"]]; assert_float_eq(&expected, &actual); @@ -173,10 +173,10 @@ async fn csv_query_stddev_2() -> Result<()> { #[tokio::test] async fn csv_query_stddev_3() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_pop(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.30387865541334363"]]; assert_float_eq(&expected, &actual); @@ -185,10 +185,10 @@ async fn csv_query_stddev_3() -> Result<()> { #[tokio::test] async fn csv_query_stddev_4() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.3054095399405338"]]; assert_float_eq(&expected, &actual); @@ -197,10 +197,10 @@ async fn csv_query_stddev_4() -> Result<()> { #[tokio::test] async fn csv_query_stddev_5() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT stddev_samp(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.3054095399405338"]]; assert_float_eq(&expected, &actual); @@ -209,10 +209,10 @@ async fn csv_query_stddev_5() -> Result<()> { #[tokio::test] async fn csv_query_stddev_6() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.9504384952922168"]]; assert_float_eq(&expected, &actual); @@ -221,10 +221,10 @@ async fn csv_query_stddev_6() -> Result<()> { #[tokio::test] async fn csv_query_median_1() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["3"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -232,10 +232,10 @@ async fn csv_query_median_1() -> Result<()> { #[tokio::test] async fn csv_query_median_2() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["1146409980542786560"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -243,10 +243,10 @@ async fn csv_query_median_2() -> Result<()> { #[tokio::test] async fn csv_query_median_3() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.5550065410522981"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -254,8 +254,8 @@ async fn csv_query_median_3() -> Result<()> { #[tokio::test] async fn csv_query_external_table_count() { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(c12) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -271,9 +271,9 @@ async fn csv_query_external_table_count() { #[tokio::test] async fn csv_query_external_table_sum() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // cast smallint and int to bigint to avoid overflow during calculation - register_aggregate_csv_by_sql(&mut ctx).await; + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; @@ -289,8 +289,8 @@ async fn csv_query_external_table_sum() { #[tokio::test] async fn csv_query_count() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT count(c12) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -306,8 +306,8 @@ async fn csv_query_count() -> Result<()> { #[tokio::test] async fn csv_query_count_distinct() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT count(distinct c2) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -323,8 +323,8 @@ async fn csv_query_count_distinct() -> Result<()> { #[tokio::test] async fn csv_query_count_distinct_expr() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT count(distinct c2 % 2) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -340,8 +340,8 @@ async fn csv_query_count_distinct_expr() -> Result<()> { #[tokio::test] async fn csv_query_count_star() { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(*) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -356,8 +356,8 @@ async fn csv_query_count_star() { #[tokio::test] async fn csv_query_count_one() { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(1) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -372,8 +372,8 @@ async fn csv_query_count_one() { #[tokio::test] async fn csv_query_approx_count() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -412,8 +412,8 @@ async fn csv_query_approx_count() -> Result<()> { // float values. #[tokio::test] async fn csv_query_approx_percentile_cont() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; // Generate an assertion that the estimated $percentile value for $column is // within 5% of the $actual percentile value. @@ -478,8 +478,8 @@ async fn csv_query_approx_percentile_cont() -> Result<()> { #[tokio::test] async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; // compare approx_percentile_cont and approx_percentile_cont_with_weight let sql = "SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1"; @@ -517,7 +517,7 @@ async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> { assert_batches_eq!(expected, &actual); let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100", ) .await @@ -525,7 +525,7 @@ async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> { assert_eq!(results.to_string(), "Error during planning: The function ApproxPercentileContWithWeight does not support inputs of type Utf8."); let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100", ) .await @@ -533,7 +533,7 @@ async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> { assert_eq!(results.to_string(), "Error during planning: The weight argument for ApproxPercentileContWithWeight does not support inputs of type Utf8."); let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100", ) .await @@ -545,8 +545,8 @@ async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> { #[tokio::test] async fn csv_query_sum_crossjoin() { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -601,8 +601,8 @@ async fn query_count_without_from() -> Result<()> { #[tokio::test] async fn csv_query_array_agg() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test"; let actual = execute_to_batches(&ctx, sql).await; @@ -619,8 +619,8 @@ async fn csv_query_array_agg() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_empty() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test"; let actual = execute_to_batches(&ctx, sql).await; @@ -637,8 +637,8 @@ async fn csv_query_array_agg_empty() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_one() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test"; let actual = execute_to_batches(&ctx, sql).await; @@ -655,8 +655,8 @@ async fn csv_query_array_agg_one() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT array_agg(distinct c2) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; @@ -705,11 +705,11 @@ async fn csv_query_array_agg_distinct() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_sum() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", ) .await @@ -722,7 +722,7 @@ async fn aggregate_timestamps_sum() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_count() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( @@ -745,7 +745,7 @@ async fn aggregate_timestamps_count() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_min() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( @@ -768,7 +768,7 @@ async fn aggregate_timestamps_min() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_max() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( @@ -791,11 +791,11 @@ async fn aggregate_timestamps_max() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_avg() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = plan_and_collect( - &mut ctx, + &ctx, "SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t", ) .await diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs index 5289efa981c7..b5ea5477ead2 100644 --- a/datafusion/tests/sql/avro.rs +++ b/datafusion/tests/sql/avro.rs @@ -17,7 +17,7 @@ use super::*; -async fn register_alltypes_avro(ctx: &mut SessionContext) { +async fn register_alltypes_avro(ctx: &SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); ctx.register_avro( "alltypes_plain", @@ -30,8 +30,8 @@ async fn register_alltypes_avro(ctx: &mut SessionContext) { #[tokio::test] async fn avro_query() { - let mut ctx = SessionContext::new(); - register_alltypes_avro(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_avro(&ctx).await; // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; @@ -71,7 +71,7 @@ async fn avro_query_multiple_files() { ) .unwrap(); - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_avro( "alltypes_plain", table_path.display().to_string().as_str(), @@ -111,7 +111,7 @@ async fn avro_query_multiple_files() { #[tokio::test] async fn avro_single_nan_schema() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::arrow_test_data(); ctx.register_avro( "single_nan", @@ -134,11 +134,11 @@ async fn avro_single_nan_schema() { #[tokio::test] async fn avro_explain() { - let mut ctx = SessionContext::new(); - register_alltypes_avro(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_avro(&ctx).await; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); let expected = vec![ vec![ diff --git a/datafusion/tests/sql/create_drop.rs b/datafusion/tests/sql/create_drop.rs index ad786610af4a..748bd7ae960e 100644 --- a/datafusion/tests/sql/create_drop.rs +++ b/datafusion/tests/sql/create_drop.rs @@ -23,8 +23,8 @@ use super::*; #[tokio::test] async fn create_table_as() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; ctx.sql(sql).await.unwrap(); @@ -47,8 +47,8 @@ async fn create_table_as() -> Result<()> { #[tokio::test] async fn drop_table() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; ctx.sql(sql).await.unwrap(); @@ -67,8 +67,8 @@ async fn drop_table() -> Result<()> { #[tokio::test] async fn csv_query_create_external_table() { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -83,7 +83,7 @@ async fn csv_query_create_external_table() { #[tokio::test] async fn create_external_table_with_timestamps() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let data = "Jorge,2018-12-13T12:12:10.011Z\n\ Andrew,2018-11-13T17:11:10.011Z"; @@ -110,12 +110,12 @@ async fn create_external_table_with_timestamps() { file_path.to_str().expect("path is utf8") ); - plan_and_collect(&mut ctx, &sql) + plan_and_collect(&ctx, &sql) .await .expect("Executing CREATE EXTERNAL TABLE"); let sql = "SELECT * from csv_with_timestamps"; - let result = plan_and_collect(&mut ctx, sql).await.unwrap(); + let result = plan_and_collect(&ctx, sql).await.unwrap(); let expected = vec![ "+--------+-------------------------+", "| name | ts |", diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs index ef4bb6ee34ea..7680c3e8f1df 100644 --- a/datafusion/tests/sql/errors.rs +++ b/datafusion/tests/sql/errors.rs @@ -20,8 +20,8 @@ use super::*; #[tokio::test] async fn csv_query_error() -> Result<()> { // sin(utf8) should error - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT sin(c1) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(sql); assert!(plan.is_err()); @@ -31,8 +31,8 @@ async fn csv_query_error() -> Result<()> { #[tokio::test] async fn test_cast_expressions_error() -> Result<()> { // sin(utf8) should error - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT CAST(c1 AS INT) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); @@ -54,8 +54,8 @@ async fn test_cast_expressions_error() -> Result<()> { #[tokio::test] async fn test_aggregation_with_bad_arguments() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; let logical_plan = ctx.create_logical_plan(sql); let err = logical_plan.unwrap_err(); @@ -105,7 +105,7 @@ async fn query_cte_incorrect() -> Result<()> { #[tokio::test] async fn test_select_wildcard_without_table() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * "; let actual = ctx.sql(sql).await; match actual { @@ -122,8 +122,8 @@ async fn test_select_wildcard_without_table() -> Result<()> { #[tokio::test] async fn invalid_qualified_table_references() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; for table_ref in &[ "nonexistentschema.aggregate_test_100", diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index 6ed8507ec89f..ef6ea52dd774 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -22,8 +22,8 @@ async fn explain_analyze_baseline_metrics() { // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE // and then validate the presence of baseline metrics for supported operators let config = SessionConfig::new().with_target_partitions(3); - let mut ctx = SessionContext::with_config(config); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::with_config(config); + register_aggregate_csv_by_sql(&ctx).await; // a query with as many operators as we have metrics for let sql = "EXPLAIN ANALYZE \ SELECT count(*) as cnt FROM \ @@ -168,8 +168,8 @@ async fn explain_analyze_baseline_metrics() { async fn csv_explain_plans() { // This test verify the look of each plan in its full cycle plan creation - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; // Logical plan @@ -342,10 +342,10 @@ async fn csv_explain_plans() { #[tokio::test] async fn csv_explain_verbose() { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -365,8 +365,8 @@ async fn csv_explain_verbose() { async fn csv_explain_verbose_plans() { // This test verify the look of each plan in its full cycle plan creation - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; // Logical plan @@ -545,8 +545,8 @@ async fn csv_explain_verbose_plans() { async fn explain_analyze_runs_optimizers() { // repro for https://github.com/apache/arrow-datafusion/issues/917 // where EXPLAIN ANALYZE was not correctly running optiimizer - let mut ctx = SessionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // This happens as an optimization pass where count(*) can be // answered using statistics only. @@ -570,12 +570,12 @@ async fn explain_analyze_runs_optimizers() { #[tokio::test] async fn tpch_explain_q10() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); - register_tpch_csv(&mut ctx, "customer").await?; - register_tpch_csv(&mut ctx, "orders").await?; - register_tpch_csv(&mut ctx, "lineitem").await?; - register_tpch_csv(&mut ctx, "nation").await?; + register_tpch_csv(&ctx, "customer").await?; + register_tpch_csv(&ctx, "orders").await?; + register_tpch_csv(&ctx, "lineitem").await?; + register_tpch_csv(&ctx, "nation").await?; let sql = "select c_custkey, @@ -634,8 +634,8 @@ order by async fn test_physical_plan_display_indent() { // Hard code target_partitions as it appears in the RepartitionExec output let config = SessionConfig::new().with_target_partitions(3); - let mut ctx = SessionContext::with_config(config); - register_aggregate_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1, MAX(c12), MIN(c12) as the_min \ FROM aggregate_test_100 \ WHERE c12 < 10 \ @@ -680,9 +680,9 @@ async fn test_physical_plan_display_indent() { async fn test_physical_plan_display_indent_multi_children() { // Hard code target_partitions as it appears in the RepartitionExec output let config = SessionConfig::new().with_target_partitions(3); - let mut ctx = SessionContext::with_config(config); + let ctx = SessionContext::with_config(config); // ensure indenting works for nodes with multiple children - register_aggregate_csv(&mut ctx).await.unwrap(); + register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1 \ FROM (select c1 from aggregate_test_100) AS a \ JOIN\ @@ -731,10 +731,10 @@ async fn test_physical_plan_display_indent_multi_children() { async fn csv_explain() { // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, // then execute the physical plan and return the final explain results - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); // Note can't use `assert_batches_eq` as the plan needs to be @@ -758,7 +758,7 @@ async fn csv_explain() { // Also, expect same result with lowercase explain let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); assert_eq!(expected, actual); } @@ -766,8 +766,8 @@ async fn csv_explain() { #[tokio::test] async fn csv_explain_analyze() { // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN ANALYZE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; let actual = execute_to_batches(&ctx, sql).await; let formatted = arrow::util::pretty::pretty_format_batches(&actual) @@ -787,8 +787,8 @@ async fn csv_explain_analyze() { #[tokio::test] async fn csv_explain_analyze_verbose() { // This test uses the execute function to run an actual plan under EXPLAIN VERBOSE ANALYZE - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "EXPLAIN ANALYZE VERBOSE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/tests/sql/expr.rs b/datafusion/tests/sql/expr.rs index d600c7b4d733..ae9ce4be9159 100644 --- a/datafusion/tests/sql/expr.rs +++ b/datafusion/tests/sql/expr.rs @@ -124,7 +124,7 @@ async fn query_not() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT NOT c1 FROM test"; let actual = execute_to_batches(&ctx, sql).await; @@ -143,12 +143,12 @@ async fn query_not() -> Result<()> { #[tokio::test] async fn csv_query_sum_cast() { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; // c8 = i32; c9 = i64 let sql = "SELECT c8 + c9 FROM aggregate_test_100"; // check that the physical and logical schemas are equal - execute(&mut ctx, sql).await; + execute(&ctx, sql).await; } #[tokio::test] @@ -166,7 +166,7 @@ async fn query_is_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NULL FROM test"; let actual = execute_to_batches(&ctx, sql).await; @@ -198,7 +198,7 @@ async fn query_is_not_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NOT NULL FROM test"; let actual = execute_to_batches(&ctx, sql).await; @@ -262,7 +262,7 @@ async fn query_scalar_minus_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT 4 - c1 FROM test"; let actual = execute_to_batches(&ctx, sql).await; @@ -658,9 +658,9 @@ async fn test_cast_expressions() -> Result<()> { #[tokio::test] async fn test_random_expression() -> Result<()> { - let mut ctx = create_ctx()?; + let ctx = create_ctx()?; let sql = "SELECT random() r1"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let r1 = actual[0][0].parse::().unwrap(); assert!(0.0 <= r1); assert!(r1 < 1.0); @@ -685,8 +685,8 @@ async fn case_with_bool_type_result() -> Result<()> { #[tokio::test] async fn in_list_array() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT c1 IN ('a', 'c') AS utf8_in_true ,c1 IN ('x', 'y') AS utf8_in_false @@ -810,8 +810,8 @@ async fn test_in_list_scalar() -> Result<()> { #[tokio::test] async fn csv_query_boolean_eq_neq() { - let mut ctx = SessionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for eq and neq let sql = "SELECT a, b, a = b as eq, b = true as eq_scalar, a != b as neq, a != true as neq_scalar FROM t1"; let actual = execute_to_batches(&ctx, sql).await; @@ -836,8 +836,8 @@ async fn csv_query_boolean_eq_neq() { #[tokio::test] async fn csv_query_boolean_lt_lt_eq() { - let mut ctx = SessionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for < and <= let sql = "SELECT a, b, a < b as lt, b = true as lt_scalar, a <= b as lt_eq, a <= true as lt_eq_scalar FROM t1"; let actual = execute_to_batches(&ctx, sql).await; @@ -862,8 +862,8 @@ async fn csv_query_boolean_lt_lt_eq() { #[tokio::test] async fn csv_query_boolean_gt_gt_eq() { - let mut ctx = SessionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for > and >= let sql = "SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as gt_eq, a >= true as gt_eq_scalar FROM t1"; let actual = execute_to_batches(&ctx, sql).await; @@ -888,8 +888,8 @@ async fn csv_query_boolean_gt_gt_eq() { #[tokio::test] async fn csv_query_boolean_distinct_from() { - let mut ctx = SessionContext::new(); - register_boolean(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_boolean(&ctx).await.unwrap(); // verify the plumbing is all hooked up for is distinct from and is not distinct from let sql = "SELECT a, b, \ a is distinct from b as df, \ @@ -919,10 +919,10 @@ async fn csv_query_boolean_distinct_from() { #[tokio::test] async fn csv_query_nullif_divide_by_0() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c8/nullif(c7, 0) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = &actual[80..90]; // We just want to compare rows 80-89 let expected = vec![ vec!["258"], @@ -941,8 +941,8 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> { } #[tokio::test] async fn csv_count_star() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(*), COUNT(1) AS c, COUNT(c1) FROM aggregate_test_100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -958,10 +958,10 @@ async fn csv_count_star() -> Result<()> { #[tokio::test] async fn csv_query_avg_sqrt() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let expected = vec![vec!["0.6706002946036462"]]; assert_float_eq(&expected, &actual); @@ -971,10 +971,10 @@ async fn csv_query_avg_sqrt() -> Result<()> { // this query used to deadlock due to the call udf(udf()) #[tokio::test] async fn csv_query_sqrt_sqrt() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT sqrt(sqrt(c12)) FROM aggregate_test_100 LIMIT 1"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; // sqrt(sqrt(c12=0.9294097332465232)) = 0.9818650561397431 let expected = vec![vec!["0.9818650561397431"]]; assert_float_eq(&expected, &actual); diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs index 0e37b3c5cabc..3b6150da13a8 100644 --- a/datafusion/tests/sql/functions.rs +++ b/datafusion/tests/sql/functions.rs @@ -20,16 +20,16 @@ use super::*; /// sqrt(f32) is slightly different than sqrt(CAST(f32 AS double))) #[tokio::test] async fn sqrt_f32_vs_f64() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; // sqrt(f32)'s plan passes let sql = "SELECT avg(sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.6584407806396484"]]; assert_eq!(actual, expected); let sql = "SELECT avg(sqrt(CAST(c11 AS double))) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.6584408483418833"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -37,8 +37,8 @@ async fn sqrt_f32_vs_f64() -> Result<()> { #[tokio::test] async fn csv_query_cast() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; let actual = execute_to_batches(&ctx, sql).await; @@ -57,8 +57,8 @@ async fn csv_query_cast() -> Result<()> { #[tokio::test] async fn csv_query_cast_literal() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c12, CAST(1 AS float) FROM aggregate_test_100 WHERE c12 > CAST(0 AS float) LIMIT 2"; let actual = execute_to_batches(&ctx, sql).await; @@ -93,7 +93,7 @@ async fn query_concat() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; let actual = execute_to_batches(&ctx, sql).await; @@ -129,10 +129,10 @@ async fn query_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![ vec!["[,0]"], vec!["[a,1]"], @@ -160,7 +160,7 @@ async fn query_count_distinct() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(DISTINCT c1) FROM test"; let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/tests/sql/group_by.rs b/datafusion/tests/sql/group_by.rs index 5430ca9d09f0..78323310a350 100644 --- a/datafusion/tests/sql/group_by.rs +++ b/datafusion/tests/sql/group_by.rs @@ -19,8 +19,8 @@ use super::*; #[tokio::test] async fn csv_query_group_by_int_min_max() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -40,8 +40,8 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { #[tokio::test] async fn csv_query_group_by_float32() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT COUNT(*) as cnt, c1 FROM aggregate_simple GROUP BY c1 ORDER BY cnt DESC"; @@ -65,8 +65,8 @@ async fn csv_query_group_by_float32() -> Result<()> { #[tokio::test] async fn csv_query_group_by_float64() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT COUNT(*) as cnt, c2 FROM aggregate_simple GROUP BY c2 ORDER BY cnt DESC"; @@ -90,8 +90,8 @@ async fn csv_query_group_by_float64() -> Result<()> { #[tokio::test] async fn csv_query_group_by_boolean() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT COUNT(*) as cnt, c3 FROM aggregate_simple GROUP BY c3 ORDER BY cnt DESC"; @@ -112,8 +112,8 @@ async fn csv_query_group_by_boolean() -> Result<()> { #[tokio::test] async fn csv_query_group_by_two_columns() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c2, MIN(c3) FROM aggregate_test_100 GROUP BY c1, c2"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -153,8 +153,8 @@ async fn csv_query_group_by_two_columns() -> Result<()> { #[tokio::test] async fn csv_query_group_by_and_having() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 GROUP BY c1 HAVING m < -100 AND MAX(c3) > 70"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -171,8 +171,8 @@ async fn csv_query_group_by_and_having() -> Result<()> { #[tokio::test] async fn csv_query_group_by_and_having_and_where() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 WHERE c1 IN ('a', 'b') @@ -192,8 +192,8 @@ async fn csv_query_group_by_and_having_and_where() -> Result<()> { #[tokio::test] async fn csv_query_having_without_group_by() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c2, c3 FROM aggregate_test_100 HAVING c2 >= 4 AND c3 > 90"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -213,8 +213,8 @@ async fn csv_query_having_without_group_by() -> Result<()> { #[tokio::test] async fn csv_query_group_by_avg() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -234,8 +234,8 @@ async fn csv_query_group_by_avg() -> Result<()> { #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, count(c12) FROM aggregate_test_100 GROUP BY c1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -255,8 +255,8 @@ async fn csv_query_group_by_int_count() -> Result<()> { #[tokio::test] async fn csv_query_group_with_aliased_aggregate() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -276,8 +276,8 @@ async fn csv_query_group_with_aliased_aggregate() -> Result<()> { #[tokio::test] async fn csv_query_group_by_string_min_max() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -312,7 +312,7 @@ async fn query_group_on_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1"; @@ -371,7 +371,7 @@ async fn query_group_on_null_multi_col() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2"; @@ -400,7 +400,7 @@ async fn query_group_on_null_multi_col() -> Result<()> { #[tokio::test] async fn csv_group_by_date() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("date", DataType::Date32, false), Field::new("cnt", DataType::Int32, false), diff --git a/datafusion/tests/sql/information_schema.rs b/datafusion/tests/sql/information_schema.rs index 57ef9ba3b65d..1317488e703a 100644 --- a/datafusion/tests/sql/information_schema.rs +++ b/datafusion/tests/sql/information_schema.rs @@ -29,9 +29,9 @@ use super::*; #[tokio::test] async fn information_schema_tables_not_exist_by_default() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let err = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap_err(); assert_eq!( @@ -42,10 +42,10 @@ async fn information_schema_tables_not_exist_by_default() { #[tokio::test] async fn information_schema_tables_no_tables() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -62,14 +62,14 @@ async fn information_schema_tables_no_tables() { #[tokio::test] async fn information_schema_tables_tables_default_catalog() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); // Now, register an empty table ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -88,7 +88,7 @@ async fn information_schema_tables_tables_default_catalog() { ctx.register_table("t2", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -107,7 +107,7 @@ async fn information_schema_tables_tables_default_catalog() { #[tokio::test] async fn information_schema_tables_tables_with_multiple_catalogs() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -132,7 +132,7 @@ async fn information_schema_tables_tables_with_multiple_catalogs() { .unwrap(); ctx.register_catalog("my_other_catalog", Arc::new(catalog)); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -182,7 +182,7 @@ async fn information_schema_tables_table_types() { } } - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) @@ -192,7 +192,7 @@ async fn information_schema_tables_table_types() { ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) .unwrap(); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") .await .unwrap(); @@ -212,27 +212,27 @@ async fn information_schema_tables_table_types() { #[tokio::test] async fn information_schema_show_tables_no_information_schema() { - let mut ctx = SessionContext::with_config(SessionConfig::new()); + let ctx = SessionContext::with_config(SessionConfig::new()); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias - let err = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap_err(); + let err = plan_and_collect(&ctx, "SHOW TABLES").await.unwrap_err(); assert_eq!(err.to_string(), "Error during planning: SHOW TABLES is not supported unless information_schema is enabled"); } #[tokio::test] async fn information_schema_show_tables() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias - let result = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap(); + let result = plan_and_collect(&ctx, "SHOW TABLES").await.unwrap(); let expected = vec![ "+---------------+--------------------+------------+------------+", @@ -245,19 +245,19 @@ async fn information_schema_show_tables() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW tables").await.unwrap(); + let result = plan_and_collect(&ctx, "SHOW tables").await.unwrap(); assert_batches_sorted_eq!(expected, &result); } #[tokio::test] async fn information_schema_show_columns_no_information_schema() { - let mut ctx = SessionContext::with_config(SessionConfig::new()); + let ctx = SessionContext::with_config(SessionConfig::new()); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") + let err = plan_and_collect(&ctx, "SHOW COLUMNS FROM t") .await .unwrap_err(); @@ -266,7 +266,7 @@ async fn information_schema_show_columns_no_information_schema() { #[tokio::test] async fn information_schema_show_columns_like_where() { - let mut ctx = SessionContext::with_config(SessionConfig::new()); + let ctx = SessionContext::with_config(SessionConfig::new()); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -274,12 +274,12 @@ async fn information_schema_show_columns_like_where() { let expected = "Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported"; - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t LIKE 'f'") + let err = plan_and_collect(&ctx, "SHOW COLUMNS FROM t LIKE 'f'") .await .unwrap_err(); assert_eq!(err.to_string(), expected); - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") + let err = plan_and_collect(&ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") .await .unwrap_err(); assert_eq!(err.to_string(), expected); @@ -287,15 +287,13 @@ async fn information_schema_show_columns_like_where() { #[tokio::test] async fn information_schema_show_columns() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") - .await - .unwrap(); + let result = plan_and_collect(&ctx, "SHOW COLUMNS FROM t").await.unwrap(); let expected = vec![ "+---------------+--------------+------------+-------------+-----------+-------------+", @@ -306,13 +304,11 @@ async fn information_schema_show_columns() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW columns from t") - .await - .unwrap(); + let result = plan_and_collect(&ctx, "SHOW columns from t").await.unwrap(); assert_batches_sorted_eq!(expected, &result); // This isn't ideal but it is consistent behavior for `SELECT * from T` - let err = plan_and_collect(&mut ctx, "SHOW columns from T") + let err = plan_and_collect(&ctx, "SHOW columns from T") .await .unwrap_err(); assert_eq!( @@ -324,13 +320,13 @@ async fn information_schema_show_columns() { // test errors with WHERE and LIKE #[tokio::test] async fn information_schema_show_columns_full_extended() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") + let result = plan_and_collect(&ctx, "SHOW FULL COLUMNS FROM t") .await .unwrap(); let expected = vec![ @@ -342,7 +338,7 @@ async fn information_schema_show_columns_full_extended() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW EXTENDED COLUMNS FROM t") + let result = plan_and_collect(&ctx, "SHOW EXTENDED COLUMNS FROM t") .await .unwrap(); assert_batches_sorted_eq!(expected, &result); @@ -350,13 +346,13 @@ async fn information_schema_show_columns_full_extended() { #[tokio::test] async fn information_schema_show_table_table_names() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") + let result = plan_and_collect(&ctx, "SHOW COLUMNS FROM public.t") .await .unwrap(); @@ -369,12 +365,12 @@ async fn information_schema_show_table_table_names() { ]; assert_batches_sorted_eq!(expected, &result); - let result = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t") + let result = plan_and_collect(&ctx, "SHOW columns from datafusion.public.t") .await .unwrap(); assert_batches_sorted_eq!(expected, &result); - let err = plan_and_collect(&mut ctx, "SHOW columns from t2") + let err = plan_and_collect(&ctx, "SHOW columns from t2") .await .unwrap_err(); assert_eq!( @@ -382,7 +378,7 @@ async fn information_schema_show_table_table_names() { "Error during planning: Unknown relation for SHOW COLUMNS: t2" ); - let err = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t2") + let err = plan_and_collect(&ctx, "SHOW columns from datafusion.public.t2") .await .unwrap_err(); assert_eq!( @@ -393,9 +389,9 @@ async fn information_schema_show_table_table_names() { #[tokio::test] async fn show_unsupported() { - let mut ctx = SessionContext::with_config(SessionConfig::new()); + let ctx = SessionContext::with_config(SessionConfig::new()); - let err = plan_and_collect(&mut ctx, "SHOW SOMETHING_UNKNOWN") + let err = plan_and_collect(&ctx, "SHOW SOMETHING_UNKNOWN") .await .unwrap_err(); @@ -404,9 +400,9 @@ async fn show_unsupported() { #[tokio::test] async fn information_schema_columns_not_exist_by_default() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + let err = plan_and_collect(&ctx, "SELECT * from information_schema.columns") .await .unwrap_err(); assert_eq!( @@ -452,7 +448,7 @@ fn table_with_many_types() -> Arc { #[tokio::test] async fn information_schema_columns() { - let mut ctx = + let ctx = SessionContext::with_config(SessionConfig::new().with_information_schema(true)); let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -469,7 +465,7 @@ async fn information_schema_columns() { .unwrap(); ctx.register_catalog("my_catalog", Arc::new(catalog)); - let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + let result = plan_and_collect(&ctx, "SELECT * from information_schema.columns") .await .unwrap(); @@ -491,9 +487,6 @@ async fn information_schema_columns() { } /// Execute SQL and return results -async fn plan_and_collect( - ctx: &mut SessionContext, - sql: &str, -) -> Result> { +async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } diff --git a/datafusion/tests/sql/intersection.rs b/datafusion/tests/sql/intersection.rs index eec22eecdf55..fadeefee4136 100644 --- a/datafusion/tests/sql/intersection.rs +++ b/datafusion/tests/sql/intersection.rs @@ -49,8 +49,8 @@ async fn intersect_with_null_equal() { #[tokio::test] async fn test_intersect_all() -> Result<()> { - let mut ctx = SessionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4"; let actual = execute_to_batches(&ctx, sql).await; @@ -70,8 +70,8 @@ async fn test_intersect_all() -> Result<()> { #[tokio::test] async fn test_intersect_distinct() -> Result<()> { - let mut ctx = SessionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT SELECT int_col, double_col FROM alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs index 36ad05755591..d8f5d4ad8a2b 100644 --- a/datafusion/tests/sql/joins.rs +++ b/datafusion/tests/sql/joins.rs @@ -552,21 +552,21 @@ async fn equijoin_implicit_syntax_reversed() -> Result<()> { #[tokio::test] async fn cross_join() { - let mut ctx = create_join_context("t1_id", "t2_id").unwrap(); + let ctx = create_join_context("t1_id", "t2_id").unwrap(); let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4, actual.len()); let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4, actual.len()); let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4, actual.len()); let actual = execute_to_batches(&ctx, sql).await; @@ -597,13 +597,13 @@ async fn cross_join() { // Two partitions (from UNION) on the left let sql = "SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4 * 2, actual.len()); // Two partitions (from UNION) on the right let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(4 * 4 * 2, actual.len()); } @@ -648,7 +648,7 @@ async fn cross_join_unbalanced() { #[tokio::test] async fn test_join_timestamp() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // register time table let timestamp_schema = Arc::new(Schema::new(vec![Field::new( @@ -691,7 +691,7 @@ async fn test_join_timestamp() -> Result<()> { #[tokio::test] async fn test_join_float32() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // register population table let population_schema = Arc::new(Schema::new(vec![ @@ -732,7 +732,7 @@ async fn test_join_float32() -> Result<()> { #[tokio::test] async fn test_join_float64() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); // register population table let population_schema = Arc::new(Schema::new(vec![ @@ -857,7 +857,7 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul .unwrap(); let cities = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("countries", Arc::new(countries))?; ctx.register_table("cities", Arc::new(cities))?; @@ -885,7 +885,7 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul #[tokio::test] async fn join_timestamp() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let expected = vec![ diff --git a/datafusion/tests/sql/json.rs b/datafusion/tests/sql/json.rs index c87bcfd77ffb..e2209b8764fb 100644 --- a/datafusion/tests/sql/json.rs +++ b/datafusion/tests/sql/json.rs @@ -21,7 +21,7 @@ const TEST_DATA_BASE: &str = "tests/jsons"; #[tokio::test] async fn json_query() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let path = format!("{}/2.json", TEST_DATA_BASE); ctx.register_json("t1", &path, NdJsonReadOptions::default()) .await @@ -54,7 +54,7 @@ async fn json_query() { #[tokio::test] #[should_panic] async fn json_single_nan_schema() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let path = format!("{}/3.json", TEST_DATA_BASE); ctx.register_json("single_nan", &path, NdJsonReadOptions::default()) .await @@ -73,14 +73,14 @@ async fn json_single_nan_schema() { #[tokio::test] async fn json_explain() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let path = format!("{}/2.json", TEST_DATA_BASE); ctx.register_json("t1", &path, NdJsonReadOptions::default()) .await .unwrap(); let sql = "EXPLAIN SELECT count(*) from t1"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); let expected = vec![ vec![ diff --git a/datafusion/tests/sql/limit.rs b/datafusion/tests/sql/limit.rs index 3e7c1d6049d4..fc2dc4c95645 100644 --- a/datafusion/tests/sql/limit.rs +++ b/datafusion/tests/sql/limit.rs @@ -19,8 +19,8 @@ use super::*; #[tokio::test] async fn csv_query_limit() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 2"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+----+", "| c1 |", "+----+", "| c |", "| d |", "+----+"]; @@ -30,8 +30,8 @@ async fn csv_query_limit() -> Result<()> { #[tokio::test] async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200"; let actual = execute_to_batches(&ctx, sql).await; // println!("{}", pretty_format_batches(&a).unwrap()); @@ -56,8 +56,8 @@ async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { #[tokio::test] async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -81,8 +81,8 @@ async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { #[tokio::test] async fn csv_query_limit_zero() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 0"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["++", "++"]; diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 1e8938a06f83..724fffce7bbf 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -66,9 +66,9 @@ macro_rules! assert_metrics { macro_rules! test_expression { ($SQL:expr, $EXPECTED:expr) => { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let sql = format!("SELECT {}", $SQL); - let actual = execute(&mut ctx, sql.as_str()).await; + let actual = execute(&ctx, sql.as_str()).await; assert_eq!(actual[0][0], $EXPECTED); }; } @@ -153,7 +153,7 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { } fn create_case_context() -> Result { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); let data = RecordBatch::try_new( schema.clone(), @@ -170,7 +170,7 @@ fn create_case_context() -> Result { } fn create_join_context(column_left: &str, column_right: &str) -> Result { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new(column_left, DataType::UInt32, true), @@ -214,7 +214,7 @@ fn create_join_context(column_left: &str, column_right: &str) -> Result Result { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, true), @@ -256,7 +256,7 @@ fn create_join_context_unbalanced( column_left: &str, column_right: &str, ) -> Result { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new(column_left, DataType::UInt32, true), @@ -302,7 +302,7 @@ fn create_join_context_unbalanced( // Create memory tables with nulls fn create_join_context_with_nulls() -> Result { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new("t1_id", DataType::UInt32, true), @@ -404,7 +404,7 @@ fn get_tpch_table_schema(table: &str) -> Schema { } } -async fn register_tpch_csv(ctx: &mut SessionContext, table: &str) -> Result<()> { +async fn register_tpch_csv(ctx: &SessionContext, table: &str) -> Result<()> { let schema = get_tpch_table_schema(table); ctx.register_csv( @@ -416,7 +416,7 @@ async fn register_tpch_csv(ctx: &mut SessionContext, table: &str) -> Result<()> Ok(()) } -async fn register_aggregate_csv_by_sql(ctx: &mut SessionContext) { +async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); // TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once @@ -458,7 +458,7 @@ async fn register_aggregate_csv_by_sql(ctx: &mut SessionContext) { } /// Create table "t1" with two boolean columns "a" and "b" -async fn register_boolean(ctx: &mut SessionContext) -> Result<()> { +async fn register_boolean(ctx: &SessionContext) -> Result<()> { let a: BooleanArray = [ Some(true), Some(true), @@ -493,7 +493,7 @@ async fn register_boolean(ctx: &mut SessionContext) -> Result<()> { Ok(()) } -async fn register_aggregate_simple_csv(ctx: &mut SessionContext) -> Result<()> { +async fn register_aggregate_simple_csv(ctx: &SessionContext) -> Result<()> { // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Float32, false), @@ -510,7 +510,7 @@ async fn register_aggregate_simple_csv(ctx: &mut SessionContext) -> Result<()> { Ok(()) } -async fn register_aggregate_csv(ctx: &mut SessionContext) -> Result<()> { +async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { let testdata = datafusion::test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( @@ -523,10 +523,7 @@ async fn register_aggregate_csv(ctx: &mut SessionContext) -> Result<()> { } /// Execute SQL and return results as a RecordBatch -async fn plan_and_collect( - ctx: &mut SessionContext, - sql: &str, -) -> Result> { +async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } @@ -553,7 +550,7 @@ async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec /// Execute query and return result set as 2-d table of Vecs /// `result[row][column]` -async fn execute(ctx: &mut SessionContext, sql: &str) -> Vec> { +async fn execute(ctx: &SessionContext, sql: &str) -> Vec> { result_vec(&execute_to_batches(ctx, sql).await) } @@ -600,7 +597,7 @@ fn result_vec(results: &[RecordBatch]) -> Vec> { result } -async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut SessionContext) { +async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &SessionContext) { let df = ctx .sql( "CREATE EXTERNAL TABLE aggregate_simple ( @@ -622,7 +619,7 @@ async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut SessionCont ); } -async fn register_alltypes_parquet(ctx: &mut SessionContext) { +async fn register_alltypes_parquet(ctx: &SessionContext) { let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", @@ -831,7 +828,7 @@ async fn nyc() -> Result<()> { Field::new("total_amount", DataType::Float64, true), ]); - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_csv( "tripdata", "file.csv", diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs index c6613d581bc3..2683f20e7fdd 100644 --- a/datafusion/tests/sql/order.rs +++ b/datafusion/tests/sql/order.rs @@ -19,8 +19,8 @@ use super::*; #[tokio::test] async fn test_sort_unprojected_col() -> Result<()> { - let mut ctx = SessionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT id FROM alltypes_plain ORDER BY int_col, double_col"; let actual = execute_to_batches(&ctx, sql).await; @@ -34,8 +34,8 @@ async fn test_sort_unprojected_col() -> Result<()> { #[tokio::test] async fn test_order_by_agg_expr() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ diff --git a/datafusion/tests/sql/parquet.rs b/datafusion/tests/sql/parquet.rs index 77949938cc7a..6dfb37ed3c7e 100644 --- a/datafusion/tests/sql/parquet.rs +++ b/datafusion/tests/sql/parquet.rs @@ -24,8 +24,8 @@ use super::*; #[tokio::test] async fn parquet_query() { - let mut ctx = SessionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; @@ -50,7 +50,7 @@ async fn parquet_query() { #[tokio::test] async fn parquet_single_nan_schema() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata)) .await @@ -70,7 +70,7 @@ async fn parquet_single_nan_schema() { #[tokio::test] #[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] async fn parquet_list_columns() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet( "list_columns", @@ -212,7 +212,7 @@ async fn schema_merge_ignores_metadata() { // Read the parquet files into a dataframe to confirm results // (no errors) - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let df = ctx .read_parquet(table_dir.to_str().unwrap().to_string()) .await diff --git a/datafusion/tests/sql/partitioned_csv.rs b/datafusion/tests/sql/partitioned_csv.rs index e90b7a0beffa..9b87c2f09b67 100644 --- a/datafusion/tests/sql/partitioned_csv.rs +++ b/datafusion/tests/sql/partitioned_csv.rs @@ -78,8 +78,7 @@ pub async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let mut ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; diff --git a/datafusion/tests/sql/predicates.rs b/datafusion/tests/sql/predicates.rs index 879107c84e94..1369baa75f4c 100644 --- a/datafusion/tests/sql/predicates.rs +++ b/datafusion/tests/sql/predicates.rs @@ -19,8 +19,8 @@ use super::*; #[tokio::test] async fn csv_query_with_predicate() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -37,8 +37,8 @@ async fn csv_query_with_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_negative_predicate() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c4 FROM aggregate_test_100 WHERE c3 < -55 AND -c4 > 30000"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -55,8 +55,8 @@ async fn csv_query_with_negative_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_negated_predicate() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE NOT(c1 != 'a')"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -72,8 +72,8 @@ async fn csv_query_with_negated_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_is_not_null_predicate() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NOT NULL"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -89,8 +89,8 @@ async fn csv_query_with_is_not_null_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_is_null_predicate() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NULL"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -106,8 +106,8 @@ async fn csv_query_with_is_null_predicate() -> Result<()> { #[tokio::test] async fn query_where_neg_num() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; // Negative numbers do not parse correctly as of Arrow 2.0.0 let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2 and c7 < 10"; @@ -134,8 +134,8 @@ async fn query_where_neg_num() -> Result<()> { #[tokio::test] async fn like() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; // check that the physical and logical schemas are equal let actual = execute_to_batches(&ctx, sql).await; @@ -152,8 +152,8 @@ async fn like() -> Result<()> { #[tokio::test] async fn csv_between_expr() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 BETWEEN 0.995 AND 1.0"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -169,8 +169,8 @@ async fn csv_between_expr() -> Result<()> { #[tokio::test] async fn csv_between_expr_negated() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 NOT BETWEEN 0 AND 0.995"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -193,7 +193,7 @@ async fn like_on_strings() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; @@ -220,7 +220,7 @@ async fn like_on_string_dictionaries() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; @@ -247,7 +247,7 @@ async fn test_regexp_is_match() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 ~ 'z'"; @@ -333,8 +333,8 @@ async fn except_with_null_equal() { #[tokio::test] async fn test_expect_all() -> Result<()> { - let mut ctx = SessionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; let actual = execute_to_batches(&ctx, sql).await; @@ -354,8 +354,8 @@ async fn test_expect_all() -> Result<()> { #[tokio::test] async fn test_expect_distinct() -> Result<()> { - let mut ctx = SessionContext::new(); - register_alltypes_parquet(&mut ctx).await; + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; let actual = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/tests/sql/projection.rs b/datafusion/tests/sql/projection.rs index d2bcfbcd6d63..d717698e4522 100644 --- a/datafusion/tests/sql/projection.rs +++ b/datafusion/tests/sql/projection.rs @@ -35,8 +35,8 @@ async fn projection_same_fields() -> Result<()> { #[tokio::test] async fn projection_type_alias() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; // Query that aliases one column to the name of a different column // that also has a different type (c1 == float32, c3 == boolean) @@ -58,8 +58,8 @@ async fn projection_type_alias() -> Result<()> { #[tokio::test] async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ diff --git a/datafusion/tests/sql/references.rs b/datafusion/tests/sql/references.rs index ec658d534063..23f6058b8034 100644 --- a/datafusion/tests/sql/references.rs +++ b/datafusion/tests/sql/references.rs @@ -19,8 +19,8 @@ use super::*; #[tokio::test] async fn qualified_table_references() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; for table_ref in &[ "aggregate_test_100", @@ -43,7 +43,7 @@ async fn qualified_table_references() -> Result<()> { #[tokio::test] async fn qualified_table_references_and_fields() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let c1: StringArray = vec!["foofoo", "foobar", "foobaz"] .into_iter() diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index 54120e023be9..4f878d23c808 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -24,8 +24,8 @@ use tempfile::TempDir; #[tokio::test] async fn all_where_empty() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT * FROM aggregate_test_100 WHERE 1=2"; @@ -249,8 +249,8 @@ async fn select_values_list() -> Result<()> { #[tokio::test] async fn select_all() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_simple order by c1"; let results = execute_to_batches(&ctx, sql).await; @@ -288,11 +288,11 @@ async fn select_all() -> Result<()> { #[tokio::test] async fn select_distinct() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT DISTINCT * FROM aggregate_simple"; - let mut actual = execute(&mut ctx, sql).await; + let mut actual = execute(&ctx, sql).await; actual.sort(); let mut dedup = actual.clone(); @@ -305,8 +305,8 @@ async fn select_distinct() -> Result<()> { #[tokio::test] async fn select_distinct_simple_1() { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT DISTINCT c1 FROM aggregate_simple order by c1"; let actual = execute_to_batches(&ctx, sql).await; @@ -327,8 +327,8 @@ async fn select_distinct_simple_1() { #[tokio::test] async fn select_distinct_simple_2() { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT DISTINCT c1, c2 FROM aggregate_simple order by c1"; let actual = execute_to_batches(&ctx, sql).await; @@ -349,8 +349,8 @@ async fn select_distinct_simple_2() { #[tokio::test] async fn select_distinct_simple_3() { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT distinct c3 FROM aggregate_simple order by c3"; let actual = execute_to_batches(&ctx, sql).await; @@ -368,8 +368,8 @@ async fn select_distinct_simple_3() { #[tokio::test] async fn select_distinct_simple_4() { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await.unwrap(); + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await.unwrap(); let sql = "SELECT distinct c1+c2 as a FROM aggregate_simple"; let actual = execute_to_batches(&ctx, sql).await; @@ -434,8 +434,8 @@ async fn select_distinct_from_utf8() { #[tokio::test] async fn csv_query_with_decimal_by_sql() -> Result<()> { - let mut ctx = SessionContext::new(); - register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await; + let ctx = SessionContext::new(); + register_simple_aggregate_csv_with_decimal_by_sql(&ctx).await; let sql = "SELECT c1 from aggregate_simple"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -465,7 +465,7 @@ async fn csv_query_with_decimal_by_sql() -> Result<()> { #[tokio::test] async fn use_between_expression_in_select_query() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1 NOT BETWEEN 3 AND 5"; let actual = execute_to_batches(&ctx, sql).await; @@ -515,7 +515,7 @@ async fn use_between_expression_in_select_query() -> Result<()> { #[tokio::test] async fn query_get_indexed_field() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![Field::new( "some_list", DataType::List(Box::new(Field::new("item", DataType::Int64, true))), @@ -549,7 +549,7 @@ async fn query_get_indexed_field() -> Result<()> { #[tokio::test] async fn query_nested_get_indexed_field() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); // Nested schema of { "some_list": [[i64]] } let schema = Arc::new(Schema::new(vec![Field::new( @@ -607,7 +607,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { #[tokio::test] async fn query_nested_get_indexed_field_on_struct() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); // Nested schema of { "some_struct": { "bar": [i64] } } let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)]; @@ -676,7 +676,7 @@ async fn query_on_string_dictionary() -> Result<()> { .unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; // Basic SELECT @@ -903,8 +903,8 @@ async fn query_cte() -> Result<()> { #[tokio::test] async fn csv_select_nested() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT o1, o2, c3 FROM ( SELECT c1 AS o1, c2 + 1 AS o2, c3 @@ -991,11 +991,11 @@ async fn parallel_query_with_filter() -> Result<()> { #[tokio::test] async fn query_empty_table() { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let empty_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))); ctx.register_table("test_tbl", empty_table).unwrap(); let sql = "SELECT * FROM test_tbl"; - let result = plan_and_collect(&mut ctx, sql) + let result = plan_and_collect(&ctx, sql) .await .expect("Query empty table"); let expected = vec!["++", "++"]; diff --git a/datafusion/tests/sql/timestamp.rs b/datafusion/tests/sql/timestamp.rs index 4e0e7a8c79a8..1e475fb175bd 100644 --- a/datafusion/tests/sql/timestamp.rs +++ b/datafusion/tests/sql/timestamp.rs @@ -20,7 +20,7 @@ use datafusion::from_slice::FromSlice; #[tokio::test] async fn query_cast_timestamp_millis() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -52,7 +52,7 @@ async fn query_cast_timestamp_millis() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_micros() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -85,7 +85,7 @@ async fn query_cast_timestamp_micros() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_seconds() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -116,7 +116,7 @@ async fn query_cast_timestamp_seconds() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_nanos_to_others() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; // Original column is nanos, convert to millis and check timestamp @@ -166,7 +166,7 @@ async fn query_cast_timestamp_nanos_to_others() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_seconds_to_others() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_secs", make_timestamp_table::()?)?; // Original column is seconds, convert to millis and check timestamp @@ -216,7 +216,7 @@ async fn query_cast_timestamp_seconds_to_others() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_micros_to_others() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table( "ts_micros", make_timestamp_table::()?, @@ -268,7 +268,7 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { #[tokio::test] async fn to_timestamp() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; @@ -287,7 +287,7 @@ async fn to_timestamp() -> Result<()> { #[tokio::test] async fn to_timestamp_millis() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table( "ts_data", make_timestamp_table::()?, @@ -308,7 +308,7 @@ async fn to_timestamp_millis() -> Result<()> { #[tokio::test] async fn to_timestamp_micros() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table( "ts_data", make_timestamp_table::()?, @@ -330,7 +330,7 @@ async fn to_timestamp_micros() -> Result<()> { #[tokio::test] async fn to_timestamp_seconds() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_table::()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_seconds('2020-09-08T12:00:00+00:00')"; @@ -349,7 +349,7 @@ async fn to_timestamp_seconds() -> Result<()> { #[tokio::test] async fn count_distinct_timestamps() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(DISTINCT(ts)) FROM ts_data"; @@ -369,8 +369,8 @@ async fn count_distinct_timestamps() -> Result<()> { #[tokio::test] async fn test_current_timestamp_expressions() -> Result<()> { let t1 = chrono::Utc::now().timestamp(); - let mut ctx = SessionContext::new(); - let actual = execute(&mut ctx, "SELECT NOW(), NOW() as t2").await; + let ctx = SessionContext::new(); + let actual = execute(&ctx, "SELECT NOW(), NOW() as t2").await; let res1 = actual[0][0].as_str(); let res2 = actual[0][1].as_str(); let t3 = chrono::Utc::now().timestamp(); @@ -416,7 +416,7 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { #[tokio::test] async fn timestamp_minmax() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_tz_table::(None)?; let table_b = make_timestamp_tz_table::(Some("UTC".to_owned()))?; @@ -440,7 +440,7 @@ async fn timestamp_minmax() -> Result<()> { #[tokio::test] async fn timestamp_coercion() -> Result<()> { { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_tz_table::(Some("UTC".to_owned()))?; let table_b = @@ -469,7 +469,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -496,7 +496,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -523,7 +523,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -550,7 +550,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -577,7 +577,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -604,7 +604,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -631,7 +631,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -658,7 +658,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -685,7 +685,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -712,7 +712,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -739,7 +739,7 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; @@ -770,7 +770,7 @@ async fn timestamp_coercion() -> Result<()> { #[tokio::test] async fn group_by_timestamp_millis() -> Result<()> { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new( diff --git a/datafusion/tests/sql/udf.rs b/datafusion/tests/sql/udf.rs index c35c1d54148e..178a47d61c2d 100644 --- a/datafusion/tests/sql/udf.rs +++ b/datafusion/tests/sql/udf.rs @@ -27,10 +27,10 @@ use datafusion::{ /// physical plan have the same schema. #[tokio::test] async fn csv_query_custom_udf_with_cast() -> Result<()> { - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx).await?; + let ctx = create_ctx()?; + register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0.6584408483418833"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -172,7 +172,7 @@ async fn simple_udaf() -> Result<()> { ctx.register_udaf(my_avg); - let result = plan_and_collect(&mut ctx, "SELECT MY_AVG(a) FROM t").await?; + let result = plan_and_collect(&ctx, "SELECT MY_AVG(a) FROM t").await?; let expected = vec![ "+-------------+", diff --git a/datafusion/tests/sql/unicode.rs b/datafusion/tests/sql/unicode.rs index f9cb4c482989..b9c9cbd1c5bc 100644 --- a/datafusion/tests/sql/unicode.rs +++ b/datafusion/tests/sql/unicode.rs @@ -116,10 +116,10 @@ async fn generic_query_length>>( let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT length(c1) FROM test"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; assert_eq!(expected, actual); Ok(()) diff --git a/datafusion/tests/sql/union.rs b/datafusion/tests/sql/union.rs index 958e4b9d9389..4ebd4e88e3b0 100644 --- a/datafusion/tests/sql/union.rs +++ b/datafusion/tests/sql/union.rs @@ -29,11 +29,11 @@ async fn union_all() -> Result<()> { #[tokio::test] async fn csv_union_all() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 UNION ALL SELECT c1 FROM aggregate_test_100"; - let actual = execute(&mut ctx, sql).await; + let actual = execute(&ctx, sql).await; assert_eq!(actual.len(), 200); Ok(()) } diff --git a/datafusion/tests/sql/wildcard.rs b/datafusion/tests/sql/wildcard.rs index b312fe1212e5..106f47a45778 100644 --- a/datafusion/tests/sql/wildcard.rs +++ b/datafusion/tests/sql/wildcard.rs @@ -19,8 +19,8 @@ use super::*; #[tokio::test] async fn select_qualified_wildcard() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT agg.* FROM aggregate_simple as agg order by c1"; let results = execute_to_batches(&ctx, sql).await; @@ -54,8 +54,8 @@ async fn select_qualified_wildcard() -> Result<()> { #[tokio::test] async fn select_non_alias_qualified_wildcard() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT aggregate_simple.* FROM aggregate_simple order by c1"; let results = execute_to_batches(&ctx, sql).await; @@ -132,8 +132,8 @@ async fn select_non_alias_qualified_wildcard_join() -> Result<()> { #[tokio::test] async fn select_wrong_qualified_wildcard() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_simple_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; let sql = "SELECT agg.* FROM aggregate_simple order by c1"; let result = ctx.create_logical_plan(sql); diff --git a/datafusion/tests/sql/window.rs b/datafusion/tests/sql/window.rs index 1b7335b12233..ba493e6fe559 100644 --- a/datafusion/tests/sql/window.rs +++ b/datafusion/tests/sql/window.rs @@ -20,8 +20,8 @@ use super::*; /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_empty_over() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ count(c5) over (), \ @@ -49,8 +49,8 @@ async fn csv_query_window_with_empty_over() -> Result<()> { /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_partition_by() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ sum(cast(c4 as Int)) over (partition by c3), \ @@ -79,8 +79,8 @@ async fn csv_query_window_with_partition_by() -> Result<()> { #[tokio::test] async fn csv_query_window_with_order_by() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ sum(c5) over (order by c9), \ @@ -112,8 +112,8 @@ async fn csv_query_window_with_order_by() -> Result<()> { #[tokio::test] async fn csv_query_window_with_partition_by_order_by() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; let sql = "select \ c9, \ sum(c5) over (partition by c4 order by c9), \ diff --git a/datafusion/tests/statistics.rs b/datafusion/tests/statistics.rs index 76b185e2c2a3..9f5a1aab0f93 100644 --- a/datafusion/tests/statistics.rs +++ b/datafusion/tests/statistics.rs @@ -172,7 +172,7 @@ impl ExecutionPlan for StatisticsValidation { } fn init_ctx(stats: Statistics, schema: Schema) -> Result { - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let provider: Arc = Arc::new(StatisticsValidation::new(stats, Arc::new(schema))); ctx.register_table("stats_table", provider)?; @@ -210,7 +210,7 @@ fn fully_defined() -> (Statistics, Schema) { #[tokio::test] async fn sql_basic() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * from stats_table").await.unwrap(); @@ -228,7 +228,7 @@ async fn sql_basic() -> Result<()> { #[tokio::test] async fn sql_filter() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats, schema)?; + let ctx = init_ctx(stats, schema)?; let df = ctx .sql("SELECT * FROM stats_table WHERE c1 = 5") @@ -249,7 +249,7 @@ async fn sql_filter() -> Result<()> { #[tokio::test] async fn sql_limit() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); let physical_plan = ctx @@ -284,7 +284,7 @@ async fn sql_limit() -> Result<()> { #[tokio::test] async fn sql_window() -> Result<()> { let (stats, schema) = fully_defined(); - let mut ctx = init_ctx(stats.clone(), schema)?; + let ctx = init_ctx(stats.clone(), schema)?; let df = ctx .sql("SELECT c2, sum(c1) over (partition by c2) FROM stats_table") diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index 3b41a28e077e..86da4da85320 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -247,7 +247,7 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()); - let state = SessionState::with_config(config, runtime) + let state = SessionState::with_config_rt(config, runtime) .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); SessionContext::with_state(state) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 4f2a6cf1a01f..6be802400a77 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -36,7 +36,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // register the table - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; // create a plan to run a SQL query @@ -56,7 +56,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // create the dataframe - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; let df = df.filter(col("a").lt_eq(col("b")))? diff --git a/docs/source/user-guide/library.md b/docs/source/user-guide/library.md index 5f4224816455..cb9ca48e78f8 100644 --- a/docs/source/user-guide/library.md +++ b/docs/source/user-guide/library.md @@ -57,7 +57,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // register the table - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); ctx.register_csv("test", "", CsvReadOptions::new()).await?; // create a plan to run a SQL query