Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,52 @@ async fn unparse_cross_join() -> Result<()> {

Ok(())
}

// Issue #20437: https://github.com/apache/datafusion/issues/20437
#[tokio::test]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep unit tests at a minimum when possible. Use sqllogictests instead here

async fn test_hash_join_multi_key_dictionary_encoded() -> Result<()> {
let ctx = SessionContext::new();

ctx.sql(
"CREATE TABLE small AS
SELECT id, arrow_cast(region, 'Dictionary(Int32, Utf8)') AS region
FROM (VALUES (1, 'west'), (2, 'west')) AS t(id, region)",
)
.await?
.collect()
.await?;

ctx.sql(
"CREATE TABLE large AS
SELECT id, region, value
FROM (VALUES (1, 'west', 100), (2, 'west', 200), (3, 'east', 300)) AS t(id, region, value)",
)
.await?
.collect()
.await?;

let results = ctx
.sql(
"SELECT s.id, s.region, l.value
FROM small s
JOIN large l ON s.id = l.id AND s.region = l.region
ORDER BY s.id",
)
.await?
.collect()
.await?;

assert_batches_eq!(
[
"+----+--------+-------+",
"| id | region | value |",
"+----+--------+-------+",
"| 1 | west | 100 |",
"| 2 | west | 200 |",
"+----+--------+-------+",
],
&results
);

Ok(())
}
58 changes: 49 additions & 9 deletions datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
use std::sync::Arc;

use arrow::array::{ArrayRef, StructArray};
use arrow::compute::cast;
use arrow::datatypes::{Field, FieldRef, Fields};
use arrow::downcast_dictionary_array;
use arrow_schema::DataType;
use datafusion_common::Result;

Expand All @@ -33,15 +33,16 @@ pub(super) fn build_struct_fields(data_types: &[DataType]) -> Result<Fields> {
.collect()
}

/// Flattens dictionary-encoded arrays to their underlying value arrays.
/// Casts dictionary-encoded arrays to their underlying value type, preserving row count.
/// Non-dictionary arrays are returned as-is.
fn flatten_dictionary_array(array: &ArrayRef) -> ArrayRef {
downcast_dictionary_array! {
array => {
fn flatten_dictionary_array(array: &ArrayRef) -> Result<ArrayRef> {
match array.data_type() {
DataType::Dictionary(_, value_type) => {
let casted = cast(array, value_type)?;
// Recursively flatten in case of nested dictionaries
flatten_dictionary_array(array.values())
flatten_dictionary_array(&casted)
}
_ => Arc::clone(array)
_ => Ok(Arc::clone(array)),
}
}

Expand All @@ -68,7 +69,7 @@ pub(super) fn build_struct_inlist_values(
let flattened_arrays: Vec<ArrayRef> = join_key_arrays
.iter()
.map(flatten_dictionary_array)
.collect();
.collect::<Result<Vec<_>>>()?;

// Build the source array/struct
let source_array: ArrayRef = if flattened_arrays.len() == 1 {
Expand Down Expand Up @@ -99,7 +100,9 @@ pub(super) fn build_struct_inlist_values(
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int32Array, StringArray};
use arrow::array::{
DictionaryArray, Int8Array, Int32Array, StringArray, StringDictionaryBuilder,
};
use arrow_schema::DataType;
use std::sync::Arc;

Expand Down Expand Up @@ -130,4 +133,41 @@ mod tests {
)
);
}

#[test]
fn test_build_multi_column_inlist_with_dictionary() {
let mut builder = StringDictionaryBuilder::<arrow::datatypes::Int8Type>::new();
builder.append_value("foo");
builder.append_value("foo");
builder.append_value("foo");
let dict_array = Arc::new(builder.finish()) as ArrayRef;

let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;

let result = build_struct_inlist_values(&[dict_array, int_array])
.unwrap()
.unwrap();

assert_eq!(result.len(), 3);
assert_eq!(
*result.data_type(),
DataType::Struct(
build_struct_fields(&[DataType::Utf8, DataType::Int32]).unwrap()
)
);
}

#[test]
fn test_build_single_column_dictionary_inlist() {
let keys = Int8Array::from(vec![0i8, 0, 0]);
let values = Arc::new(StringArray::from(vec!["foo"]));
let dict_array = Arc::new(DictionaryArray::new(keys, values)) as ArrayRef;

let result = build_struct_inlist_values(std::slice::from_ref(&dict_array))
.unwrap()
.unwrap();

assert_eq!(result.len(), 3);
assert_eq!(*result.data_type(), DataType::Utf8);
}
}