diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java index e58cddc274c5f..65f2bd88940a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java @@ -136,6 +136,12 @@ *
  • Since version: 3.3.0
  • * * + *
  • Name: BOOLEAN_EXPRESSION + * + *
  • * * * @since 3.3.0 @@ -145,5 +151,8 @@ public class Predicate extends GeneralScalarExpression { public Predicate(String name, Expression[] children) { super(name, children); + if ("BOOLEAN_EXPRESSION".equals(name)) { + assert children.length == 1; + } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 4298f31227500..2bc994acaf33f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -99,6 +99,8 @@ public String build(Expression expr) { case "CONTAINS" -> visitContains(build(e.children()[0]), build(e.children()[1])); case "=", "<>", "<=>", "<", "<=", ">", ">=" -> visitBinaryComparison(name, e.children()[0], e.children()[1]); + case "BOOLEAN_EXPRESSION" -> + build(expr.children()[0]); case "+", "*", "/", "%", "&", "|", "^" -> visitBinaryArithmetic(name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); case "-" -> { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 7cc03f3ac3fa6..53a5c44733986 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -243,6 +243,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { case "ENDS_WITH" => convertBinaryExpr(expr, EndsWith) case "CONTAINS" => convertBinaryExpr(expr, Contains) case "IN" => convertExpr(expr, children => In(children.head, children.tail)) + case "BOOLEAN_EXPRESSION" => toCatalyst(expr.children().head) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index fad73a6d81464..4e6c7ef417801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -37,12 +37,21 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L def build(): Option[V2Expression] = generateExpression(e, isPredicate) def buildPredicate(): Option[V2Predicate] = { - if (isPredicate) { - val translated = build() + val translated0 = build() + val conf = SQLConf.get + val alwaysCreateV2Predicate = conf.getConf(SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE) + val translated = if (alwaysCreateV2Predicate && e.dataType == BooleanType) { + translated0.map { + case p: V2Predicate => p + case other => new V2Predicate("BOOLEAN_EXPRESSION", Array(other)) + } + } else { + translated0 + } val modifiedExprOpt = if ( - SQLConf.get.getConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE) + conf.getConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE) && translated.isDefined && !translated.get.isInstanceOf[V2Predicate]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0fe08fc719a8e..caeff591ea924 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1695,9 +1695,19 @@ object SQLConf { buildConf("spark.sql.dataSource.skipAssertOnPredicatePushdown") .internal() .doc("Enable skipping assert when expression in not translated to predicate.") + .version("4.0.0") .booleanConf .createWithDefault(!Utils.isTesting) + val DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE = + buildConf("spark.sql.dataSource.alwaysCreateV2Predicate") + .internal() + .doc("When true, the v2 push-down framework always wraps the expression that returns " + + "boolean type with a v2 Predicate so that it can be pushed down.") + .version("4.1.0") + .booleanConf + .createWithDefault(true) + // This is used to set the default data source val DEFAULT_DATA_SOURCE_NAME = buildConf("spark.sql.sources.default") .doc("The default data source to use in input/output.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala index 8b99e3aa6981a..2d2c3ae110bf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala @@ -18,38 +18,71 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysTrue, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.v2.PushablePredicate import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.BooleanType class PushablePredicateSuite extends QueryTest with SharedSparkSession { - test("PushablePredicate None returned - flag on") { - withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "true") { - val pushable = PushablePredicate.unapply(Literal.create("string")) - assert(!pushable.isDefined) - } - } - - test("PushablePredicate success - flag on") { - withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "true") { - val pushable = PushablePredicate.unapply(Literal.create(true)) - assert(pushable.isDefined) + test("simple boolean expression should always return v2 Predicate") { + Seq(true, false).foreach { createV2Predicate => + Seq(true, false).foreach { noAssert => + withSQLConf( + SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE.key -> createV2Predicate.toString, + SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> noAssert.toString) { + val pushable = PushablePredicate.unapply(Literal.create(true)) + assert(pushable.isDefined) + assert(pushable.get.isInstanceOf[AlwaysTrue]) + } + } } } - test("PushablePredicate success") { - withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "false") { - val pushable = PushablePredicate.unapply(Literal.create(true)) - assert(pushable.isDefined) + test("non-boolean expression") { + Seq(true, false).foreach { createV2Predicate => + Seq(true, false).foreach { noAssert => + withSQLConf( + SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE.key -> createV2Predicate.toString, + SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> noAssert.toString) { + val catalystExpr = Literal.create("string") + if (noAssert) { + val pushable = PushablePredicate.unapply(catalystExpr) + assert(pushable.isEmpty) + } else { + intercept[java.lang.AssertionError] { + PushablePredicate.unapply(catalystExpr) + } + } + } + } } } - test("PushablePredicate throws") { - withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "false") { - intercept[java.lang.AssertionError] { - PushablePredicate.unapply(Literal.create("string")) + test("non-trivial boolean expression") { + Seq(true, false).foreach { createV2Predicate => + Seq(true, false).foreach { noAssert => + withSQLConf( + SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE.key -> createV2Predicate.toString, + SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> noAssert.toString) { + val catalystExpr = Cast(Literal.create("true"), BooleanType) + if (createV2Predicate) { + val pushable = PushablePredicate.unapply(catalystExpr) + assert(pushable.isDefined) + assert(pushable.get.isInstanceOf[V2Predicate]) + } else { + if (noAssert) { + val pushable = PushablePredicate.unapply(catalystExpr) + assert(pushable.isEmpty) + } else { + intercept[java.lang.AssertionError] { + PushablePredicate.unapply(catalystExpr) + } + } + } + } } } }