Skip to content
Open
698 changes: 658 additions & 40 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ resolver = "3"
[workspace.dependencies]
arrow = { version = "55.2.0", default-features = false, features = ["ipc"] }
chrono = { version = "0.4.42", default-features = false }
datafusion = { version = "49.0.1", default-features = false }
datafusion-common = { version = "49.0.1", default-features = false }
datafusion-expr = { version = "49.0.1", default-features = false }
datafusion-sql = { version = "49.0.1", default-features = false }
datafusion-udf-wasm-arrow2bytes = { path = "arrow2bytes", version = "0.1.0" }
datafusion-udf-wasm-bundle = { path = "guests/bundle", version = "0.1.0" }
datafusion-udf-wasm-guest = { path = "guests/rust", version = "0.1.0" }
datafusion-udf-wasm-python = { path = "guests/python", version = "0.1.0" }
sqlparser = { version = "0.55.0", default-features = false, features = ["std", "visitor"] }
tokio = { version = "1.48.0", default-features = false }
pyo3 = { version = "0.27.1", default-features = false, features = ["macros"] }
tar = { version = "0.4.44", default-features = false }
Expand Down Expand Up @@ -61,8 +64,10 @@ private_intra_doc_links = "deny"
[patch.crates-io]
# use same DataFusion fork as InfluxDB
# See https://github.com/influxdata/arrow-datafusion/pull/72
datafusion = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }
datafusion-common = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }
datafusion-expr = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }
datafusion-sql = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }

# faster tests
[profile.dev.package]
Expand Down
3 changes: 3 additions & 0 deletions host/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ workspace = true

[dependencies]
arrow.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-udf-wasm-arrow2bytes.workspace = true
datafusion-sql.workspace = true
sqlparser.workspace = true
tar.workspace = true
tempfile.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "sync"] }
Expand Down
4 changes: 4 additions & 0 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ mod bindings;
mod conversion;
mod error;
mod tokio_helpers;
pub mod udf_query;

/// State of the WASM payload.
struct WasmStateImpl {
Expand Down Expand Up @@ -366,3 +367,6 @@ impl ScalarUDFImpl for WasmScalarUdf {
})
}
}

// Re-export the UDF query functionality
pub use udf_query::{UdfQuery, UdfQueryInvocator};
195 changes: 195 additions & 0 deletions host/src/udf_query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
//! Embedded SQL approach for executing Python UDFs within SQL queries.
use datafusion::prelude::SessionContext;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::ScalarUDF;
use datafusion_sql::parser::{DFParserBuilder, Statement};
use sqlparser::ast::{CreateFunctionBody, Expr, Statement as SqlStatement, Value};
use sqlparser::dialect::dialect_from_str;

use crate::{WasmComponentPrecompiled, WasmScalarUdf};

/// A SQL query containing a Python UDF and SQL string that uses the UDF
#[derive(Debug, Clone)]
pub struct UdfQuery(String);

impl UdfQuery {
/// Create a new UDF query
pub fn new(query: String) -> Self {
Self(query)
}

/// Get the query string
pub fn query(&self) -> &str {
&self.0
}
}

/// Accepts a [UdfQuery] and invokes the query using DataFusion
pub struct UdfQueryInvocator<'a> {
/// DataFusion session context
session_ctx: SessionContext,
/// Pre-compiled Python WASM component
component: &'a WasmComponentPrecompiled,
}

impl std::fmt::Debug for UdfQueryInvocator<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdfQueryInvocator")
.field("session_ctx", &"SessionContext { ... }")
.field("python_component", &self.component)
.finish()
}
}

impl<'a> UdfQueryInvocator<'a> {
/// Registers the UDF query in DataFusion
pub async fn new(
session_ctx: SessionContext,
component: &'a WasmComponentPrecompiled,
) -> DataFusionResult<Self> {
Ok(Self {
session_ctx,
component,
})
}

/// Invoke the query, returning a result
pub async fn invoke(&mut self, udf_query: UdfQuery) -> DataFusionResult<Vec<Vec<String>>> {
let query_str = udf_query.query();

let (code, sql_query) = self.parse_combined_query(query_str)?;

let udfs = WasmScalarUdf::new(self.component, code).await?;

for udf in udfs {
let scalar_udf = ScalarUDF::new_from_impl(udf);
self.session_ctx.register_udf(scalar_udf);
}

let df = self.session_ctx.sql(&sql_query).await?;
let batches = df.collect().await?;

let mut results = Vec::new();
for batch in batches {
for row_idx in 0..batch.num_rows() {
let mut row = Vec::new();
for col_idx in 0..batch.num_columns() {
let column = batch.column(col_idx);
let value = arrow::util::display::array_value_to_string(column, row_idx)?;
row.push(value);
}
results.push(row);
}
}

Ok(results)
}

/// Parse the combined query to extract Python code and SQL
fn parse_combined_query(&self, query: &str) -> DataFusionResult<(String, String)> {
let task_ctx = self.session_ctx.task_ctx();
let options = task_ctx.session_config().options();

let dialect = dialect_from_str(options.sql_parser.dialect.clone()).expect("valid dialect");
let recursion_limit = options.sql_parser.recursion_limit;

let statements = DFParserBuilder::new(query)
.with_dialect(dialect.as_ref())
.with_recursion_limit(recursion_limit)
.build()?
.parse_statements()?;

let mut code = String::new();
let mut sql_statements = Vec::new();

for statement in statements {
match statement {
Statement::Statement(s) => parse_statement(*s, &mut code, &mut sql_statements)?,
_ => {
// do nothing
}
}
}

if code.is_empty() {
return Err(DataFusionError::Plan(
"no Python UDF found in query".to_string(),
));
}

if sql_statements.is_empty() {
return Err(DataFusionError::Plan("no SQL query found".to_string()));
}

let sql_query = sql_statements
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>()
.join(";\n");

Ok((code, sql_query))
}
}

