diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 736cf648f0a8..2dd270459ca6 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1810,22 +1810,14 @@ impl FunctionRegistry for SessionState { } } -/// Task Context Properties -pub enum TaskProperties { - ///SessionConfig - SessionConfig(SessionConfig), - /// Name-value pairs of task properties - KVPairs(HashMap), -} - /// Task Execution Context pub struct TaskContext { /// Session Id session_id: String, /// Optional Task Identify task_id: Option, - /// Task properties - properties: TaskProperties, + /// Session configuration + session_config: SessionConfig, /// Scalar functions associated with this task context scalar_functions: HashMap>, /// Aggregate functions associated with this task context @@ -1844,10 +1836,43 @@ impl TaskContext { aggregate_functions: HashMap>, runtime: Arc, ) -> Self { + let session_config = if task_props.is_empty() { + SessionConfig::new() + } else { + SessionConfig::new() + .with_batch_size(task_props.get(OPT_BATCH_SIZE).unwrap().parse().unwrap()) + .with_target_partitions( + task_props.get(TARGET_PARTITIONS).unwrap().parse().unwrap(), + ) + .with_repartition_joins( + task_props.get(REPARTITION_JOINS).unwrap().parse().unwrap(), + ) + .with_repartition_aggregations( + task_props + .get(REPARTITION_AGGREGATIONS) + .unwrap() + .parse() + .unwrap(), + ) + .with_repartition_windows( + task_props + .get(REPARTITION_WINDOWS) + .unwrap() + .parse() + .unwrap(), + ) + .with_parquet_pruning( + task_props.get(PARQUET_PRUNING).unwrap().parse().unwrap(), + ) + .with_collect_statistics( + task_props.get(COLLECT_STATISTICS).unwrap().parse().unwrap(), + ) + }; + Self { task_id: Some(task_id), session_id, - properties: TaskProperties::KVPairs(task_props), + session_config, scalar_functions, aggregate_functions, runtime, @@ -1855,44 +1880,8 @@ impl TaskContext { } /// Return the SessionConfig associated with the Task - pub fn session_config(&self) -> SessionConfig { - let task_props = &self.properties; - match task_props { - TaskProperties::KVPairs(props) => { - let session_config = SessionConfig::new(); - if props.is_empty() { - session_config - } else { - session_config - .with_batch_size( - props.get(OPT_BATCH_SIZE).unwrap().parse().unwrap(), - ) - .with_target_partitions( - props.get(TARGET_PARTITIONS).unwrap().parse().unwrap(), - ) - .with_repartition_joins( - props.get(REPARTITION_JOINS).unwrap().parse().unwrap(), - ) - .with_repartition_aggregations( - props - .get(REPARTITION_AGGREGATIONS) - .unwrap() - .parse() - .unwrap(), - ) - .with_repartition_windows( - props.get(REPARTITION_WINDOWS).unwrap().parse().unwrap(), - ) - .with_parquet_pruning( - props.get(PARQUET_PRUNING).unwrap().parse().unwrap(), - ) - .with_collect_statistics( - props.get(COLLECT_STATISTICS).unwrap().parse().unwrap(), - ) - } - } - TaskProperties::SessionConfig(session_config) => session_config.clone(), - } + pub fn session_config(&self) -> &SessionConfig { + &self.session_config } /// Return the session_id of this [TaskContext] @@ -1914,24 +1903,7 @@ impl TaskContext { /// Create a new task context instance from SessionContext impl From<&SessionContext> for TaskContext { fn from(session: &SessionContext) -> Self { - let session_id = session.session_id.clone(); - let (config, scalar_functions, aggregate_functions) = { - let session_state = session.state.read(); - ( - session_state.config.clone(), - session_state.scalar_functions.clone(), - session_state.aggregate_functions.clone(), - ) - }; - let runtime = session.runtime_env(); - Self { - task_id: None, - session_id, - properties: TaskProperties::SessionConfig(config), - scalar_functions, - aggregate_functions, - runtime, - } + TaskContext::from(&*session.state.read()) } } @@ -1939,14 +1911,14 @@ impl From<&SessionContext> for TaskContext { impl From<&SessionState> for TaskContext { fn from(state: &SessionState) -> Self { let session_id = state.session_id.clone(); - let config = state.config.clone(); + let session_config = state.config.clone(); let scalar_functions = state.scalar_functions.clone(); let aggregate_functions = state.aggregate_functions.clone(); let runtime = state.runtime_env.clone(); Self { task_id: None, session_id, - properties: TaskProperties::SessionConfig(config), + session_config, scalar_functions, aggregate_functions, runtime, diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 0dfac7876734..2d7e06f0e647 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -916,7 +916,7 @@ async fn do_sort( schema.clone(), expr, metrics_set, - Arc::new(context.session_config()), + Arc::new(context.session_config().clone()), context.runtime_env(), fetch, );