Skip to content

Commit

Permalink
chore: extract math_funcs expressions to folders based on spark group…
Browse files Browse the repository at this point in the history
…ing (#1219)

* extract math_funcs expressions to folders based on spark grouping

* fix merge conflicts and move chr to `string_funcs`
  • Loading branch information
rluvaton authored Jan 20, 2025
1 parent 2588e13 commit 517c255
Show file tree
Hide file tree
Showing 21 changed files with 661 additions and 589 deletions.
2 changes: 1 addition & 1 deletion native/spark-expr/benches/decimal_div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use arrow::compute::cast;
use arrow_array::builder::Decimal128Builder;
use arrow_schema::DataType;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_comet_spark_expr::scalar_funcs::spark_decimal_div;
use datafusion_comet_spark_expr::spark_decimal_div;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;

Expand Down
8 changes: 4 additions & 4 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
// under the License.

use crate::hash_funcs::*;
use crate::scalar_funcs::{
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_round, spark_unhex, spark_unscaled_value, SparkChrFunc,
use crate::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_unhex,
spark_unscaled_value, SparkChrFunc,
};
use crate::{spark_date_add, spark_date_sub, spark_read_side_padding};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/hash_funcs/sha2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::scalar_funcs::hex_strings;
use crate::math_funcs::hex::hex_strings;
use arrow_array::{Array, StringArray};
use datafusion::functions::crypto::{sha224, sha256, sha384, sha512};
use datafusion_common::cast::as_binary_array;
Expand Down
25 changes: 12 additions & 13 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,22 @@

mod error;

mod checkoverflow;
pub use checkoverflow::CheckOverflow;

mod kernels;
pub mod scalar_funcs;
mod schema_adapter;
mod static_invoke;
pub use schema_adapter::SparkSchemaAdapterFactory;
pub use static_invoke::*;

mod negative;
mod struct_funcs;
pub use negative::{create_negate_expr, NegativeExpr};
mod normalize_nan;
pub use struct_funcs::{CreateNamedStruct, GetStructField};

mod json_funcs;
pub mod test_common;
pub mod timezone;
mod unbound;
pub use unbound::UnboundColumn;
pub mod utils;
pub use normalize_nan::NormalizeNaNAndZero;
mod predicate_funcs;
pub mod utils;
pub use predicate_funcs::{spark_isnan, RLike};

mod agg_funcs;
Expand All @@ -57,24 +50,30 @@ mod string_funcs;
mod datetime_funcs;
pub use agg_funcs::*;

pub use crate::{CreateNamedStruct, GetStructField};
pub use crate::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};
pub use cast::{spark_cast, Cast, SparkCastOptions};
mod conditional_funcs;
mod conversion_funcs;
mod math_funcs;

pub use array_funcs::*;
pub use bitwise_funcs::*;
pub use conditional_funcs::*;
pub use conversion_funcs::*;

pub use comet_scalar_funcs::create_comet_physical_fun;
pub use datetime_funcs::*;
pub use datetime_funcs::{
spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr, SecondExpr,
TimestampTruncExpr,
};
pub use error::{SparkError, SparkResult};
pub use hash_funcs::*;
pub use json_funcs::ToJson;
pub use math_funcs::{
create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_make_decimal,
spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr,
NormalizeNaNAndZero,
};
pub use string_funcs::*;
pub use struct_funcs::*;

/// Spark supports three evaluation modes when evaluating expressions, which affect
/// the behavior when processing input values that are invalid or would result in an
Expand Down
83 changes: 83 additions & 0 deletions native/spark-expr/src/math_funcs/ceil.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::downcast_compute_op;
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
use arrow::array::{Float32Array, Float64Array, Int64Array};
use arrow_array::{Array, ArrowNativeTypeOp};
use arrow_schema::DataType;
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use num::integer::div_ceil;
use std::sync::Arc;

