Skip to content

Commit d6d35f7

Browse files
alambviirya
andauthored
Create datafusion-functions crate, extract encode and decode to (#8705)
* Extract encode and decode to `datafusion-functions` crate * better docs * Improve docs + macros * tweaks * updates * fix doc * tomlfmt * fix doc * update datafusion-cli Cargo.locl * update datafusion-cli cargo.lock * Remove outdated comment, make non pub * Apply suggestions from code review Co-authored-by: Liang-Chi Hsieh <[email protected]> --------- Co-authored-by: Liang-Chi Hsieh <[email protected]>
1 parent 262d093 commit d6d35f7

File tree

26 files changed

+639
-159
lines changed

26 files changed

+639
-159
lines changed

.github/workflows/rust.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ jobs:
7676

7777
- name: Check workspace with all features
7878
run: cargo check --workspace --benches --features avro,json
79+
80+
# Ensure that the datafusion crate can be built with only a subset of the function
81+
# packages enabled.
82+
- name: Check function packages (encoding_expressions)
83+
run: cargo check --no-default-features --features=encoding_expressions -p datafusion
84+
7985
- name: Check Cargo.lock for datafusion-cli
8086
run: |
8187
# If this test fails, try running `cargo update` in the `datafusion-cli` directory

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
[workspace]
1919
exclude = ["datafusion-cli"]
20-
members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks",
20+
members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/functions", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks",
2121
]
2222
resolver = "2"
2323

@@ -49,6 +49,7 @@ datafusion = { path = "datafusion/core", version = "35.0.0" }
4949
datafusion-common = { path = "datafusion/common", version = "35.0.0" }
5050
datafusion-execution = { path = "datafusion/execution", version = "35.0.0" }
5151
datafusion-expr = { path = "datafusion/expr", version = "35.0.0" }
52+
datafusion-functions = { path = "datafusion/functions", version = "35.0.0" }
5253
datafusion-optimizer = { path = "datafusion/optimizer", version = "35.0.0" }
5354
datafusion-physical-expr = { path = "datafusion/physical-expr", version = "35.0.0" }
5455
datafusion-physical-plan = { path = "datafusion/physical-plan", version = "35.0.0" }

datafusion-cli/Cargo.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ backtrace = ["datafusion-common/backtrace"]
4040
compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"]
4141
crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"]
4242
default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"]
43-
encoding_expressions = ["datafusion-physical-expr/encoding_expressions"]
43+
encoding_expressions = ["datafusion-functions/encoding_expressions"]
4444
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
4545
force_hash_collisions = []
4646
parquet = ["datafusion-common/parquet", "dep:parquet"]
@@ -65,6 +65,7 @@ dashmap = { workspace = true }
6565
datafusion-common = { path = "../common", version = "35.0.0", features = ["object_store"], default-features = false }
6666
datafusion-execution = { workspace = true }
6767
datafusion-expr = { workspace = true }
68+
datafusion-functions = { path = "../functions", version = "35.0.0" }
6869
datafusion-optimizer = { path = "../optimizer", version = "35.0.0", default-features = false }
6970
datafusion-physical-expr = { path = "../physical-expr", version = "35.0.0", default-features = false }
7071
datafusion-physical-plan = { workspace = true }

datafusion/core/src/dataframe/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ use crate::physical_plan::{
4040
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
4141
ExecutionPlan, SendableRecordBatchStream,
4242
};
43-
use crate::prelude::SessionContext;
4443

4544
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
4645
use arrow::compute::{cast, concat};
@@ -59,6 +58,7 @@ use datafusion_expr::{
5958
TableProviderFilterPushDown, UNNAMED_TABLE,
6059
};
6160

61+
use crate::prelude::SessionContext;
6262
use async_trait::async_trait;
6363

6464
/// Contains options that control how data is

datafusion/core/src/execution/context/mod.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,7 @@ impl SessionState {
13401340
);
13411341
}
13421342

1343-
SessionState {
1343+
let mut new_self = SessionState {
13441344
session_id,
13451345
analyzer: Analyzer::new(),
13461346
optimizer: Optimizer::new(),
@@ -1356,7 +1356,13 @@ impl SessionState {
13561356
execution_props: ExecutionProps::new(),
13571357
runtime_env: runtime,
13581358
table_factories,
1359-
}
1359+
};
1360+
1361+
// register built in functions
1362+
datafusion_functions::register_all(&mut new_self)
1363+
.expect("can not register built in functions");
1364+
1365+
new_self
13601366
}
13611367
/// Returns new [`SessionState`] using the provided
13621368
/// [`SessionConfig`] and [`RuntimeEnv`].
@@ -1976,6 +1982,10 @@ impl FunctionRegistry for SessionState {
19761982
plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry")
19771983
})
19781984
}
1985+
1986+
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
1987+
Ok(self.scalar_functions.insert(udf.name().into(), udf))
1988+
}
19791989
}
19801990

