Skip to content

Commit 4320a75

Browse files
Implement Take for UnionArray (#4883)
Implement Take for UnionArray (#4883)
1 parent 39e4d94 commit 4320a75

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

arrow-select/src/take.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use arrow_buffer::{
2828
ScalarBuffer,
2929
};
3030
use arrow_data::{ArrayData, ArrayDataBuilder};
31-
use arrow_schema::{ArrowError, DataType, FieldRef};
31+
use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
3232

3333
use num::{One, Zero};
3434

@@ -223,6 +223,21 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
223223
Ok(new_null_array(&DataType::Null, indices.len()))
224224
}
225225
}
226+
DataType::Union(fields, UnionMode::Sparse) => {
227+
let mut field_type_ids = Vec::with_capacity(fields.len());
228+
let mut children = Vec::with_capacity(fields.len());
229+
let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
230+
let type_ids = take_native(values.type_ids(), indices).into_inner();
231+
for (type_id, field) in fields.iter() {
232+
let values = values.child(type_id);
233+
let values = take_impl(values, indices)?;
234+
let field = (**field).clone();
235+
children.push((field, values));
236+
field_type_ids.push(type_id);
237+
}
238+
let array = UnionArray::try_new(field_type_ids.as_slice(), type_ids, None, children)?;
239+
Ok(Arc::new(array))
240+
}
226241
t => unimplemented!("Take not supported for data type {:?}", t)
227242
}
228243
}
@@ -2013,4 +2028,41 @@ mod tests {
20132028
let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
20142029
assert_eq!(&values, &[Some("foo"), None, None, None])
20152030
}
2031+
2032+
#[test]
2033+
fn test_take_union() {
2034+
let structs = create_test_struct(vec![
2035+
Some((Some(true), Some(42))),
2036+
Some((Some(false), Some(28))),
2037+
Some((Some(false), Some(19))),
2038+
Some((Some(true), Some(31))),
2039+
None,
2040+
]);
2041+
let strings =
2042+
StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2043+
let type_ids = Buffer::from_slice_ref(vec![1i8; 5]);
2044+
2045+
let children: Vec<(Field, Arc<dyn Array>)> = vec![
2046+
(
2047+
Field::new("f1", structs.data_type().clone(), true),
2048+
Arc::new(structs),
2049+
),
2050+
(
2051+
Field::new("f2", strings.data_type().clone(), true),
2052+
Arc::new(strings),
2053+
),
2054+
];
2055+
let array = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap();
2056+
2057+
let indices = vec![0, 3, 1, 0, 2, 4];
2058+
let index = UInt32Array::from(indices.clone());
2059+
let actual = take(&array, &index, None).unwrap();
2060+
let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2061+
let strings = actual.child(1);
2062+
let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2063+
2064+
let actual = strings.iter().collect::<Vec<_>>();
2065+
let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2066+
assert_eq!(expected, actual);
2067+
}
20162068
}

0 commit comments

Comments
 (0)