diff --git a/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs b/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs index 7dccc5b0ba7c..9bf59d9e333d 100644 --- a/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs +++ b/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs @@ -20,8 +20,8 @@ use std::sync::Arc; use arrow::array::{ArrayRef, StructArray}; +use arrow::compute::cast; use arrow::datatypes::{Field, FieldRef, Fields}; -use arrow::downcast_dictionary_array; use arrow_schema::DataType; use datafusion_common::Result; @@ -33,15 +33,16 @@ pub(super) fn build_struct_fields(data_types: &[DataType]) -> Result { .collect() } -/// Flattens dictionary-encoded arrays to their underlying value arrays. +/// Casts dictionary-encoded arrays to their underlying value type, preserving row count. /// Non-dictionary arrays are returned as-is. -fn flatten_dictionary_array(array: &ArrayRef) -> ArrayRef { - downcast_dictionary_array! { - array => { +fn flatten_dictionary_array(array: &ArrayRef) -> Result { + match array.data_type() { + DataType::Dictionary(_, value_type) => { + let casted = cast(array, value_type)?; // Recursively flatten in case of nested dictionaries - flatten_dictionary_array(array.values()) + flatten_dictionary_array(&casted) } - _ => Arc::clone(array) + _ => Ok(Arc::clone(array)), } } @@ -68,7 +69,7 @@ pub(super) fn build_struct_inlist_values( let flattened_arrays: Vec = join_key_arrays .iter() .map(flatten_dictionary_array) - .collect(); + .collect::>>()?; // Build the source array/struct let source_array: ArrayRef = if flattened_arrays.len() == 1 { @@ -99,7 +100,9 @@ pub(super) fn build_struct_inlist_values( #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int32Array, StringArray}; + use arrow::array::{ + DictionaryArray, Int8Array, Int32Array, StringArray, StringDictionaryBuilder, + }; use arrow_schema::DataType; use std::sync::Arc; @@ -130,4 +133,41 @@ mod tests { ) ); } + + #[test] + fn test_build_multi_column_inlist_with_dictionary() { + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("foo"); + builder.append_value("foo"); + builder.append_value("foo"); + let dict_array = Arc::new(builder.finish()) as ArrayRef; + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + + let result = build_struct_inlist_values(&[dict_array, int_array]) + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!( + *result.data_type(), + DataType::Struct( + build_struct_fields(&[DataType::Utf8, DataType::Int32]).unwrap() + ) + ); + } + + #[test] + fn test_build_single_column_dictionary_inlist() { + let keys = Int8Array::from(vec![0i8, 0, 0]); + let values = Arc::new(StringArray::from(vec!["foo"])); + let dict_array = Arc::new(DictionaryArray::new(keys, values)) as ArrayRef; + + let result = build_struct_inlist_values(std::slice::from_ref(&dict_array)) + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!(*result.data_type(), DataType::Utf8); + } } diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 2fb544a638d6..282d7c374bdd 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -5290,3 +5290,31 @@ DROP TABLE empty_proj_left; statement count 0 DROP TABLE empty_proj_right; + +# Issue #20437: HashJoin panic with dictionary-encoded columns in multi-key joins +# https://github.com/apache/datafusion/issues/20437 + +statement ok +CREATE TABLE issue_20437_small AS +SELECT id, arrow_cast(region, 'Dictionary(Int32, Utf8)') AS region +FROM (VALUES (1, 'west'), (2, 'west')) AS t(id, region); + +statement ok +CREATE TABLE issue_20437_large AS +SELECT id, region, value +FROM (VALUES (1, 'west', 100), (2, 'west', 200), (3, 'east', 300)) AS t(id, region, value); + +query ITI +SELECT s.id, s.region, l.value +FROM issue_20437_small s +JOIN issue_20437_large l ON s.id = l.id AND s.region = l.region +ORDER BY s.id; +---- +1 west 100 +2 west 200 + +statement count 0 +DROP TABLE issue_20437_small; + +statement count 0 +DROP TABLE issue_20437_large;