/// Parse a single SQL statement to extract Python UDF code
fn parse_statement(
stmt: SqlStatement,
code: &mut String,
sql: &mut Vec<SqlStatement>,
) -> DataFusionResult<()> {
match stmt {
SqlStatement::CreateFunction(cf) => {
let function_body = cf.function_body.as_ref();
let language = cf.language.as_ref();

if let Some(lang) = language
&& lang.to_string().to_lowercase() != "python"
{
return Err(DataFusionError::Plan(format!(
"only Python is supported, got: {}",
lang
)));
}

match function_body {
Some(body) => extract_function_body(body, code),
None => Err(DataFusionError::Plan(
"function body is required for Python UDFs".to_string(),
)),
}
}
_ => {
sql.push(stmt);
Ok(())
}
}
}

/// Extracts the code from the function body
fn extract_function_body(body: &CreateFunctionBody, code: &mut String) -> DataFusionResult<()> {
match body {
CreateFunctionBody::AsAfterOptions(e) | CreateFunctionBody::AsBeforeOptions(e) => {
let s = expression_into_str(e)?;
code.push_str(s);
code.push('\n');
Ok(())
}
CreateFunctionBody::Return(_) => Err(DataFusionError::Plan(
"`RETURN` function body not supported for Python UDFs".to_string(),
)),
}
}

/// Convert an expression into a string
fn expression_into_str(expr: &Expr) -> DataFusionResult<&str> {
match expr {
Expr::Value(v) => match &v.value {
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(s),
_ => Err(DataFusionError::Plan("expected string value".to_string())),
},
_ => Err(DataFusionError::Plan(
"expected value expression".to_string(),
)),
}
}
1 change: 1 addition & 0 deletions host/tests/integration_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod python;
mod rust;
mod test_utils;
mod udf_query;
2 changes: 1 addition & 1 deletion host/tests/integration_tests/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ mod examples;
mod inspection;
mod runtime;
mod state;
mod test_utils;
pub(crate) mod test_utils;
mod types;
2 changes: 1 addition & 1 deletion host/tests/integration_tests/python/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tokio::sync::OnceCell;

static COMPONENT: OnceCell<WasmComponentPrecompiled> = OnceCell::const_new();

async fn python_component() -> &'static WasmComponentPrecompiled {
pub(crate) async fn python_component() -> &'static WasmComponentPrecompiled {
COMPONENT
.get_or_init(async || {
WasmComponentPrecompiled::new(datafusion_udf_wasm_bundle::BIN_PYTHON.into())
Expand Down
66 changes: 66 additions & 0 deletions host/tests/integration_tests/udf_query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use datafusion::prelude::SessionContext;
use datafusion_udf_wasm_host::udf_query::{UdfQuery, UdfQueryInvocator};

use crate::integration_tests::python::test_utils::python_component;

#[tokio::test(flavor = "multi_thread")]
async fn test_simple_udf_query() {
let query = r#"
CREATE FUNCTION add_one()
LANGUAGE python
AS '
def add_one(x: int) -> int:
return x + 1
';
SELECT add_one(1);
"#;

let ctx = SessionContext::new();
let python_component = python_component().await;

let udf_query = UdfQuery::new(query.to_string());
let mut invocator = UdfQueryInvocator::new(ctx, python_component).await.unwrap();

let result = invocator.invoke(udf_query).await.unwrap();

// Verify the result
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 1);
assert_eq!(result[0][0], "2");
}

#[tokio::test(flavor = "multi_thread")]
async fn test_multiple_functions() {
let query = r#"
CREATE FUNCTION add_one()
LANGUAGE python
AS '
def add_one(x: int) -> int:
return x + 1
';
CREATE FUNCTION multiply_two()
LANGUAGE python
AS '
def multiply_two(x: int) -> int:
return x * 2
';
SELECT add_one(1), multiply_two(3);
"#;

let ctx = SessionContext::new();
let python_component = python_component().await;

let udf_query = UdfQuery::new(query.to_string());
let mut invocator = UdfQueryInvocator::new(ctx, python_component).await.unwrap();

let result = invocator.invoke(udf_query).await.unwrap();

// Verify the result
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 2);
assert_eq!(result[0][0], "2"); // add_one(1) = 2
assert_eq!(result[0][1], "6"); // multiply_two(3) = 6
}
Loading