Skip to content

support merging primitive dictionaries in interleave and concat #7468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 121 additions & 10 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@
//! assert_eq!(arr.len(), 3);
//! ```

use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
use arrow_array::builder::{BooleanBuilder, GenericByteBuilder, PrimitiveBuilder};
use crate::dictionary::{
merge_dictionary_values, should_merge_dictionary_values, ShouldMergeValues,
};
use arrow_array::builder::{
BooleanBuilder, GenericByteBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder,
};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
Expand Down Expand Up @@ -84,6 +88,7 @@ fn fixed_size_list_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capa
}

fn concat_dictionaries<K: ArrowDictionaryKeyType>(
value_type: &DataType,
arrays: &[&dyn Array],
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
Expand All @@ -93,11 +98,41 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
.inspect(|d| output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
return concat_fallback(arrays, Capacities::Array(output_len));
let is_overflow = match should_merge_dictionary_values::<K>(&dictionaries, output_len) {
ShouldMergeValues::ConcatWillOverflow => true,
ShouldMergeValues::Yes => false,
ShouldMergeValues::No => {
return concat_fallback(arrays, Capacities::Array(output_len));
}
};

macro_rules! primitive_dict_helper {
($t:ty) => {
merge_concat_primitive_dictionaries::<K, $t>(&dictionaries, output_len)
};
}

let merged = merge_dictionary_values(&dictionaries, None)?;
downcast_primitive! {
value_type => (primitive_dict_helper),
DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
merge_concat_byte_dictionaries(&dictionaries, output_len)
},
// merge not yet implemented for this type and it's not going to overflow, so fall back
// to concatenating values
_ if !is_overflow => concat_fallback(arrays, Capacities::Array(output_len)),
other => Err(ArrowError::NotYetImplemented(format!(
"concat of dictionaries would overflow key type {key_type:?} and \
value type {other:?} not yet supported for merging",
key_type = K::DATA_TYPE,
)))
}
}

