Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,8 @@ mod tests {
assert_eq!(hashes1, hashes2);
}

// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
#[test]
#[cfg(not(feature = "force_hash_collisions"))]
fn test_create_hashes_with_quality_hash_state() {
Expand Down
97 changes: 97 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,55 @@ async fn test_udaf() {
assert!(!test_state.retract_batch());
}

#[tokio::test]
async fn test_zero_argument_udaf() {
let TestContext { mut ctx, .. } = TestContext::new();
NullaryAccumulator::register(&mut ctx, "window_start");

let actual = execute(&ctx, "SELECT window_start() from t").await.unwrap();

insta::assert_snapshot!(batches_to_string(&actual), @r"
+----------------+
| window_start() |
+----------------+
| 7 |
+----------------+
");

let actual = execute(
&ctx,
"SELECT window_start() FILTER (WHERE value < 0.0) FROM t",
)
.await
.unwrap();

insta::assert_snapshot!(batches_to_string(&actual), @r"
+----------------------------------------------------+
| window_start() FILTER (WHERE t.value < Float64(0)) |
+----------------------------------------------------+
| |
+----------------------------------------------------+
");

let actual = execute(
&ctx,
"SELECT value, window_start() FROM t GROUP BY value ORDER BY value",
)
.await
.unwrap();

insta::assert_snapshot!(batches_to_string(&actual), @r"
+-------+----------------+
| value | window_start() |
+-------+----------------+
| 1.0 | 7 |
| 2.0 | 7 |
| 3.0 | 7 |
| 5.0 | 7 |
+-------+----------------+
");
}

/// User defined aggregate used as a window function
#[tokio::test]
async fn test_udaf_as_window() {
Expand Down Expand Up @@ -559,6 +608,54 @@ impl TestState {
}
}

#[derive(Debug, Default)]
struct NullaryAccumulator {
value: Option<i64>,
}

impl NullaryAccumulator {
fn register(ctx: &mut SessionContext, name: &str) {
let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::<Self>::default()));

let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
name,
Signature::nullary(Volatility::Immutable),
DataType::Int64,
accumulator,
vec![Field::new("value", DataType::Int64, true).into()],
));

ctx.register_udaf(udaf)
}
}

impl Accumulator for NullaryAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
assert!(values.is_empty());
self.value = Some(7);
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
assert_eq!(states.len(), 1);
self.value = Some(7);
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(self.value))
}

fn size(&self) -> usize {
size_of_val(self)
}
}

/// Models a user defined aggregate function that computes the a sum
/// of timestamps (not a quantity that has much real world meaning)
#[derive(Debug)]
Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ pub fn check_arg_count(
);
}
}
TypeSignature::Nullary => {
if !input_fields.is_empty() {
return plan_err!(
"The function {func_name} expects zero arguments, but {} were provided",
input_fields.len()
);
}
}
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::mem::{size_of, size_of_val};

use arrow::array::new_empty_array;
use arrow::{
array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray},
array::{Array, ArrayRef, AsArray, BooleanArray, PrimitiveArray},
compute,
compute::take_arrays,
datatypes::UInt32Type,
Expand Down Expand Up @@ -100,6 +100,12 @@ pub struct GroupsAccumulatorAdapter {
/// bottleneck in earlier implementations when there were many
/// distinct groups.
allocation_bytes: usize,

/// Whether this adapter can convert raw input rows directly to aggregate
/// state. Nullary aggregates do not carry a value array with the input row
/// cardinality, so this optimization cannot safely reconstruct one state
/// per input row for them.
supports_convert_to_state: bool,
}

struct AccumulatorState {
Expand Down Expand Up @@ -137,6 +143,24 @@ impl GroupsAccumulatorAdapter {
factory: Box::new(factory),
states: vec![],
allocation_bytes: 0,
supports_convert_to_state: true,
}
}

