diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index b17cfa1d9b..d353471967 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -43,6 +43,7 @@ use datafusion_comet_proto::spark_operator::Operator; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; +use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::string::char::CharFunc; @@ -332,6 +333,7 @@ fn prepare_datafusion_session_context( session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); // Must be the last one to override existing functions with the same name datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 233261091b..9f418e3068 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -156,7 +156,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Md5] -> CometScalarFunction("md5"), classOf[Murmur3Hash] -> CometMurmur3Hash, classOf[Sha2] -> CometSha2, - classOf[XxHash64] -> CometXxHash64) + classOf[XxHash64] -> CometXxHash64, + classOf[Sha1] -> CometSha1) private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Ascii] -> CometScalarFunction("ascii"), diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala index a384ecf44e..5cc689f39a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/hash.scala +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha2, XxHash64} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha1, Sha2, XxHash64} import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType, StringType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -89,6 +89,20 @@ object CometSha2 extends CometExpressionSerde[Sha2] { } } +object CometSha1 extends CometExpressionSerde[Sha1] { + override def convert( + expr: Sha1, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (!HashUtils.isSupportedType(expr)) { + withInfo(expr, s"HashUtils doesn't support dataType: ${expr.child.dataType}") + return None + } + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + scalarFunctionExprToProtoWithReturnType("sha1", StringType, false, childExpr) + } +} + private object HashUtils { def isSupportedType(expr: Expression): Boolean = { for (child <- expr.children) { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 9085c0fa29..1eca17dccc 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2024,7 +2024,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |md5(col), md5(cast(a as string)), md5(cast(b as string)), |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), + |sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) |from test |""".stripMargin) } @@ -2136,7 +2137,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |md5(col), md5(cast(a as string)), --md5(cast(b as string)), |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), + |sha1(col), sha1(cast(a as string)) |from test |""".stripMargin) }