diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java
index e58cddc274c5f..65f2bd88940a3 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java
@@ -136,6 +136,12 @@
*
Since version: 3.3.0
*
*
+ * Name: BOOLEAN_EXPRESSION
+ *
+ * - A simple wrapper for any expression that returns boolean type.
+ * - Since version: 4.1.0
+ *
+ *
*
*
* @since 3.3.0
@@ -145,5 +151,8 @@ public class Predicate extends GeneralScalarExpression {
public Predicate(String name, Expression[] children) {
super(name, children);
+ if ("BOOLEAN_EXPRESSION".equals(name)) {
+ assert children.length == 1;
+ }
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 4298f31227500..2bc994acaf33f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -99,6 +99,8 @@ public String build(Expression expr) {
case "CONTAINS" -> visitContains(build(e.children()[0]), build(e.children()[1]));
case "=", "<>", "<=>", "<", "<=", ">", ">=" ->
visitBinaryComparison(name, e.children()[0], e.children()[1]);
+ case "BOOLEAN_EXPRESSION" ->
+ build(expr.children()[0]);
case "+", "*", "/", "%", "&", "|", "^" ->
visitBinaryArithmetic(name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1]));
case "-" -> {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index 7cc03f3ac3fa6..53a5c44733986 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -243,6 +243,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
case "ENDS_WITH" => convertBinaryExpr(expr, EndsWith)
case "CONTAINS" => convertBinaryExpr(expr, Contains)
case "IN" => convertExpr(expr, children => In(children.head, children.tail))
+ case "BOOLEAN_EXPRESSION" => toCatalyst(expr.children().head)
case _ => None
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index fad73a6d81464..4e6c7ef417801 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -37,12 +37,21 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
def build(): Option[V2Expression] = generateExpression(e, isPredicate)
def buildPredicate(): Option[V2Predicate] = {
-
if (isPredicate) {
- val translated = build()
+ val translated0 = build()
+ val conf = SQLConf.get
+ val alwaysCreateV2Predicate = conf.getConf(SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE)
+ val translated = if (alwaysCreateV2Predicate && e.dataType == BooleanType) {
+ translated0.map {
+ case p: V2Predicate => p
+ case other => new V2Predicate("BOOLEAN_EXPRESSION", Array(other))
+ }
+ } else {
+ translated0
+ }
val modifiedExprOpt = if (
- SQLConf.get.getConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE)
+ conf.getConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE)
&& translated.isDefined
&& !translated.get.isInstanceOf[V2Predicate]) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 0fe08fc719a8e..caeff591ea924 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1695,9 +1695,19 @@ object SQLConf {
buildConf("spark.sql.dataSource.skipAssertOnPredicatePushdown")
.internal()
.doc("Enable skipping assert when expression in not translated to predicate.")
+ .version("4.0.0")
.booleanConf
.createWithDefault(!Utils.isTesting)
+ val DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE =
+ buildConf("spark.sql.dataSource.alwaysCreateV2Predicate")
+ .internal()
+ .doc("When true, the v2 push-down framework always wraps the expression that returns " +
+ "boolean type with a v2 Predicate so that it can be pushed down.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
// This is used to set the default data source
val DEFAULT_DATA_SOURCE_NAME = buildConf("spark.sql.sources.default")
.doc("The default data source to use in input/output.")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala
index 8b99e3aa6981a..2d2c3ae110bf2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/PushablePredicateSuite.scala
@@ -18,38 +18,71 @@
package org.apache.spark.sql.connector
import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
+import org.apache.spark.sql.connector.expressions.filter.{AlwaysTrue, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.v2.PushablePredicate
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.BooleanType
class PushablePredicateSuite extends QueryTest with SharedSparkSession {
- test("PushablePredicate None returned - flag on") {
- withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "true") {
- val pushable = PushablePredicate.unapply(Literal.create("string"))
- assert(!pushable.isDefined)
- }
- }
-
- test("PushablePredicate success - flag on") {
- withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "true") {
- val pushable = PushablePredicate.unapply(Literal.create(true))
- assert(pushable.isDefined)
+ test("simple boolean expression should always return v2 Predicate") {
+ Seq(true, false).foreach { createV2Predicate =>
+ Seq(true, false).foreach { noAssert =>
+ withSQLConf(
+ SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE.key -> createV2Predicate.toString,
+ SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> noAssert.toString) {
+ val pushable = PushablePredicate.unapply(Literal.create(true))
+ assert(pushable.isDefined)
+ assert(pushable.get.isInstanceOf[AlwaysTrue])
+ }
+ }
}
}
- test("PushablePredicate success") {
- withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "false") {
- val pushable = PushablePredicate.unapply(Literal.create(true))
- assert(pushable.isDefined)
+ test("non-boolean expression") {
+ Seq(true, false).foreach { createV2Predicate =>
+ Seq(true, false).foreach { noAssert =>
+ withSQLConf(
+ SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE.key -> createV2Predicate.toString,
+ SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> noAssert.toString) {
+ val catalystExpr = Literal.create("string")
+ if (noAssert) {
+ val pushable = PushablePredicate.unapply(catalystExpr)
+ assert(pushable.isEmpty)
+ } else {
+ intercept[java.lang.AssertionError] {
+ PushablePredicate.unapply(catalystExpr)
+ }
+ }
+ }
+ }
}
}
- test("PushablePredicate throws") {
- withSQLConf(SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> "false") {
- intercept[java.lang.AssertionError] {
- PushablePredicate.unapply(Literal.create("string"))
+ test("non-trivial boolean expression") {
+ Seq(true, false).foreach { createV2Predicate =>
+ Seq(true, false).foreach { noAssert =>
+ withSQLConf(
+ SQLConf.DATA_SOURCE_ALWAYS_CREATE_V2_PREDICATE.key -> createV2Predicate.toString,
+ SQLConf.DATA_SOURCE_DONT_ASSERT_ON_PREDICATE.key -> noAssert.toString) {
+ val catalystExpr = Cast(Literal.create("true"), BooleanType)
+ if (createV2Predicate) {
+ val pushable = PushablePredicate.unapply(catalystExpr)
+ assert(pushable.isDefined)
+ assert(pushable.get.isInstanceOf[V2Predicate])
+ } else {
+ if (noAssert) {
+ val pushable = PushablePredicate.unapply(catalystExpr)
+ assert(pushable.isEmpty)
+ } else {
+ intercept[java.lang.AssertionError] {
+ PushablePredicate.unapply(catalystExpr)
+ }
+ }
+ }
+ }
}
}
}