@@ -28,7 +28,7 @@ use arrow_buffer::{
28
28
ScalarBuffer ,
29
29
} ;
30
30
use arrow_data:: { ArrayData , ArrayDataBuilder } ;
31
- use arrow_schema:: { ArrowError , DataType , FieldRef } ;
31
+ use arrow_schema:: { ArrowError , DataType , FieldRef , UnionMode } ;
32
32
33
33
use num:: { One , Zero } ;
34
34
@@ -223,6 +223,21 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
223
223
Ok ( new_null_array( & DataType :: Null , indices. len( ) ) )
224
224
}
225
225
}
226
+ DataType :: Union ( fields, UnionMode :: Sparse ) => {
227
+ let mut field_type_ids = Vec :: with_capacity( fields. len( ) ) ;
228
+ let mut children = Vec :: with_capacity( fields. len( ) ) ;
229
+ let values = values. as_any( ) . downcast_ref:: <UnionArray >( ) . unwrap( ) ;
230
+ let type_ids = take_native( values. type_ids( ) , indices) . into_inner( ) ;
231
+ for ( type_id, field) in fields. iter( ) {
232
+ let values = values. child( type_id) ;
233
+ let values = take_impl( values, indices) ?;
234
+ let field = ( * * field) . clone( ) ;
235
+ children. push( ( field, values) ) ;
236
+ field_type_ids. push( type_id) ;
237
+ }
238
+ let array = UnionArray :: try_new( field_type_ids. as_slice( ) , type_ids, None , children) ?;
239
+ Ok ( Arc :: new( array) )
240
+ }
226
241
t => unimplemented!( "Take not supported for data type {:?}" , t)
227
242
}
228
243
}
@@ -2013,4 +2028,41 @@ mod tests {
2013
2028
let values = r. as_string :: < i32 > ( ) . iter ( ) . collect :: < Vec < _ > > ( ) ;
2014
2029
assert_eq ! ( & values, & [ Some ( "foo" ) , None , None , None ] )
2015
2030
}
2031
+
2032
+ #[ test]
2033
+ fn test_take_union ( ) {
2034
+ let structs = create_test_struct ( vec ! [
2035
+ Some ( ( Some ( true ) , Some ( 42 ) ) ) ,
2036
+ Some ( ( Some ( false ) , Some ( 28 ) ) ) ,
2037
+ Some ( ( Some ( false ) , Some ( 19 ) ) ) ,
2038
+ Some ( ( Some ( true ) , Some ( 31 ) ) ) ,
2039
+ None ,
2040
+ ] ) ;
2041
+ let strings =
2042
+ StringArray :: from ( vec ! [ Some ( "a" ) , None , Some ( "c" ) , None , Some ( "d" ) ] ) ;
2043
+ let type_ids = Buffer :: from_slice_ref ( vec ! [ 1i8 ; 5 ] ) ;
2044
+
2045
+ let children: Vec < ( Field , Arc < dyn Array > ) > = vec ! [
2046
+ (
2047
+ Field :: new( "f1" , structs. data_type( ) . clone( ) , true ) ,
2048
+ Arc :: new( structs) ,
2049
+ ) ,
2050
+ (
2051
+ Field :: new( "f2" , strings. data_type( ) . clone( ) , true ) ,
2052
+ Arc :: new( strings) ,
2053
+ ) ,
2054
+ ] ;
2055
+ let array = UnionArray :: try_new ( & [ 0 , 1 ] , type_ids, None , children) . unwrap ( ) ;
2056
+
2057
+ let indices = vec ! [ 0 , 3 , 1 , 0 , 2 , 4 ] ;
2058
+ let index = UInt32Array :: from ( indices. clone ( ) ) ;
2059
+ let actual = take ( & array, & index, None ) . unwrap ( ) ;
2060
+ let actual = actual. as_any ( ) . downcast_ref :: < UnionArray > ( ) . unwrap ( ) ;
2061
+ let strings = actual. child ( 1 ) ;
2062
+ let strings = strings. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
2063
+
2064
+ let actual = strings. iter ( ) . collect :: < Vec < _ > > ( ) ;
2065
+ let expected = vec ! [ Some ( "a" ) , None , None , Some ( "a" ) , Some ( "c" ) , Some ( "d" ) ] ;
2066
+ assert_eq ! ( expected, actual) ;
2067
+ }
2016
2068
}
0 commit comments