Skip to content

Commit ae0f401

Browse files
authored
Add support for FixedSizeList type in arrow_cast, hashing (apache#8344)
* Add support for parsing FixedSizeList type * fix fmt * support cast fixedsizelist from list * clean comment * support cast between NULL and FixedSizedLisr * add test for FixedSizeList hash * add test for cast fixedsizelist
1 parent 72b81f1 commit ae0f401

File tree

6 files changed

+155
-8
lines changed

6 files changed

+155
-8
lines changed

datafusion/common/src/hash_utils.rs

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array};
2727
use arrow_buffer::i256;
2828

2929
use crate::cast::{
30-
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array,
31-
as_primitive_array, as_string_array, as_struct_array,
30+
as_boolean_array, as_fixed_size_list_array, as_generic_binary_array,
31+
as_large_list_array, as_list_array, as_primitive_array, as_string_array,
32+
as_struct_array,
3233
};
3334
use crate::error::{DataFusionError, Result, _internal_err};
3435

@@ -267,6 +268,38 @@ where
267268
Ok(())
268269
}
269270

271+
fn hash_fixed_list_array(
272+
array: &FixedSizeListArray,
273+
random_state: &RandomState,
274+
hashes_buffer: &mut [u64],
275+
) -> Result<()> {
276+
let values = array.values().clone();
277+
let value_len = array.value_length();
278+
let offset_size = value_len as usize / array.len();
279+
let nulls = array.nulls();
280+
let mut values_hashes = vec![0u64; values.len()];
281+
create_hashes(&[values], random_state, &mut values_hashes)?;
282+
if let Some(nulls) = nulls {
283+
for i in 0..array.len() {
284+
if nulls.is_valid(i) {
285+
let hash = &mut hashes_buffer[i];
286+
for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size]
287+
{
288+
*hash = combine_hashes(*hash, *values_hash);
289+
}
290+
}
291+
}
292+
} else {
293+
for i in 0..array.len() {
294+
let hash = &mut hashes_buffer[i];
295+
for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] {
296+
*hash = combine_hashes(*hash, *values_hash);
297+
}
298+
}
299+
}
300+
Ok(())
301+
}
302+
270303
/// Test version of `create_hashes` that produces the same value for
271304
/// all hashes (to test collisions)
272305
///
@@ -366,6 +399,10 @@ pub fn create_hashes<'a>(
366399
let array = as_large_list_array(array)?;
367400
hash_list_array(array, random_state, hashes_buffer)?;
368401
}
402+
DataType::FixedSizeList(_,_) => {
403+
let array = as_fixed_size_list_array(array)?;
404+
hash_fixed_list_array(array, random_state, hashes_buffer)?;
405+
}
369406
_ => {
370407
// This is internal because we should have caught this before.
371408
return _internal_err!(
@@ -546,6 +583,30 @@ mod tests {
546583
assert_eq!(hashes[2], hashes[3]);
547584
}
548585

586+
#[test]
587+
// Tests actual values of hashes, which are different if forcing collisions
588+
#[cfg(not(feature = "force_hash_collisions"))]
589+
fn create_hashes_for_fixed_size_list_arrays() {
590+
let data = vec![
591+
Some(vec![Some(0), Some(1), Some(2)]),
592+
None,
593+
Some(vec![Some(3), None, Some(5)]),
594+
Some(vec![Some(3), None, Some(5)]),
595+
None,
596+
Some(vec![Some(0), Some(1), Some(2)]),
597+
];
598+
let list_array =
599+
Arc::new(FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
600+
data, 3,
601+
)) as ArrayRef;
602+
let random_state = RandomState::with_seeds(0, 0, 0, 0);
603+
let mut hashes = vec![0; list_array.len()];
604+
create_hashes(&[list_array], &random_state, &mut hashes).unwrap();
605+
assert_eq!(hashes[0], hashes[5]);
606+
assert_eq!(hashes[1], hashes[4]);
607+
assert_eq!(hashes[2], hashes[3]);
608+
}
609+
549610
#[test]
550611
// Tests actual values of hashes, which are different if forcing collisions
551612
#[cfg(not(feature = "force_hash_collisions"))]

datafusion/common/src/scalar.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ use crate::cast::{
3434
};
3535
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
3636
use crate::hash_utils::create_hashes;
37-
use crate::utils::{array_into_large_list_array, array_into_list_array};
38-
37+
use crate::utils::{
38+
array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array,
39+
};
3940
use arrow::compute::kernels::numeric::*;
4041
use arrow::util::display::{ArrayFormatter, FormatOptions};
4142
use arrow::{
@@ -2223,9 +2224,11 @@ impl ScalarValue {
22232224
let list_array = as_fixed_size_list_array(array)?;
22242225
let nested_array = list_array.value(index);
22252226
// Produces a single element `ListArray` with the value at `index`.
2226-
let arr = Arc::new(array_into_list_array(nested_array));
2227+
let list_size = nested_array.len();
2228+
let arr =
2229+
Arc::new(array_into_fixed_size_list_array(nested_array, list_size));
22272230

2228-
ScalarValue::List(arr)
2231+
ScalarValue::FixedSizeList(arr)
22292232
}
22302233
DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?,
22312234
DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?,
@@ -2971,6 +2974,19 @@ impl TryFrom<&DataType> for ScalarValue {
29712974
.to_owned()
29722975
.into(),
29732976
),
2977+
// `ScalaValue::FixedSizeList` contains single element `FixedSizeList`.
2978+
DataType::FixedSizeList(field, _) => ScalarValue::FixedSizeList(
2979+
new_null_array(
2980+
&DataType::FixedSizeList(
2981+
Arc::new(Field::new("item", field.data_type().clone(), true)),
2982+
1,
2983+
),
2984+
1,
2985+
)
2986+
.as_fixed_size_list()
2987+
.to_owned()
2988+
.into(),
2989+
),
29742990
DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()),
29752991
DataType::Null => ScalarValue::Null,
29762992
_ => {

datafusion/common/src/utils.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ use arrow::compute;
2525
use arrow::compute::{partition, SortColumn, SortOptions};
2626
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
2727
use arrow::record_batch::RecordBatch;
28-
use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions};
28+
use arrow_array::{
29+
Array, FixedSizeListArray, LargeListArray, ListArray, RecordBatchOptions,
30+
};
2931
use arrow_schema::DataType;
3032
use sqlparser::ast::Ident;
3133
use sqlparser::dialect::GenericDialect;
@@ -368,6 +370,19 @@ pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray {
368370
)
369371
}
370372