19811991
impl OptimizerConfig for SessionState {

datafusion/core/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,11 @@ pub mod sql {
516516
pub use datafusion_sql::*;
517517
}
518518

519+
/// re-export of [`datafusion_functions`] crate
520+
pub mod functions {
521+
pub use datafusion_functions::*;
522+
}
523+
519524
#[cfg(test)]
520525
pub mod test;
521526
pub mod test_util;

datafusion/core/src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub use datafusion_expr::{
3838
logical_plan::{JoinType, Partitioning},
3939
Expr,
4040
};
41+
pub use datafusion_functions::expr_fn::*;
4142

4243
pub use std::ops::Not;
4344
pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub};

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use arrow::{
2020
array::{Int32Array, StringArray},
2121
record_batch::RecordBatch,
2222
};
23+
use arrow_schema::SchemaRef;
2324
use std::sync::Arc;
2425

2526
use datafusion::dataframe::DataFrame;
@@ -31,14 +32,19 @@ use datafusion::prelude::*;
3132
use datafusion::execution::context::SessionContext;
3233

3334
use datafusion::assert_batches_eq;
35+
use datafusion_common::DFSchema;
3436
use datafusion_expr::expr::Alias;
35-
use datafusion_expr::{approx_median, cast};
37+
use datafusion_expr::{approx_median, cast, ExprSchemable};
3638

37-
async fn create_test_table() -> Result<DataFrame> {
38-
let schema = Arc::new(Schema::new(vec![
39+
fn test_schema() -> SchemaRef {
40+
Arc::new(Schema::new(vec![
3941
Field::new("a", DataType::Utf8, false),
4042
Field::new("b", DataType::Int32, false),
41-
]));
43+
]))
44+
}
45+
46+
async fn create_test_table() -> Result<DataFrame> {
47+
let schema = test_schema();
4248

4349
// define data.
4450
let batch = RecordBatch::try_new(
@@ -790,3 +796,48 @@ async fn test_fn_upper() -> Result<()> {
790796

791797
Ok(())
792798
}
799+
800+
#[tokio::test]
801+
async fn test_fn_encode() -> Result<()> {
802+
let expr = encode(col("a"), lit("hex"));
803+
804+
let expected = [
805+
"+----------------------------+",
806+
"| encode(test.a,Utf8(\"hex\")) |",
807+
"+----------------------------+",
808+
"| 616263444546 |",
809+
"| 616263313233 |",
810+
"| 434241646566 |",
811+
"| 313233416263446566 |",
812+
"+----------------------------+",
813+
];
814+
assert_fn_batches!(expr, expected);
815+
816+
Ok(())
817+
}
818+
819+
#[tokio::test]
820+
async fn test_fn_decode() -> Result<()> {
821+
// Note that the decode function returns binary, and the default display of
822+
// binary is "hexadecimal" and therefore the output looks like decode did
823+
// nothing. So compare to a constant.
824+
let df_schema = DFSchema::try_from(test_schema().as_ref().clone())?;
825+
let expr = decode(encode(col("a"), lit("hex")), lit("hex"))
826+
// need to cast to utf8 otherwise the default display of binary array is hex
827+
// so it looks like nothing is done
828+
.cast_to(&DataType::Utf8, &df_schema)?;
829+
830+
let expected = [
831+
"+------------------------------------------------+",
832+
"| decode(encode(test.a,Utf8(\"hex\")),Utf8(\"hex\")) |",
833+
"+------------------------------------------------+",
834+
"| abcDEF |",
835+
"| abc123 |",
836+
"| CBAdef |",
837+
"| 123AbcDef |",
838+
"+------------------------------------------------+",
839+
];
840+
assert_fn_batches!(expr, expected);
841+
842+
Ok(())
843+
}

datafusion/execution/src/registry.rs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
//! FunctionRegistry trait
1919
20-
use datafusion_common::Result;
20+
use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, Result};
2121
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
22+
use std::collections::HashMap;
2223
use std::{collections::HashSet, sync::Arc};
2324

2425
/// A registry knows how to build logical expressions out of user-defined function' names
@@ -34,6 +35,17 @@ pub trait FunctionRegistry {
3435

3536
/// Returns a reference to the udwf named `name`.
3637
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
38+
39+
/// Registers a new [`ScalarUDF`], returning any previously registered
40+
/// implementation.
41+
///
42+
/// Returns an error (the default) if the function can not be registered,
43+
/// for example if the registry is read only.
44+
fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
45+
not_impl_err!("Registering ScalarUDF")
46+
}
47+
48+
// TODO add register_udaf and register_udwf
3749
}
3850

3951
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
@@ -53,3 +65,51 @@ pub trait SerializerRegistry: Send + Sync {
5365
bytes: &[u8],
5466
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
5567
}
68+
69+
/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
70+
#[derive(Default, Debug)]
71+
pub struct MemoryFunctionRegistry {
72+
/// Scalar Functions
73+
udfs: HashMap<String, Arc<ScalarUDF>>,
74+
/// Aggregate Functions
75+
udafs: HashMap<String, Arc<AggregateUDF>>,
76+
/// Window Functions
77+
udwfs: HashMap<String, Arc<WindowUDF>>,
78+
}
79+
80+
impl MemoryFunctionRegistry {
81+
pub fn new() -> Self {
82+
Self::default()
83+
}
84+
}
85+
86+
impl FunctionRegistry for MemoryFunctionRegistry {
87+
fn udfs(&self) -> HashSet<String> {
88+
self.udfs.keys().cloned().collect()
89+
}
90+
91+
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
92+
self.udfs
93+
.get(name)
94+
.cloned()
95+
.ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
96+
}
97+
98+
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
99+
self.udafs
100+
.get(name)
101+
.cloned()
102+
.ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
103+
}
104+
105+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
106+
self.udwfs
107+
.get(name)
108+
.cloned()
109+
.ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
110+
}
111+
112+
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
113+
Ok(self.udfs.insert(udf.name().to_string(), udf))
114+
}
115+
}

