Skip to content

Commit 6797375

Browse files
authored
fix: use return_type_from_args and mark nullable if any of the input is nullable (#14841)
* fix: use `return_type_from_args` and mark nullable if any of the input is nullable * keep `return_type` function * Revert "keep `return_type` function" This reverts commit b1b7aac.
1 parent e0da97b commit 6797375

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

datafusion/functions/src/unicode/strpos.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use arrow::array::{
2323
ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType,
2424
};
2525
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26-
use datafusion_common::{exec_err, Result};
26+
use datafusion_common::{exec_err, internal_err, Result};
2727
use datafusion_expr::{
2828
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
2929
};
@@ -79,8 +79,17 @@ impl ScalarUDFImpl for StrposFunc {
7979
&self.signature
8080
}
8181

82-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
83-
utf8_to_int_type(&arg_types[0], "strpos/instr/position")
82+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
83+
internal_err!("return_type_from_args should be used instead")
84+
}
85+
86+
fn return_type_from_args(
87+
&self,
88+
args: datafusion_expr::ReturnTypeArgs,
89+
) -> Result<datafusion_expr::ReturnInfo> {
90+
utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| {
91+
datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x))
92+
})
8493
}
8594

8695
fn invoke_with_args(
@@ -201,6 +210,7 @@ mod tests {
201210
use arrow::array::{Array, Int32Array, Int64Array};
202211
use arrow::datatypes::DataType::{Int32, Int64};
203212

213+
use arrow::datatypes::DataType;
204214
use datafusion_common::{Result, ScalarValue};
205215
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
206216

@@ -288,4 +298,27 @@ mod tests {
288298
test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
289299
test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array);
290300
}
301+
302+
#[test]
303+
fn nullable_return_type() {
304+
fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool {
305+
let strpos = StrposFunc::new();
306+
let args = datafusion_expr::ReturnTypeArgs {
307+
arg_types: &[DataType::Utf8, DataType::Utf8],
308+
nullables: &[string_array_nullable, substring_nullable],
309+
scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>],
310+
};
311+
312+
let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts();
313+
314+
nullable
315+
}
316+
317+
assert!(!get_nullable(false, false));
318+
319+
// If any of the arguments is nullable, the result is nullable
320+
assert!(get_nullable(true, false));
321+
assert!(get_nullable(false, true));
322+
assert!(get_nullable(true, true));
323+
}
291324
}

0 commit comments

Comments
 (0)