diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 92a719b3bc47..0b75d362f9cf 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -35,7 +35,7 @@ case class PaimonScan( filters: Seq[Predicate], reservedFilters: Seq[Filter], override val pushDownLimit: Option[Int], - disableBucketedScan: Boolean = false) + disableBucketedScan: Boolean = true) extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) with SupportsRuntimeFiltering { @@ -57,11 +57,9 @@ case class PaimonScan( case _ => None } if (partitionFilter.nonEmpty) { - this.runtimeFilters = filters readBuilder.withFilter(partitionFilter.head) // set inputPartitions null to trigger to get the new splits. inputPartitions = null } } - } diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 10b83ccf08b1..395f8707ab9d 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -18,6 +18,61 @@ package org.apache.paimon.spark +import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate} import org.apache.paimon.table.Table -class PaimonScanBuilder(table: Table) extends PaimonBaseScanBuilder(table) +import org.apache.spark.sql.connector.read.SupportsPushDownFilters +import org.apache.spark.sql.sources.Filter + +import scala.collection.mutable + +class PaimonScanBuilder(table: Table) + extends PaimonBaseScanBuilder(table) + with SupportsPushDownFilters { + + private var pushedSparkFilters = Array.empty[Filter] + + /** + * Pushes down filters, and returns filters that need to be evaluated after scanning.
Rows
+ * should be returned from the data source if and only if all the filters match. That is, filters
+ * must be interpreted as ANDed together.
+ */
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)]
+ val postScan = mutable.ArrayBuffer.empty[Filter]
+ val reserved = mutable.ArrayBuffer.empty[Filter]
+
+ val converter = new SparkFilterConverter(table.rowType)
+ val visitor = new PartitionPredicateVisitor(table.partitionKeys())
+ filters.foreach {
+ filter =>
+ val predicate = converter.convertIgnoreFailure(filter)
+ if (predicate == null) {
+ postScan.append(filter)
+ } else {
+ pushable.append((filter, predicate))
+ if (predicate.visit(visitor)) {
+ reserved.append(filter)
+ } else {
+ postScan.append(filter)
+ }
+ }
+ }
+
+ if (pushable.nonEmpty) {
+ this.pushedSparkFilters = pushable.map(_._1).toArray
+ this.pushedPaimonPredicates = pushable.map(_._2).toArray
+ }
+ if (reserved.nonEmpty) {
+ this.reservedFilters = reserved.toArray
+ }
+ if (postScan.nonEmpty) {
+ this.hasPostScanPredicates = true
+ }
+ postScan.toArray
+ }
+
+ override def pushedFilters(): Array[Filter] = {
+ pushedSparkFilters
+ }
+}
diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonSplitScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonSplitScanBuilder.scala
new file mode 100644
index 000000000000..ede18b8cc990
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonSplitScanBuilder.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark
+
+import org.apache.paimon.table.KnownSplitsTable
+
+import org.apache.spark.sql.connector.read.Scan
+
+class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonScanBuilder(table) {
+ override def build(): Scan = {
+ PaimonSplitScan(table, table.splits(), requiredSchema, pushedPaimonPredicates)
+ }
+}
diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala
new file mode 100644
index 000000000000..56223c36cd75
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.catalyst.analysis.expressions
+
+import org.apache.paimon.predicate.{Predicate, PredicateBuilder}
+import org.apache.paimon.spark.SparkFilterConverter
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilter}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+
+trait ExpressionHelper extends ExpressionHelperBase {
+
+ def convertConditionToPaimonPredicate(
+ condition: Expression,
+ output: Seq[Attribute],
+ rowType: RowType,
+ ignorePartialFailure: Boolean = false): Option[Predicate] = {
+ val converter = new SparkFilterConverter(rowType)
+ val filters = normalizeExprs(Seq(condition), output)
+ .flatMap(splitConjunctivePredicates(_).flatMap {
+ f =>
+ val filter = translateFilter(f, supportNestedPredicatePushdown = true)
+ if (filter.isEmpty && !ignorePartialFailure) {
+ throw new RuntimeException(
+ "Exec update failed:" +
+ s" cannot translate expression to source filter: $f")
+ }
+ filter
+ })
+
+ val predicates = filters.map(converter.convert(_, ignorePartialFailure)).filter(_ != null)
+ if (predicates.isEmpty) {
+ None
+ } else {
+ Some(PredicateBuilder.and(predicates: _*))
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
new file mode 100644
index 000000000000..e0705b761ab9
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.Filter
+
+class PaimonPushDownTest extends PaimonPushDownTestBase {
+
+ override def checkFilterExists(sql: String): Boolean = {
+ spark
+ .sql(sql)
+ .queryExecution
+ .optimizedPlan
+ .find {
+ case Filter(_: Expression, _) => true
+ case _ => false
+ }
+ .isDefined
+ }
+
+ override def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = {
+ spark
+ .sql(sql)
+ .queryExecution
+ .optimizedPlan
+ .find {
+ case Filter(c: Expression, _) =>
+ c.find {
+ case EqualTo(a: AttributeReference, r: Literal) =>
+ a.name.equals(name) && r.equals(value)
+ case _ => false
+ }.isDefined
+ case _ => false
+ }
+ .isDefined
+ }
+}
diff --git a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
new file mode 100644
index 000000000000..4c62d58a811b
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark
+
+import org.apache.paimon.predicate.Predicate
+import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
+import org.apache.paimon.table.source.{DataSplit, Split}
+
+import org.apache.spark.sql.PaimonUtils.fieldReference
+import org.apache.spark.sql.connector.expressions._
+import org.apache.spark.sql.connector.read.{SupportsReportPartitioning, SupportsRuntimeFiltering}
+import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.sources.{Filter, In}
+import org.apache.spark.sql.types.StructType
+
+import scala.collection.JavaConverters._
+
+case class PaimonScan(
+ table: Table,
+ requiredSchema: StructType,
+ filters: Seq[Predicate],
+ reservedFilters: Seq[Filter],
+ override val pushDownLimit: Option[Int],
+ bucketedScanDisabled: Boolean = false)
+ extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
+ with SupportsRuntimeFiltering
+ with SupportsReportPartitioning {
+
+ def disableBucketedScan(): PaimonScan = {
+ copy(bucketedScanDisabled = true)
+ }
+
+ @transient
+ private lazy val extractBucketTransform: Option[Transform] = {
+ table match {
+ case fileStoreTable: FileStoreTable =>
+ val bucketSpec = fileStoreTable.bucketSpec()
+ if (bucketSpec.getBucketMode != BucketMode.HASH_FIXED) {
+ None
+ } else if (bucketSpec.getBucketKeys.size() > 1) {
+ None
+ } else {
+ // Spark does not support bucket with several input attributes,
+ // so we only support one bucket key case.
+ assert(bucketSpec.getNumBuckets > 0)
+ assert(bucketSpec.getBucketKeys.size() == 1)
+ val bucketKey = bucketSpec.getBucketKeys.get(0)
+ if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) {
+ Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey))
+ } else {
+ None
+ }
+ }
+
+ case _ => None
+ }
+ }
+
+ private def shouldDoBucketedScan: Boolean = {
+ !bucketedScanDisabled && conf.v2BucketingEnabled && extractBucketTransform.isDefined
+ }
+
+ // Since Spark 3.3
+ override def outputPartitioning: Partitioning = {
+ extractBucketTransform
+ .map(bucket => new KeyGroupedPartitioning(Array(bucket), lazyInputPartitions.size))
+ .getOrElse(new UnknownPartitioning(0))
+ }
+
+ override def getInputPartitions(splits: Array[Split]): Seq[PaimonInputPartition] = {
+ if (!shouldDoBucketedScan || splits.exists(!_.isInstanceOf[DataSplit])) {
+ return super.getInputPartitions(splits)
+ }
+
+ splits
+ .map(_.asInstanceOf[DataSplit])
+ .groupBy(_.bucket())
+ .map {
+ case (bucket, groupedSplits) =>
+ PaimonBucketedInputPartition(groupedSplits, bucket)
+ }
+ .toSeq
+ }
+
+ // Since Spark 3.2
+ override def filterAttributes(): Array[NamedReference] = {
+ val requiredFields = readBuilder.readType().getFieldNames.asScala
+ table
+ .partitionKeys()
+ .asScala
+ .toArray
+ .filter(requiredFields.contains)
+ .map(fieldReference)
+ }
+
+ override def filter(filters: Array[Filter]): Unit = {
+ val converter = new SparkFilterConverter(table.rowType())
+ val partitionFilter = filters.flatMap {
+ case in @ In(attr, _) if table.partitionKeys().contains(attr) =>
+ Some(converter.convert(in))
+ case _ => None
+ }
+ if (partitionFilter.nonEmpty) {
+ readBuilder.withFilter(partitionFilter.head)
+ // set inputPartitions null to trigger to get the new splits.
+ inputPartitions = null
+ }
+ }
+}
diff --git a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
new file mode 100644
index 000000000000..43459648095d
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.connector
+
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference}
+import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate}
+import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
+import org.apache.spark.sql.types.StringType
+
+// Copy from Spark 3.4+
+private[sql] object PredicateUtils {
+
+ def toV1(predicate: Predicate): Option[Filter] = {
+
+ def isValidBinaryPredicate(): Boolean = {
+ if (
+ predicate.children().length == 2 &&
+ predicate.children()(0).isInstanceOf[NamedReference] &&
+ predicate.children()(1).isInstanceOf[LiteralValue[_]]
+ ) {
+ true
+ } else {
+ false
+ }
+ }
+
+ predicate.name() match {
+ case "IN" if predicate.children()(0).isInstanceOf[NamedReference] =>
+ val attribute = predicate.children()(0).toString
+ val values = predicate.children().drop(1)
+ if (values.length > 0) {
+ if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None
+ val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType
+ if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) {
+ return None
+ }
+ val inValues = values.map(
+ v =>
+ CatalystTypeConverters.convertToScala(
+ v.asInstanceOf[LiteralValue[_]].value,
+ dataType))
+ Some(In(attribute, inValues))
+ } else {
+ Some(In(attribute, Array.empty[Any]))
+ }
+
+ case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate() =>
+ val attribute = predicate.children()(0).toString
+ val value = predicate.children()(1).asInstanceOf[LiteralValue[_]]
+ val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType)
+ val v1Filter = predicate.name() match {
+ case "=" => EqualTo(attribute, v1Value)
+ case "<=>" => EqualNullSafe(attribute, v1Value)
+ case ">" => GreaterThan(attribute, v1Value)
+ case ">=" => GreaterThanOrEqual(attribute, v1Value)
+ case "<" => LessThan(attribute, v1Value)
+ case "<=" => LessThanOrEqual(attribute, v1Value)
+ }
+ Some(v1Filter)
+
+ case "IS_NULL" | "IS_NOT_NULL"
+ if predicate.children().length == 1 &&
+ predicate.children()(0).isInstanceOf[NamedReference] =>
+ val attribute = predicate.children()(0).toString
+ val v1Filter = predicate.name() match {
+ case "IS_NULL" => IsNull(attribute)
+ case "IS_NOT_NULL" => IsNotNull(attribute)
+ }
+ Some(v1Filter)
+
+ case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate() =>
+ val attribute = predicate.children()(0).toString
+ val value = predicate.children()(1).asInstanceOf[LiteralValue[_]]
+ if (!value.dataType.sameType(StringType)) return None
+ val v1Value = value.value.toString
+ val v1Filter = predicate.name() match {
+ case "STARTS_WITH" =>
+ StringStartsWith(attribute, v1Value)
+ case "ENDS_WITH" =>
+ StringEndsWith(attribute, v1Value)
+ case "CONTAINS" =>
+ StringContains(attribute, v1Value)
+ }
+ Some(v1Filter)
+
+ case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty =>
+ val v1Filter = predicate.name() match {
+ case "ALWAYS_TRUE" => AlwaysTrue()
+ case "ALWAYS_FALSE" => AlwaysFalse()
+ }
+ Some(v1Filter)
+
+ case "AND" =>
+ val and = predicate.asInstanceOf[V2And]
+ val left = toV1(and.left())
+ val right = toV1(and.right())
+ if (left.nonEmpty && right.nonEmpty) {
+ Some(And(left.get, right.get))
+ } else {
+ None
+ }
+
+ case "OR" =>
+ val or = predicate.asInstanceOf[V2Or]
+ val left = toV1(or.left())
+ val right = toV1(or.right())
+ if (left.nonEmpty && right.nonEmpty) {
+ Some(Or(left.get, right.get))
+ } else if (left.nonEmpty) {
+ left
+ } else {
+ right
+ }
+
+ case "NOT" =>
+ val child = toV1(predicate.asInstanceOf[V2Not].child())
+ if (child.nonEmpty) {
+ Some(Not(child.get))
+ } else {
+ None
+ }
+
+ case _ => None
+ }
+ }
+
+ def toV1(predicates: Array[Predicate]): Array[Filter] = {
+ predicates.flatMap(toV1(_))
+ }
+}
diff --git a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
new file mode 100644
index 000000000000..26677d85c71a
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+class PaimonPushDownTest extends PaimonPushDownTestBase {}
diff --git a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
new file mode 100644
index 000000000000..26677d85c71a
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+class PaimonPushDownTest extends PaimonPushDownTestBase {}
diff --git a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
new file mode 100644
index 000000000000..26677d85c71a
--- /dev/null
+++ b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+class PaimonPushDownTest extends PaimonPushDownTestBase {}
diff --git a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
new file mode 100644
index 000000000000..26677d85c71a
--- /dev/null
+++ b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+class PaimonPushDownTest extends PaimonPushDownTestBase {}
diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
index 2050c937c6a3..c0b8cfd66be1 100644
--- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
+++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
@@ -49,7 +49,7 @@
import static org.apache.paimon.predicate.PredicateBuilder.convertJavaObject;
-/** Conversion from {@link Filter} to {@link Predicate}. */
+/** Conversion from {@link Filter} to {@link Predicate}, remove it when Spark 3.2 is dropped. */
public class SparkFilterConverter {
public static final List Rows
- * should be returned from the data source if and only if all of the filters match. That is,
- * filters must be interpreted as ANDed together.
- */
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
- val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)]
- val postScan = mutable.ArrayBuffer.empty[Filter]
- val partitionFilter = mutable.ArrayBuffer.empty[Filter]
-
- val converter = new SparkFilterConverter(table.rowType)
- val visitor = new PartitionPredicateVisitor(table.partitionKeys())
- filters.foreach {
- filter =>
- val predicate = converter.convertIgnoreFailure(filter)
- if (predicate == null) {
- postScan.append(filter)
- } else {
- pushable.append((filter, predicate))
- if (predicate.visit(visitor)) {
- partitionFilter.append(filter)
- } else {
- postScan.append(filter)
- }
- }
- }
-
- if (pushable.nonEmpty) {
- this.pushedPredicates = pushable.toArray
- }
- if (partitionFilter.nonEmpty) {
- this.partitionFilters = partitionFilter.toArray
- }
- if (postScan.nonEmpty) {
- this.postScanFilters = postScan.toArray
- }
- postScan.toArray
- }
-
- override def pushedFilters(): Array[Filter] = {
- pushedPredicates.map(_._1)
+ PaimonScan(table, requiredSchema, pushedPaimonPredicates, reservedFilters, pushDownLimit)
}
override def pruneColumns(requiredSchema: StructType): Unit = {
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala
index 490a1b133f6f..1f4e88e8d160 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala
@@ -18,11 +18,11 @@
package org.apache.paimon.spark
+import org.apache.paimon.predicate.Predicate
import org.apache.paimon.table.Table
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.LocalScan
-import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
/** A scan does not require [[RDD]] to execute */
@@ -30,7 +30,7 @@ case class PaimonLocalScan(
rows: Array[InternalRow],
readSchema: StructType,
table: Table,
- filters: Array[Filter])
+ filters: Array[Predicate])
extends LocalScan {
override def description(): String = {
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 2f1e6c53ab0a..f2818f7c032c 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -23,10 +23,11 @@ import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
import org.apache.paimon.table.source.{DataSplit, Split}
import org.apache.spark.sql.PaimonUtils.fieldReference
-import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, SortDirection, SortOrder, Transform}
-import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning, SupportsRuntimeFiltering}
+import org.apache.spark.sql.connector.expressions._
+import org.apache.spark.sql.connector.expressions.filter.{Predicate => SparkPredicate}
+import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
-import org.apache.spark.sql.sources.{Filter, In}
+import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import scala.collection.JavaConverters._
@@ -39,7 +40,7 @@ case class PaimonScan(
override val pushDownLimit: Option[Int],
bucketedScanDisabled: Boolean = false)
extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit)
- with SupportsRuntimeFiltering
+ with SupportsRuntimeV2Filtering
with SupportsReportPartitioning
with SupportsReportOrdering {
@@ -156,19 +157,18 @@ case class PaimonScan(
.map(fieldReference)
}
- override def filter(filters: Array[Filter]): Unit = {
- val converter = new SparkFilterConverter(table.rowType())
- val partitionFilter = filters.flatMap {
- case in @ In(attr, _) if table.partitionKeys().contains(attr) =>
- Some(converter.convert(in))
+ override def filter(predicates: Array[SparkPredicate]): Unit = {
+ val converter = SparkV2FilterConverter(table.rowType())
+ val partitionKeys = table.partitionKeys().asScala.toSeq
+ val partitionFilter = predicates.flatMap {
+ case p if SparkV2FilterConverter.isSupportedRuntimeFilter(p, partitionKeys) =>
+ converter.convert(p, ignoreFailure = true)
case _ => None
}
if (partitionFilter.nonEmpty) {
- this.runtimeFilters = filters
- readBuilder.withFilter(partitionFilter.head)
+ readBuilder.withFilter(partitionFilter.toList.asJava)
// set inputPartitions null to trigger to get the new splits.
inputPartitions = null
}
}
-
}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index 0393a1cd1578..7b6c65c37f76 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -18,22 +18,71 @@
package org.apache.paimon.spark
-import org.apache.paimon.predicate.PredicateBuilder
+import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, PredicateBuilder}
import org.apache.paimon.spark.aggregate.LocalAggregator
import org.apache.paimon.table.Table
import org.apache.paimon.table.source.DataSplit
+import org.apache.spark.sql.PaimonUtils
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit}
+import org.apache.spark.sql.connector.expressions.filter.{Predicate => SparkPredicate}
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownV2Filters}
+import org.apache.spark.sql.sources.Filter
import scala.collection.JavaConverters._
+import scala.collection.mutable
class PaimonScanBuilder(table: Table)
extends PaimonBaseScanBuilder(table)
+ with SupportsPushDownV2Filters
with SupportsPushDownLimit
with SupportsPushDownAggregates {
+
private var localScan: Option[Scan] = None
+ private var pushedSparkPredicates = Array.empty[SparkPredicate]
+
+ /** Pushes down filters, and returns filters that need to be evaluated after scanning. */
+ override def pushPredicates(predicates: Array[SparkPredicate]): Array[SparkPredicate] = {
+ val pushable = mutable.ArrayBuffer.empty[(SparkPredicate, Predicate)]
+ val postScan = mutable.ArrayBuffer.empty[SparkPredicate]
+ val reserved = mutable.ArrayBuffer.empty[Filter]
+
+ val converter = SparkV2FilterConverter(table.rowType)
+ val visitor = new PartitionPredicateVisitor(table.partitionKeys())
+ predicates.foreach {
+ predicate =>
+ converter.convert(predicate, ignoreFailure = true) match {
+ case Some(paimonPredicate) =>
+ pushable.append((predicate, paimonPredicate))
+ if (paimonPredicate.visit(visitor)) {
+ // We need to filter the stats using filter instead of predicate.
+ reserved.append(PaimonUtils.filterV2ToV1(predicate).get)
+ } else {
+ postScan.append(predicate)
+ }
+ case None =>
+ postScan.append(predicate)
+ }
+ }
+
+ if (pushable.nonEmpty) {
+ this.pushedSparkPredicates = pushable.map(_._1).toArray
+ this.pushedPaimonPredicates = pushable.map(_._2).toArray
+ }
+ if (reserved.nonEmpty) {
+ this.reservedFilters = reserved.toArray
+ }
+ if (postScan.nonEmpty) {
+ this.hasPostScanPredicates = true
+ }
+ postScan.toArray
+ }
+
+ override def pushedPredicates: Array[SparkPredicate] = {
+ pushedSparkPredicates
+ }
+
override def pushLimit(limit: Int): Boolean = {
// It is safe, since we will do nothing if it is the primary table and the split is not `rawConvertible`
pushDownLimit = Some(limit)
@@ -52,8 +101,8 @@ class PaimonScanBuilder(table: Table)
return true
}
- // Only support with push down partition filter
- if (postScanFilters.nonEmpty) {
+ // Only support when there is no post scan predicates.
+ if (hasPostScanPredicates) {
return false
}
@@ -63,8 +112,8 @@ class PaimonScanBuilder(table: Table)
}
val readBuilder = table.newReadBuilder
- if (pushedPredicates.nonEmpty) {
- val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*)
+ if (pushedPaimonPredicates.nonEmpty) {
+ val pushedPartitionPredicate = PredicateBuilder.and(pushedPaimonPredicates.toList.asJava)
readBuilder.withFilter(pushedPartitionPredicate)
}
val dataSplits = readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
@@ -77,7 +126,7 @@ class PaimonScanBuilder(table: Table)
aggregator.result(),
aggregator.resultSchema(),
table,
- pushedPredicates.map(_._1)))
+ pushedPaimonPredicates))
true
}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala
index 8d9e643f9485..7d0bf831550e 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala
@@ -26,9 +26,9 @@ import org.apache.paimon.table.source.{DataSplit, Split}
import org.apache.spark.sql.connector.read.{Batch, Scan}
import org.apache.spark.sql.types.StructType
-class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonBaseScanBuilder(table) {
+class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonScanBuilder(table) {
override def build(): Scan = {
- PaimonSplitScan(table, table.splits(), requiredSchema, pushedPredicates.map(_._2))
+ PaimonSplitScan(table, table.splits(), requiredSchema, pushedPaimonPredicates)
}
}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala
index 11ef302672e1..e99d2aa7794c 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala
@@ -20,9 +20,11 @@ package org.apache.paimon.spark
import org.apache.paimon.data.{BinaryString, Decimal, Timestamp}
import org.apache.paimon.predicate.{Predicate, PredicateBuilder}
+import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType
import org.apache.paimon.types.{DataTypeRoot, DecimalType, RowType}
import org.apache.paimon.types.DataTypeRoot._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.expressions.{Literal, NamedReference}
import org.apache.spark.sql.connector.expressions.filter.{And, Not, Or, Predicate => SparkPredicate}
@@ -35,6 +37,15 @@ case class SparkV2FilterConverter(rowType: RowType) {
val builder = new PredicateBuilder(rowType)
+ def convert(sparkPredicate: SparkPredicate, ignoreFailure: Boolean): Option[Predicate] = {
+ try {
+ Some(convert(sparkPredicate))
+ } catch {
+ case _ if ignoreFailure => None
+ case e: Exception => throw e
+ }
+ }
+
def convert(sparkPredicate: SparkPredicate): Predicate = {
sparkPredicate.name() match {
case EQUAL_TO =>
@@ -147,36 +158,6 @@ case class SparkV2FilterConverter(rowType: RowType) {
}
}
- private object UnaryPredicate {
- def unapply(sparkPredicate: SparkPredicate): Option[String] = {
- sparkPredicate.children() match {
- case Array(n: NamedReference) => Some(toFieldName(n))
- case _ => None
- }
- }
- }
-
- private object BinaryPredicate {
- def unapply(sparkPredicate: SparkPredicate): Option[(String, Any)] = {
- sparkPredicate.children() match {
- case Array(l: NamedReference, r: Literal[_]) => Some((toFieldName(l), r.value))
- case Array(l: Literal[_], r: NamedReference) => Some((toFieldName(r), l.value))
- case _ => None
- }
- }
- }
-
- private object MultiPredicate {
- def unapply(sparkPredicate: SparkPredicate): Option[(String, Array[Any])] = {
- sparkPredicate.children() match {
- case Array(first: NamedReference, rest @ _*)
- if rest.nonEmpty && rest.forall(_.isInstanceOf[Literal[_]]) =>
- Some(toFieldName(first), rest.map(_.asInstanceOf[Literal[_]].value).toArray)
- case _ => None
- }
- }
- }
-
private def fieldIndex(fieldName: String): Int = {
val index = rowType.getFieldIndex(fieldName)
// TODO: support nested field
@@ -205,15 +186,19 @@ case class SparkV2FilterConverter(rowType: RowType) {
value.asInstanceOf[org.apache.spark.sql.types.Decimal].toJavaBigDecimal,
precision,
scale)
- case DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE | DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE =>
+ case DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE =>
Timestamp.fromMicros(value.asInstanceOf[Long])
+ case DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE =>
+ if (treatPaimonTimestampTypeAsSparkTimestampType()) {
+ Timestamp.fromSQLTimestamp(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]))
+ } else {
+ Timestamp.fromMicros(value.asInstanceOf[Long])
+ }
case _ =>
throw new UnsupportedOperationException(
s"Convert value: $value to datatype: $dataType is unsupported.")
}
}
-
- private def toFieldName(ref: NamedReference): String = ref.fieldNames().mkString(".")
}
object SparkV2FilterConverter {
@@ -233,4 +218,48 @@ object SparkV2FilterConverter {
private val STRING_START_WITH = "STARTS_WITH"
private val STRING_END_WITH = "ENDS_WITH"
private val STRING_CONTAINS = "CONTAINS"
+
+ private object UnaryPredicate {
+ def unapply(sparkPredicate: SparkPredicate): Option[String] = {
+ sparkPredicate.children() match {
+ case Array(n: NamedReference) => Some(toFieldName(n))
+ case _ => None
+ }
+ }
+ }
+
+ private object BinaryPredicate {
+ def unapply(sparkPredicate: SparkPredicate): Option[(String, Any)] = {
+ sparkPredicate.children() match {
+ case Array(l: NamedReference, r: Literal[_]) => Some((toFieldName(l), r.value))
+ case Array(l: Literal[_], r: NamedReference) => Some((toFieldName(r), l.value))
+ case _ => None
+ }
+ }
+ }
+
+ private object MultiPredicate {
+ def unapply(sparkPredicate: SparkPredicate): Option[(String, Array[Any])] = {
+ sparkPredicate.children() match {
+ case Array(first: NamedReference, rest @ _*)
+ if rest.nonEmpty && rest.forall(_.isInstanceOf[Literal[_]]) =>
+ Some(toFieldName(first), rest.map(_.asInstanceOf[Literal[_]].value).toArray)
+ case _ => None
+ }
+ }
+ }
+
+ private def toFieldName(ref: NamedReference): String = ref.fieldNames().mkString(".")
+
+ def isSupportedRuntimeFilter(
+ sparkPredicate: SparkPredicate,
+ partitionKeys: Seq[String]): Boolean = {
+ sparkPredicate.name() match {
+ case IN =>
+ MultiPredicate.unapply(sparkPredicate) match {
+ case Some((fieldName, _)) => partitionKeys.contains(fieldName)
+ }
+ case _ => false
+ }
+ }
}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala
index 2eef2c41aebe..fcece0a26236 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala
@@ -19,12 +19,12 @@
package org.apache.paimon.spark.catalyst.analysis.expressions
import org.apache.paimon.predicate.{Predicate, PredicateBuilder}
-import org.apache.paimon.spark.SparkFilterConverter
+import org.apache.paimon.spark.SparkV2FilterConverter
import org.apache.paimon.spark.catalyst.Compatibility
import org.apache.paimon.types.RowType
import org.apache.spark.sql.{Column, SparkSession}
-import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilter}
+import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilterV2}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, Cast, Expression, GetStructField, Literal, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
@@ -33,7 +33,40 @@ import org.apache.spark.sql.paimon.shims.SparkShimLoader
import org.apache.spark.sql.types.{DataType, NullType}
/** An expression helper. */
-trait ExpressionHelper extends PredicateHelper {
+trait ExpressionHelper extends ExpressionHelperBase {
+
+ def convertConditionToPaimonPredicate(
+ condition: Expression,
+ output: Seq[Attribute],
+ rowType: RowType,
+ ignorePartialFailure: Boolean = false): Option[Predicate] = {
+ val converter = SparkV2FilterConverter(rowType)
+ val sparkPredicates = normalizeExprs(Seq(condition), output)
+ .flatMap(splitConjunctivePredicates(_).flatMap {
+ f =>
+ val predicate =
+ try {
+ translateFilterV2(f)
+ } catch {
+ case _: Throwable =>
+ None
+ }
+ if (predicate.isEmpty && !ignorePartialFailure) {
+ throw new RuntimeException(s"Cannot translate expression to predicate: $f")
+ }
+ predicate
+ })
+
+ val predicates = sparkPredicates.flatMap(converter.convert(_, ignorePartialFailure))
+ if (predicates.isEmpty) {
+ None
+ } else {
+ Some(PredicateBuilder.and(predicates: _*))
+ }
+ }
+}
+
+trait ExpressionHelperBase extends PredicateHelper {
import ExpressionHelper._
@@ -162,32 +195,6 @@ trait ExpressionHelper extends PredicateHelper {
isPredicatePartitionColumnsOnly(e, partitionCols, resolver) &&
!SubqueryExpression.hasSubquery(expr))
}
-
- def convertConditionToPaimonPredicate(
- condition: Expression,
- output: Seq[Attribute],
- rowType: RowType,
- ignorePartialFailure: Boolean = false): Option[Predicate] = {
- val converter = new SparkFilterConverter(rowType)
- val filters = normalizeExprs(Seq(condition), output)
- .flatMap(splitConjunctivePredicates(_).flatMap {
- f =>
- val filter = translateFilter(f, supportNestedPredicatePushdown = true)
- if (filter.isEmpty && !ignorePartialFailure) {
- throw new RuntimeException(
- "Exec update failed:" +
- s" cannot translate expression to source filter: $f")
- }
- filter
- })
-
- val predicates = filters.map(converter.convert(_, ignorePartialFailure)).filter(_ != null)
- if (predicates.isEmpty) {
- None
- } else {
- Some(PredicateBuilder.and(predicates: _*))
- }
- }
}
object ExpressionHelper {
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
index 28ac1623fb59..d41fd7d4d287 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
@@ -63,6 +63,7 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper with SQLCon
def convertPartitionFilterToMap(
filter: Filter,
partitionRowType: RowType): Map[String, String] = {
+ // todo: replace it with SparkV2FilterConverter when we drop Spark3.2
val converter = new SparkFilterConverter(partitionRowType)
splitConjunctiveFilters(filter).map {
case EqualNullSafe(attribute, value) =>
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala
index d01a840f8ece..a1ce25137436 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy.translateFilterV2WithMapping
+import org.apache.spark.sql.internal.connector.PredicateUtils
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.PartitioningUtils
@@ -74,6 +75,10 @@ object PaimonUtils {
translateFilterV2WithMapping(predicate, None)
}
+ def filterV2ToV1(predicate: Predicate): Option[Filter] = {
+ PredicateUtils.toV1(predicate)
+ }
+
def fieldReference(name: String): FieldReference = {
fieldReference(Seq(name))
}
diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
similarity index 97%
rename from paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
rename to paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
index 503f1c8e3e9d..15c021babb75 100644
--- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
+++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.junit.jupiter.api.Assertions
-class PaimonPushDownTest extends PaimonSparkTestBase {
+abstract class PaimonPushDownTestBase extends PaimonSparkTestBase {
import testImplicits._
@@ -101,6 +101,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
}
test("Paimon pushDown: limit for append-only tables") {
+ assume(gteqSpark3_3)
spark.sql(s"""
|CREATE TABLE T (a INT, b STRING, c STRING)
|PARTITIONED BY (c)
@@ -128,6 +129,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
}
test("Paimon pushDown: limit for primary key table") {
+ assume(gteqSpark3_3)
spark.sql(s"""
|CREATE TABLE T (a INT, b STRING, c STRING)
|TBLPROPERTIES ('primary-key'='a')
@@ -202,6 +204,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
}
test("Paimon pushDown: limit for table with deletion vector") {
+ assume(gteqSpark3_3)
Seq(true, false).foreach(
deletionVectorsEnabled => {
Seq(true, false).foreach(
@@ -279,14 +282,14 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty())
}
- private def checkFilterExists(sql: String): Boolean = {
+ def checkFilterExists(sql: String): Boolean = {
spark.sql(sql).queryExecution.optimizedPlan.exists {
case Filter(_: Expression, _) => true
case _ => false
}
}
- private def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = {
+ def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = {
spark.sql(sql).queryExecution.optimizedPlan.exists {
case Filter(c: Expression, _) =>
c.exists {
diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala
index b9cbc29b3aa3..e51a8f7a0114 100644
--- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala
+++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala
@@ -21,10 +21,13 @@ package org.apache.paimon.spark.sql
import org.apache.paimon.data.{BinaryString, Decimal, Timestamp}
import org.apache.paimon.predicate.PredicateBuilder
import org.apache.paimon.spark.{PaimonSparkTestBase, SparkV2FilterConverter}
+import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType
+import org.apache.paimon.table.source.DataSplit
import org.apache.paimon.types.RowType
import org.apache.spark.SparkConf
import org.apache.spark.sql.PaimonUtils.translateFilterV2
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -54,9 +57,26 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase {
| decimal_col DECIMAL(10, 5),
| boolean_col BOOLEAN,
| date_col DATE,
- | binary BINARY
+ | binary_col BINARY
|) USING paimon
|""".stripMargin)
+
+ sql("""
+ |INSERT INTO test_tbl VALUES
+ |('hello', 1, 1, 1, 1, 1.0, 1.0, 12.12345, true, date('2025-01-15'), binary('b1'))
+ |""".stripMargin)
+ sql("""
+ |INSERT INTO test_tbl VALUES
+ |('world', 2, 2, 2, 2, 2.0, 2.0, 22.12345, false, date('2025-01-16'), binary('b2'))
+ |""".stripMargin)
+ sql("""
+ |INSERT INTO test_tbl VALUES
+ |('hi', 3, 3, 3, 3, 3.0, 3.0, 32.12345, false, date('2025-01-17'), binary('b3'))
+ |""".stripMargin)
+ sql("""
+ |INSERT INTO test_tbl VALUES
+ |('paimon', 4, 4, null, 4, 4.0, 4.0, 42.12345, true, date('2025-01-18'), binary('b4'))
+ |""".stripMargin)
}
override protected def afterAll(): Unit = {
@@ -71,147 +91,269 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase {
lazy val converter: SparkV2FilterConverter = SparkV2FilterConverter(rowType)
test("V2Filter: all types") {
- var actual = converter.convert(v2Filter("string_col = 'hello'"))
+ var filter = "string_col = 'hello'"
+ var actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(0, BinaryString.fromString("hello"))))
+ checkAnswer(sql(s"SELECT string_col from test_tbl WHERE $filter"), Seq(Row("hello")))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("byte_col = 1"))
+ filter = "byte_col = 1"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(1, 1.toByte)))
+ checkAnswer(sql(s"SELECT byte_col from test_tbl WHERE $filter"), Seq(Row(1.toByte)))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("short_col = 1"))
+ filter = "short_col = 1"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(2, 1.toShort)))
+ checkAnswer(sql(s"SELECT short_col from test_tbl WHERE $filter"), Seq(Row(1.toShort)))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("int_col = 1"))
+ filter = "int_col = 1"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(3, 1)))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter"), Seq(Row(1)))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("long_col = 1"))
+ filter = "long_col = 1"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(4, 1L)))
+ checkAnswer(sql(s"SELECT long_col from test_tbl WHERE $filter"), Seq(Row(1L)))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("float_col = 1.0"))
+ filter = "float_col = 1.0"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(5, 1.0f)))
+ checkAnswer(sql(s"SELECT float_col from test_tbl WHERE $filter"), Seq(Row(1.0f)))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("double_col = 1.0"))
+ filter = "double_col = 1.0"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(6, 1.0d)))
+ checkAnswer(sql(s"SELECT double_col from test_tbl WHERE $filter"), Seq(Row(1.0d)))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("decimal_col = 12.12345"))
+ filter = "decimal_col = 12.12345"
+ actual = converter.convert(v2Filter(filter))
assert(
actual.equals(
builder.equal(7, Decimal.fromBigDecimal(new java.math.BigDecimal("12.12345"), 10, 5))))
+ checkAnswer(
+ sql(s"SELECT decimal_col from test_tbl WHERE $filter"),
+ Seq(Row(new java.math.BigDecimal("12.12345"))))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("boolean_col = true"))
+ filter = "boolean_col = true"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(8, true)))
+ checkAnswer(sql(s"SELECT boolean_col from test_tbl WHERE $filter"), Seq(Row(true), Row(true)))
+ assert(scanFilesCount(filter) == 2)
- actual = converter.convert(v2Filter("date_col = cast('2025-01-15' as date)"))
+ filter = "date_col = date('2025-01-15')"
+ actual = converter.convert(v2Filter(filter))
val localDate = LocalDate.parse("2025-01-15")
val epochDay = localDate.toEpochDay.toInt
assert(actual.equals(builder.equal(9, epochDay)))
+ checkAnswer(
+ sql(s"SELECT date_col from test_tbl WHERE $filter"),
+ sql("SELECT date('2025-01-15')"))
+ assert(scanFilesCount(filter) == 1)
+ filter = "binary_col = binary('b1')"
intercept[UnsupportedOperationException] {
- actual = converter.convert(v2Filter("binary = binary('b1')"))
+ actual = converter.convert(v2Filter(filter))
}
+ checkAnswer(sql(s"SELECT binary_col from test_tbl WHERE $filter"), sql("SELECT binary('b1')"))
+ assert(scanFilesCount(filter) == 4)
}
test("V2Filter: timestamp and timestamp_ntz") {
withTimeZone("Asia/Shanghai") {
withTable("ts_tbl", "ts_ntz_tbl") {
sql("CREATE TABLE ts_tbl (ts_col TIMESTAMP) USING paimon")
+ sql("INSERT INTO ts_tbl VALUES (timestamp'2025-01-15 00:00:00.123')")
+ sql("INSERT INTO ts_tbl VALUES (timestamp'2025-01-16 00:00:00.123')")
+
+ val filter1 = "ts_col = timestamp'2025-01-15 00:00:00.123'"
val rowType1 = loadTable("ts_tbl").rowType()
val converter1 = SparkV2FilterConverter(rowType1)
- val actual1 =
- converter1.convert(v2Filter("ts_col = timestamp'2025-01-15 00:00:00.123'", "ts_tbl"))
- assert(
- actual1.equals(new PredicateBuilder(rowType1)
+ val actual1 = converter1.convert(v2Filter(filter1, "ts_tbl"))
+ if (treatPaimonTimestampTypeAsSparkTimestampType()) {
+ assert(actual1.equals(new PredicateBuilder(rowType1)
+ .equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-15T00:00:00.123")))))
+ } else {
+ assert(actual1.equals(new PredicateBuilder(rowType1)
.equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-14T16:00:00.123")))))
+ }
+ checkAnswer(
+ sql(s"SELECT ts_col from ts_tbl WHERE $filter1"),
+ sql("SELECT timestamp'2025-01-15 00:00:00.123'"))
+ assert(scanFilesCount(filter1, "ts_tbl") == 1)
// Spark support TIMESTAMP_NTZ since Spark 3.4
if (gteqSpark3_4) {
sql("CREATE TABLE ts_ntz_tbl (ts_ntz_col TIMESTAMP_NTZ) USING paimon")
+ sql("INSERT INTO ts_ntz_tbl VALUES (timestamp_ntz'2025-01-15 00:00:00.123')")
+ sql("INSERT INTO ts_ntz_tbl VALUES (timestamp_ntz'2025-01-16 00:00:00.123')")
+ val filter2 = "ts_ntz_col = timestamp_ntz'2025-01-15 00:00:00.123'"
val rowType2 = loadTable("ts_ntz_tbl").rowType()
val converter2 = SparkV2FilterConverter(rowType2)
- val actual2 = converter2.convert(
- v2Filter("ts_ntz_col = timestamp_ntz'2025-01-15 00:00:00.123'", "ts_ntz_tbl"))
+ val actual2 = converter2.convert(v2Filter(filter2, "ts_ntz_tbl"))
assert(actual2.equals(new PredicateBuilder(rowType2)
.equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-15T00:00:00.123")))))
+ checkAnswer(
+ sql(s"SELECT ts_ntz_col from ts_ntz_tbl WHERE $filter2"),
+ sql("SELECT timestamp_ntz'2025-01-15 00:00:00.123'"))
+ assert(scanFilesCount(filter2, "ts_ntz_tbl") == 1)
}
}
}
}
test("V2Filter: EqualTo") {
- val actual = converter.convert(v2Filter("int_col = 1"))
+ val filter = "int_col = 1"
+ val actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(3, 1)))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1)))
+ assert(scanFilesCount(filter) == 1)
}
test("V2Filter: EqualNullSafe") {
- var actual = converter.convert(v2Filter("int_col <=> 1"))
+ var filter = "int_col <=> 1"
+ var actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.equal(3, 1)))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1)))
+ assert(scanFilesCount(filter) == 1)
- actual = converter.convert(v2Filter("int_col <=> null"))
+ filter = "int_col <=> null"
+ actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.isNull(3)))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(null)))
+ assert(scanFilesCount(filter) == 1)
}
test("V2Filter: GreaterThan") {
- val actual = converter.convert(v2Filter("int_col > 1"))
- assert(actual.equals(builder.greaterThan(3, 1)))
+ val filter = "int_col > 2"
+ val actual = converter.convert(v2Filter(filter))
+ assert(actual.equals(builder.greaterThan(3, 2)))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(3)))
+ assert(scanFilesCount(filter) == 1)
}
test("V2Filter: GreaterThanOrEqual") {
- val actual = converter.convert(v2Filter("int_col >= 1"))
- assert(actual.equals(builder.greaterOrEqual(3, 1)))
+ val filter = "int_col >= 2"
+ val actual = converter.convert(v2Filter(filter))
+ assert(actual.equals(builder.greaterOrEqual(3, 2)))
+ checkAnswer(
+ sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"),
+ Seq(Row(2), Row(3)))
+ assert(scanFilesCount(filter) == 2)
}
test("V2Filter: LessThan") {
- val actual = converter.convert(v2Filter("int_col < 1"))
- assert(actual.equals(builder.lessThan(3, 1)))
+ val filter = "int_col < 2"
+ val actual = converter.convert(v2Filter("int_col < 2"))
+ assert(actual.equals(builder.lessThan(3, 2)))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1)))
+ assert(scanFilesCount(filter) == 1)
}
test("V2Filter: LessThanOrEqual") {
- val actual = converter.convert(v2Filter("int_col <= 1"))
- assert(actual.equals(builder.lessOrEqual(3, 1)))
+ val filter = "int_col <= 2"
+ val actual = converter.convert(v2Filter("int_col <= 2"))
+ assert(actual.equals(builder.lessOrEqual(3, 2)))
+ checkAnswer(
+ sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col "),
+ Seq(Row(1), Row(2)))
+ assert(scanFilesCount(filter) == 2)
}
test("V2Filter: In") {
- val actual = converter.convert(v2Filter("int_col IN (1, 2, 3)"))
- assert(actual.equals(builder.in(3, List(1, 2, 3).map(_.asInstanceOf[AnyRef]).asJava)))
+ val filter = "int_col IN (1, 2)"
+ val actual = converter.convert(v2Filter("int_col IN (1, 2)"))
+ assert(actual.equals(builder.in(3, List(1, 2).map(_.asInstanceOf[AnyRef]).asJava)))
+ checkAnswer(
+ sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"),
+ Seq(Row(1), Row(2)))
+ assert(scanFilesCount(filter) == 2)
}
test("V2Filter: IsNull") {
- val actual = converter.convert(v2Filter("int_col IS NULL"))
+ val filter = "int_col IS NULL"
+ val actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.isNull(3)))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(null)))
+ assert(scanFilesCount(filter) == 1)
}
test("V2Filter: IsNotNull") {
- val actual = converter.convert(v2Filter("int_col IS NOT NULL"))
+ val filter = "int_col IS NOT NULL"
+ val actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.isNotNull(3)))
+ checkAnswer(
+ sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"),
+ Seq(Row(1), Row(2), Row(3)))
+ assert(scanFilesCount(filter) == 3)
}
test("V2Filter: And") {
- val actual = converter.convert(v2Filter("int_col > 1 AND int_col < 10"))
- assert(actual.equals(PredicateBuilder.and(builder.greaterThan(3, 1), builder.lessThan(3, 10))))
+ val filter = "int_col > 1 AND int_col < 3"
+ val actual = converter.convert(v2Filter(filter))
+ assert(actual.equals(PredicateBuilder.and(builder.greaterThan(3, 1), builder.lessThan(3, 3))))
+ checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(2)))
+ assert(scanFilesCount(filter) == 1)
}
test("V2Filter: Or") {
- val actual = converter.convert(v2Filter("int_col > 1 OR int_col < 10"))
- assert(actual.equals(PredicateBuilder.or(builder.greaterThan(3, 1), builder.lessThan(3, 10))))
+ val filter = "int_col > 2 OR int_col IS NULL"
+ val actual = converter.convert(v2Filter(filter))
+ assert(actual.equals(PredicateBuilder.or(builder.greaterThan(3, 2), builder.isNull(3))))
+ checkAnswer(
+ sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"),
+ Seq(Row(null), Row(3)))
+ assert(scanFilesCount(filter) == 2)
}
test("V2Filter: Not") {
- val actual = converter.convert(v2Filter("NOT (int_col > 1)"))
- assert(actual.equals(builder.greaterThan(3, 1).negate().get()))
+ val filter = "NOT (int_col > 2)"
+ val actual = converter.convert(v2Filter(filter))
+ assert(actual.equals(builder.greaterThan(3, 2).negate().get()))
+ checkAnswer(
+ sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"),
+ Seq(Row(1), Row(2)))
+ assert(scanFilesCount(filter) == 2)
}
test("V2Filter: StartWith") {
- val actual = converter.convert(v2Filter("string_col LIKE 'h%'"))
+ val filter = "string_col LIKE 'h%'"
+ val actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.startsWith(0, BinaryString.fromString("h"))))
+ checkAnswer(
+ sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"),
+ Seq(Row("hello"), Row("hi")))
+ assert(scanFilesCount(filter) == 2)
}
test("V2Filter: EndWith") {
- val actual = converter.convert(v2Filter("string_col LIKE '%o'"))
+ val filter = "string_col LIKE '%o'"
+ val actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.endsWith(0, BinaryString.fromString("o"))))
+ checkAnswer(
+ sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"),
+ Seq(Row("hello")))
+ // EndWith does not have file skipping effect now.
+ assert(scanFilesCount(filter) == 4)
}
test("V2Filter: Contains") {
- val actual = converter.convert(v2Filter("string_col LIKE '%e%'"))
+ val filter = "string_col LIKE '%e%'"
+ val actual = converter.convert(v2Filter(filter))
assert(actual.equals(builder.contains(0, BinaryString.fromString("e"))))
+ checkAnswer(
+ sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"),
+ Seq(Row("hello")))
+ // Contains does not have file skipping effect now.
+ assert(scanFilesCount(filter) == 4)
}
private def v2Filter(str: String, tableName: String = "test_tbl"): Predicate = {
@@ -221,4 +363,11 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase {
.condition
translateFilterV2(condition).get
}
+
+ private def scanFilesCount(str: String, tableName: String = "test_tbl"): Int = {
+ getPaimonScan(s"SELECT * FROM $tableName WHERE $str").lazyInputPartitions
+ .flatMap(_.splits)
+ .map(_.asInstanceOf[DataSplit].dataFiles().size())
+ .sum
+ }
}