/// Create a new adapter with explicit control over whether row-to-state
/// conversion is supported.
pub fn new_with_convert_to_state<F>(
factory: F,
supports_convert_to_state: bool,
) -> Self
where
F: Fn() -> Result<Box<dyn Accumulator>> + Send + 'static,
{
Self {
factory: Box::new(factory),
states: vec![],
allocation_bytes: 0,
supports_convert_to_state,
}
}

Expand Down Expand Up @@ -196,13 +220,24 @@ impl GroupsAccumulatorAdapter {
{
self.make_accumulators_if_needed(total_num_groups)?;

assert_eq!(values[0].len(), group_indices.len());
if values.is_empty() {
for (idx, group_index) in group_indices.iter().enumerate() {
if opt_filter
.is_some_and(|filter| !filter.is_valid(idx) || !filter.value(idx))
{
continue;
}
self.states[*group_index].indices.push(idx as u32);
}
} else {
assert_eq!(values[0].len(), group_indices.len());

// figure out which input rows correspond to which groups.
// Note that self.state.indices starts empty for all groups
// (it is cleared out below)
for (idx, group_index) in group_indices.iter().enumerate() {
self.states[*group_index].indices.push(idx as u32);
// figure out which input rows correspond to which groups.
// Note that self.state.indices starts empty for all groups
// (it is cleared out below)
for (idx, group_index) in group_indices.iter().enumerate() {
self.states[*group_index].indices.push(idx as u32);
}
}

// groups_with_rows holds a list of group indexes that have
Expand Down Expand Up @@ -230,13 +265,18 @@ impl GroupsAccumulatorAdapter {
offset_so_far += indices.len();
offsets.push(offset_so_far);
}
let batch_indices = batch_indices.into();
let values_and_filter = if values.is_empty() {
None
} else {
let batch_indices = batch_indices.into();

// reorder the values and opt_filter by batch_indices so that
// all values for each group are contiguous, then invoke the
// accumulator once per group with values
let values = take_arrays(values, &batch_indices, None)?;
let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?;
// reorder the values and opt_filter by batch_indices so that
// all values for each group are contiguous, then invoke the
// accumulator once per group with values
let values = take_arrays(values, &batch_indices, None)?;
let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?;
Some((values, opt_filter))
};

// invoke each accumulator with the appropriate rows, first
// pulling the input arguments for this group into their own
Expand All @@ -249,11 +289,14 @@ impl GroupsAccumulatorAdapter {
let state = &mut self.states[group_idx];
sizes_pre += state.size();

let values_to_accumulate = slice_and_maybe_filter(
&values,
opt_filter.as_ref().map(|f| f.as_boolean()),
offsets,
)?;
let values_to_accumulate = match &values_and_filter {
Some((values, opt_filter)) => slice_and_maybe_filter(
values,
opt_filter.as_ref().map(|f| f.as_boolean()),
offsets,
)?,
None => vec![],
};
f(state.accumulator.as_mut(), &values_to_accumulate)?;

// clear out the state so they are empty for next
Expand Down Expand Up @@ -443,7 +486,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
}

fn supports_convert_to_state(&self) -> bool {
true
self.supports_convert_to_state
}
}

Expand Down
31 changes: 23 additions & 8 deletions datafusion/functions-aggregate/src/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,12 @@ mod tests {
hll.count() as u64
}

fn array_hashes(array: &ArrayRef) -> Vec<u64> {
let mut hashes = vec![0; array.len()];
create_hashes([array.as_ref()], &HLL_HASH_STATE, &mut hashes).unwrap();
hashes
}

fn serialize(g: &mut GroupHll) -> Vec<u8> {
let mut buf = Vec::new();
g.serialize(&mut buf);
Expand Down Expand Up @@ -992,17 +998,20 @@ mod tests {
BooleanArray::from(vec![Some(true), None, Some(false), None, Some(true)]);

let mut acc = HllGroupsAccumulator::new();
// put all rows in group 0
let group_indices = vec![0usize; 5];
acc.update_batch(&[values], &group_indices, Some(&filter), 1)
// Put true rows in group 0, and the null/false filter rows in group 1.
// The collision CI job forces every hash to 0, so group 1 is what proves
// filtered-out rows are not counted there.
let group_indices = vec![0usize, 1, 1, 1, 0];
let hashes = array_hashes(&values);
acc.update_batch(&[values], &group_indices, Some(&filter), 2)
.unwrap();

// Only rows 0 and 4 (values 1 and 5) should be counted.
let result = acc.evaluate(EmitTo::All).unwrap();
let counts = result.as_any().downcast_ref::<UInt64Array>().unwrap();
// reference: hash 1 and 5 into a dense sketch
let expected = reference_count(&[h(1), h(5)]);
let expected = reference_count(&[hashes[0], hashes[4]]);
assert_eq!(counts.value(0), expected);
assert_eq!(counts.value(1), 0);
}

/// Regression: a short (≤ 12-byte) Utf8View string must hash identically
Expand All @@ -1020,16 +1029,21 @@ mod tests {
assert!(!batch2.as_string_view().data_buffers().is_empty());

let group_indices = vec![0usize, 0];
let batch1_hashes = array_hashes(&batch1);
let batch2_hashes = array_hashes(&batch2);
let expected =
reference_count(&[batch1_hashes[0], batch1_hashes[1], batch2_hashes[1]]);
let mut acc = HllGroupsAccumulator::new();
acc.update_batch(&[batch1], &group_indices, None, 1)
.unwrap();
acc.update_batch(&[batch2], &group_indices, None, 1)
.unwrap();

// True distinct values: {"aaa", "bbb", LONG} == 3.
// Under normal hashing the true distinct values are {"aaa", "bbb", LONG};
// under collision testing they intentionally collapse to fewer hashes.
let result = acc.evaluate(EmitTo::All).unwrap();
let counts = result.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(counts.value(0), 3);
assert_eq!(counts.value(0), expected);
}

/// Regression: a short (≤ 12-byte) Utf8View string must hash identically
Expand All @@ -1039,6 +1053,7 @@ mod tests {
// Multiset: {"aaa" x2, "bbb", LONG}, so 3 distinct values.
let mixed: ArrayRef =
Arc::new(StringViewArray::from(vec!["aaa", "bbb", LONG, "aaa"]));
let expected = reference_count(&array_hashes(&mixed)[0..3]);
let mut acc_single = HLLAccumulator::new();
acc_single.update_batch(&[mixed]).unwrap();

Expand All @@ -1057,6 +1072,6 @@ mod tests {
distinct_count(&mut acc_single),
distinct_count(&mut acc_split)
);
assert_eq!(distinct_count(&mut acc_single), 3);
assert_eq!(distinct_count(&mut acc_single), expected);
}
}
10 changes: 7 additions & 3 deletions datafusion/physical-expr/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
use datafusion_common::metadata::FieldMetadata;
use datafusion_common::{
DFSchema, Result, ScalarValue, assert_or_internal_err, internal_err, not_impl_err,
DFSchema, Result, ScalarValue, internal_err, not_impl_err, plan_err,
};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::{
Expand Down Expand Up @@ -261,8 +261,6 @@ impl AggregateExprBuilder {
is_distinct,
is_reversed,
} = self;
assert_or_internal_err!(!args.is_empty(), "args should not be empty");

let ordering_types = order_bys
.iter()
.map(|e| e.expr.data_type(&schema))
Expand All @@ -275,6 +273,12 @@ impl AggregateExprBuilder {
.map(|arg| arg.return_field(&schema))
.collect::<Result<Vec<_>>>()?;

if args.is_empty() && fun.name() == "count" {
return plan_err!(
"Physical count aggregate requires an argument; use COUNT_STAR_EXPANSION for count(*)"
);
}

check_arg_count(
fun.name(),
&input_exprs_fields,
Expand Down
Loading
Loading