Skip to content
65 changes: 19 additions & 46 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,36 @@ use std::fmt::{Debug, Formatter};
use std::mem::size_of_val;
use std::sync::Arc;

use arrow::array::{Array, RecordBatch};
use arrow::array::Array;
use arrow::compute::{filter, is_not_null};
use arrow::datatypes::FieldRef;
use arrow::{
array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
datatypes::{DataType, Field, Schema},
datatypes::{DataType, Field},
};
use datafusion_common::{
downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
Result, ScalarValue,
downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::expr::{AggregateFunction, Sort};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
TypeSignature, Volatility,
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
Volatility,
};
use datafusion_functions_aggregate_common::tdigest::{
TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
};
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

use crate::utils::{get_scalar_value, validate_percentile_expr};

create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);

/// Computes the approximate percentile continuous of a set of numbers
Expand Down Expand Up @@ -164,7 +166,8 @@ impl ApproxPercentileCont {
&self,
args: AccumulatorArgs,
) -> Result<ApproxPercentileAccumulator> {
let percentile = validate_input_percentile_expr(&args.exprs[1])?;
let percentile =
validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?;

let is_descending = args
.order_bys
Expand Down Expand Up @@ -214,45 +217,15 @@ impl ApproxPercentileCont {
}
}

fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
let empty_schema = Arc::new(Schema::empty());
let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
Ok(s)
} else {
internal_err!("Didn't expect ColumnarValue::Array")
}
}

fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
let percentile = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
ScalarValue::Float32(Some(value)) => {
value as f64
}
ScalarValue::Float64(Some(value)) => {
value
}
sv => {
return not_impl_err!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
sv.data_type()
)
}
};

// Ensure the percentile is between 0 and 1.
if !(0.0..=1.0).contains(&percentile) {
return plan_err!(
"Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
);
}
Ok(percentile)
}

fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
let max_size = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
let scalar_value = get_scalar_value(expr).map_err(|_e| {
DataFusionError::Plan(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal"
.to_string(),
)
})?;

let max_size = match scalar_value {
ScalarValue::UInt8(Some(q)) => q as usize,
ScalarValue::UInt16(Some(q)) => q as usize,
ScalarValue::UInt32(Some(q)) => q as usize,
Expand All @@ -262,7 +235,7 @@ fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
sv => {
return not_impl_err!(
return plan_err!(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
sv.data_type()
)
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@ pub mod hyperloglog;
pub mod median;
pub mod min_max;
pub mod nth_value;
pub mod percentile_cont;
pub mod regr;
pub mod stddev;
pub mod string_agg;
pub mod sum;
pub mod variance;

pub mod planner;
mod utils;

use crate::approx_percentile_cont::approx_percentile_cont_udaf;
use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
Expand Down Expand Up @@ -123,6 +125,7 @@ pub mod expr_fn {
pub use super::min_max::max;
pub use super::min_max::min;
pub use super::nth_value::nth_value;
pub use super::percentile_cont::percentile_cont;
pub use super::regr::regr_avgx;
pub use super::regr::regr_avgy;
pub use super::regr::regr_count;
Expand Down Expand Up @@ -171,6 +174,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
approx_distinct::approx_distinct_udaf(),
approx_percentile_cont_udaf(),
approx_percentile_cont_with_weight_udaf(),
percentile_cont::percentile_cont_udaf(),
string_agg::string_agg_udaf(),
bit_and_or_xor::bit_and_udaf(),
bit_and_or_xor::bit_or_udaf(),
Expand Down
Loading