From dd71d2b6637c210bedc62f5ce495ccbe04cba131 Mon Sep 17 00:00:00 2001 From: zouxxyy Date: Thu, 16 Jan 2025 00:51:45 +0800 Subject: [PATCH] use v2 filter --- .../org/apache/paimon/spark/PaimonScan.scala | 4 +- .../paimon/spark/PaimonScanBuilder.scala | 57 ++++- .../paimon/spark/PaimonSplitScanBuilder.scala | 29 +++ .../expressions/ExpressionHelper.scala | 55 +++++ .../paimon/spark/sql/PaimonPushDownTest.scala | 54 ++++ .../org/apache/paimon/spark/PaimonScan.scala | 125 ++++++++++ .../internal/connector/PredicateUtils.scala | 147 +++++++++++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/sql/PaimonPushDownTest.scala | 21 ++ .../paimon/spark/SparkFilterConverter.java | 2 +- .../apache/paimon/spark/PaimonBaseScan.scala | 2 - .../paimon/spark/PaimonBaseScanBuilder.scala | 58 +---- .../apache/paimon/spark/PaimonLocalScan.scala | 4 +- .../org/apache/paimon/spark/PaimonScan.scala | 24 +- .../paimon/spark/PaimonScanBuilder.scala | 63 ++++- .../apache/paimon/spark/PaimonSplitScan.scala | 4 +- .../paimon/spark/SparkV2FilterConverter.scala | 95 ++++--- .../expressions/ExpressionHelper.scala | 65 ++--- .../paimon/spark/commands/PaimonCommand.scala | 1 + .../org/apache/spark/sql/PaimonUtils.scala | 5 + ...est.scala => PaimonPushDownTestBase.scala} | 9 +- .../sql/SparkV2FilterConverterTestBase.scala | 233 ++++++++++++++---- 24 files changed, 931 insertions(+), 189 deletions(-) create mode 100644 paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonSplitScanBuilder.scala create mode 100644 paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala create mode 100644 paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala create mode 100644 paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala create mode 100644 paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala create mode 100644 paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala rename paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/{PaimonPushDownTest.scala => PaimonPushDownTestBase.scala} (97%) diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 92a719b3bc47..0b75d362f9cf 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -35,7 +35,7 @@ case class PaimonScan( filters: Seq[Predicate], reservedFilters: Seq[Filter], override val pushDownLimit: Option[Int], - disableBucketedScan: Boolean = false) + disableBucketedScan: Boolean = true) extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) with SupportsRuntimeFiltering { @@ -57,11 +57,9 @@ case class PaimonScan( case _ => None } if (partitionFilter.nonEmpty) { - this.runtimeFilters = filters readBuilder.withFilter(partitionFilter.head) // set inputPartitions null to trigger to get the new splits. inputPartitions = null } } - } diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 10b83ccf08b1..395f8707ab9d 100644 --- a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -18,6 +18,61 @@ package org.apache.paimon.spark +import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate} import org.apache.paimon.table.Table -class PaimonScanBuilder(table: Table) extends PaimonBaseScanBuilder(table) +import org.apache.spark.sql.connector.read.SupportsPushDownFilters +import org.apache.spark.sql.sources.Filter + +import scala.collection.mutable + +class PaimonScanBuilder(table: Table) + extends PaimonBaseScanBuilder(table) + with SupportsPushDownFilters { + + private var pushedSparkFilters = Array.empty[Filter] + + /** + * Pushes down filters, and returns filters that need to be evaluated after scanning.

Rows + * should be returned from the data source if and only if all the filters match. That is, filters + * must be interpreted as ANDed together. + */ + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)] + val postScan = mutable.ArrayBuffer.empty[Filter] + val reserved = mutable.ArrayBuffer.empty[Filter] + + val converter = new SparkFilterConverter(table.rowType) + val visitor = new PartitionPredicateVisitor(table.partitionKeys()) + filters.foreach { + filter => + val predicate = converter.convertIgnoreFailure(filter) + if (predicate == null) { + postScan.append(filter) + } else { + pushable.append((filter, predicate)) + if (predicate.visit(visitor)) { + reserved.append(filter) + } else { + postScan.append(filter) + } + } + } + + if (pushable.nonEmpty) { + this.pushedSparkFilters = pushable.map(_._1).toArray + this.pushedPaimonPredicates = pushable.map(_._2).toArray + } + if (reserved.nonEmpty) { + this.reservedFilters = reserved.toArray + } + if (postScan.nonEmpty) { + this.hasPostScanPredicates = true + } + postScan.toArray + } + + override def pushedFilters(): Array[Filter] = { + pushedSparkFilters + } +} diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonSplitScanBuilder.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonSplitScanBuilder.scala new file mode 100644 index 000000000000..ede18b8cc990 --- /dev/null +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonSplitScanBuilder.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +import org.apache.paimon.table.KnownSplitsTable + +import org.apache.spark.sql.connector.read.Scan + +class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonScanBuilder(table) { + override def build(): Scan = { + PaimonSplitScan(table, table.splits(), requiredSchema, pushedPaimonPredicates) + } +} diff --git a/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala new file mode 100644 index 000000000000..56223c36cd75 --- /dev/null +++ b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.catalyst.analysis.expressions + +import org.apache.paimon.predicate.{Predicate, PredicateBuilder} +import org.apache.paimon.spark.SparkFilterConverter +import org.apache.paimon.types.RowType + +import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilter} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +trait ExpressionHelper extends ExpressionHelperBase { + + def convertConditionToPaimonPredicate( + condition: Expression, + output: Seq[Attribute], + rowType: RowType, + ignorePartialFailure: Boolean = false): Option[Predicate] = { + val converter = new SparkFilterConverter(rowType) + val filters = normalizeExprs(Seq(condition), output) + .flatMap(splitConjunctivePredicates(_).flatMap { + f => + val filter = translateFilter(f, supportNestedPredicatePushdown = true) + if (filter.isEmpty && !ignorePartialFailure) { + throw new RuntimeException( + "Exec update failed:" + + s" cannot translate expression to source filter: $f") + } + filter + }) + + val predicates = filters.map(converter.convert(_, ignorePartialFailure)).filter(_ != null) + if (predicates.isEmpty) { + None + } else { + Some(PredicateBuilder.and(predicates: _*)) + } + } +} diff --git a/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..e0705b761ab9 --- /dev/null +++ b/paimon-spark/paimon-spark-3.2/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.Filter + +class PaimonPushDownTest extends PaimonPushDownTestBase { + + override def checkFilterExists(sql: String): Boolean = { + spark + .sql(sql) + .queryExecution + .optimizedPlan + .find { + case Filter(_: Expression, _) => true + case _ => false + } + .isDefined + } + + override def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = { + spark + .sql(sql) + .queryExecution + .optimizedPlan + .find { + case Filter(c: Expression, _) => + c.find { + case EqualTo(a: AttributeReference, r: Literal) => + a.name.equals(name) && r.equals(value) + case _ => false + }.isDefined + case _ => false + } + .isDefined + } +} diff --git a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala new file mode 100644 index 000000000000..4c62d58a811b --- /dev/null +++ b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark + +import org.apache.paimon.predicate.Predicate +import org.apache.paimon.table.{BucketMode, FileStoreTable, Table} +import org.apache.paimon.table.source.{DataSplit, Split} + +import org.apache.spark.sql.PaimonUtils.fieldReference +import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.connector.read.{SupportsReportPartitioning, SupportsRuntimeFiltering} +import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.sources.{Filter, In} +import org.apache.spark.sql.types.StructType + +import scala.collection.JavaConverters._ + +case class PaimonScan( + table: Table, + requiredSchema: StructType, + filters: Seq[Predicate], + reservedFilters: Seq[Filter], + override val pushDownLimit: Option[Int], + bucketedScanDisabled: Boolean = false) + extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) + with SupportsRuntimeFiltering + with SupportsReportPartitioning { + + def disableBucketedScan(): PaimonScan = { + copy(bucketedScanDisabled = true) + } + + @transient + private lazy val extractBucketTransform: Option[Transform] = { + table match { + case fileStoreTable: FileStoreTable => + val bucketSpec = fileStoreTable.bucketSpec() + if (bucketSpec.getBucketMode != BucketMode.HASH_FIXED) { + None + } else if (bucketSpec.getBucketKeys.size() > 1) { + None + } else { + // Spark does not support bucket with several input attributes, + // so we only support one bucket key case. + assert(bucketSpec.getNumBuckets > 0) + assert(bucketSpec.getBucketKeys.size() == 1) + val bucketKey = bucketSpec.getBucketKeys.get(0) + if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) { + Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey)) + } else { + None + } + } + + case _ => None + } + } + + private def shouldDoBucketedScan: Boolean = { + !bucketedScanDisabled && conf.v2BucketingEnabled && extractBucketTransform.isDefined + } + + // Since Spark 3.3 + override def outputPartitioning: Partitioning = { + extractBucketTransform + .map(bucket => new KeyGroupedPartitioning(Array(bucket), lazyInputPartitions.size)) + .getOrElse(new UnknownPartitioning(0)) + } + + override def getInputPartitions(splits: Array[Split]): Seq[PaimonInputPartition] = { + if (!shouldDoBucketedScan || splits.exists(!_.isInstanceOf[DataSplit])) { + return super.getInputPartitions(splits) + } + + splits + .map(_.asInstanceOf[DataSplit]) + .groupBy(_.bucket()) + .map { + case (bucket, groupedSplits) => + PaimonBucketedInputPartition(groupedSplits, bucket) + } + .toSeq + } + + // Since Spark 3.2 + override def filterAttributes(): Array[NamedReference] = { + val requiredFields = readBuilder.readType().getFieldNames.asScala + table + .partitionKeys() + .asScala + .toArray + .filter(requiredFields.contains) + .map(fieldReference) + } + + override def filter(filters: Array[Filter]): Unit = { + val converter = new SparkFilterConverter(table.rowType()) + val partitionFilter = filters.flatMap { + case in @ In(attr, _) if table.partitionKeys().contains(attr) => + Some(converter.convert(in)) + case _ => None + } + if (partitionFilter.nonEmpty) { + readBuilder.withFilter(partitionFilter.head) + // set inputPartitions null to trigger to get the new splits. + inputPartitions = null + } + } +} diff --git a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala new file mode 100644 index 000000000000..43459648095d --- /dev/null +++ b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} +import org.apache.spark.sql.types.StringType + +// Copy from Spark 3.4+ +private[sql] object PredicateUtils { + + def toV1(predicate: Predicate): Option[Filter] = { + + def isValidBinaryPredicate(): Boolean = { + if ( + predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]] + ) { + true + } else { + false + } + } + + predicate.name() match { + case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val values = predicate.children().drop(1) + if (values.length > 0) { + if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None + val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType + if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) { + return None + } + val inValues = values.map( + v => + CatalystTypeConverters.convertToScala( + v.asInstanceOf[LiteralValue[_]].value, + dataType)) + Some(In(attribute, inValues)) + } else { + Some(In(attribute, Array.empty[Any])) + } + + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate() => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType) + val v1Filter = predicate.name() match { + case "=" => EqualTo(attribute, v1Value) + case "<=>" => EqualNullSafe(attribute, v1Value) + case ">" => GreaterThan(attribute, v1Value) + case ">=" => GreaterThanOrEqual(attribute, v1Value) + case "<" => LessThan(attribute, v1Value) + case "<=" => LessThanOrEqual(attribute, v1Value) + } + Some(v1Filter) + + case "IS_NULL" | "IS_NOT_NULL" + if predicate.children().length == 1 && + predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val v1Filter = predicate.name() match { + case "IS_NULL" => IsNull(attribute) + case "IS_NOT_NULL" => IsNotNull(attribute) + } + Some(v1Filter) + + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate() => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + if (!value.dataType.sameType(StringType)) return None + val v1Value = value.value.toString + val v1Filter = predicate.name() match { + case "STARTS_WITH" => + StringStartsWith(attribute, v1Value) + case "ENDS_WITH" => + StringEndsWith(attribute, v1Value) + case "CONTAINS" => + StringContains(attribute, v1Value) + } + Some(v1Filter) + + case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty => + val v1Filter = predicate.name() match { + case "ALWAYS_TRUE" => AlwaysTrue() + case "ALWAYS_FALSE" => AlwaysFalse() + } + Some(v1Filter) + + case "AND" => + val and = predicate.asInstanceOf[V2And] + val left = toV1(and.left()) + val right = toV1(and.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(And(left.get, right.get)) + } else { + None + } + + case "OR" => + val or = predicate.asInstanceOf[V2Or] + val left = toV1(or.left()) + val right = toV1(or.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(Or(left.get, right.get)) + } else if (left.nonEmpty) { + left + } else { + right + } + + case "NOT" => + val child = toV1(predicate.asInstanceOf[V2Not].child()) + if (child.nonEmpty) { + Some(Not(child.get)) + } else { + None + } + + case _ => None + } + } + + def toV1(predicates: Array[Predicate]): Array[Filter] = { + predicates.flatMap(toV1(_)) + } +} diff --git a/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-3.3/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-3.4/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-3.5/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala new file mode 100644 index 000000000000..26677d85c71a --- /dev/null +++ b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class PaimonPushDownTest extends PaimonPushDownTestBase {} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index 2050c937c6a3..c0b8cfd66be1 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -49,7 +49,7 @@ import static org.apache.paimon.predicate.PredicateBuilder.convertJavaObject; -/** Conversion from {@link Filter} to {@link Predicate}. */ +/** Conversion from {@link Filter} to {@link Predicate}, remove it when Spark 3.2 is dropped. */ public class SparkFilterConverter { public static final List SUPPORT_FILTERS = diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala index b9d235a9de1d..5e790cb301de 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala @@ -49,8 +49,6 @@ abstract class PaimonBaseScan( with ColumnPruningAndPushDown with StatisticsHelper { - protected var runtimeFilters: Array[Filter] = Array.empty - protected var inputPartitions: Seq[PaimonInputPartition] = _ override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options()) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala index a265ee78f5b9..1db178448413 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala @@ -18,77 +18,31 @@ package org.apache.paimon.spark -import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate} +import org.apache.paimon.predicate.Predicate import org.apache.paimon.table.Table import org.apache.spark.internal.Logging -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownRequiredColumns} import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType -import scala.collection.mutable - abstract class PaimonBaseScanBuilder(table: Table) extends ScanBuilder - with SupportsPushDownFilters with SupportsPushDownRequiredColumns with Logging { protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType()) - protected var pushedPredicates: Array[(Filter, Predicate)] = Array.empty + protected var pushedPaimonPredicates: Array[Predicate] = Array.empty - protected var partitionFilters: Array[Filter] = Array.empty + protected var reservedFilters: Array[Filter] = Array.empty - protected var postScanFilters: Array[Filter] = Array.empty + protected var hasPostScanPredicates = false protected var pushDownLimit: Option[Int] = None override def build(): Scan = { - PaimonScan(table, requiredSchema, pushedPredicates.map(_._2), partitionFilters, pushDownLimit) - } - - /** - * Pushes down filters, and returns filters that need to be evaluated after scanning.

Rows - * should be returned from the data source if and only if all of the filters match. That is, - * filters must be interpreted as ANDed together. - */ - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val pushable = mutable.ArrayBuffer.empty[(Filter, Predicate)] - val postScan = mutable.ArrayBuffer.empty[Filter] - val partitionFilter = mutable.ArrayBuffer.empty[Filter] - - val converter = new SparkFilterConverter(table.rowType) - val visitor = new PartitionPredicateVisitor(table.partitionKeys()) - filters.foreach { - filter => - val predicate = converter.convertIgnoreFailure(filter) - if (predicate == null) { - postScan.append(filter) - } else { - pushable.append((filter, predicate)) - if (predicate.visit(visitor)) { - partitionFilter.append(filter) - } else { - postScan.append(filter) - } - } - } - - if (pushable.nonEmpty) { - this.pushedPredicates = pushable.toArray - } - if (partitionFilter.nonEmpty) { - this.partitionFilters = partitionFilter.toArray - } - if (postScan.nonEmpty) { - this.postScanFilters = postScan.toArray - } - postScan.toArray - } - - override def pushedFilters(): Array[Filter] = { - pushedPredicates.map(_._1) + PaimonScan(table, requiredSchema, pushedPaimonPredicates, reservedFilters, pushDownLimit) } override def pruneColumns(requiredSchema: StructType): Unit = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala index 490a1b133f6f..1f4e88e8d160 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonLocalScan.scala @@ -18,11 +18,11 @@ package org.apache.paimon.spark +import org.apache.paimon.predicate.Predicate import org.apache.paimon.table.Table import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.LocalScan -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType /** A scan does not require [[RDD]] to execute */ @@ -30,7 +30,7 @@ case class PaimonLocalScan( rows: Array[InternalRow], readSchema: StructType, table: Table, - filters: Array[Filter]) + filters: Array[Predicate]) extends LocalScan { override def description(): String = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala index 2f1e6c53ab0a..f2818f7c032c 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala @@ -23,10 +23,11 @@ import org.apache.paimon.table.{BucketMode, FileStoreTable, Table} import org.apache.paimon.table.source.{DataSplit, Split} import org.apache.spark.sql.PaimonUtils.fieldReference -import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, SortDirection, SortOrder, Transform} -import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning, SupportsRuntimeFiltering} +import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.connector.expressions.filter.{Predicate => SparkPredicate} +import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning, SupportsRuntimeV2Filtering} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} -import org.apache.spark.sql.sources.{Filter, In} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import scala.collection.JavaConverters._ @@ -39,7 +40,7 @@ case class PaimonScan( override val pushDownLimit: Option[Int], bucketedScanDisabled: Boolean = false) extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, pushDownLimit) - with SupportsRuntimeFiltering + with SupportsRuntimeV2Filtering with SupportsReportPartitioning with SupportsReportOrdering { @@ -156,19 +157,18 @@ case class PaimonScan( .map(fieldReference) } - override def filter(filters: Array[Filter]): Unit = { - val converter = new SparkFilterConverter(table.rowType()) - val partitionFilter = filters.flatMap { - case in @ In(attr, _) if table.partitionKeys().contains(attr) => - Some(converter.convert(in)) + override def filter(predicates: Array[SparkPredicate]): Unit = { + val converter = SparkV2FilterConverter(table.rowType()) + val partitionKeys = table.partitionKeys().asScala.toSeq + val partitionFilter = predicates.flatMap { + case p if SparkV2FilterConverter.isSupportedRuntimeFilter(p, partitionKeys) => + converter.convert(p, ignoreFailure = true) case _ => None } if (partitionFilter.nonEmpty) { - this.runtimeFilters = filters - readBuilder.withFilter(partitionFilter.head) + readBuilder.withFilter(partitionFilter.toList.asJava) // set inputPartitions null to trigger to get the new splits. inputPartitions = null } } - } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala index 0393a1cd1578..7b6c65c37f76 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala @@ -18,22 +18,71 @@ package org.apache.paimon.spark -import org.apache.paimon.predicate.PredicateBuilder +import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, PredicateBuilder} import org.apache.paimon.spark.aggregate.LocalAggregator import org.apache.paimon.table.Table import org.apache.paimon.table.source.DataSplit +import org.apache.spark.sql.PaimonUtils import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit} +import org.apache.spark.sql.connector.expressions.filter.{Predicate => SparkPredicate} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownV2Filters} +import org.apache.spark.sql.sources.Filter import scala.collection.JavaConverters._ +import scala.collection.mutable class PaimonScanBuilder(table: Table) extends PaimonBaseScanBuilder(table) + with SupportsPushDownV2Filters with SupportsPushDownLimit with SupportsPushDownAggregates { + private var localScan: Option[Scan] = None + private var pushedSparkPredicates = Array.empty[SparkPredicate] + + /** Pushes down filters, and returns filters that need to be evaluated after scanning. */ + override def pushPredicates(predicates: Array[SparkPredicate]): Array[SparkPredicate] = { + val pushable = mutable.ArrayBuffer.empty[(SparkPredicate, Predicate)] + val postScan = mutable.ArrayBuffer.empty[SparkPredicate] + val reserved = mutable.ArrayBuffer.empty[Filter] + + val converter = SparkV2FilterConverter(table.rowType) + val visitor = new PartitionPredicateVisitor(table.partitionKeys()) + predicates.foreach { + predicate => + converter.convert(predicate, ignoreFailure = true) match { + case Some(paimonPredicate) => + pushable.append((predicate, paimonPredicate)) + if (paimonPredicate.visit(visitor)) { + // We need to filter the stats using filter instead of predicate. + reserved.append(PaimonUtils.filterV2ToV1(predicate).get) + } else { + postScan.append(predicate) + } + case None => + postScan.append(predicate) + } + } + + if (pushable.nonEmpty) { + this.pushedSparkPredicates = pushable.map(_._1).toArray + this.pushedPaimonPredicates = pushable.map(_._2).toArray + } + if (reserved.nonEmpty) { + this.reservedFilters = reserved.toArray + } + if (postScan.nonEmpty) { + this.hasPostScanPredicates = true + } + postScan.toArray + } + + override def pushedPredicates: Array[SparkPredicate] = { + pushedSparkPredicates + } + override def pushLimit(limit: Int): Boolean = { // It is safe, since we will do nothing if it is the primary table and the split is not `rawConvertible` pushDownLimit = Some(limit) @@ -52,8 +101,8 @@ class PaimonScanBuilder(table: Table) return true } - // Only support with push down partition filter - if (postScanFilters.nonEmpty) { + // Only support when there is no post scan predicates. + if (hasPostScanPredicates) { return false } @@ -63,8 +112,8 @@ class PaimonScanBuilder(table: Table) } val readBuilder = table.newReadBuilder - if (pushedPredicates.nonEmpty) { - val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*) + if (pushedPaimonPredicates.nonEmpty) { + val pushedPartitionPredicate = PredicateBuilder.and(pushedPaimonPredicates.toList.asJava) readBuilder.withFilter(pushedPartitionPredicate) } val dataSplits = readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit]) @@ -77,7 +126,7 @@ class PaimonScanBuilder(table: Table) aggregator.result(), aggregator.resultSchema(), table, - pushedPredicates.map(_._1))) + pushedPaimonPredicates)) true } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala index 8d9e643f9485..7d0bf831550e 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSplitScan.scala @@ -26,9 +26,9 @@ import org.apache.paimon.table.source.{DataSplit, Split} import org.apache.spark.sql.connector.read.{Batch, Scan} import org.apache.spark.sql.types.StructType -class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonBaseScanBuilder(table) { +class PaimonSplitScanBuilder(table: KnownSplitsTable) extends PaimonScanBuilder(table) { override def build(): Scan = { - PaimonSplitScan(table, table.splits(), requiredSchema, pushedPredicates.map(_._2)) + PaimonSplitScan(table, table.splits(), requiredSchema, pushedPaimonPredicates) } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala index 11ef302672e1..e99d2aa7794c 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala @@ -20,9 +20,11 @@ package org.apache.paimon.spark import org.apache.paimon.data.{BinaryString, Decimal, Timestamp} import org.apache.paimon.predicate.{Predicate, PredicateBuilder} +import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType import org.apache.paimon.types.{DataTypeRoot, DecimalType, RowType} import org.apache.paimon.types.DataTypeRoot._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.{Literal, NamedReference} import org.apache.spark.sql.connector.expressions.filter.{And, Not, Or, Predicate => SparkPredicate} @@ -35,6 +37,15 @@ case class SparkV2FilterConverter(rowType: RowType) { val builder = new PredicateBuilder(rowType) + def convert(sparkPredicate: SparkPredicate, ignoreFailure: Boolean): Option[Predicate] = { + try { + Some(convert(sparkPredicate)) + } catch { + case _ if ignoreFailure => None + case e: Exception => throw e + } + } + def convert(sparkPredicate: SparkPredicate): Predicate = { sparkPredicate.name() match { case EQUAL_TO => @@ -147,36 +158,6 @@ case class SparkV2FilterConverter(rowType: RowType) { } } - private object UnaryPredicate { - def unapply(sparkPredicate: SparkPredicate): Option[String] = { - sparkPredicate.children() match { - case Array(n: NamedReference) => Some(toFieldName(n)) - case _ => None - } - } - } - - private object BinaryPredicate { - def unapply(sparkPredicate: SparkPredicate): Option[(String, Any)] = { - sparkPredicate.children() match { - case Array(l: NamedReference, r: Literal[_]) => Some((toFieldName(l), r.value)) - case Array(l: Literal[_], r: NamedReference) => Some((toFieldName(r), l.value)) - case _ => None - } - } - } - - private object MultiPredicate { - def unapply(sparkPredicate: SparkPredicate): Option[(String, Array[Any])] = { - sparkPredicate.children() match { - case Array(first: NamedReference, rest @ _*) - if rest.nonEmpty && rest.forall(_.isInstanceOf[Literal[_]]) => - Some(toFieldName(first), rest.map(_.asInstanceOf[Literal[_]].value).toArray) - case _ => None - } - } - } - private def fieldIndex(fieldName: String): Int = { val index = rowType.getFieldIndex(fieldName) // TODO: support nested field @@ -205,15 +186,19 @@ case class SparkV2FilterConverter(rowType: RowType) { value.asInstanceOf[org.apache.spark.sql.types.Decimal].toJavaBigDecimal, precision, scale) - case DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE | DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE => + case DataTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE => Timestamp.fromMicros(value.asInstanceOf[Long]) + case DataTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE => + if (treatPaimonTimestampTypeAsSparkTimestampType()) { + Timestamp.fromSQLTimestamp(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) + } else { + Timestamp.fromMicros(value.asInstanceOf[Long]) + } case _ => throw new UnsupportedOperationException( s"Convert value: $value to datatype: $dataType is unsupported.") } } - - private def toFieldName(ref: NamedReference): String = ref.fieldNames().mkString(".") } object SparkV2FilterConverter { @@ -233,4 +218,48 @@ object SparkV2FilterConverter { private val STRING_START_WITH = "STARTS_WITH" private val STRING_END_WITH = "ENDS_WITH" private val STRING_CONTAINS = "CONTAINS" + + private object UnaryPredicate { + def unapply(sparkPredicate: SparkPredicate): Option[String] = { + sparkPredicate.children() match { + case Array(n: NamedReference) => Some(toFieldName(n)) + case _ => None + } + } + } + + private object BinaryPredicate { + def unapply(sparkPredicate: SparkPredicate): Option[(String, Any)] = { + sparkPredicate.children() match { + case Array(l: NamedReference, r: Literal[_]) => Some((toFieldName(l), r.value)) + case Array(l: Literal[_], r: NamedReference) => Some((toFieldName(r), l.value)) + case _ => None + } + } + } + + private object MultiPredicate { + def unapply(sparkPredicate: SparkPredicate): Option[(String, Array[Any])] = { + sparkPredicate.children() match { + case Array(first: NamedReference, rest @ _*) + if rest.nonEmpty && rest.forall(_.isInstanceOf[Literal[_]]) => + Some(toFieldName(first), rest.map(_.asInstanceOf[Literal[_]].value).toArray) + case _ => None + } + } + } + + private def toFieldName(ref: NamedReference): String = ref.fieldNames().mkString(".") + + def isSupportedRuntimeFilter( + sparkPredicate: SparkPredicate, + partitionKeys: Seq[String]): Boolean = { + sparkPredicate.name() match { + case IN => + MultiPredicate.unapply(sparkPredicate) match { + case Some((fieldName, _)) => partitionKeys.contains(fieldName) + } + case _ => false + } + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala index 2eef2c41aebe..fcece0a26236 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/expressions/ExpressionHelper.scala @@ -19,12 +19,12 @@ package org.apache.paimon.spark.catalyst.analysis.expressions import org.apache.paimon.predicate.{Predicate, PredicateBuilder} -import org.apache.paimon.spark.SparkFilterConverter +import org.apache.paimon.spark.SparkV2FilterConverter import org.apache.paimon.spark.catalyst.Compatibility import org.apache.paimon.types.RowType import org.apache.spark.sql.{Column, SparkSession} -import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilter} +import org.apache.spark.sql.PaimonUtils.{normalizeExprs, translateFilterV2} import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, Cast, Expression, GetStructField, Literal, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} @@ -33,7 +33,40 @@ import org.apache.spark.sql.paimon.shims.SparkShimLoader import org.apache.spark.sql.types.{DataType, NullType} /** An expression helper. */ -trait ExpressionHelper extends PredicateHelper { +trait ExpressionHelper extends ExpressionHelperBase { + + def convertConditionToPaimonPredicate( + condition: Expression, + output: Seq[Attribute], + rowType: RowType, + ignorePartialFailure: Boolean = false): Option[Predicate] = { + val converter = SparkV2FilterConverter(rowType) + val sparkPredicates = normalizeExprs(Seq(condition), output) + .flatMap(splitConjunctivePredicates(_).flatMap { + f => + val predicate = + try { + translateFilterV2(f) + } catch { + case _: Throwable => + None + } + if (predicate.isEmpty && !ignorePartialFailure) { + throw new RuntimeException(s"Cannot translate expression to predicate: $f") + } + predicate + }) + + val predicates = sparkPredicates.flatMap(converter.convert(_, ignorePartialFailure)) + if (predicates.isEmpty) { + None + } else { + Some(PredicateBuilder.and(predicates: _*)) + } + } +} + +trait ExpressionHelperBase extends PredicateHelper { import ExpressionHelper._ @@ -162,32 +195,6 @@ trait ExpressionHelper extends PredicateHelper { isPredicatePartitionColumnsOnly(e, partitionCols, resolver) && !SubqueryExpression.hasSubquery(expr)) } - - def convertConditionToPaimonPredicate( - condition: Expression, - output: Seq[Attribute], - rowType: RowType, - ignorePartialFailure: Boolean = false): Option[Predicate] = { - val converter = new SparkFilterConverter(rowType) - val filters = normalizeExprs(Seq(condition), output) - .flatMap(splitConjunctivePredicates(_).flatMap { - f => - val filter = translateFilter(f, supportNestedPredicatePushdown = true) - if (filter.isEmpty && !ignorePartialFailure) { - throw new RuntimeException( - "Exec update failed:" + - s" cannot translate expression to source filter: $f") - } - filter - }) - - val predicates = filters.map(converter.convert(_, ignorePartialFailure)).filter(_ != null) - if (predicates.isEmpty) { - None - } else { - Some(PredicateBuilder.and(predicates: _*)) - } - } } object ExpressionHelper { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala index 28ac1623fb59..d41fd7d4d287 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala @@ -63,6 +63,7 @@ trait PaimonCommand extends WithFileStoreTable with ExpressionHelper with SQLCon def convertPartitionFilterToMap( filter: Filter, partitionRowType: RowType): Map[String, String] = { + // todo: replace it with SparkV2FilterConverter when we drop Spark3.2 val converter = new SparkFilterConverter(partitionRowType) splitConjunctiveFilters(filter).map { case EqualNullSafe(attribute, value) => diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala index d01a840f8ece..a1ce25137436 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/PaimonUtils.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy.translateFilterV2WithMapping +import org.apache.spark.sql.internal.connector.PredicateUtils import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.PartitioningUtils @@ -74,6 +75,10 @@ object PaimonUtils { translateFilterV2WithMapping(predicate, None) } + def filterV2ToV1(predicate: Predicate): Option[Filter] = { + PredicateUtils.toV1(predicate) + } + def fieldReference(name: String): FieldReference = { fieldReference(Seq(name)) } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala similarity index 97% rename from paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala rename to paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala index 503f1c8e3e9d..15c021babb75 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTestBase.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.junit.jupiter.api.Assertions -class PaimonPushDownTest extends PaimonSparkTestBase { +abstract class PaimonPushDownTestBase extends PaimonSparkTestBase { import testImplicits._ @@ -101,6 +101,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { } test("Paimon pushDown: limit for append-only tables") { + assume(gteqSpark3_3) spark.sql(s""" |CREATE TABLE T (a INT, b STRING, c STRING) |PARTITIONED BY (c) @@ -128,6 +129,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { } test("Paimon pushDown: limit for primary key table") { + assume(gteqSpark3_3) spark.sql(s""" |CREATE TABLE T (a INT, b STRING, c STRING) |TBLPROPERTIES ('primary-key'='a') @@ -202,6 +204,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase { } test("Paimon pushDown: limit for table with deletion vector") { + assume(gteqSpark3_3) Seq(true, false).foreach( deletionVectorsEnabled => { Seq(true, false).foreach( @@ -279,14 +282,14 @@ class PaimonPushDownTest extends PaimonSparkTestBase { SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty()) } - private def checkFilterExists(sql: String): Boolean = { + def checkFilterExists(sql: String): Boolean = { spark.sql(sql).queryExecution.optimizedPlan.exists { case Filter(_: Expression, _) => true case _ => false } } - private def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = { + def checkEqualToFilterExists(sql: String, name: String, value: Literal): Boolean = { spark.sql(sql).queryExecution.optimizedPlan.exists { case Filter(c: Expression, _) => c.exists { diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala index b9cbc29b3aa3..e51a8f7a0114 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala @@ -21,10 +21,13 @@ package org.apache.paimon.spark.sql import org.apache.paimon.data.{BinaryString, Decimal, Timestamp} import org.apache.paimon.predicate.PredicateBuilder import org.apache.paimon.spark.{PaimonSparkTestBase, SparkV2FilterConverter} +import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType +import org.apache.paimon.table.source.DataSplit import org.apache.paimon.types.RowType import org.apache.spark.SparkConf import org.apache.spark.sql.PaimonUtils.translateFilterV2 +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -54,9 +57,26 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { | decimal_col DECIMAL(10, 5), | boolean_col BOOLEAN, | date_col DATE, - | binary BINARY + | binary_col BINARY |) USING paimon |""".stripMargin) + + sql(""" + |INSERT INTO test_tbl VALUES + |('hello', 1, 1, 1, 1, 1.0, 1.0, 12.12345, true, date('2025-01-15'), binary('b1')) + |""".stripMargin) + sql(""" + |INSERT INTO test_tbl VALUES + |('world', 2, 2, 2, 2, 2.0, 2.0, 22.12345, false, date('2025-01-16'), binary('b2')) + |""".stripMargin) + sql(""" + |INSERT INTO test_tbl VALUES + |('hi', 3, 3, 3, 3, 3.0, 3.0, 32.12345, false, date('2025-01-17'), binary('b3')) + |""".stripMargin) + sql(""" + |INSERT INTO test_tbl VALUES + |('paimon', 4, 4, null, 4, 4.0, 4.0, 42.12345, true, date('2025-01-18'), binary('b4')) + |""".stripMargin) } override protected def afterAll(): Unit = { @@ -71,147 +91,269 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { lazy val converter: SparkV2FilterConverter = SparkV2FilterConverter(rowType) test("V2Filter: all types") { - var actual = converter.convert(v2Filter("string_col = 'hello'")) + var filter = "string_col = 'hello'" + var actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(0, BinaryString.fromString("hello")))) + checkAnswer(sql(s"SELECT string_col from test_tbl WHERE $filter"), Seq(Row("hello"))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("byte_col = 1")) + filter = "byte_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(1, 1.toByte))) + checkAnswer(sql(s"SELECT byte_col from test_tbl WHERE $filter"), Seq(Row(1.toByte))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("short_col = 1")) + filter = "short_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(2, 1.toShort))) + checkAnswer(sql(s"SELECT short_col from test_tbl WHERE $filter"), Seq(Row(1.toShort))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("int_col = 1")) + filter = "int_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(3, 1))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("long_col = 1")) + filter = "long_col = 1" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(4, 1L))) + checkAnswer(sql(s"SELECT long_col from test_tbl WHERE $filter"), Seq(Row(1L))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("float_col = 1.0")) + filter = "float_col = 1.0" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(5, 1.0f))) + checkAnswer(sql(s"SELECT float_col from test_tbl WHERE $filter"), Seq(Row(1.0f))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("double_col = 1.0")) + filter = "double_col = 1.0" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(6, 1.0d))) + checkAnswer(sql(s"SELECT double_col from test_tbl WHERE $filter"), Seq(Row(1.0d))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("decimal_col = 12.12345")) + filter = "decimal_col = 12.12345" + actual = converter.convert(v2Filter(filter)) assert( actual.equals( builder.equal(7, Decimal.fromBigDecimal(new java.math.BigDecimal("12.12345"), 10, 5)))) + checkAnswer( + sql(s"SELECT decimal_col from test_tbl WHERE $filter"), + Seq(Row(new java.math.BigDecimal("12.12345")))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("boolean_col = true")) + filter = "boolean_col = true" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(8, true))) + checkAnswer(sql(s"SELECT boolean_col from test_tbl WHERE $filter"), Seq(Row(true), Row(true))) + assert(scanFilesCount(filter) == 2) - actual = converter.convert(v2Filter("date_col = cast('2025-01-15' as date)")) + filter = "date_col = date('2025-01-15')" + actual = converter.convert(v2Filter(filter)) val localDate = LocalDate.parse("2025-01-15") val epochDay = localDate.toEpochDay.toInt assert(actual.equals(builder.equal(9, epochDay))) + checkAnswer( + sql(s"SELECT date_col from test_tbl WHERE $filter"), + sql("SELECT date('2025-01-15')")) + assert(scanFilesCount(filter) == 1) + filter = "binary_col = binary('b1')" intercept[UnsupportedOperationException] { - actual = converter.convert(v2Filter("binary = binary('b1')")) + actual = converter.convert(v2Filter(filter)) } + checkAnswer(sql(s"SELECT binary_col from test_tbl WHERE $filter"), sql("SELECT binary('b1')")) + assert(scanFilesCount(filter) == 4) } test("V2Filter: timestamp and timestamp_ntz") { withTimeZone("Asia/Shanghai") { withTable("ts_tbl", "ts_ntz_tbl") { sql("CREATE TABLE ts_tbl (ts_col TIMESTAMP) USING paimon") + sql("INSERT INTO ts_tbl VALUES (timestamp'2025-01-15 00:00:00.123')") + sql("INSERT INTO ts_tbl VALUES (timestamp'2025-01-16 00:00:00.123')") + + val filter1 = "ts_col = timestamp'2025-01-15 00:00:00.123'" val rowType1 = loadTable("ts_tbl").rowType() val converter1 = SparkV2FilterConverter(rowType1) - val actual1 = - converter1.convert(v2Filter("ts_col = timestamp'2025-01-15 00:00:00.123'", "ts_tbl")) - assert( - actual1.equals(new PredicateBuilder(rowType1) + val actual1 = converter1.convert(v2Filter(filter1, "ts_tbl")) + if (treatPaimonTimestampTypeAsSparkTimestampType()) { + assert(actual1.equals(new PredicateBuilder(rowType1) + .equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-15T00:00:00.123"))))) + } else { + assert(actual1.equals(new PredicateBuilder(rowType1) .equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-14T16:00:00.123"))))) + } + checkAnswer( + sql(s"SELECT ts_col from ts_tbl WHERE $filter1"), + sql("SELECT timestamp'2025-01-15 00:00:00.123'")) + assert(scanFilesCount(filter1, "ts_tbl") == 1) // Spark support TIMESTAMP_NTZ since Spark 3.4 if (gteqSpark3_4) { sql("CREATE TABLE ts_ntz_tbl (ts_ntz_col TIMESTAMP_NTZ) USING paimon") + sql("INSERT INTO ts_ntz_tbl VALUES (timestamp_ntz'2025-01-15 00:00:00.123')") + sql("INSERT INTO ts_ntz_tbl VALUES (timestamp_ntz'2025-01-16 00:00:00.123')") + val filter2 = "ts_ntz_col = timestamp_ntz'2025-01-15 00:00:00.123'" val rowType2 = loadTable("ts_ntz_tbl").rowType() val converter2 = SparkV2FilterConverter(rowType2) - val actual2 = converter2.convert( - v2Filter("ts_ntz_col = timestamp_ntz'2025-01-15 00:00:00.123'", "ts_ntz_tbl")) + val actual2 = converter2.convert(v2Filter(filter2, "ts_ntz_tbl")) assert(actual2.equals(new PredicateBuilder(rowType2) .equal(0, Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-15T00:00:00.123"))))) + checkAnswer( + sql(s"SELECT ts_ntz_col from ts_ntz_tbl WHERE $filter2"), + sql("SELECT timestamp_ntz'2025-01-15 00:00:00.123'")) + assert(scanFilesCount(filter2, "ts_ntz_tbl") == 1) } } } } test("V2Filter: EqualTo") { - val actual = converter.convert(v2Filter("int_col = 1")) + val filter = "int_col = 1" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(3, 1))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: EqualNullSafe") { - var actual = converter.convert(v2Filter("int_col <=> 1")) + var filter = "int_col <=> 1" + var actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.equal(3, 1))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) - actual = converter.convert(v2Filter("int_col <=> null")) + filter = "int_col <=> null" + actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.isNull(3))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(null))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: GreaterThan") { - val actual = converter.convert(v2Filter("int_col > 1")) - assert(actual.equals(builder.greaterThan(3, 1))) + val filter = "int_col > 2" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(builder.greaterThan(3, 2))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(3))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: GreaterThanOrEqual") { - val actual = converter.convert(v2Filter("int_col >= 1")) - assert(actual.equals(builder.greaterOrEqual(3, 1))) + val filter = "int_col >= 2" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(builder.greaterOrEqual(3, 2))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(2), Row(3))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: LessThan") { - val actual = converter.convert(v2Filter("int_col < 1")) - assert(actual.equals(builder.lessThan(3, 1))) + val filter = "int_col < 2" + val actual = converter.convert(v2Filter("int_col < 2")) + assert(actual.equals(builder.lessThan(3, 2))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(1))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: LessThanOrEqual") { - val actual = converter.convert(v2Filter("int_col <= 1")) - assert(actual.equals(builder.lessOrEqual(3, 1))) + val filter = "int_col <= 2" + val actual = converter.convert(v2Filter("int_col <= 2")) + assert(actual.equals(builder.lessOrEqual(3, 2))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col "), + Seq(Row(1), Row(2))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: In") { - val actual = converter.convert(v2Filter("int_col IN (1, 2, 3)")) - assert(actual.equals(builder.in(3, List(1, 2, 3).map(_.asInstanceOf[AnyRef]).asJava))) + val filter = "int_col IN (1, 2)" + val actual = converter.convert(v2Filter("int_col IN (1, 2)")) + assert(actual.equals(builder.in(3, List(1, 2).map(_.asInstanceOf[AnyRef]).asJava))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(1), Row(2))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: IsNull") { - val actual = converter.convert(v2Filter("int_col IS NULL")) + val filter = "int_col IS NULL" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.isNull(3))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(null))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: IsNotNull") { - val actual = converter.convert(v2Filter("int_col IS NOT NULL")) + val filter = "int_col IS NOT NULL" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.isNotNull(3))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(1), Row(2), Row(3))) + assert(scanFilesCount(filter) == 3) } test("V2Filter: And") { - val actual = converter.convert(v2Filter("int_col > 1 AND int_col < 10")) - assert(actual.equals(PredicateBuilder.and(builder.greaterThan(3, 1), builder.lessThan(3, 10)))) + val filter = "int_col > 1 AND int_col < 3" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(PredicateBuilder.and(builder.greaterThan(3, 1), builder.lessThan(3, 3)))) + checkAnswer(sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), Seq(Row(2))) + assert(scanFilesCount(filter) == 1) } test("V2Filter: Or") { - val actual = converter.convert(v2Filter("int_col > 1 OR int_col < 10")) - assert(actual.equals(PredicateBuilder.or(builder.greaterThan(3, 1), builder.lessThan(3, 10)))) + val filter = "int_col > 2 OR int_col IS NULL" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(PredicateBuilder.or(builder.greaterThan(3, 2), builder.isNull(3)))) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(null), Row(3))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: Not") { - val actual = converter.convert(v2Filter("NOT (int_col > 1)")) - assert(actual.equals(builder.greaterThan(3, 1).negate().get())) + val filter = "NOT (int_col > 2)" + val actual = converter.convert(v2Filter(filter)) + assert(actual.equals(builder.greaterThan(3, 2).negate().get())) + checkAnswer( + sql(s"SELECT int_col from test_tbl WHERE $filter ORDER BY int_col"), + Seq(Row(1), Row(2))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: StartWith") { - val actual = converter.convert(v2Filter("string_col LIKE 'h%'")) + val filter = "string_col LIKE 'h%'" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.startsWith(0, BinaryString.fromString("h")))) + checkAnswer( + sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"), + Seq(Row("hello"), Row("hi"))) + assert(scanFilesCount(filter) == 2) } test("V2Filter: EndWith") { - val actual = converter.convert(v2Filter("string_col LIKE '%o'")) + val filter = "string_col LIKE '%o'" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.endsWith(0, BinaryString.fromString("o")))) + checkAnswer( + sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"), + Seq(Row("hello"))) + // EndWith does not have file skipping effect now. + assert(scanFilesCount(filter) == 4) } test("V2Filter: Contains") { - val actual = converter.convert(v2Filter("string_col LIKE '%e%'")) + val filter = "string_col LIKE '%e%'" + val actual = converter.convert(v2Filter(filter)) assert(actual.equals(builder.contains(0, BinaryString.fromString("e")))) + checkAnswer( + sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"), + Seq(Row("hello"))) + // Contains does not have file skipping effect now. + assert(scanFilesCount(filter) == 4) } private def v2Filter(str: String, tableName: String = "test_tbl"): Predicate = { @@ -221,4 +363,11 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { .condition translateFilterV2(condition).get } + + private def scanFilesCount(str: String, tableName: String = "test_tbl"): Int = { + getPaimonScan(s"SELECT * FROM $tableName WHERE $str").lazyInputPartitions + .flatMap(_.splits) + .map(_.asInstanceOf[DataSplit].dataFiles().size()) + .sum + } }