/// `ceil` function that simulates Spark `ceil` expression
pub fn spark_ceil(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
let value = &args[0];
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float32 => {
let result = downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array);
Ok(ColumnarValue::Array(result?))
}
DataType::Float64 => {
let result = downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array);
Ok(ColumnarValue::Array(result?))
}
DataType::Int64 => {
let result = array.as_any().downcast_ref::<Int64Array>().unwrap();
Ok(ColumnarValue::Array(Arc::new(result.clone())))
}
DataType::Decimal128(_, scale) if *scale > 0 => {
let f = decimal_ceil_f(scale);
let (precision, scale) = get_precision_scale(data_type);
make_decimal_array(array, precision, scale, &f)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function ceil",
other,
))),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
a.map(|x| x.ceil() as i64),
))),
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
a.map(|x| x.ceil() as i64),
))),
ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
let f = decimal_ceil_f(scale);
let (precision, scale) = get_precision_scale(data_type);
make_decimal_scalar(a, precision, scale, &f)
}
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function ceil",
value.data_type(),
))),
},
}
}

#[inline]
fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 {
let div = 10_i128.pow_wrapping(*scale as u32);
move |x: i128| div_ceil(x, div)
}
92 changes: 92 additions & 0 deletions native/spark-expr/src/math_funcs/div.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::math_funcs::utils::get_precision_scale;
use arrow::{
array::{ArrayRef, AsArray},
datatypes::Decimal128Type,
};
use arrow_array::{Array, Decimal128Array};
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::DataFusionError;
use num::{BigInt, Signed, ToPrimitive};
use std::sync::Arc;

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale >
// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt.
pub fn spark_decimal_div(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
let left = &args[0];
let right = &args[1];
let (p3, s3) = get_precision_scale(data_type);

let (left, right): (ArrayRef, ArrayRef) = match (left, right) {
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
(ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
(l.to_array_of_size(r.len())?, Arc::clone(r))
}
(ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
(Arc::clone(l), r.to_array_of_size(l.len())?)
}
(ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?),
};
let left = left.as_primitive::<Decimal128Type>();
let right = right.as_primitive::<Decimal128Type>();
let (p1, s1) = get_precision_scale(left.data_type());
let (p2, s2) = get_precision_scale(right.data_type());

let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32
|| p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
{
let ten = BigInt::from(10);
let l_mul = ten.pow(l_exp);
let r_mul = ten.pow(r_exp);
let five = BigInt::from(5);
let zero = BigInt::from(0);
arrow::compute::kernels::arity::binary(left, right, |l, r| {
let l = BigInt::from(l) * &l_mul;
let r = BigInt::from(r) * &r_mul;
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
let res = if div.is_negative() {
div - &five
} else {
div + &five
} / &ten;
res.to_i128().unwrap_or(i128::MAX)
})?
} else {
let l_mul = 10_i128.pow(l_exp);
let r_mul = 10_i128.pow(r_exp);
arrow::compute::kernels::arity::binary(left, right, |l, r| {
let l = l * l_mul;
let r = r * r_mul;
let div = if r == 0 { 0 } else { l / r };
let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
res.to_i128().unwrap_or(i128::MAX)
})?
};
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Ok(ColumnarValue::Array(Arc::new(result)))
}
83 changes: 83 additions & 0 deletions native/spark-expr/src/math_funcs/floor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::downcast_compute_op;
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
use arrow::array::{Float32Array, Float64Array, Int64Array};
use arrow_array::{Array, ArrowNativeTypeOp};
use arrow_schema::DataType;
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use num::integer::div_floor;
use std::sync::Arc;

/// `floor` function that simulates Spark `floor` expression
pub fn spark_floor(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
let value = &args[0];
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float32 => {
let result = downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array);
Ok(ColumnarValue::Array(result?))
}
DataType::Float64 => {
let result = downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array);
Ok(ColumnarValue::Array(result?))
}
DataType::Int64 => {
let result = array.as_any().downcast_ref::<Int64Array>().unwrap();
Ok(ColumnarValue::Array(Arc::new(result.clone())))
}
DataType::Decimal128(_, scale) if *scale > 0 => {
let f = decimal_floor_f(scale);
let (precision, scale) = get_precision_scale(data_type);
make_decimal_array(array, precision, scale, &f)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function floor",
other,
))),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
a.map(|x| x.floor() as i64),
))),
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
a.map(|x| x.floor() as i64),
))),
ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
let f = decimal_floor_f(scale);
let (precision, scale) = get_precision_scale(data_type);
make_decimal_scalar(a, precision, scale, &f)
}
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function floor",
value.data_type(),
))),
},
}
}

#[inline]
fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 {
let div = 10_i128.pow_wrapping(*scale as u32);
move |x: i128| div_floor(x, div)
}
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 517c255

Please sign in to comment.