fn merge_concat_byte_dictionaries<K: ArrowDictionaryKeyType>(
dictionaries: &[&DictionaryArray<K>],
output_len: usize,
) -> Result<ArrayRef, ArrowError> {
let merged = merge_dictionary_values(dictionaries, None)?;

// Recompute keys
let mut key_values = Vec::with_capacity(output_len);
Expand All @@ -113,7 +148,7 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(

let nulls = has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for d in &dictionaries {
for d in dictionaries {
match d.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(d.len(), true),
Expand All @@ -130,6 +165,19 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
Ok(Arc::new(array))
}

fn merge_concat_primitive_dictionaries<K: ArrowDictionaryKeyType, V: ArrowPrimitiveType>(
dictionaries: &[&DictionaryArray<K>],
output_len: usize,
) -> Result<ArrayRef, ArrowError> {
let mut builder = PrimitiveDictionaryBuilder::<K, V>::with_capacity(output_len, 0);
for dict in dictionaries {
for value in dict.downcast_dict::<PrimitiveArray<V>>().unwrap() {
builder.append_option(value);
}
}
Ok(Arc::new(builder.finish()))
}

fn concat_lists<OffsetSize: OffsetSizeTrait>(
arrays: &[&dyn Array],
field: &FieldRef,
Expand Down Expand Up @@ -231,8 +279,8 @@ fn concat_bytes<T: ByteArrayType>(arrays: &[&dyn Array]) -> Result<ArrayRef, Arr
}

macro_rules! dict_helper {
($t:ty, $arrays:expr) => {
return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
($t:ty, $value_type:expr, $arrays:expr) => {
concat_dictionaries::<$t>($value_type.as_ref(), $arrays)
};
}

Expand Down Expand Up @@ -300,9 +348,9 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
downcast_primitive! {
d => (primitive_concat, arrays),
DataType::Boolean => concat_boolean(arrays),
DataType::Dictionary(k, _) => {
DataType::Dictionary(k, v) => {
downcast_integer! {
k.as_ref() => (dict_helper, arrays),
k.as_ref() => (dict_helper, v, arrays),
_ => unreachable!("illegal dictionary key type {k}")
}
}
Expand Down Expand Up @@ -938,6 +986,69 @@ mod tests {
assert!((30..40).contains(&values_len), "{values_len}")
}

#[test]
fn test_concat_dictionary_overflows() {
// each array has length equal to the full dictionary key space
let len: usize = usize::try_from(i8::MAX).unwrap();

let a = DictionaryArray::<Int8Type>::new(
Int8Array::from_value(0, len),
Arc::new(Int8Array::from_value(0, len)),
);
let b = DictionaryArray::<Int8Type>::new(
Int8Array::from_value(0, len),
Arc::new(Int8Array::from_value(1, len)),
);

// Case 1: with a single input array, should _never_ overflow
let values = concat(&[&a]).unwrap();
let v = values.as_dictionary::<Int8Type>();
let vc = v.downcast_dict::<Int8Array>().unwrap();
let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect();
assert_eq!(&collected, &vec![0; len]);

// Case 2: two arrays
// Should still not overflow, there are only two values
let values = concat(&[&a, &b]).unwrap();
let v = values.as_dictionary::<Int8Type>();
let vc = v.downcast_dict::<Int8Array>().unwrap();
let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect();
assert_eq!(
&collected,
&vec![0; len]
.into_iter()
.chain(vec![1; len])
.collect::<Vec<_>>()
);
}

#[test]
fn test_unsupported_concat_dictionary_overflow() {
// each array has length equal to the full dictionary key space
let len: usize = usize::try_from(i8::MAX).unwrap();

let a = DictionaryArray::<Int8Type>::new(
Int8Array::from_value(0, len),
Arc::new(NullArray::new(len)),
);
let b = DictionaryArray::<Int8Type>::new(
Int8Array::from_value(0, len),
Arc::new(NullArray::new(len)),
);

// Case 1: with a single input array, should _never_ overflow
concat(&[&a]).unwrap();

// Case 2: two arrays
// Will fail to merge values on unsupported datatype
let values = concat(&[&a, &b]).unwrap_err();
assert_eq!(
values.to_string(),
"Not yet implemented: concat of dictionaries would overflow key type Int8 and \
value type Null not yet supported for merging"
);
}

#[test]
fn test_concat_string_sizes() {
let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();
Expand Down
37 changes: 28 additions & 9 deletions arrow-select/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,19 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
}
}

/// Whether selection kernels should attempt to merge dictionary values
pub enum ShouldMergeValues {
/// Concatenation of the dictionary values will lead to overflowing
/// the key space; it's necessary to attempt to merge
ConcatWillOverflow,
/// The heuristic suggests that merging will be beneficial
Yes,
/// The heuristic suggests that merging is not necessary
No,
}

/// A type-erased function that compares two array for pointer equality
type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool;
type PtrEq = fn(&dyn Array, &dyn Array) -> bool;

/// A weak heuristic of whether to merge dictionary values that aims to only
/// perform the expensive merge computation when it is likely to yield at least
Expand All @@ -112,15 +123,15 @@ type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool;
pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
dictionaries: &[&DictionaryArray<K>],
len: usize,
) -> bool {
) -> ShouldMergeValues {
use DataType::*;
let first_values = dictionaries[0].values().as_ref();
let ptr_eq: Box<PtrEq> = match first_values.data_type() {
Utf8 => Box::new(bytes_ptr_eq::<Utf8Type>),
LargeUtf8 => Box::new(bytes_ptr_eq::<LargeUtf8Type>),
Binary => Box::new(bytes_ptr_eq::<BinaryType>),
LargeBinary => Box::new(bytes_ptr_eq::<LargeBinaryType>),
_ => return false,
let ptr_eq: PtrEq = match first_values.data_type() {
Utf8 => bytes_ptr_eq::<Utf8Type>,
LargeUtf8 => bytes_ptr_eq::<LargeUtf8Type>,
Binary => bytes_ptr_eq::<BinaryType>,
LargeBinary => bytes_ptr_eq::<LargeBinaryType>,
_ => |_, _| false,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this line allows should_merge_dictionary_values to return true for the primitive data types, but it also introduces a regression for weird types like dictionaries-of-unions. This heuristic might ask for these to merge if the values exceed the expected output length, which will now fail. Before they might be able to concat / interleave via the fallback methods even if merging is not supported (however they could easily hit the panic case).

I think probably the solution is to adjust this heuristic so it outputs why it needs to merge (e.g. must merge to attempt to avoid overflow) - and in callers, for overflow we would fail gracefully for types which can't yet be merged.

};

let mut single_dictionary = true;
Expand All @@ -136,7 +147,15 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
let overflow = K::Native::from_usize(total_values).is_none();
let values_exceed_length = total_values >= len;

!single_dictionary && (overflow || values_exceed_length)
if single_dictionary {
ShouldMergeValues::No
} else if overflow {
ShouldMergeValues::ConcatWillOverflow
} else if values_exceed_length {
ShouldMergeValues::Yes
} else {
ShouldMergeValues::No
}
}

/// Given an array of dictionaries and an optional key mask compute a values array
Expand Down
Loading
Loading