373+
pub fn array_into_fixed_size_list_array(
374+
arr: ArrayRef,
375+
list_size: usize,
376+
) -> FixedSizeListArray {
377+
let list_size = list_size as i32;
378+
FixedSizeListArray::new(
379+
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
380+
list_size,
381+
arr,
382+
None,
383+
)
384+
}
385+
371386
/// Wrap arrays into a single element `ListArray`.
372387
///
373388
/// Example:

datafusion/expr/src/utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ pub fn can_hash(data_type: &DataType) -> bool {
911911
}
912912
DataType::List(_) => true,
913913
DataType::LargeList(_) => true,
914+
DataType::FixedSizeList(_, _) => true,
914915
_ => false,
915916
}
916917
}

datafusion/sql/src/expr/arrow_cast.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ impl<'a> Parser<'a> {
150150
Token::Dictionary => self.parse_dictionary(),
151151
Token::List => self.parse_list(),
152152
Token::LargeList => self.parse_large_list(),
153+
Token::FixedSizeList => self.parse_fixed_size_list(),
153154
tok => Err(make_error(
154155
self.val,
155156
&format!("finding next type, got unexpected '{tok}'"),
@@ -177,6 +178,19 @@ impl<'a> Parser<'a> {
177178
))))
178179
}
179180

181+
/// Parses the FixedSizeList type
182+
fn parse_fixed_size_list(&mut self) -> Result<DataType> {
183+
self.expect_token(Token::LParen)?;
184+
let length = self.parse_i32("FixedSizeList")?;
185+
self.expect_token(Token::Comma)?;
186+
let data_type = self.parse_next_type()?;
187+
self.expect_token(Token::RParen)?;
188+
Ok(DataType::FixedSizeList(
189+
Arc::new(Field::new("item", data_type, true)),
190+
length,
191+
))
192+
}
193+
180194
/// Parses the next timeunit
181195
fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
182196
match self.next_token()? {
@@ -508,6 +522,7 @@ impl<'a> Tokenizer<'a> {
508522

509523
"List" => Token::List,
510524
"LargeList" => Token::LargeList,
525+
"FixedSizeList" => Token::FixedSizeList,
511526

512527
"Second" => Token::TimeUnit(TimeUnit::Second),
513528
"Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
@@ -598,6 +613,7 @@ enum Token {
598613
DoubleQuotedString(String),
599614
List,
600615
LargeList,
616+
FixedSizeList,
601617
}
602618

603619
impl Display for Token {
@@ -606,6 +622,7 @@ impl Display for Token {
606622
Token::SimpleType(t) => write!(f, "{t}"),
607623
Token::List => write!(f, "List"),
608624
Token::LargeList => write!(f, "LargeList"),
625+
Token::FixedSizeList => write!(f, "FixedSizeList"),
609626
Token::Timestamp => write!(f, "Timestamp"),
610627
Token::Time32 => write!(f, "Time32"),
611628
Token::Time64 => write!(f, "Time64"),

datafusion/sqllogictest/test_files/arrow_typeof.slt

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,4 +384,41 @@ LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, di
384384
query T
385385
select arrow_typeof(arrow_cast(make_array([1, 2, 3]), 'LargeList(LargeList(Int64))'));
386386
----
387-
LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
387+
LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
388+
389+
## FixedSizeList
390+
391+
query ?
392+
select arrow_cast(null, 'FixedSizeList(1, Int64)');
393+
----
394+
NULL
395+
396+
#TODO: arrow-rs doesn't support it yet
397+
#query ?
398+
#select arrow_cast('1', 'FixedSizeList(1, Int64)');
399+
#----
400+
#[1]
401+
402+
403+
query ?
404+
select arrow_cast([1], 'FixedSizeList(1, Int64)');
405+
----
406+
[1]
407+
408+
query error DataFusion error: Optimizer rule 'simplify_expressions' failed
409+
select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)');
410+
411+
query ?
412+
select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)');
413+
----
414+
[1, 2, 3]
415+
416+
query T
417+
select arrow_typeof(arrow_cast(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 'FixedSizeList(3, Int64)'));
418+
----
419+
FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3)
420+
421+
query ?
422+
select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)');
423+
----
424+
[1, 2, 3]

0 commit comments

Comments
 (0)