diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 890ae56ffa529..2d578d7495178 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -26,7 +26,18 @@ from contextlib import redirect_stdout from pyspark.sql import Row, functions, DataFrame -from pyspark.sql.functions import col, lit, count, struct, date_format, to_date, array, explode +from pyspark.sql.functions import ( + col, + lit, + count, + struct, + date_format, + to_date, + array, + explode, + when, + concat, +) from pyspark.sql.types import ( StringType, IntegerType, @@ -189,6 +200,26 @@ def test_drop(self): self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"]) self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"]) + def test_drop_notexistent_col(self): + df1 = self.spark.createDataFrame( + [("a", "b", "c")], + schema="colA string, colB string, colC string", + ) + df2 = self.spark.createDataFrame( + [("c", "d", "e")], + schema="colC string, colD string, colE string", + ) + df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn( + "colB", + when(df1["colB"] == "b", concat(df1["colB"].cast("string"), lit("x"))).otherwise( + df1["colB"] + ), + ) + df4 = df3.drop(df1["colB"]) + + self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD", "colE"]) + self.assertEqual(df4.count(), 1) + def test_drop_join(self): left_df = self.spark.createDataFrame( [(1, "a"), (2, "b"), (3, "c")], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f9272cc037359..845171722073e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -454,7 +454,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveNaturalAndUsingJoin :: ResolveOutputRelation :: new ResolveTableConstraints(catalogManager) :: - new ResolveDataFrameDropColumns(catalogManager) :: new ResolveSetVariable(catalogManager) :: ExtractWindowExpressions :: GlobalAggregates :: @@ -1483,6 +1482,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor new ResolveReferencesInUpdate(catalogManager) private val resolveReferencesInSort = new ResolveReferencesInSort(catalogManager) + private val resolveDataFrameDropColumns = + new ResolveDataFrameDropColumns(catalogManager) /** * Return true if there're conflicting attributes among children's outputs of a plan @@ -1793,6 +1794,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Pass for Execute Immediate as arguments will be resolved by [[SubstituteExecuteImmediate]]. case e : ExecuteImmediateQuery => e + case d: DataFrameDropColumns if !d.resolved => + resolveDataFrameDropColumns(d) + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 69591ed8c5f9b..d3e52d11b465f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -509,6 +509,33 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { includeLastResort = includeLastResort) } + // Try to resolve `UnresolvedAttribute` by the children with Plan Ids. + // The `UnresolvedAttribute` must have a Plan Id: + // - If Plan Id not found in the plan, raise CANNOT_RESOLVE_DATAFRAME_COLUMN. + // - If Plan Id found in the plan, but column not found, return None. + // - Otherwise, return the resolved expression. + private[sql] def tryResolveColumnByPlanChildren( + u: UnresolvedAttribute, + q: LogicalPlan, + includeLastResort: Boolean = false): Option[Expression] = { + assert(u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty, + s"UnresolvedAttribute $u should have a Plan Id tag") + + resolveDataFrameColumn(u, q.children).map { r => + resolveExpression( + r, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, conf.resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + throws = true, + includeLastResort = includeLastResort) + } + } + /** * The last resort to resolve columns. Currently it does two things: * - Try to resolve column names as outer references diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala index 0f9b93cc2986d..a0f67fa3f445f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS import org.apache.spark.sql.connector.catalog.CatalogManager @@ -27,17 +27,24 @@ import org.apache.spark.sql.connector.catalog.CatalogManager * Note that DataFrameDropColumns allows and ignores non-existing columns. */ class ResolveDataFrameDropColumns(val catalogManager: CatalogManager) - extends Rule[LogicalPlan] with ColumnResolutionHelper { + extends SQLConfHelper with ColumnResolutionHelper { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( _.containsPattern(DF_DROP_COLUMNS)) { case d: DataFrameDropColumns if d.childrenResolved => // expressions in dropList can be unresolved, e.g. // df.drop(col("non-existing-column")) - val dropped = d.dropList.map { + val dropped = d.dropList.flatMap { case u: UnresolvedAttribute => - resolveExpressionByPlanChildren(u, d) - case e => e + if (u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty) { + // Plan Id comes from Spark Connect, + // Here we ignore the `UnresolvedAttribute` if its Plan Id can be found + // but column not found. + tryResolveColumnByPlanChildren(u, d) + } else { + Some(resolveExpressionByPlanChildren(u, d)) + } + case e => Some(e) } val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr))) if (remaining.size == d.child.output.size) { diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala new file mode 100644 index 0000000000000..2993f44efcebb --- /dev/null +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.functions.{concat, lit, when} + +class DataFrameSuite extends QueryTest with RemoteSparkSession { + + test("drop") { + val sparkSession = spark + import sparkSession.implicits._ + + val df1 = Seq[(String, String, String)](("a", "b", "c")).toDF("colA", "colB", "colC") + + val df2 = Seq[(String, String, String)](("c", "d", "e")).toDF("colC", "colD", "colE") + + val df3 = df1 + .join(df2, df1.col("colC") === df2.col("colC")) + .withColumn( + "colB", + when(df1.col("colB") === "b", concat(df1.col("colB").cast("string"), lit("x"))) + .otherwise(df1.col("colB"))) + + val df4 = df3.drop(df1.col("colB")) + + assert(df4.columns === Array("colA", "colB", "colC", "colC", "colD", "colE")) + assert(df4.count() === 1) + } +}