diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 2272997fae06..0f229059d05c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -17,11 +17,17 @@ //! Built-in functions module contains all the built-in functions definitions. -use crate::Volatility; +use crate::nullif::SUPPORTED_NULLIF_TYPES; +use crate::type_coercion::functions::data_types; +use crate::{ + conditional_expressions, struct_expressions, Signature, TypeSignature, Volatility, +}; +use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{DataFusionError, Result}; use std::collections::HashMap; use std::fmt; use std::str::FromStr; +use std::sync::Arc; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -383,6 +389,672 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Uuid => Volatility::Volatile, } } + + /// Creates a detailed error message for a function with wrong signature. + /// + /// For example, a query like `select round(3.14, 1.1);` would yield: + /// ```text + /// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. + /// Candidate functions: + /// round(Float64, Int64) + /// round(Float32, Int64) + /// round(Float64) + /// round(Float32) + /// ``` + fn generate_signature_error_msg(&self, input_expr_types: &[DataType]) -> String { + let candidate_signatures = self + .signature() + .type_signature + .to_string_repr() + .iter() + .map(|args_str| format!("\t{self}({args_str})")) + .collect::>() + .join("\n"); + + format!( + "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", + self, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures + ) + } + + /// Returns the output [`DataType` of this function + pub fn return_type(self, input_expr_types: &[DataType]) -> Result { + use DataType::*; + use TimeUnit::*; + + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + if input_expr_types.is_empty() && !self.supports_zero_argument() { + return Err(DataFusionError::Plan( + self.generate_signature_error_msg(input_expr_types), + )); + } + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()).map_err(|_| { + DataFusionError::Plan(self.generate_signature_error_msg(input_expr_types)) + })?; + + // the return type of the built in function. + // Some built-in functions' return type depends on the incoming type. + match self { + BuiltinScalarFunction::ArrayAppend => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayConcat => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept fixed size list as the args." + ))), + }, + BuiltinScalarFunction::ArrayDims => Ok(UInt8), + BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), + BuiltinScalarFunction::ArrayLength => Ok(UInt8), + BuiltinScalarFunction::ArrayNdims => Ok(UInt8), + BuiltinScalarFunction::ArrayPosition => Ok(UInt8), + BuiltinScalarFunction::ArrayPositions => Ok(UInt8), + BuiltinScalarFunction::ArrayPrepend => match &input_expr_types[1] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayRemove => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayReplace => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayToString => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::Cardinality => Ok(UInt64), + BuiltinScalarFunction::MakeArray => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), + BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::Ascii => Ok(Int32), + BuiltinScalarFunction::BitLength => { + utf8_to_int_type(&input_expr_types[0], "bit_length") + } + BuiltinScalarFunction::Btrim => { + utf8_to_str_type(&input_expr_types[0], "btrim") + } + BuiltinScalarFunction::CharacterLength => { + utf8_to_int_type(&input_expr_types[0], "character_length") + } + BuiltinScalarFunction::Chr => Ok(Utf8), + BuiltinScalarFunction::Coalesce => { + // COALESCE has multiple args and they might get coerced, get a preview of this + let coerced_types = data_types(input_expr_types, &self.signature()); + coerced_types.map(|types| types[0].clone()) + } + BuiltinScalarFunction::Concat => Ok(Utf8), + BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), + BuiltinScalarFunction::DatePart => Ok(Float64), + BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin => { + match input_expr_types[1] { + Timestamp(Nanosecond, _) | Utf8 => Ok(Timestamp(Nanosecond, None)), + Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), + Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), + Timestamp(Second, _) => Ok(Timestamp(Second, None)), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept timestamp as the second arg." + ))), + } + } + BuiltinScalarFunction::InitCap => { + utf8_to_str_type(&input_expr_types[0], "initcap") + } + BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), + BuiltinScalarFunction::Lower => { + utf8_to_str_type(&input_expr_types[0], "lower") + } + BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), + BuiltinScalarFunction::Ltrim => { + utf8_to_str_type(&input_expr_types[0], "ltrim") + } + BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), + BuiltinScalarFunction::NullIf => { + // NULLIF has two args and they might get coerced, get a preview of this + let coerced_types = data_types(input_expr_types, &self.signature()); + coerced_types.map(|typs| typs[0].clone()) + } + BuiltinScalarFunction::OctetLength => { + utf8_to_int_type(&input_expr_types[0], "octet_length") + } + BuiltinScalarFunction::Pi => Ok(Float64), + BuiltinScalarFunction::Random => Ok(Float64), + BuiltinScalarFunction::Uuid => Ok(Utf8), + BuiltinScalarFunction::RegexpReplace => { + utf8_to_str_type(&input_expr_types[0], "regex_replace") + } + BuiltinScalarFunction::Repeat => { + utf8_to_str_type(&input_expr_types[0], "repeat") + } + BuiltinScalarFunction::Replace => { + utf8_to_str_type(&input_expr_types[0], "replace") + } + BuiltinScalarFunction::Reverse => { + utf8_to_str_type(&input_expr_types[0], "reverse") + } + BuiltinScalarFunction::Right => { + utf8_to_str_type(&input_expr_types[0], "right") + } + BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), + BuiltinScalarFunction::Rtrim => { + utf8_to_str_type(&input_expr_types[0], "rtrimp") + } + BuiltinScalarFunction::SHA224 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224") + } + BuiltinScalarFunction::SHA256 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256") + } + BuiltinScalarFunction::SHA384 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384") + } + BuiltinScalarFunction::SHA512 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512") + } + BuiltinScalarFunction::Digest => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") + } + BuiltinScalarFunction::SplitPart => { + utf8_to_str_type(&input_expr_types[0], "split_part") + } + BuiltinScalarFunction::StartsWith => Ok(Boolean), + BuiltinScalarFunction::Strpos => { + utf8_to_int_type(&input_expr_types[0], "strpos") + } + BuiltinScalarFunction::Substr => { + utf8_to_str_type(&input_expr_types[0], "substr") + } + BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { + Int8 | Int16 | Int32 | Int64 => Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The to_hex function can only accept integers.".to_string(), + )); + } + }), + BuiltinScalarFunction::ToTimestamp => Ok(Timestamp(Nanosecond, None)), + BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), + BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), + BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::Now => { + Ok(Timestamp(Nanosecond, Some("+00:00".into()))) + } + BuiltinScalarFunction::CurrentDate => Ok(Date32), + BuiltinScalarFunction::CurrentTime => Ok(Time64(Nanosecond)), + BuiltinScalarFunction::Translate => { + utf8_to_str_type(&input_expr_types[0], "translate") + } + BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), + BuiltinScalarFunction::Upper => { + utf8_to_str_type(&input_expr_types[0], "upper") + } + BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] { + LargeUtf8 => List(Arc::new(Field::new("item", LargeUtf8, true))), + Utf8 => List(Arc::new(Field::new("item", Utf8, true))), + Null => Null, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_extract function can only accept strings." + .to_string(), + )); + } + }), + + BuiltinScalarFunction::Factorial + | BuiltinScalarFunction::Gcd + | BuiltinScalarFunction::Lcm => Ok(Int64), + + BuiltinScalarFunction::Power => match &input_expr_types[0] { + Int64 => Ok(Int64), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Struct => { + let return_fields = input_expr_types + .iter() + .enumerate() + .map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true)) + .collect::>(); + Ok(Struct(Fields::from(return_fields))) + } + + BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Log => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), + + BuiltinScalarFunction::Abs + | BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Cosh + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Round + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Trunc => match input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + } + } + + /// Return the argument [`Signature`] supported by this function + pub fn signature(&self) -> Signature { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + use TypeSignature::*; + // note: the physical expression must accept the type returned by this function or the execution panics. + + // for now, the list is small, as we do not have many built-in functions. + match self { + BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayConcat => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayLength => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayPosition => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayReplace => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayToString => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), + BuiltinScalarFunction::MakeArray => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::TrimArray => Signature::any(2, self.volatility()), + BuiltinScalarFunction::Struct => Signature::variadic( + struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), + self.volatility(), + ), + BuiltinScalarFunction::Concat + | BuiltinScalarFunction::ConcatWithSeparator => { + Signature::variadic(vec![Utf8], self.volatility()) + } + BuiltinScalarFunction::Coalesce => Signature::variadic( + conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), + self.volatility(), + ), + BuiltinScalarFunction::SHA224 + | BuiltinScalarFunction::SHA256 + | BuiltinScalarFunction::SHA384 + | BuiltinScalarFunction::SHA512 + | BuiltinScalarFunction::MD5 => Signature::uniform( + 1, + vec![Utf8, LargeUtf8, Binary, LargeBinary], + self.volatility(), + ), + BuiltinScalarFunction::Ascii + | BuiltinScalarFunction::BitLength + | BuiltinScalarFunction::CharacterLength + | BuiltinScalarFunction::InitCap + | BuiltinScalarFunction::Lower + | BuiltinScalarFunction::OctetLength + | BuiltinScalarFunction::Reverse + | BuiltinScalarFunction::Upper => { + Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) + } + BuiltinScalarFunction::Btrim + | BuiltinScalarFunction::Ltrim + | BuiltinScalarFunction::Rtrim + | BuiltinScalarFunction::Trim => Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + self.volatility(), + ), + BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + self.volatility(), + ) + } + BuiltinScalarFunction::Left + | BuiltinScalarFunction::Repeat + | BuiltinScalarFunction::Right => Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestamp => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::FromUnixtime => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Digest => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::DateTrunc => Signature::exact( + vec![Utf8, Timestamp(Nanosecond, None)], + self.volatility(), + ), + BuiltinScalarFunction::DateBin => { + let base_sig = |array_type: TimeUnit| { + vec![ + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), + ]), + Exact(vec![ + Interval(DayTime), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), + ]), + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), + ]), + Exact(vec![Interval(DayTime), Timestamp(array_type, None)]), + ] + }; + + let full_sig = [Nanosecond, Microsecond, Millisecond, Second] + .into_iter() + .map(base_sig) + .collect::>() + .concat(); + + Signature::one_of(full_sig, self.volatility()) + } + BuiltinScalarFunction::DatePart => Signature::one_of( + vec![ + Exact(vec![Utf8, Date32]), + Exact(vec![Utf8, Date64]), + Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![Utf8, Timestamp(Nanosecond, None)]), + Exact(vec![Utf8, Timestamp(Nanosecond, Some("+00:00".into()))]), + ], + self.volatility(), + ), + BuiltinScalarFunction::SplitPart => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, Utf8, Int64]), + Exact(vec![Utf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + self.volatility(), + ) + } + + BuiltinScalarFunction::Substr => Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { + Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) + } + BuiltinScalarFunction::RegexpReplace => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8, Utf8]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::NullIf => { + Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), self.volatility()) + } + BuiltinScalarFunction::RegexpMatch => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Uuid => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Power => Signature::one_of( + vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], + self.volatility(), + ), + BuiltinScalarFunction::Round => Signature::one_of( + vec![ + Exact(vec![Float64, Int64]), + Exact(vec![Float32, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Atan2 => Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + self.volatility(), + ), + BuiltinScalarFunction::Log => Signature::one_of( + vec![ + Exact(vec![Float32]), + Exact(vec![Float64]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Factorial => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { + Signature::uniform(2, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), + BuiltinScalarFunction::Abs + | BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Cosh + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Trunc => { + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + Signature::uniform(1, vec![Float64, Float32], self.volatility()) + } + BuiltinScalarFunction::Now + | BuiltinScalarFunction::CurrentDate + | BuiltinScalarFunction::CurrentTime => { + Signature::uniform(0, vec![], self.volatility()) + } + } + } } fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { @@ -526,6 +1198,44 @@ impl FromStr for BuiltinScalarFunction { } } +macro_rules! make_utf8_to_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + DataType::LargeUtf8 => $largeUtf8Type, + DataType::Utf8 => $utf8Type, + DataType::Null => DataType::Null, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal(format!( + "The {:?} function can only accept strings.", + name + ))); + } + }) + } + }; +} + +make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); +make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + +fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::LargeBinary => DataType::Binary, + DataType::Null => DataType::Null, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal(format!( + "The {name:?} function can only accept strings or binary arrays." + ))); + } + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3c68c4acd7dd..52ad65773c1a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -22,9 +22,7 @@ use crate::expr::{ }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::get_result_type; -use crate::{ - aggregate_function, function, window_function, LogicalPlan, Projection, Subquery, -}; +use crate::{aggregate_function, window_function, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::DataType; use datafusion_common::{Column, DFField, DFSchema, DataFusionError, ExprSchema, Result}; @@ -87,7 +85,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - function::return_type(fun, &data_types) + fun.return_type(&data_types) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index f47b94322d94..bd242c493e43 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,17 +17,13 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::function_err::generate_signature_error_msg; -use crate::nullif::SUPPORTED_NULLIF_TYPES; -use crate::type_coercion::functions::data_types; -use crate::ColumnarValue; -use crate::{ - conditional_expressions, struct_expressions, Accumulator, BuiltinScalarFunction, - Signature, TypeSignature, -}; -use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{DataFusionError, Result}; +use crate::{Accumulator, BuiltinScalarFunction, Signature}; +use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; +use arrow::datatypes::DataType; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::Result; use std::sync::Arc; +use strum::IntoEnumIterator; /// Scalar function /// @@ -54,646 +50,53 @@ pub type AccumulatorFunctionImplementation = pub type StateTypeFunction = Arc Result>> + Send + Sync>; -macro_rules! make_utf8_to_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - DataType::LargeUtf8 => $largeUtf8Type, - DataType::Utf8 => $utf8Type, - DataType::Null => DataType::Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal(format!( - "The {:?} function can only accept strings.", - name - ))); - } - }) - } - }; -} - -make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); -make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - -fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::LargeBinary => DataType::Binary, - DataType::Null => DataType::Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal(format!( - "The {name:?} function can only accept strings or binary arrays." - ))); - } - }) -} - /// Returns the datatype of the scalar function +#[deprecated( + since = "27.0.0", + note = "please use `BuiltinScalarFunction::return_type` instead" +)] pub fn return_type( fun: &BuiltinScalarFunction, input_expr_types: &[DataType], ) -> Result { - use DataType::*; - use TimeUnit::*; - - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - if input_expr_types.is_empty() && !fun.supports_zero_argument() { - return Err(DataFusionError::Plan(generate_signature_error_msg( - fun, - input_expr_types, - ))); - } - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &signature(fun)).map_err(|_| { - DataFusionError::Plan(generate_signature_error_msg(fun, input_expr_types)) - })?; - - // the return type of the built in function. - // Some built-in functions' return type depends on the incoming type. - match fun { - BuiltinScalarFunction::ArrayAppend => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayConcat => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept fixed size list as the args." - ))), - }, - BuiltinScalarFunction::ArrayDims => Ok(UInt8), - BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), - BuiltinScalarFunction::ArrayLength => Ok(UInt8), - BuiltinScalarFunction::ArrayNdims => Ok(UInt8), - BuiltinScalarFunction::ArrayPosition => Ok(UInt8), - BuiltinScalarFunction::ArrayPositions => Ok(UInt8), - BuiltinScalarFunction::ArrayPrepend => match &input_expr_types[1] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayRemove => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayReplace => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayToString => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::Cardinality => Ok(UInt64), - BuiltinScalarFunction::MakeArray => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), - BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::Ascii => Ok(Int32), - BuiltinScalarFunction::BitLength => { - utf8_to_int_type(&input_expr_types[0], "bit_length") - } - BuiltinScalarFunction::Btrim => utf8_to_str_type(&input_expr_types[0], "btrim"), - BuiltinScalarFunction::CharacterLength => { - utf8_to_int_type(&input_expr_types[0], "character_length") - } - BuiltinScalarFunction::Chr => Ok(Utf8), - BuiltinScalarFunction::Coalesce => { - // COALESCE has multiple args and they might get coerced, get a preview of this - let coerced_types = data_types(input_expr_types, &signature(fun)); - coerced_types.map(|types| types[0].clone()) - } - BuiltinScalarFunction::Concat => Ok(Utf8), - BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), - BuiltinScalarFunction::DatePart => Ok(Float64), - BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin => { - match input_expr_types[1] { - Timestamp(Nanosecond, _) | Utf8 => Ok(Timestamp(Nanosecond, None)), - Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), - Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), - Timestamp(Second, _) => Ok(Timestamp(Second, None)), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept timestamp as the second arg." - ))), - } - } - BuiltinScalarFunction::InitCap => { - utf8_to_str_type(&input_expr_types[0], "initcap") - } - BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lower => utf8_to_str_type(&input_expr_types[0], "lower"), - BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), - BuiltinScalarFunction::Ltrim => utf8_to_str_type(&input_expr_types[0], "ltrim"), - BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), - BuiltinScalarFunction::NullIf => { - // NULLIF has two args and they might get coerced, get a preview of this - let coerced_types = data_types(input_expr_types, &signature(fun)); - coerced_types.map(|typs| typs[0].clone()) - } - BuiltinScalarFunction::OctetLength => { - utf8_to_int_type(&input_expr_types[0], "octet_length") - } - BuiltinScalarFunction::Pi => Ok(Float64), - BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Uuid => Ok(Utf8), - BuiltinScalarFunction::RegexpReplace => { - utf8_to_str_type(&input_expr_types[0], "regex_replace") - } - BuiltinScalarFunction::Repeat => utf8_to_str_type(&input_expr_types[0], "repeat"), - BuiltinScalarFunction::Replace => { - utf8_to_str_type(&input_expr_types[0], "replace") - } - BuiltinScalarFunction::Reverse => { - utf8_to_str_type(&input_expr_types[0], "reverse") - } - BuiltinScalarFunction::Right => utf8_to_str_type(&input_expr_types[0], "right"), - BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), - BuiltinScalarFunction::Rtrim => utf8_to_str_type(&input_expr_types[0], "rtrimp"), - BuiltinScalarFunction::SHA224 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224") - } - BuiltinScalarFunction::SHA256 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256") - } - BuiltinScalarFunction::SHA384 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384") - } - BuiltinScalarFunction::SHA512 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512") - } - BuiltinScalarFunction::Digest => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") - } - BuiltinScalarFunction::SplitPart => { - utf8_to_str_type(&input_expr_types[0], "split_part") - } - BuiltinScalarFunction::StartsWith => Ok(Boolean), - BuiltinScalarFunction::Strpos => utf8_to_int_type(&input_expr_types[0], "strpos"), - BuiltinScalarFunction::Substr => utf8_to_str_type(&input_expr_types[0], "substr"), - BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { - Int8 | Int16 | Int32 | Int64 => Utf8, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The to_hex function can only accept integers.".to_string(), - )); - } - }), - BuiltinScalarFunction::ToTimestamp => Ok(Timestamp(Nanosecond, None)), - BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), - BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), - BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), - BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), - BuiltinScalarFunction::Now => Ok(Timestamp(Nanosecond, Some("+00:00".into()))), - BuiltinScalarFunction::CurrentDate => Ok(Date32), - BuiltinScalarFunction::CurrentTime => Ok(Time64(Nanosecond)), - BuiltinScalarFunction::Translate => { - utf8_to_str_type(&input_expr_types[0], "translate") - } - BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), - BuiltinScalarFunction::Upper => utf8_to_str_type(&input_expr_types[0], "upper"), - BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] { - LargeUtf8 => List(Arc::new(Field::new("item", LargeUtf8, true))), - Utf8 => List(Arc::new(Field::new("item", Utf8, true))), - Null => Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The regexp_extract function can only accept strings.".to_string(), - )); - } - }), - - BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Gcd - | BuiltinScalarFunction::Lcm => Ok(Int64), - - BuiltinScalarFunction::Power => match &input_expr_types[0] { - Int64 => Ok(Int64), - _ => Ok(Float64), - }, - - BuiltinScalarFunction::Struct => { - let return_fields = input_expr_types - .iter() - .enumerate() - .map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true)) - .collect::>(); - Ok(Struct(Fields::from(return_fields))) - } - - BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - - BuiltinScalarFunction::Log => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - - BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), - - BuiltinScalarFunction::Abs - | BuiltinScalarFunction::Acos - | BuiltinScalarFunction::Asin - | BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Cos - | BuiltinScalarFunction::Cosh - | BuiltinScalarFunction::Degrees - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 - | BuiltinScalarFunction::Radians - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Signum - | BuiltinScalarFunction::Sin - | BuiltinScalarFunction::Sinh - | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh - | BuiltinScalarFunction::Trunc => match input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - } + fun.return_type(input_expr_types) } /// Return the [`Signature`] supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `BuiltinScalarFunction::signature` instead" +)] pub fn signature(fun: &BuiltinScalarFunction) -> Signature { - use DataType::*; - use IntervalUnit::*; - use TimeUnit::*; - use TypeSignature::*; - // note: the physical expression must accept the type returned by this function or the execution panics. - - // for now, the list is small, as we do not have many built-in functions. - match fun { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayConcat => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayDims => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::ArrayFill => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayLength => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayNdims => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::ArrayPosition => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayPositions => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayRemove => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayReplace => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayToString => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::Cardinality => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::MakeArray => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::TrimArray => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::Struct => Signature::variadic( - struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), - fun.volatility(), - ), - BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { - Signature::variadic(vec![Utf8], fun.volatility()) - } - BuiltinScalarFunction::Coalesce => Signature::variadic( - conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), - fun.volatility(), - ), - BuiltinScalarFunction::SHA224 - | BuiltinScalarFunction::SHA256 - | BuiltinScalarFunction::SHA384 - | BuiltinScalarFunction::SHA512 - | BuiltinScalarFunction::MD5 => Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], - fun.volatility(), - ), - BuiltinScalarFunction::Ascii - | BuiltinScalarFunction::BitLength - | BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::Upper => { - Signature::uniform(1, vec![Utf8, LargeUtf8], fun.volatility()) - } - BuiltinScalarFunction::Btrim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim - | BuiltinScalarFunction::Trim => Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], - fun.volatility(), - ), - BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { - Signature::uniform(1, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), - Exact(vec![Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, Int64, LargeUtf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Left - | BuiltinScalarFunction::Repeat - | BuiltinScalarFunction::Right => Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestamp => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::FromUnixtime => { - Signature::uniform(1, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::Digest => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::DateTrunc => { - Signature::exact(vec![Utf8, Timestamp(Nanosecond, None)], fun.volatility()) - } - BuiltinScalarFunction::DateBin => { - let base_sig = |array_type: TimeUnit| { - vec![ - Exact(vec![ - Interval(MonthDayNano), - Timestamp(array_type.clone(), None), - Timestamp(Nanosecond, None), - ]), - Exact(vec![ - Interval(DayTime), - Timestamp(array_type.clone(), None), - Timestamp(Nanosecond, None), - ]), - Exact(vec![ - Interval(MonthDayNano), - Timestamp(array_type.clone(), None), - ]), - Exact(vec![Interval(DayTime), Timestamp(array_type, None)]), - ] - }; - - let full_sig = [Nanosecond, Microsecond, Millisecond, Second] - .into_iter() - .map(base_sig) - .collect::>() - .concat(); - - Signature::one_of(full_sig, fun.volatility()) - } - BuiltinScalarFunction::DatePart => Signature::one_of( - vec![ - Exact(vec![Utf8, Date32]), - Exact(vec![Utf8, Date64]), - Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8, Timestamp(Nanosecond, Some("+00:00".into()))]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::SplitPart => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, Utf8, Int64]), - Exact(vec![Utf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - fun.volatility(), - ), - - BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - fun.volatility(), - ) - } - - BuiltinScalarFunction::Substr => Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), - ], - fun.volatility(), - ), + fun.signature() +} - BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { - Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], fun.volatility()) - } - BuiltinScalarFunction::RegexpReplace => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8, Utf8]), - ], - fun.volatility(), - ), +/// Suggest a valid function based on an invalid input function name +pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { + let valid_funcs = if is_window_func { + // All aggregate functions and builtin window functions + AggregateFunction::iter() + .map(|func| func.to_string()) + .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) + .collect() + } else { + // All scalar functions and aggregate functions + BuiltinScalarFunction::iter() + .map(|func| func.to_string()) + .chain(AggregateFunction::iter().map(|func| func.to_string())) + .collect() + }; + find_closest_match(valid_funcs, input_function_name) +} - BuiltinScalarFunction::NullIf => { - Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), fun.volatility()) - } - BuiltinScalarFunction::RegexpMatch => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Pi => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Uuid => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Power => Signature::one_of( - vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], - fun.volatility(), - ), - BuiltinScalarFunction::Round => Signature::one_of( - vec![ - Exact(vec![Float64, Int64]), - Exact(vec![Float32, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Atan2 => Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - fun.volatility(), - ), - BuiltinScalarFunction::Log => Signature::one_of( - vec![ - Exact(vec![Float32]), - Exact(vec![Float64]), - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Factorial => { - Signature::uniform(1, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { - Signature::uniform(2, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::ArrowTypeof => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::Abs - | BuiltinScalarFunction::Acos - | BuiltinScalarFunction::Asin - | BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Cos - | BuiltinScalarFunction::Cosh - | BuiltinScalarFunction::Degrees - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 - | BuiltinScalarFunction::Radians - | BuiltinScalarFunction::Signum - | BuiltinScalarFunction::Sin - | BuiltinScalarFunction::Sinh - | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh - | BuiltinScalarFunction::Trunc => { - // math expressions expect 1 argument of type f64 or f32 - // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we - // return the best approximation for it (in f64). - // We accept f32 because in this case it is clear that the best approximation - // will be as good as the number of digits in the number - Signature::uniform(1, vec![Float64, Float32], fun.volatility()) - } - BuiltinScalarFunction::Now - | BuiltinScalarFunction::CurrentDate - | BuiltinScalarFunction::CurrentTime => { - Signature::uniform(0, vec![], fun.volatility()) - } - } +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Input `candidates` must not be empty otherwise it will panic +fn find_closest_match(candidates: Vec, target: &str) -> String { + let target = target.to_lowercase(); + candidates + .into_iter() + .min_by_key(|candidate| { + datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) + }) + .expect("No candidates provided.") // Panic if `candidates` argument is empty } diff --git a/datafusion/expr/src/function_err.rs b/datafusion/expr/src/function_err.rs deleted file mode 100644 index 1635ac3b0c8c..000000000000 --- a/datafusion/expr/src/function_err.rs +++ /dev/null @@ -1,125 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Function_err module enhances frontend error messages for unresolved functions due to incorrect parameters, -//! by providing the correct function signatures. -//! -//! For example, a query like `select round(3.14, 1.1);` would yield: -//! ```text -//! Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. -//! Candidate functions: -//! round(Float64, Int64) -//! round(Float32, Int64) -//! round(Float64) -//! round(Float32) -//! ``` - -use crate::function::signature; -use crate::{ - AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, TypeSignature, -}; -use arrow::datatypes::DataType; -use datafusion_common::utils::datafusion_strsim; -use strum::IntoEnumIterator; - -impl TypeSignature { - fn to_string_repr(&self) -> Vec { - match self { - TypeSignature::Variadic(types) => { - vec![format!("{}, ..", join_types(types, "/"))] - } - TypeSignature::Uniform(arg_count, valid_types) => { - vec![std::iter::repeat(join_types(valid_types, "/")) - .take(*arg_count) - .collect::>() - .join(", ")] - } - TypeSignature::Exact(types) => { - vec![join_types(types, ", ")] - } - TypeSignature::Any(arg_count) => { - vec![std::iter::repeat("Any") - .take(*arg_count) - .collect::>() - .join(", ")] - } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], - TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], - TypeSignature::OneOf(sigs) => { - sigs.iter().flat_map(|s| s.to_string_repr()).collect() - } - } - } -} - -/// Helper function to join types with specified delimiter. -fn join_types(types: &[T], delimiter: &str) -> String { - types - .iter() - .map(|t| t.to_string()) - .collect::>() - .join(delimiter) -} - -/// Creates a detailed error message for a function with wrong signature. -pub fn generate_signature_error_msg( - fun: &BuiltinScalarFunction, - input_expr_types: &[DataType], -) -> String { - let candidate_signatures = signature(fun) - .type_signature - .to_string_repr() - .iter() - .map(|args_str| format!("\t{fun}({args_str})")) - .collect::>() - .join("\n"); - - format!( - "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", - fun, join_types(input_expr_types, ", "), candidate_signatures - ) -} - -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) -/// Input `candidates` must not be empty otherwise it will panic -fn find_closest_match(candidates: Vec, target: &str) -> String { - let target = target.to_lowercase(); - candidates - .into_iter() - .min_by_key(|candidate| { - datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) - }) - .expect("No candidates provided.") // Panic if `candidates` argument is empty -} - -/// Suggest a valid function based on an invalid input function name -pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { - let valid_funcs = if is_window_func { - // All aggregate functions and builtin window functions - AggregateFunction::iter() - .map(|func| func.to_string()) - .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) - .collect() - } else { - // All scalar functions and aggregate functions - BuiltinScalarFunction::iter() - .map(|func| func.to_string()) - .chain(AggregateFunction::iter().map(|func| func.to_string())) - .collect() - }; - find_closest_match(valid_funcs, input_function_name) -} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 5945480aba1d..1675afb9c98a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -37,7 +37,6 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; -pub mod function_err; mod literal; pub mod logical_plan; mod nullif; diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index a2caba4fb8bb..e4ffd74d8daa 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -60,6 +60,48 @@ pub enum TypeSignature { OneOf(Vec), } +impl TypeSignature { + pub(crate) fn to_string_repr(&self) -> Vec { + match self { + TypeSignature::Variadic(types) => { + vec![format!("{}, ..", Self::join_types(types, "/"))] + } + TypeSignature::Uniform(arg_count, valid_types) => { + vec![std::iter::repeat(Self::join_types(valid_types, "/")) + .take(*arg_count) + .collect::>() + .join(", ")] + } + TypeSignature::Exact(types) => { + vec![Self::join_types(types, ", ")] + } + TypeSignature::Any(arg_count) => { + vec![std::iter::repeat("Any") + .take(*arg_count) + .collect::>() + .join(", ")] + } + TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], + TypeSignature::OneOf(sigs) => { + sigs.iter().flat_map(|s| s.to_string_repr()).collect() + } + } + } + + /// Helper function to join types with specified delimiter. + pub(crate) fn join_types( + types: &[T], + delimiter: &str, + ) -> String { + types + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(delimiter) + } +} + /// The signature of a function defines the supported argument types /// and its volatility. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 0d0061a5e435..3ee6a2401b02 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -40,8 +40,8 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_numeric, is_utf8_or_large_utf8}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ - aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown, - is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, + aggregate_function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, + is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; @@ -390,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let nex_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, - &function::signature(&fun), + &fun.signature(), )?; let expr = Expr::ScalarFunction(ScalarFunction::new(fun, nex_expr)); Ok(expr) diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 648dd4a144c6..5a5bdf4702e0 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -45,7 +45,7 @@ use arrow::{ }; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - function, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, + BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, }; use std::sync::Arc; @@ -62,7 +62,7 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; - let data_type = function::return_type(fun, &input_expr_types)?; + let data_type = fun.return_type(&input_expr_types)?; let fun_expr: ScalarFunctionImplementation = match fun { // These functions need args and input schema to pick an implementation @@ -2921,7 +2921,7 @@ mod tests { execution_props: &ExecutionProps, ) -> Result> { let type_coerced_phy_exprs = - coerce(input_phy_exprs, input_schema, &function::signature(fun)).unwrap(); + coerce(input_phy_exprs, input_schema, &fun.signature()).unwrap(); create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, execution_props) } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 104a65832dcd..0fb6b7554776 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -18,7 +18,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; -use datafusion_expr::function_err::suggest_valid_function; +use datafusion_expr::function::suggest_valid_function; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::window_frame::regularize; use datafusion_expr::{