-
Notifications
You must be signed in to change notification settings - Fork 0
feat: implement stateless udf registration #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Sl1mb0
wants to merge
11
commits into
main
Choose a base branch
from
tm/stateless-udf
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
53bbc4c
feat: implement stateless udf registration
Sl1mb0 4b5b67c
build: make sqlparser workspace dependency
Sl1mb0 b1f1d32
refactor: clean things up a bit
Sl1mb0 55dff3b
fix: sql statement parsing
Sl1mb0 6db6b4d
fix: sabor de multihilo
Sl1mb0 31cc08d
refactor: return dataframe instead of string matrix
Sl1mb0 8fabafa
test: empty string & multiple functions in single statement
Sl1mb0 1b80aa4
refactor: create separate methods for query planning and invocation
Sl1mb0 1683af4
fix: use empty schema
Sl1mb0 8b72463
refactor: *only* parse the UdfQuery; allowing callers to handle regis…
Sl1mb0 d25d705
refactor: make things generally better
Sl1mb0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| //! Embedded SQL approach for executing Python UDFs within SQL queries. | ||
Sl1mb0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| use datafusion::prelude::SessionContext; | ||
| use datafusion_common::{DataFusionError, Result as DataFusionResult}; | ||
| use datafusion_expr::ScalarUDF; | ||
| use datafusion_sql::parser::{DFParserBuilder, Statement}; | ||
| use sqlparser::ast::{CreateFunctionBody, Expr, Statement as SqlStatement, Value}; | ||
| use sqlparser::dialect::dialect_from_str; | ||
|
|
||
| use crate::{WasmComponentPrecompiled, WasmScalarUdf}; | ||
|
|
||
| /// A SQL query containing a Python UDF and SQL string that uses the UDF | ||
Sl1mb0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| #[derive(Debug, Clone)] | ||
| pub struct UdfQuery(String); | ||
|
|
||
| impl UdfQuery { | ||
| /// Create a new UDF query | ||
| pub fn new(query: String) -> Self { | ||
| Self(query) | ||
| } | ||
|
|
||
| /// Get the query string | ||
| pub fn query(&self) -> &str { | ||
| &self.0 | ||
| } | ||
| } | ||
|
|
||
| /// Accepts a [UdfQuery] and invokes the query using DataFusion | ||
| pub struct UdfQueryInvocator<'a> { | ||
| /// DataFusion session context | ||
| session_ctx: SessionContext, | ||
| /// Pre-compiled Python WASM component | ||
| component: &'a WasmComponentPrecompiled, | ||
Sl1mb0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| impl std::fmt::Debug for UdfQueryInvocator<'_> { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| f.debug_struct("UdfQueryInvocator") | ||
| .field("session_ctx", &"SessionContext { ... }") | ||
| .field("python_component", &self.component) | ||
| .finish() | ||
| } | ||
| } | ||
|
|
||
| impl<'a> UdfQueryInvocator<'a> { | ||
| /// Registers the UDF query in DataFusion | ||
| pub async fn new( | ||
| session_ctx: SessionContext, | ||
| component: &'a WasmComponentPrecompiled, | ||
| ) -> DataFusionResult<Self> { | ||
| Ok(Self { | ||
| session_ctx, | ||
| component, | ||
| }) | ||
| } | ||
|
|
||
| /// Invoke the query, returning a result | ||
| pub async fn invoke(&mut self, udf_query: UdfQuery) -> DataFusionResult<Vec<Vec<String>>> { | ||
Sl1mb0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| let query_str = udf_query.query(); | ||
|
|
||
| let (code, sql_query) = self.parse_combined_query(query_str)?; | ||
|
|
||
| let udfs = WasmScalarUdf::new(self.component, code).await?; | ||
|
|
||
| for udf in udfs { | ||
| let scalar_udf = ScalarUDF::new_from_impl(udf); | ||
| self.session_ctx.register_udf(scalar_udf); | ||
| } | ||
|
|
||
| let df = self.session_ctx.sql(&sql_query).await?; | ||
| let batches = df.collect().await?; | ||
|
|
||
| let mut results = Vec::new(); | ||
| for batch in batches { | ||
| for row_idx in 0..batch.num_rows() { | ||
| let mut row = Vec::new(); | ||
| for col_idx in 0..batch.num_columns() { | ||
| let column = batch.column(col_idx); | ||
| let value = arrow::util::display::array_value_to_string(column, row_idx)?; | ||
| row.push(value); | ||
| } | ||
| results.push(row); | ||
| } | ||
| } | ||
|
|
||
| Ok(results) | ||
| } | ||
|
|
||
| /// Parse the combined query to extract Python code and SQL | ||
| fn parse_combined_query(&self, query: &str) -> DataFusionResult<(String, String)> { | ||
| let task_ctx = self.session_ctx.task_ctx(); | ||
| let options = task_ctx.session_config().options(); | ||
|
|
||
| let dialect = dialect_from_str(options.sql_parser.dialect.clone()).expect("valid dialect"); | ||
| let recursion_limit = options.sql_parser.recursion_limit; | ||
|
|
||
| let statements = DFParserBuilder::new(query) | ||
| .with_dialect(dialect.as_ref()) | ||
| .with_recursion_limit(recursion_limit) | ||
| .build()? | ||
| .parse_statements()?; | ||
|
|
||
| let mut code = String::new(); | ||
| let mut sql_statements = Vec::new(); | ||
|
|
||
| for statement in statements { | ||
| match statement { | ||
| Statement::Statement(s) => parse_statement(*s, &mut code, &mut sql_statements)?, | ||
| _ => { | ||
| // do nothing | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if code.is_empty() { | ||
| return Err(DataFusionError::Plan( | ||
| "no Python UDF found in query".to_string(), | ||
| )); | ||
| } | ||
|
|
||
| if sql_statements.is_empty() { | ||
| return Err(DataFusionError::Plan("no SQL query found".to_string())); | ||
| } | ||
|
|
||
| let sql_query = sql_statements | ||
| .iter() | ||
| .map(|s| s.to_string()) | ||
| .collect::<Vec<String>>() | ||
| .join(";\n"); | ||
|
|
||
| Ok((code, sql_query)) | ||
| } | ||
| } | ||
|
|
||
| /// Parse a single SQL statement to extract Python UDF code | ||
| fn parse_statement( | ||
| stmt: SqlStatement, | ||
| code: &mut String, | ||
Sl1mb0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| sql: &mut Vec<SqlStatement>, | ||
| ) -> DataFusionResult<()> { | ||
| match stmt { | ||
| SqlStatement::CreateFunction(cf) => { | ||
| let function_body = cf.function_body.as_ref(); | ||
| let language = cf.language.as_ref(); | ||
|
|
||
| if let Some(lang) = language | ||
| && lang.to_string().to_lowercase() != "python" | ||
| { | ||
| return Err(DataFusionError::Plan(format!( | ||
| "only Python is supported, got: {}", | ||
| lang | ||
| ))); | ||
| } | ||
|
|
||
| match function_body { | ||
| Some(body) => extract_function_body(body, code), | ||
| None => Err(DataFusionError::Plan( | ||
| "function body is required for Python UDFs".to_string(), | ||
| )), | ||
| } | ||
| } | ||
| _ => { | ||
| sql.push(stmt); | ||
| Ok(()) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /// Extracts the code from the function body | ||
| fn extract_function_body(body: &CreateFunctionBody, code: &mut String) -> DataFusionResult<()> { | ||
| match body { | ||
| CreateFunctionBody::AsAfterOptions(e) | CreateFunctionBody::AsBeforeOptions(e) => { | ||
| let s = expression_into_str(e)?; | ||
| code.push_str(s); | ||
| code.push('\n'); | ||
| Ok(()) | ||
| } | ||
| CreateFunctionBody::Return(_) => Err(DataFusionError::Plan( | ||
| "`RETURN` function body not supported for Python UDFs".to_string(), | ||
| )), | ||
| } | ||
| } | ||
|
|
||
| /// Convert an expression into a string | ||
| fn expression_into_str(expr: &Expr) -> DataFusionResult<&str> { | ||
| match expr { | ||
| Expr::Value(v) => match &v.value { | ||
| Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(s), | ||
| _ => Err(DataFusionError::Plan("expected string value".to_string())), | ||
| }, | ||
| _ => Err(DataFusionError::Plan( | ||
| "expected value expression".to_string(), | ||
| )), | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| mod python; | ||
| mod rust; | ||
| mod test_utils; | ||
| mod udf_query; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,5 +3,5 @@ mod examples; | |
| mod inspection; | ||
| mod runtime; | ||
| mod state; | ||
| mod test_utils; | ||
| pub(crate) mod test_utils; | ||
| mod types; | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| use datafusion::prelude::SessionContext; | ||
| use datafusion_udf_wasm_host::udf_query::{UdfQuery, UdfQueryInvocator}; | ||
|
|
||
| use crate::integration_tests::python::test_utils::python_component; | ||
|
|
||
| #[tokio::test(flavor = "multi_thread")] | ||
| async fn test_simple_udf_query() { | ||
| let query = r#" | ||
| CREATE FUNCTION add_one() | ||
| LANGUAGE python | ||
| AS ' | ||
| def add_one(x: int) -> int: | ||
| return x + 1 | ||
| '; | ||
| SELECT add_one(1); | ||
| "#; | ||
|
|
||
| let ctx = SessionContext::new(); | ||
| let python_component = python_component().await; | ||
|
|
||
| let udf_query = UdfQuery::new(query.to_string()); | ||
| let mut invocator = UdfQueryInvocator::new(ctx, python_component).await.unwrap(); | ||
|
|
||
| let result = invocator.invoke(udf_query).await.unwrap(); | ||
|
|
||
| // Verify the result | ||
| assert_eq!(result.len(), 1); | ||
| assert_eq!(result[0].len(), 1); | ||
| assert_eq!(result[0][0], "2"); | ||
| } | ||
|
|
||
| #[tokio::test(flavor = "multi_thread")] | ||
| async fn test_multiple_functions() { | ||
| let query = r#" | ||
| CREATE FUNCTION add_one() | ||
| LANGUAGE python | ||
| AS ' | ||
| def add_one(x: int) -> int: | ||
| return x + 1 | ||
| '; | ||
| CREATE FUNCTION multiply_two() | ||
| LANGUAGE python | ||
| AS ' | ||
| def multiply_two(x: int) -> int: | ||
| return x * 2 | ||
| '; | ||
| SELECT add_one(1), multiply_two(3); | ||
Sl1mb0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "#; | ||
|
|
||
| let ctx = SessionContext::new(); | ||
| let python_component = python_component().await; | ||
|
|
||
| let udf_query = UdfQuery::new(query.to_string()); | ||
| let mut invocator = UdfQueryInvocator::new(ctx, python_component).await.unwrap(); | ||
|
|
||
| let result = invocator.invoke(udf_query).await.unwrap(); | ||
|
|
||
| // Verify the result | ||
| assert_eq!(result.len(), 1); | ||
| assert_eq!(result[0].len(), 2); | ||
| assert_eq!(result[0][0], "2"); // add_one(1) = 2 | ||
| assert_eq!(result[0][1], "6"); // multiply_two(3) = 6 | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.