Skip to content

Commit e6d34dc

Browse files
authored
(feat) Optimizer & Memo Full Logical -> Physical (#110)
1. Does not handle costing. 2. Fixes memo regression. 3. Extend DSL support of `stored` plans.
1 parent c452468 commit e6d34dc

File tree

27 files changed

+1590
-685
lines changed

27 files changed

+1590
-685
lines changed

optd-cli/src/main.rs

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,12 @@
3434
3535
use clap::{Parser, Subcommand};
3636
use colored::Colorize;
37-
use optd::catalog::Catalog;
3837
use optd::catalog::iceberg::memory_catalog;
3938
use optd::dsl::analyzer::hir::{CoreData, HIR, Udf, Value};
4039
use optd::dsl::compile::{Config, compile_hir};
4140
use optd::dsl::engine::{Continuation, Engine, EngineResponse};
4241
use optd::dsl::utils::errors::{CompileError, Diagnose};
43-
use optd::dsl::utils::retriever::{MockRetriever, Retriever};
42+
use optd::dsl::utils::retriever::MockRetriever;
4443
use std::collections::HashMap;
4544
use std::sync::Arc;
4645
use tokio::runtime::Builder;
@@ -66,22 +65,17 @@ enum Commands {
6665
RunFunctions(Config),
6766
}
6867

69-
/// A unimplemented user-defined function.
70-
pub fn unimplemented_udf(
71-
_args: &[Value],
72-
_catalog: &dyn Catalog,
73-
_retriever: &dyn Retriever,
74-
) -> Value {
75-
println!("This user-defined function is unimplemented!");
76-
Value::new(CoreData::<Value>::None)
77-
}
78-
7968
fn main() -> Result<(), Vec<CompileError>> {
8069
let cli = Cli::parse();
8170

8271
let mut udfs = HashMap::new();
8372
let udf = Udf {
84-
func: unimplemented_udf,
73+
func: Arc::new(|_, _, _| {
74+
Box::pin(async move {
75+
println!("This user-defined function is unimplemented!");
76+
Value::new(CoreData::<Value>::None)
77+
})
78+
}),
8579
};
8680
udfs.insert("unimplemented_udf".to_string(), udf.clone());
8781

optd/src/demo/demo.opt

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
data Physical
21
data PhysicalProperties
32
data Statistics
4-
data LogicalProperties
3+
// Taking folded here is not the most interesting property,
4+
// but it ensures they are the same for all expressions in the same group.
5+
data LogicalProperties(folded: I64)
56

67
data Logical =
78
| Add(left: Logical, right: Logical)
@@ -10,21 +11,42 @@ data Logical =
1011
| Div(left: Logical, right: Logical)
1112
\ Const(val: I64)
1213

14+
data Physical =
15+
| PhysicalAdd(left: Physical, right: Physical)
16+
| PhysicalSub(left: Physical, right: Physical)
17+
| PhysicalMult(left: Physical, right: Physical)
18+
| PhysicalDiv(left: Physical, right: Physical)
19+
\ PhysicalConst(val: I64)
20+
21+
// This will be the input plan that will be optimized.
22+
// Result is: ((1 - 2) * (3 / 4)) + ((5 - 6) * (7 / 8)) = 0
1323
fn input(): Logical =
1424
Add(
1525
Mult(
1626
Sub(Const(1), Const(2)),
1727
Div(Const(3), Const(4))
1828
),
1929
Mult(
20-
Sub(Const(5), Const(6)),
30+
Sub(Const(1), Const(2)),
2131
Div(Const(7), Const(8))
2232
)
2333
)
2434

25-
// TODO(Alexis): This should be $ really, make costing and derive consistent with each other.
35+
// External function to allow the retrieval of properties.
36+
fn properties(op: Logical*): LogicalProperties
37+
38+
// FIXME: This should be $ really (or other), make costing and derive consistent with each other.
2639
// Also, be careful of not forking in there! And make it a required function in analyzer.
27-
fn derive(op: Logical) = LogicalProperties
40+
fn derive(op: Logical*) = match op
41+
| Add(left, right) ->
42+
LogicalProperties(left.properties()#folded + right.properties()#folded)
43+
| Sub(left, right) ->
44+
LogicalProperties(left.properties()#folded - right.properties()#folded)
45+
| Mult(left, right) ->
46+
LogicalProperties(left.properties()#folded * right.properties()#folded)
47+
| Div(left, right) ->
48+
LogicalProperties(left.properties()#folded / right.properties()#folded)
49+
\ Const(val) -> LogicalProperties(val)
2850

2951
[transformation]
3052
fn (op: Logical*) mult_commute(): Logical? = match op
@@ -55,4 +77,12 @@ fn (op: Logical*) const_fold_sub(): Logical? = match op
5577
fn (op: Logical*) const_fold_div(): Logical? = match op
5678
| Div(Const(a), Const(b)) ->
5779
if b == 0 then none else Const(a / b)
58-
\ _ -> none
80+
\ _ -> none
81+
82+
[implementation]
83+
fn (op: Logical*) to_physical(props: PhysicalProperties?) = match op
84+
| Add(left, right) -> PhysicalAdd(left.to_physical(props), right.to_physical(props))
85+
| Sub(left, right) -> PhysicalSub(left.to_physical(props), right.to_physical(props))
86+
| Mult(left, right) -> PhysicalMult(left.to_physical(props), right.to_physical(props))
87+
| Div(left, right) -> PhysicalDiv(left.to_physical(props), right.to_physical(props))
88+
\ Const(val) -> PhysicalConst(val)

optd/src/demo/mod.rs

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,53 @@
11
use crate::{
2-
catalog::iceberg::memory_catalog,
2+
catalog::{Catalog, iceberg::memory_catalog},
33
dsl::{
4-
analyzer::hir::Value,
4+
analyzer::hir::{CoreData, LogicalOp, Materializable, Udf, Value},
55
compile::{Config, compile_hir},
66
engine::{Continuation, Engine, EngineResponse},
7-
utils::retriever::MockRetriever,
7+
utils::retriever::{MockRetriever, Retriever},
88
},
99
memo::MemoryMemo,
10-
optimizer::{OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical},
10+
optimizer::{ClientRequest, OptimizeRequest, Optimizer, hir_cir::into_cir::value_to_logical},
1111
};
1212
use std::{collections::HashMap, sync::Arc, time::Duration};
13-
use tokio::{sync::mpsc, time::timeout};
13+
use tokio::{
14+
sync::mpsc,
15+
time::{sleep, timeout},
16+
};
17+
18+
pub async fn properties(
19+
args: Vec<Value>,
20+
_catalog: Arc<dyn Catalog>,
21+
retriever: Arc<dyn Retriever>,
22+
) -> Value {
23+
let arg = args[0].clone();
24+
let group_id = match &arg.data {
25+
CoreData::Logical(Materializable::Materialized(LogicalOp { group_id, .. })) => {
26+
group_id.unwrap()
27+
}
28+
CoreData::Logical(Materializable::UnMaterialized(group_id)) => *group_id,
29+
_ => panic!("Expected a logical plan"),
30+
};
31+
32+
retriever.get_properties(group_id).await
33+
}
1434

1535
async fn run_demo() {
1636
// Compile the HIR.
1737
let config = Config::new("src/demo/demo.opt".into());
18-
let udfs = HashMap::new();
38+
39+
// Create a properties UDF.
40+
let properties_udf = Udf {
41+
func: Arc::new(|args, catalog, retriever| {
42+
Box::pin(async move { properties(args, catalog, retriever).await })
43+
}),
44+
};
45+
46+
// Create the UDFs HashMap.
47+
let mut udfs = HashMap::new();
48+
udfs.insert("properties".to_string(), properties_udf);
49+
50+
// Compile with the config and UDFs.
1951
let hir = compile_hir(config, udfs).unwrap();
2052

2153
// Create necessary components.
@@ -35,15 +67,15 @@ async fn run_demo() {
3567
let optimize_channel = Optimizer::launch(memo, catalog, hir);
3668
let (tx, mut rx) = mpsc::channel(1);
3769
optimize_channel
38-
.send(OptimizeRequest {
39-
plan: logical_plan,
40-
physical_tx: tx,
41-
})
70+
.send(ClientRequest::Optimize(OptimizeRequest {
71+
plan: logical_plan.clone(),
72+
physical_tx: tx.clone(),
73+
}))
4274
.await
4375
.unwrap();
4476

45-
// Timeout after 2 seconds.
46-
let timeout_duration = Duration::from_secs(2);
77+
// Timeout after 5 seconds.
78+
let timeout_duration = Duration::from_secs(5);
4779
let result = timeout(timeout_duration, async {
4880
while let Some(response) = rx.recv().await {
4981
println!("Received response: {:?}", response);
@@ -55,6 +87,13 @@ async fn run_demo() {
5587
Ok(_) => println!("Finished receiving responses."),
5688
Err(_) => println!("Timed out after 5 seconds."),
5789
}
90+
91+
// Dump the memo (debug utility).
92+
optimize_channel
93+
.send(ClientRequest::DumpMemo)
94+
.await
95+
.unwrap();
96+
sleep(Duration::from_secs(10)).await;
5897
}
5998

6099
#[cfg(test)]

optd/src/dsl/analyzer/from_ast/converter.rs

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,12 @@ impl ASTConverter {
177177
#[cfg(test)]
178178
mod converter_tests {
179179
use super::*;
180-
use crate::catalog::Catalog;
181180
use crate::dsl::analyzer::from_ast::from_ast;
182181
use crate::dsl::analyzer::hir::{CoreData, FunKind};
183182
use crate::dsl::analyzer::type_checks::registry::{Generic, TypeKind};
184183
use crate::dsl::parser::ast::{self, Adt, Function, Item, Module, Type as AstType};
185-
use crate::dsl::utils::retriever::Retriever;
186184
use crate::dsl::utils::span::{Span, Spanned};
185+
use std::sync::Arc;
187186

188187
// Helper functions to create test items
189188
fn create_test_span() -> Span {
@@ -382,19 +381,15 @@ mod converter_tests {
382381
let ext_func = create_simple_function("external_function", false);
383382
let module = create_module_with_functions(vec![ext_func]);
384383

385-
pub fn external_function(
386-
_args: &[Value],
387-
_catalog: &dyn Catalog,
388-
_retriever: &dyn Retriever,
389-
) -> Value {
390-
println!("Hello from UDF!");
391-
Value::new(CoreData::<Value>::None)
392-
}
393-
394384
// Link the dummy function.
395385
let mut udfs = HashMap::new();
396386
let udf = Udf {
397-
func: external_function,
387+
func: Arc::new(|_, _, _| {
388+
Box::pin(async move {
389+
println!("Hello from UDF!");
390+
Value::new(CoreData::None)
391+
})
392+
}),
398393
};
399394
udfs.insert("external_function".to_string(), udf);
400395

@@ -408,13 +403,6 @@ mod converter_tests {
408403
// Check that the function is in the context.
409404
let func_val = hir.context.lookup("external_function");
410405
assert!(func_val.is_some());
411-
412-
// Verify it is the same function pointer.
413-
if let CoreData::Function(FunKind::Udf(udf)) = &func_val.unwrap().data {
414-
assert_eq!(udf.func as usize, external_function as usize);
415-
} else {
416-
panic!("Expected UDF function");
417-
}
418406
}
419407

420408
#[test]

optd/src/dsl/analyzer/hir/mod.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ use crate::dsl::utils::retriever::Retriever;
2121
use crate::dsl::utils::span::Span;
2222
use context::Context;
2323
use map::Map;
24-
use std::fmt::Debug;
24+
use std::fmt::{self, Debug};
25+
use std::pin::Pin;
2526
use std::{collections::HashMap, sync::Arc};
2627

2728
pub(crate) mod context;
@@ -72,22 +73,30 @@ impl TypedSpan {
7273
}
7374
}
7475

75-
#[derive(Debug, Clone)]
76+
/// Type aliases for user-defined functions (UDFs).
77+
type UdfFutureOutput = Pin<Box<dyn Future<Output = Value> + Send>>;
78+
type UdfFunction =
79+
dyn Fn(Vec<Value>, Arc<dyn Catalog>, Arc<dyn Retriever>) -> UdfFutureOutput + Send + Sync;
80+
81+
#[derive(Clone)]
7682
pub struct Udf {
77-
/// The function pointer to the user-defined function.
78-
///
79-
/// Note that [`Value`]s passed to and returned from this UDF do not have associated metadata.
80-
pub func: fn(&[Value], &dyn Catalog, &dyn Retriever) -> Value,
83+
pub func: Arc<UdfFunction>,
84+
}
85+
86+
impl fmt::Debug for Udf {
87+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88+
write!(f, "[udf]")
89+
}
8190
}
8291

8392
impl Udf {
84-
pub fn call(
93+
pub async fn call(
8594
&self,
8695
values: &[Value],
87-
catalog: &dyn Catalog,
88-
retriever: &dyn Retriever,
96+
catalog: Arc<dyn Catalog>,
97+
retriever: Arc<dyn Retriever>,
8998
) -> Value {
90-
(self.func)(values, catalog, retriever)
99+
(self.func)(values.to_vec(), catalog, retriever).await
91100
}
92101
}
93102

optd/src/dsl/engine/eval/expr.rs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ impl<O: Clone + Send + 'static> Engine<O> {
588588
Arc::new(move |arg_values| {
589589
Box::pin(capture!([udf, catalog, deriver, k], async move {
590590
// Call the UDF with the argument values.
591-
let result = udf.call(&arg_values, catalog.as_ref(), deriver.as_ref());
591+
let result = udf.call(&arg_values, catalog, deriver).await;
592592

593593
// Pass the result to the continuation.
594594
k(result).await
@@ -1143,18 +1143,22 @@ mod tests {
11431143

11441144
// Define a Rust UDF that calculates the sum of array elements
11451145
let sum_function = Value::new(CoreData::Function(FunKind::Udf(Udf {
1146-
func: |args, _catalog, _deriver| match &args[0].data {
1147-
CoreData::Array(elements) => {
1148-
let mut sum = 0;
1149-
for elem in elements {
1150-
if let CoreData::Literal(Literal::Int64(value)) = &elem.data {
1151-
sum += value;
1146+
func: Arc::new(|args, _catalog, _retriever| {
1147+
Box::pin(async move {
1148+
match &args[0].data {
1149+
CoreData::Array(elements) => {
1150+
let mut sum: i64 = 0;
1151+
for elem in elements {
1152+
if let CoreData::Literal(Literal::Int64(value)) = &elem.data {
1153+
sum += value;
1154+
}
1155+
}
1156+
Value::new(CoreData::Literal(Literal::Int64(sum)))
11521157
}
1158+
_ => panic!("Expected array argument"),
11531159
}
1154-
Value::new(CoreData::Literal(Literal::Int64(sum)))
1155-
}
1156-
_ => panic!("Expected array argument"),
1157-
},
1160+
})
1161+
}),
11581162
})));
11591163

11601164
ctx.bind("sum".to_string(), sum_function);
@@ -1271,16 +1275,18 @@ mod tests {
12711275
ctx.bind(
12721276
"get".to_string(),
12731277
Value::new(CoreData::Function(FunKind::Udf(Udf {
1274-
func: |args, _catalog, _deriver| {
1275-
if args.len() != 2 {
1276-
panic!("get function requires 2 arguments");
1277-
}
1278+
func: Arc::new(|args, _catalog, _retriever| {
1279+
Box::pin(async move {
1280+
if args.len() != 2 {
1281+
panic!("get function requires 2 arguments");
1282+
}
12781283

1279-
match &args[0].data {
1280-
CoreData::Map(map) => map.get(&args[1]),
1281-
_ => panic!("First argument must be a map"),
1282-
}
1283-
},
1284+
match &args[0].data {
1285+
CoreData::Map(map) => map.get(&args[1]),
1286+
_ => panic!("First argument must be a map"),
1287+
}
1288+
})
1289+
}),
12841290
}))),
12851291
);
12861292

0 commit comments

Comments
 (0)