Skip to content

Commit 5936edc

Browse files
mingmwangalamb
andauthored
Refactor SessionContext, SessionState and SessionConfig to support multi-tenancy configurations - Part 2 (#2029)
* Refactor SessionContext, SessionState add SessionConfig to support multi-tenancy configurations - Part 2 * fix UT * Resolve review comments * Fix minor merge issues Co-authored-by: Andrew Lamb <[email protected]>
1 parent 8de2a76 commit 5936edc

File tree

23 files changed

+464
-446
lines changed

23 files changed

+464
-446
lines changed

ballista/rust/core/src/serde/logical_plan/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ impl AsLogicalPlan for LogicalPlanNode {
228228
};
229229

230230
let object_store = ctx
231+
.runtime_env()
231232
.object_store(scan.path.as_str())
232233
.map_err(|e| {
233234
BallistaError::NotImplemented(format!(
@@ -1287,9 +1288,10 @@ mod roundtrip_tests {
12871288
let codec: BallistaCodec<protobuf::LogicalPlanNode, protobuf::PhysicalPlanNode> =
12881289
BallistaCodec::default();
12891290
let custom_object_store = Arc::new(TestObjectStore {});
1290-
ctx.register_object_store("test", custom_object_store.clone());
1291+
ctx.runtime_env()
1292+
.register_object_store("test", custom_object_store.clone());
12911293

1292-
let (os, _) = ctx.object_store("test://foo.csv")?;
1294+
let (os, _) = ctx.runtime_env().object_store("test://foo.csv")?;
12931295

12941296
println!("Object Store {:?}", os);
12951297

ballista/rust/core/src/serde/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ mod tests {
350350
use datafusion::datasource::object_store::local::LocalFileSystem;
351351
use datafusion::error::DataFusionError;
352352
use datafusion::execution::context::{QueryPlanner, SessionState, TaskContext};
353+
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
353354
use datafusion::logical_plan::plan::Extension;
354355
use datafusion::logical_plan::{
355356
col, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode,
@@ -699,10 +700,11 @@ mod tests {
699700
#[tokio::test]
700701
async fn test_extension_plan() -> crate::error::Result<()> {
701702
let store = Arc::new(LocalFileSystem {});
702-
let config =
703-
SessionConfig::new().with_query_planner(Arc::new(TopKQueryPlanner {}));
703+
let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap());
704+
let session_state = SessionState::with_config(SessionConfig::new(), runtime)
705+
.with_query_planner(Arc::new(TopKQueryPlanner {}));
704706

705-
let ctx = SessionContext::with_config(config);
707+
let ctx = SessionContext::with_state(session_state);
706708

707709
let scan = LogicalPlanBuilder::scan_csv(
708710
store,

ballista/rust/core/src/serde/physical_plan/from_proto.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use chrono::{TimeZone, Utc};
2929
use datafusion::datasource::object_store::local::LocalFileSystem;
3030
use datafusion::datasource::object_store::{FileMeta, SizedFile};
3131
use datafusion::datasource::PartitionedFile;
32-
use datafusion::execution::context::SessionState;
32+
use datafusion::execution::context::ExecutionProps;
3333

3434
use datafusion::physical_plan::file_format::FileScanConfig;
3535

@@ -153,12 +153,12 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc<dyn PhysicalExpr> {
153153
.map(|x| x.try_into())
154154
.collect::<Result<Vec<_>, _>>()?;
155155

156-
// TODO Do not create new the SessionState
157-
let session_state = SessionState::new();
156+
// TODO Do not create new the ExecutionProps
157+
let execution_props = ExecutionProps::new();
158158

159159
let fun_expr = functions::create_physical_fun(
160160
&(&scalar_function).into(),
161-
&session_state.execution_props,
161+
&execution_props,
162162
)?;
163163

164164
Arc::new(ScalarFunctionExpr::new(

ballista/rust/core/src/serde/physical_plan/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ fn decode_scan_config(
909909
.collect::<Result<Vec<_>, _>>()?;
910910

911911
let object_store = if let Some(file) = file_groups.get(0).and_then(|h| h.get(0)) {
912-
ctx.object_store(file.file_meta.path())?.0
912+
ctx.runtime_env().object_store(file.file_meta.path())?.0
913913
} else {
914914
Arc::new(LocalFileSystem {})
915915
};

ballista/rust/core/src/utils.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
4747
use datafusion::physical_plan::common::batch_byte_size;
4848
use datafusion::physical_plan::empty::EmptyExec;
4949

50+
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
5051
use datafusion::physical_plan::file_format::{CsvExec, ParquetExec};
5152
use datafusion::physical_plan::filter::FilterExec;
5253
use datafusion::physical_plan::hash_aggregate::HashAggregateExec;
@@ -234,11 +235,16 @@ pub fn create_df_ctx_with_ballista_query_planner<T: 'static + AsLogicalPlan>(
234235
let scheduler_url = format!("http://{}:{}", scheduler_host, scheduler_port);
235236
let planner: Arc<BallistaQueryPlanner<T>> =
236237
Arc::new(BallistaQueryPlanner::new(scheduler_url, config.clone()));
237-
let config = SessionConfig::new()
238-
.with_query_planner(planner)
238+
239+
let session_config = SessionConfig::new()
239240
.with_target_partitions(config.default_shuffle_partitions())
240241
.with_information_schema(true);
241-
SessionContext::with_config(config)
242+
let session_state = SessionState::with_config(
243+
session_config,
244+
Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()),
245+
)
246+
.with_query_planner(planner);
247+
SessionContext::with_state(session_state)
242248
}
243249

244250
pub struct BallistaQueryPlanner<T: AsLogicalPlan> {

ballista/rust/scheduler/src/scheduler_server/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,23 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
179179
}
180180
}
181181

182-
/// Create a DataFusion context that is compatible with Ballista
182+
/// Create a DataFusion session context that is compatible with Ballista Configuration
183183
pub fn create_datafusion_context(config: &BallistaConfig) -> SessionContext {
184184
let config =
185185
SessionConfig::new().with_target_partitions(config.default_shuffle_partitions());
186186
SessionContext::with_config(config)
187187
}
188188

189+
/// Update the existing DataFusion session context with Ballista Configuration
190+
pub fn update_datafusion_context(
191+
session_ctx: Arc<SessionContext>,
192+
config: &BallistaConfig,
193+
) -> Arc<SessionContext> {
194+
session_ctx.state.write().config.target_partitions =
195+
config.default_shuffle_partitions();
196+
session_ctx
197+
}
198+
189199
#[cfg(all(test, feature = "sled"))]
190200
mod test {
191201
use std::sync::Arc;

datafusion-proto/src/from_proto.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ pub fn parse_expr(
11681168
ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => {
11691169
let scalar_fn = ctx
11701170
.state
1171-
.lock()
1171+
.read()
11721172
.get_function_meta(fun_name.as_str()).ok_or_else(|| Error::General(format!("invalid aggregate function message, function {} is not registered in the ExecutionContext", fun_name)))?;
11731173

11741174
Ok(Expr::ScalarUDF {
@@ -1182,7 +1182,7 @@ pub fn parse_expr(
11821182
ExprType::AggregateUdfExpr(protobuf::AggregateUdfExprNode { fun_name, args }) => {
11831183
let agg_fn = ctx
11841184
.state
1185-
.lock()
1185+
.read()
11861186
.get_aggregate_meta(fun_name.as_str()).ok_or_else(|| Error::General(format!("invalid aggregate function message, function {} is not registered in the ExecutionContext", fun_name)))?;
11871187

11881188
Ok(Expr::AggregateUDF {

datafusion/benches/sort_limit_query_sql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ fn create_context() -> Arc<Mutex<SessionContext>> {
8484
rt.block_on(async {
8585
// create local session context
8686
let mut ctx = SessionContext::new();
87-
ctx.state.lock().config.target_partitions = 1;
87+
ctx.state.write().config.target_partitions = 1;
8888

8989
let task_ctx = ctx.task_ctx();
9090
let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), task_ctx)

datafusion/src/dataframe.rs

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ use crate::arrow::datatypes::SchemaRef;
3434
use crate::arrow::util::pretty;
3535
use crate::datasource::TableProvider;
3636
use crate::datasource::TableType;
37-
use crate::execution::context::{SessionContext, SessionState, TaskContext};
37+
use crate::execution::context::{SessionState, TaskContext};
3838
use crate::physical_plan::file_format::{plan_to_csv, plan_to_parquet};
3939
use crate::physical_plan::{collect, collect_partitioned};
4040
use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan};
4141
use crate::scalar::ScalarValue;
4242
use crate::sql::utils::find_window_exprs;
43-
use parking_lot::Mutex;
43+
use parking_lot::RwLock;
4444
use std::any::Any;
4545

4646
/// DataFrame represents a logical set of rows with the same named columns.
@@ -69,13 +69,13 @@ use std::any::Any;
6969
/// # }
7070
/// ```
7171
pub struct DataFrame {
72-
session_state: Arc<Mutex<SessionState>>,
72+
session_state: Arc<RwLock<SessionState>>,
7373
plan: LogicalPlan,
7474
}
7575

7676
impl DataFrame {
7777
/// Create a new Table based on an existing logical plan
78-
pub fn new(session_state: Arc<Mutex<SessionState>>, plan: &LogicalPlan) -> Self {
78+
pub fn new(session_state: Arc<RwLock<SessionState>>, plan: &LogicalPlan) -> Self {
7979
Self {
8080
session_state,
8181
plan: plan.clone(),
@@ -84,10 +84,9 @@ impl DataFrame {
8484

8585
/// Create a physical plan
8686
pub async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
87-
let state = self.session_state.lock().clone();
88-
let ctx = SessionContext::from(Arc::new(Mutex::new(state)));
89-
let plan = ctx.optimize(&self.plan)?;
90-
ctx.create_physical_plan(&plan).await
87+
let state = self.session_state.read().clone();
88+
let optimized_plan = state.optimize(&self.plan)?;
89+
state.create_physical_plan(&optimized_plan).await
9190
}
9291

9392
/// Filter the DataFrame by column. Returns a new DataFrame only containing the
@@ -351,8 +350,8 @@ impl DataFrame {
351350
/// ```
352351
pub async fn collect(&self) -> Result<Vec<RecordBatch>> {
353352
let plan = self.create_physical_plan().await?;
354-
let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone()));
355-
collect(plan, task_ctx).await
353+
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
354+
Ok(collect(plan, task_ctx).await?)
356355
}
357356

358357
/// Print results.
@@ -406,7 +405,7 @@ impl DataFrame {
406405
/// ```
407406
pub async fn execute_stream(&self) -> Result<SendableRecordBatchStream> {
408407
let plan = self.create_physical_plan().await?;
409-
let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone()));
408+
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
410409
execute_stream(plan, task_ctx).await
411410
}
412411

@@ -426,8 +425,8 @@ impl DataFrame {
426425
/// ```
427426
pub async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>> {
428427
let plan = self.create_physical_plan().await?;
429-
let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone()));
430-
collect_partitioned(plan, task_ctx).await
428+
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
429+
Ok(collect_partitioned(plan, task_ctx).await?)
431430
}
432431

433432
/// Executes this DataFrame and returns one stream per partition.
@@ -447,8 +446,8 @@ impl DataFrame {
447446
&self,
448447
) -> Result<Vec<SendableRecordBatchStream>> {
449448
let plan = self.create_physical_plan().await?;
450-
let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone()));
451-
execute_stream_partitioned(plan, task_ctx).await
449+
let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone()));
450+
Ok(execute_stream_partitioned(plan, task_ctx).await?)
452451
}
453452

454453
/// Returns the schema describing the output of this DataFrame in terms of columns returned,
@@ -511,7 +510,7 @@ impl DataFrame {
511510
/// # }
512511
/// ```
513512
pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
514-
let registry = self.session_state.lock().clone();
513+
let registry = self.session_state.read().clone();
515514
Arc::new(registry)
516515
}
517516

@@ -563,9 +562,8 @@ impl DataFrame {
563562
/// Write a `DataFrame` to a CSV file.
564563
pub async fn write_csv(&self, path: &str) -> Result<()> {
565564
let plan = self.create_physical_plan().await?;
566-
let state = self.session_state.lock().clone();
567-
let ctx = SessionContext::from(Arc::new(Mutex::new(state)));
568-
plan_to_csv(&ctx, plan, path).await
565+
let state = self.session_state.read().clone();
566+
plan_to_csv(&state, plan, path).await
569567
}
570568

571569
/// Write a `DataFrame` to a Parquet file.
@@ -575,9 +573,8 @@ impl DataFrame {
575573
writer_properties: Option<WriterProperties>,
576574
) -> Result<()> {
577575
let plan = self.create_physical_plan().await?;
578-
let state = self.session_state.lock().clone();
579-
let ctx = SessionContext::from(Arc::new(Mutex::new(state)));
580-
plan_to_parquet(&ctx, plan, path, writer_properties).await
576+
let state = self.session_state.read().clone();
577+
plan_to_parquet(&state, plan, path, writer_properties).await
581578
}
582579
}
583580

datafusion/src/datasource/object_store/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ pub trait ObjectStore: Sync + Send + Debug {
163163

164164
static LOCAL_SCHEME: &str = "file";
165165

166-
/// A Registry holds all the object stores at runtime with a scheme for each store.
166+
/// A Registry holds all the object stores at Runtime with a scheme for each store.
167167
/// This allows the user to extend DataFusion with different storage systems such as S3 or HDFS
168168
/// and query data inside these systems.
169169
pub struct ObjectStoreRegistry {

0 commit comments

Comments
 (0)