datafusion/expr/src/built_in_function.rs

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,10 @@ pub enum BuiltinScalarFunction {
6969
Cos,
7070
/// cos
7171
Cosh,
72-
/// Decode
73-
Decode,
7472
/// degrees
7573
Degrees,
7674
/// Digest
7775
Digest,
78-
/// Encode
79-
Encode,
8076
/// exp
8177
Exp,
8278
/// factorial
@@ -381,9 +377,7 @@ impl BuiltinScalarFunction {
381377
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
382378
BuiltinScalarFunction::Cos => Volatility::Immutable,
383379
BuiltinScalarFunction::Cosh => Volatility::Immutable,
384-
BuiltinScalarFunction::Decode => Volatility::Immutable,
385380
BuiltinScalarFunction::Degrees => Volatility::Immutable,
386-
BuiltinScalarFunction::Encode => Volatility::Immutable,
387381
BuiltinScalarFunction::Exp => Volatility::Immutable,
388382
BuiltinScalarFunction::Factorial => Volatility::Immutable,
389383
BuiltinScalarFunction::Floor => Volatility::Immutable,
@@ -774,30 +768,6 @@ impl BuiltinScalarFunction {
774768
BuiltinScalarFunction::Digest => {
775769
utf8_or_binary_to_binary_type(&input_expr_types[0], "digest")
776770
}
777-
BuiltinScalarFunction::Encode => Ok(match input_expr_types[0] {
778-
Utf8 => Utf8,
779-
LargeUtf8 => LargeUtf8,
780-
Binary => Utf8,
781-
LargeBinary => LargeUtf8,
782-
Null => Null,
783-
_ => {
784-
return plan_err!(
785-
"The encode function can only accept utf8 or binary."
786-
);
787-
}
788-
}),
789-
BuiltinScalarFunction::Decode => Ok(match input_expr_types[0] {
790-
Utf8 => Binary,
791-
LargeUtf8 => LargeBinary,
792-
Binary => Binary,
793-
LargeBinary => LargeBinary,
794-
Null => Null,
795-
_ => {
796-
return plan_err!(
797-
"The decode function can only accept utf8 or binary."
798-
);
799-
}
800-
}),
801771
BuiltinScalarFunction::SplitPart => {
802772
utf8_to_str_type(&input_expr_types[0], "split_part")
803773
}
@@ -1089,24 +1059,6 @@ impl BuiltinScalarFunction {
10891059
],
10901060
self.volatility(),
10911061
),
1092-
BuiltinScalarFunction::Encode => Signature::one_of(
1093-
vec![
1094-
Exact(vec![Utf8, Utf8]),
1095-
Exact(vec![LargeUtf8, Utf8]),
1096-
Exact(vec![Binary, Utf8]),
1097-
Exact(vec![LargeBinary, Utf8]),
1098-
],
1099-
self.volatility(),
1100-
),
1101-
BuiltinScalarFunction::Decode => Signature::one_of(
1102-
vec![
1103-
Exact(vec![Utf8, Utf8]),
1104-
Exact(vec![LargeUtf8, Utf8]),
1105-
Exact(vec![Binary, Utf8]),
1106-
Exact(vec![LargeBinary, Utf8]),
1107-
],
1108-
self.volatility(),
1109-
),
11101062
BuiltinScalarFunction::DateTrunc => Signature::one_of(
11111063
vec![
11121064
Exact(vec![Utf8, Timestamp(Nanosecond, None)]),
@@ -1551,10 +1503,6 @@ impl BuiltinScalarFunction {
15511503
BuiltinScalarFunction::SHA384 => &["sha384"],
15521504
BuiltinScalarFunction::SHA512 => &["sha512"],
15531505

1554-
// encode/decode
1555-
BuiltinScalarFunction::Encode => &["encode"],
1556-
BuiltinScalarFunction::Decode => &["decode"],
1557-
15581506
// other functions
15591507
BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"],
15601508

0 commit comments

Comments
 (0)