@@ -23,7 +23,7 @@ use arrow::array::{
23
23
ArrayRef , ArrowPrimitiveType , AsArray , PrimitiveArray , StringArrayType ,
24
24
} ;
25
25
use arrow:: datatypes:: { ArrowNativeType , DataType , Int32Type , Int64Type } ;
26
- use datafusion_common:: { exec_err, Result } ;
26
+ use datafusion_common:: { exec_err, internal_err , Result } ;
27
27
use datafusion_expr:: {
28
28
ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility ,
29
29
} ;
@@ -79,8 +79,17 @@ impl ScalarUDFImpl for StrposFunc {
79
79
& self . signature
80
80
}
81
81
82
- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
83
- utf8_to_int_type ( & arg_types[ 0 ] , "strpos/instr/position" )
82
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
83
+ internal_err ! ( "return_type_from_args should be used instead" )
84
+ }
85
+
86
+ fn return_type_from_args (
87
+ & self ,
88
+ args : datafusion_expr:: ReturnTypeArgs ,
89
+ ) -> Result < datafusion_expr:: ReturnInfo > {
90
+ utf8_to_int_type ( & args. arg_types [ 0 ] , "strpos/instr/position" ) . map ( |data_type| {
91
+ datafusion_expr:: ReturnInfo :: new ( data_type, args. nullables . iter ( ) . any ( |x| * x) )
92
+ } )
84
93
}
85
94
86
95
fn invoke_with_args (
@@ -201,6 +210,7 @@ mod tests {
201
210
use arrow:: array:: { Array , Int32Array , Int64Array } ;
202
211
use arrow:: datatypes:: DataType :: { Int32 , Int64 } ;
203
212
213
+ use arrow:: datatypes:: DataType ;
204
214
use datafusion_common:: { Result , ScalarValue } ;
205
215
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl } ;
206
216
@@ -288,4 +298,27 @@ mod tests {
288
298
test_strpos ! ( "" , "" -> 1 ; Utf8View LargeUtf8 i32 Int32 Int32Array ) ;
289
299
test_strpos ! ( "ДатаФусион数据融合📊🔥" , "📊" -> 15 ; Utf8View LargeUtf8 i32 Int32 Int32Array ) ;
290
300
}
301
+
302
+ #[ test]
303
+ fn nullable_return_type ( ) {
304
+ fn get_nullable ( string_array_nullable : bool , substring_nullable : bool ) -> bool {
305
+ let strpos = StrposFunc :: new ( ) ;
306
+ let args = datafusion_expr:: ReturnTypeArgs {
307
+ arg_types : & [ DataType :: Utf8 , DataType :: Utf8 ] ,
308
+ nullables : & [ string_array_nullable, substring_nullable] ,
309
+ scalar_arguments : & [ None :: < & ScalarValue > , None :: < & ScalarValue > ] ,
310
+ } ;
311
+
312
+ let ( _, nullable) = strpos. return_type_from_args ( args) . unwrap ( ) . into_parts ( ) ;
313
+
314
+ nullable
315
+ }
316
+
317
+ assert ! ( !get_nullable( false , false ) ) ;
318
+
319
+ // If any of the arguments is nullable, the result is nullable
320
+ assert ! ( get_nullable( true , false ) ) ;
321
+ assert ! ( get_nullable( false , true ) ) ;
322
+ assert ! ( get_nullable( true , true ) ) ;
323
+ }
291
324
}
0 commit comments