diff --git a/crates/control_plane/src/service.rs b/crates/control_plane/src/service.rs index 60cb88bb6..8774aa1b7 100644 --- a/crates/control_plane/src/service.rs +++ b/crates/control_plane/src/service.rs @@ -286,15 +286,17 @@ impl ControlService for ControlServiceImpl { // let mut buffer = Vec::new(); // let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); - // let mut stream_writer = StreamWriter::try_new_with_options( - // &mut buffer, &records[0].schema_ref(), options).unwrap(); + // let mut stream_writer = + // StreamWriter::try_new_with_options(&mut buffer, &records[0].schema_ref(), options) + // .unwrap(); // stream_writer.write(&records[0]).unwrap(); // stream_writer.finish().unwrap(); // drop(stream_writer); - // + // // Try to add flatbuffer verification // println!("{:?}", buffer.len()); - // let res = general_purpose::STANDARD.encode(buffer); + // let base64 = general_purpose::STANDARD.encode(buffer); + // Ok((base64, columns)) // let encoded = general_purpose::STANDARD.decode(res.clone()).unwrap(); // // let mut verifier = Verifier::new(&VerifierOptions::default(), &encoded); diff --git a/crates/control_plane/src/sql/functions/common.rs b/crates/control_plane/src/sql/functions/common.rs index b12e4fd41..008396df6 100644 --- a/crates/control_plane/src/sql/functions/common.rs +++ b/crates/control_plane/src/sql/functions/common.rs @@ -1,15 +1,18 @@ use crate::models::ColumnInfo; use arrow::array::{ - Array, Int32Array, Int64Array, StructArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UnionArray, + Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UnionArray, }; use arrow::datatypes::{Field, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; +use chrono::{DateTime, Utc}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; use std::sync::Arc; +const TIMESTAMP_FORMAT: &str = "%Y-%m-%d-%H:%M:%S%.9f"; + pub fn first_non_empty_type(union_array: &UnionArray) -> Option<(DataType, ArrayRef)> { for i in 0..union_array.type_ids().len() { let type_id = union_array.type_id(i); @@ -71,8 +74,6 @@ pub fn convert_record_batches( columns.push(converted_column); } let new_schema = Arc::new(Schema::new(fields)); - println!("new schema: {:?}", new_schema); - println!("columns: {:?}", columns); let converted_batch = RecordBatch::try_new(new_schema, columns)?; converted_batches.push(converted_batch); } @@ -81,60 +82,116 @@ pub fn convert_record_batches( } fn convert_timestamp_to_struct(column: &ArrayRef, unit: &TimeUnit) -> ArrayRef { - let (epoch, fraction) = match unit { - TimeUnit::Second => { - let array = column - .as_any() - .downcast_ref::() - .unwrap(); - let epoch: Int64Array = array.clone().unary(|x| x); - let fraction: Int32Array = Int32Array::from(vec![0; column.len()]); - (epoch, fraction) - } - TimeUnit::Millisecond => { - let array = column - .as_any() - .downcast_ref::() - .unwrap(); - let epoch: Int64Array = array.clone().unary(|x| x / 1_000); - let fraction: Int32Array = array.clone().unary(|x| (x % 1_000 * 1_000_000) as i32); - (epoch, fraction) - } - TimeUnit::Microsecond => { - let array = column - .as_any() - .downcast_ref::() - .unwrap(); - let epoch: Int64Array = array.clone().unary(|x| x / 1_000_000); - let fraction: Int32Array = array.clone().unary(|x| (x % 1_000_000 * 1_000) as i32); - (epoch, fraction) - } - TimeUnit::Nanosecond => { - let array = column - .as_any() - .downcast_ref::() - .unwrap(); - let epoch: Int64Array = array.clone().unary(|x| x / 1_000_000_000); - let fraction: Int32Array = array.clone().unary(|x| (x % 1_000_000_000) as i32); - (epoch, fraction) - } + let now = Utc::now(); + let timestamps: Vec<_> = match unit { + TimeUnit::Second => column + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|x| { + let ts = DateTime::from_timestamp(x.unwrap_or(now.timestamp()), 0).unwrap(); + format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_micros()) + }) + .collect(), + TimeUnit::Millisecond => column + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|x| { + let ts = + DateTime::from_timestamp_millis(x.unwrap_or(now.timestamp_millis())).unwrap(); + format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_micros()) + }) + .collect(), + TimeUnit::Microsecond => column + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|x| { + let ts = + DateTime::from_timestamp_micros(x.unwrap_or(now.timestamp_micros())).unwrap(); + format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_micros()) + }) + .collect(), + TimeUnit::Nanosecond => column + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|x| { + let ts = + DateTime::from_timestamp_nanos(x.unwrap_or(now.timestamp_nanos_opt().unwrap())); + format!("{}.{}", ts.timestamp(), ts.timestamp_subsec_micros()) + }) + .collect(), }; - + // let (epoch, fraction) = match unit { + // TimeUnit::Second => { + // let array = column + // .as_any() + // .downcast_ref::() + // .unwrap(); + // let epoch: Int64Array = array.iter().map(|x| x.unwrap_or(now)).collect(); + // let fraction: Int32Array = Int32Array::from(vec![0; column.len()]); + // (epoch, fraction) + // } + // TimeUnit::Millisecond => { + // let array = column + // .as_any() + // .downcast_ref::() + // .unwrap(); + // let now_millis = now * 1_000; + // let epoch: Int64Array = array.iter().map(|x| x.unwrap_or(now_millis) / 1_000).collect(); + // let fraction: Int32Array = array + // .iter() + // .map(|x| (x.unwrap_or(0) % 1_000 * 1_000_000) as i32) + // .collect(); + // (epoch, fraction) + // } + // TimeUnit::Microsecond => { + // let array = column + // .as_any() + // .downcast_ref::() + // .unwrap(); + // let now_micros = now * 1_000_000; + // let epoch: Int64Array = array.iter().map(|x| x.unwrap_or(now_micros) / 1_000_000).collect(); + // let fraction: Int32Array = array + // .iter() + // .map(|x| (x.unwrap_or(0) % 1_000_000 * 1_000) as i32) + // .collect(); + // (epoch, fraction) + // } + // TimeUnit::Nanosecond => { + // let array = column + // .as_any() + // .downcast_ref::() + // .unwrap(); + // let now_nanos = now * 1_000_000_000; + // let epoch: Int64Array = array.iter().map(|x| x.unwrap_or(now_nanos) / 1_000_000_000).collect(); + // let fraction: Int32Array = array + // .iter() + // .map(|x| (x.unwrap_or(0) % 1_000_000_000) as i32) + // .collect(); + // (epoch, fraction) + // } + // }; + // let string_values: Vec<_> = epoch.iter().map(|x| x.unwrap_or(0).to_string()).collect(); + // let string_array = StringArray::from(string_values); // let timezone = Int32Array::from(vec![1440; column.len()]); // Assuming UTC timezone - let struct_array = StructArray::from(vec![ - ( - Arc::new(Field::new("epoch", DataType::Int64, false)), - Arc::new(epoch) as ArrayRef, - ), - ( - Arc::new(Field::new("fraction", DataType::Int32, false)), - Arc::new(fraction) as ArrayRef, - ), - // ( - // Arc::new(Field::new("timezone", DataType::Int32, false)), - // Arc::new(timezone) as ArrayRef, - // ), - ]); + // let struct_array = StructArray::try_new( + // vec![ + // Arc::new(Field::new("epoch", DataType::Int64, false)), + // Arc::new(Field::new("fraction", DataType::Int32, false)), + // ] + // .into(), + // vec![Arc::new(epoch) as ArrayRef, Arc::new(fraction) as ArrayRef], + // None, + // ) + // .unwrap(); + // Arc::new(epoch) as ArrayRef - Arc::new(struct_array) as ArrayRef + Arc::new(StringArray::from(timestamps)) as ArrayRef } diff --git a/crates/control_plane/src/sql/functions/date_add.rs b/crates/control_plane/src/sql/functions/date_add.rs new file mode 100644 index 000000000..b563f0b72 --- /dev/null +++ b/crates/control_plane/src/sql/functions/date_add.rs @@ -0,0 +1,198 @@ +use arrow::array::Array; +use arrow::datatypes::DataType::{Date32, Date64, Int64, Time32, Time64, Timestamp, Utf8}; +use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; +use arrow::datatypes::{DataType, Fields}; +use datafusion::common::{plan_err, Result}; +use datafusion::logical_expr::TypeSignature::Exact; +use datafusion::logical_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, +}; +use datafusion::scalar::ScalarValue; +use std::any::Any; + +#[derive(Debug)] +pub struct DateAddFunc { + signature: Signature, + aliases: Vec, +} + +impl Default for DateAddFunc { + fn default() -> Self { + Self::new() + } +} + +impl DateAddFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64, Date32]), + Exact(vec![Utf8, Int64, Date64]), + Exact(vec![Utf8, Int64, Time32(Second)]), + Exact(vec![Utf8, Int64, Time32(Nanosecond)]), + Exact(vec![Utf8, Int64, Time32(Microsecond)]), + Exact(vec![Utf8, Int64, Time32(Millisecond)]), + Exact(vec![Utf8, Int64, Time64(Second)]), + Exact(vec![Utf8, Int64, Time64(Nanosecond)]), + Exact(vec![Utf8, Int64, Time64(Microsecond)]), + Exact(vec![Utf8, Int64, Time64(Millisecond)]), + Exact(vec![Utf8, Int64, Timestamp(Second, None)]), + Exact(vec![Utf8, Int64, Timestamp(Millisecond, None)]), + Exact(vec![Utf8, Int64, Timestamp(Microsecond, None)]), + Exact(vec![Utf8, Int64, Timestamp(Nanosecond, None)]), + Exact(vec![ + Utf8, + Int64, + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![ + Utf8, + Int64, + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![ + Utf8, + Int64, + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![ + Utf8, + Int64, + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), + ], + Volatility::Immutable, + ), + aliases: vec![ + String::from("date_add"), + String::from("time_add"), + String::from("timeadd"), + String::from("timestamp_add"), + String::from("timestampadd"), + ], + } + } + + fn add_nanoseconds(val: &ScalarValue, nanoseconds: i64) -> Result { + Ok(ColumnarValue::Scalar( + val.add(ScalarValue::DurationNanosecond(Some(nanoseconds))) + .unwrap_or(ScalarValue::DurationNanosecond(Some(0))), + )) + } + fn add_years(val: &ScalarValue, years: i64) -> Result { + Ok(ColumnarValue::Scalar( + val.add(ScalarValue::new_interval_ym(i32::try_from(years) + .unwrap_or(0), 0) + ).unwrap_or(ScalarValue::new_interval_ym(0, 0)) + )) + } + fn add_months(val: &ScalarValue, months: i64) -> Result { + Ok(ColumnarValue::Scalar( + val.add(ScalarValue::new_interval_ym(0, i32::try_from(months) + .unwrap_or(0)) + ).unwrap_or(ScalarValue::new_interval_ym(0, 0)) + )) + } + fn add_days(val: &ScalarValue, days: i64) -> Result { + Ok(ColumnarValue::Scalar( + val.add(ScalarValue::new_interval_dt(i32::try_from(days) + .unwrap_or(0), 0) + ).unwrap_or(ScalarValue::new_interval_dt(0, 0)) + )) + } +} + +/// dateadd SQL function +/// Syntax: `DATEADD(, , )` +/// - : This indicates the units of time that you want to add. +/// For example if you want to add two days, then specify day. This unit of measure must be one of the values listed in Supported date and time parts. +/// - : This is the number of units of time that you want to add. +/// For example, if the units of time is day, and you want to add two days, specify 2. If you want to subtract two days, specify -2. +/// - : Must evaluate to a date, time, or timestamp. +/// This is the date, time, or timestamp to which you want to add. +/// For example, if you want to add two days to August 1, 2024, then specify '2024-08-01'::DATE. +/// If the data type is TIME, then the date_or_time_part must be in units of hours or smaller, not days or bigger. +/// If the input data type is DATE, and the date_or_time_part is hours or smaller, the input value will not be rejected, +/// but instead will be treated as a TIMESTAMP with hours, minutes, seconds, and fractions of a second all initially set to 0 (e.g. midnight on the specified date). +/// +/// Note: `dateadd` returns +/// If date_or_time_expr is a time, then the return data type is a time. +/// If date_or_time_expr is a timestamp, then the return data type is a timestamp. +/// If date_or_time_expr is a date: +/// - If date_or_time_part is day or larger (for example, month, year), the function returns a DATE value. +/// - If date_or_time_part is smaller than a day (for example, hour, minute, second), the function returns a TIMESTAMP_NTZ value, with 00:00:00.000 as the starting time for the date. +/// Usage notes: +/// - When date_or_time_part is year, quarter, or month (or any of their variations), +/// if the result month has fewer days than the original day of the month, the result day of the month might be different from the original day. +/// Examples +/// - dateadd(day, 30, CAST('2024-12-26' AS TIMESTAMP)) +impl ScalarUDFImpl for DateAddFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dateadd" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 3 { + return plan_err!("function requires three arguments"); + } + Ok(arg_types[2].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 3 { + return plan_err!("function requires three arguments"); + } + + let date_or_time_part = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(part))) => part.clone(), + _ => return plan_err!("Invalid unit type format"), + }; + + let value = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(val))) => *val, + _ => return plan_err!("Invalid value type"), + }; + let date_or_time_expr = match &args[2] { + ColumnarValue::Scalar(val) => val.clone(), + _ => return plan_err!("Invalid datetime type"), + }; + //there shouldn't be overflows + match date_or_time_part.as_str() { + //should consider leap year (365-366 days) + "year" | "y" | "yy" | "yyy" | "yyyy" | "yr" | "years" => DateAddFunc::add_years( + &date_or_time_expr, value), + //should consider months 28-31 days + "month" | "mm" | "mon" | "mons" | "months" => DateAddFunc::add_months( + &date_or_time_expr, value), + "day" | "d" | "dd" | "days" | "dayofmonth" => DateAddFunc::add_days( + &date_or_time_expr, value), + "week" | "w" | "wk" | "weekofyear" | "woy" | "wy" => DateAddFunc::add_days( + &date_or_time_expr, value * 7), + //should consider months 28-31 days + "quarter" | "q" | "qtr" | "qtrs" | "quarters" => DateAddFunc::add_months( + &date_or_time_expr, value * 3), + "hour" | "h" | "hh" | "hr" | "hours" | "hrs" => DateAddFunc::add_nanoseconds( + &date_or_time_expr, value * 3_600_000_000_000), + "minute" | "m" | "mi" | "min" | "minutes" | "mins" => DateAddFunc::add_nanoseconds( + &date_or_time_expr, value * 60_000_000_000), + "second" | "s" | "sec" | "seconds" | "secs" => DateAddFunc::add_nanoseconds( + &date_or_time_expr, value * 1_000_000_000), + "millisecond" | "ms" | "msec" | "milliseconds" => DateAddFunc::add_nanoseconds( + &date_or_time_expr, value * 1_000_000), + "microsecond" | "us" | "usec" | "microseconds" => DateAddFunc::add_nanoseconds( + &date_or_time_expr, value * 1000), + "nanosecond" | "ns" | "nsec" | "nanosec" | "nsecond" | "nanoseconds" | "nanosecs" | "nseconds" => + DateAddFunc::add_nanoseconds(&date_or_time_expr, value), + _ => return plan_err!("Invalid date_or_time_part type"), + } + } +} diff --git a/crates/control_plane/src/sql/functions/greatest.rs b/crates/control_plane/src/sql/functions/greatest.rs new file mode 100644 index 000000000..767573566 --- /dev/null +++ b/crates/control_plane/src/sql/functions/greatest.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::sql::functions::greatest_least_utils::GreatestLeastOperator; +use arrow::array::{make_comparator, Array, BooleanArray}; +use arrow::buffer::BooleanBuffer; +use arrow::compute::kernels::cmp; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; +use datafusion::common::{internal_err, Result, ScalarValue}; +use datafusion::logical_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion::logical_expr::{ColumnarValue, Documentation}; +use datafusion::logical_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::OnceLock; + +const SORT_OPTIONS: SortOptions = SortOptions { + // We want greatest first + descending: false, + + // NULL will be less than any other value + nulls_first: true, +}; + +#[derive(Debug)] +pub struct GreatestFunc { + signature: Signature, +} + +impl Default for GreatestFunc { + fn default() -> Self { + GreatestFunc::new() + } +} + +impl GreatestFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl GreatestLeastOperator for GreatestFunc { + const NAME: &'static str = "greatest"; + + fn keep_scalar<'a>(lhs: &'a ScalarValue, rhs: &'a ScalarValue) -> Result<&'a ScalarValue> { + if !lhs.data_type().is_nested() { + return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) }; + } + + // If complex type we can't compare directly as we want null values to be smaller + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; + + if cmp(0, 0).is_ge() { + Ok(lhs) + } else { + Ok(rhs) + } + } + + /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array + /// Nulls are always considered smaller than any other value + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorized kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() + && lhs.logical_null_count() == 0 + && rhs.logical_null_count() == 0 + { + return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into()); + } + + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; + + if lhs.len() != rhs.len() { + return internal_err!("All arrays should have the same length for greatest comparison"); + } + + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge()); + + // No nulls as we only want to keep the values that are larger, its either true or false + Ok(BooleanArray::new(values, None)) + } +} + +impl ScalarUDFImpl for GreatestFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "greatest" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + super::greatest_least_utils::execute_conditional::(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let coerced_type = super::greatest_least_utils::find_coerced_type::(arg_types)?; + + Ok(vec![coerced_type; arg_types.len()]) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_greatest_doc()) + } +} +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_greatest_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder().with_doc_section(DOC_SECTION_CONDITIONAL) + .with_sql_example(r#"```sql +> select greatest(4, 7, 5); ++---------------------------+ +| greatest(4,7,5) | ++---------------------------+ +| 7 | ++---------------------------+ +```"#, + ) + .with_argument( + "expression1, expression_n", + "Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary.", + ) + .build().unwrap() + }) +} diff --git a/crates/control_plane/src/sql/functions/greatest_least_utils.rs b/crates/control_plane/src/sql/functions/greatest_least_utils.rs new file mode 100644 index 000000000..cb8f73255 --- /dev/null +++ b/crates/control_plane/src/sql/functions/greatest_least_utils.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BooleanArray}; +use arrow::compute::kernels::zip::zip; +use arrow::datatypes::DataType; +use datafusion::common::{internal_err, plan_err, Result, ScalarValue}; +use datafusion::logical_expr::type_coercion::binary::type_union_resolution; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +pub(super) trait GreatestLeastOperator { + const NAME: &'static str; + + fn keep_scalar<'a>(lhs: &'a ScalarValue, rhs: &'a ScalarValue) -> Result<&'a ScalarValue>; + + /// Return array with true for values that we should keep from the lhs array + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result; +} + +fn keep_array(lhs: ArrayRef, rhs: ArrayRef) -> Result { + // True for values that we should keep from the left array + let keep_lhs = Op::get_indexes_to_keep(lhs.as_ref(), rhs.as_ref())?; + + let result = zip(&keep_lhs, &lhs, &rhs)?; + + Ok(result) +} + +pub(super) fn execute_conditional( + args: &[ColumnarValue], +) -> Result { + if args.is_empty() { + return internal_err!( + "{} was called with no arguments. It requires at least 1.", + Op::NAME + ); + } + + // Some engines (e.g. SQL Server) allow greatest/least with single arg, it's a noop + if args.len() == 1 { + return Ok(args[0].clone()); + } + + // Split to scalars and arrays for later optimization + let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x { + ColumnarValue::Scalar(_) => true, + ColumnarValue::Array(_) => false, + }); + + let mut arrays_iter = arrays.iter().map(|x| match x { + ColumnarValue::Array(a) => a, + _ => unreachable!(), + }); + + let first_array = arrays_iter.next(); + + let mut result: ArrayRef; + + // Optimization: merge all scalars into one to avoid recomputing (constant folding) + if !scalars.is_empty() { + let mut scalars_iter = scalars.iter().map(|x| match x { + ColumnarValue::Scalar(s) => s, + _ => unreachable!(), + }); + + // We have at least one scalar + let mut result_scalar = scalars_iter.next().unwrap(); + + for scalar in scalars_iter { + result_scalar = Op::keep_scalar(result_scalar, scalar)?; + } + + // If we only have scalars, return the one that we should keep (largest/least) + if arrays.is_empty() { + return Ok(ColumnarValue::Scalar(result_scalar.clone())); + } + + // We have at least one array + let first_array = first_array.unwrap(); + + // Start with the result value + result = keep_array::( + Arc::clone(first_array), + result_scalar.to_array_of_size(first_array.len())?, + )?; + } else { + // If we only have arrays, start with the first array + // (We must have at least one array) + result = Arc::clone(first_array.unwrap()); + } + + for array in arrays_iter { + result = keep_array::(Arc::clone(array), result)?; + } + + Ok(ColumnarValue::Array(result)) +} + +pub(super) fn find_coerced_type( + data_types: &[DataType], +) -> Result { + if data_types.is_empty() { + plan_err!( + "{} was called without any arguments. It requires at least 1.", + Op::NAME + ) + } else if let Some(coerced_type) = type_union_resolution(data_types) { + Ok(coerced_type) + } else { + plan_err!("Cannot find a common type for arguments") + } +} diff --git a/crates/control_plane/src/sql/functions/least.rs b/crates/control_plane/src/sql/functions/least.rs new file mode 100644 index 000000000..f6205715c --- /dev/null +++ b/crates/control_plane/src/sql/functions/least.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::sql::functions::greatest_least_utils::GreatestLeastOperator; +use arrow::array::{make_comparator, Array, BooleanArray}; +use arrow::buffer::BooleanBuffer; +use arrow::compute::kernels::cmp; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; +use datafusion::common::{internal_err, Result, ScalarValue}; +use datafusion::logical_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion::logical_expr::{ColumnarValue, Documentation}; +use datafusion::logical_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::OnceLock; + +const SORT_OPTIONS: SortOptions = SortOptions { + // Having the smallest result first + descending: false, + + // NULL will be greater than any other value + nulls_first: false, +}; + +#[derive(Debug)] +pub struct LeastFunc { + signature: Signature, +} + +impl Default for LeastFunc { + fn default() -> Self { + LeastFunc::new() + } +} + +impl LeastFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl GreatestLeastOperator for LeastFunc { + const NAME: &'static str = "least"; + + fn keep_scalar<'a>(lhs: &'a ScalarValue, rhs: &'a ScalarValue) -> Result<&'a ScalarValue> { + // Manual checking for nulls as: + // 1. If we're going to use <=, in Rust None is smaller than Some(T), which we don't want + // 2. And we can't use make_comparator as it has no natural order (Arrow error) + if lhs.is_null() { + return Ok(rhs); + } + + if rhs.is_null() { + return Ok(lhs); + } + + if !lhs.data_type().is_nested() { + return if lhs <= rhs { Ok(lhs) } else { Ok(rhs) }; + } + + // Not using <= as in Rust None is smaller than Some(T) + + // If complex type we can't compare directly as we want null values to be larger + let cmp = make_comparator( + lhs.to_array()?.as_ref(), + rhs.to_array()?.as_ref(), + SORT_OPTIONS, + )?; + + if cmp(0, 0).is_le() { + Ok(lhs) + } else { + Ok(rhs) + } + } + + /// Return boolean array where `arr[i] = lhs[i] <= rhs[i]` for all i, where `arr` is the result array + /// Nulls are always considered larger than any other value + fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result { + // Fast path: + // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorized kernel + // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined. + // - both array does not have any nulls: cmp::lt_eq will return null if any of the input is null while we want to return false in that case + if !lhs.data_type().is_nested() + && lhs.logical_null_count() == 0 + && rhs.logical_null_count() == 0 + { + return cmp::lt_eq(&lhs, &rhs).map_err(|e| e.into()); + } + + let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?; + + if lhs.len() != rhs.len() { + return internal_err!("All arrays should have the same length for least comparison"); + } + + let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_le()); + + // No nulls as we only want to keep the values that are smaller, its either true or false + Ok(BooleanArray::new(values, None)) + } +} + +impl ScalarUDFImpl for LeastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "least" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + super::greatest_least_utils::execute_conditional::(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let coerced_type = super::greatest_least_utils::find_coerced_type::(arg_types)?; + + Ok(vec![coerced_type; arg_types.len()]) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_smallest_doc()) + } +} +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_smallest_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder().with_doc_section(DOC_SECTION_CONDITIONAL) + .with_sql_example(r#"```sql +> select least(4, 7, 5); ++---------------------------+ +| least(4,7,5) | ++---------------------------+ +| 4 | ++---------------------------+ +```"#, + ) + .with_argument( + "expression1, expression_n", + "Expressions to compare and return the smallest value. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary.", + ) + .build().unwrap() + }) +} diff --git a/crates/control_plane/src/sql/functions/mod.rs b/crates/control_plane/src/sql/functions/mod.rs index a0a341a6a..52bbead77 100644 --- a/crates/control_plane/src/sql/functions/mod.rs +++ b/crates/control_plane/src/sql/functions/mod.rs @@ -1,2 +1,7 @@ +pub mod common; pub mod parse_json; -pub mod common; \ No newline at end of file + +pub mod date_add; +pub mod greatest; +pub mod greatest_least_utils; +pub mod least; diff --git a/crates/control_plane/src/sql/planner.rs b/crates/control_plane/src/sql/planner.rs index 6adddf5c8..aa49b3fea 100644 --- a/crates/control_plane/src/sql/planner.rs +++ b/crates/control_plane/src/sql/planner.rs @@ -101,8 +101,6 @@ where all_constraints.extend(inline_constraints); // Build column default values let column_defaults = self.build_column_defaults(&columns, planner_context)?; - println!("column_defaults: {:?}", column_defaults); - println!("statement 11: {:?}", statement); let has_columns = !columns.is_empty(); let schema = self.build_schema(columns.clone())?.to_dfschema_ref()?; if has_columns { diff --git a/crates/control_plane/src/sql/sql.rs b/crates/control_plane/src/sql/sql.rs index ba0d0920a..744ae87c2 100644 --- a/crates/control_plane/src/sql/sql.rs +++ b/crates/control_plane/src/sql/sql.rs @@ -1,5 +1,8 @@ use crate::models::created_entity_response; use crate::sql::context::CustomContextProvider; +use crate::sql::functions::date_add::DateAddFunc; +use crate::sql::functions::greatest::GreatestFunc; +use crate::sql::functions::least::LeastFunc; use crate::sql::functions::parse_json::ParseJsonFunc; use crate::sql::planner::ExtendedSqlToRel; use arrow::array::RecordBatch; @@ -35,6 +38,9 @@ pub struct SqlExecutor { impl SqlExecutor { pub fn new(mut ctx: SessionContext) -> Self { ctx.register_udf(ScalarUDF::from(ParseJsonFunc::new())); + ctx.register_udf(ScalarUDF::from(DateAddFunc::new())); + ctx.register_udf(ScalarUDF::from(LeastFunc::new())); + ctx.register_udf(ScalarUDF::from(GreatestFunc::new())); register_all(&mut ctx).expect("Cannot register UDF JSON funcs"); Self { ctx } } @@ -75,9 +81,13 @@ impl SqlExecutor { pub fn preprocess_query(&self, query: &String) -> String { // Replace field[0].subfield -> json_get(json_get(field, 0), 'subfield') let re = regex::Regex::new(r"(\w+)\[(\d+)]\.(\w+)").unwrap(); + let date_add = regex::Regex::new(r"(date|time|timestamp)(_?add)\(([a-zA-Z]+),").unwrap(); let query = re .replace_all(query, "json_get(json_get($1, $2), '$3')") .to_string(); + let query = date_add + .replace_all(&query, "$1$2('$3',") + .to_string(); // TODO implement alter session logic query.replace( "alter session set query_tag = 'snowplow_dbt'", @@ -323,7 +333,6 @@ impl SqlExecutor { ) -> Result> { let plan = self.get_custom_logical_plan(query, warehouse_name).await?; let res = self.ctx.execute_logical_plan(plan).await?.collect().await; - println!("Result: {:?}", res); res } diff --git a/crates/nexus/src/http/dbt/handlers.rs b/crates/nexus/src/http/dbt/handlers.rs index db4b6a85f..512355df6 100644 --- a/crates/nexus/src/http/dbt/handlers.rs +++ b/crates/nexus/src/http/dbt/handlers.rs @@ -129,7 +129,7 @@ pub async fn query( row_type: columns.into_iter().map(|c| c.into()).collect(), // row_set_base_64: Option::from(result.clone()), row_set_base_64: None, - row_set: serde_json::from_str(&*result).unwrap(), + row_set: ResponseData::rows_to_vec(result), total: Some(1), query_result_format: Option::from("json".to_string()), error_code: None, diff --git a/crates/nexus/src/http/dbt/schemas.rs b/crates/nexus/src/http/dbt/schemas.rs index 2eaa1a74d..00226ff08 100644 --- a/crates/nexus/src/http/dbt/schemas.rs +++ b/crates/nexus/src/http/dbt/schemas.rs @@ -93,7 +93,7 @@ pub struct ResponseData { #[serde(rename = "rowsetBase64")] pub row_set_base_64: Option, #[serde(rename = "rowset")] - pub row_set: Vec>, + pub row_set: Vec>, pub total: Option, #[serde(rename = "queryResultFormat")] pub query_result_format: Option, @@ -103,6 +103,15 @@ pub struct ResponseData { pub sql_state: Option, } +impl ResponseData { + pub fn rows_to_vec(json_rows_string: String) -> Vec> { + let json_array: Vec> = serde_json::from_str(&json_rows_string).unwrap(); + json_array.into_iter().map(|obj| { + obj.values().cloned().collect() + }).collect() + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct JsonResponse { pub data: Option,