Skip to content

Commit a93d914

Browse files
yaboladongjoon-hyun
authored andcommitted
[SPARK-41470][SQL] SPJ: Relax constraints on Storage-Partitioned-Join should assume InternalRow implements equals and hashCode
### What changes were proposed in this pull request? Introduce a new wrapper class for comparable InternalRow (returned by `HasPartitionKey`, with datatype) and remove `InternalRowSet` for easy `groupBy`, `Set`, `Map` and other operations. ### Why are the changes needed? Currently SPJ (Storage-Partitioned Join) actually assumes the `InternalRow` returned by `HasPartitionKey` implements equals and hashCode. We should remove this restriction. For example, see [comments](https://github.com/apache/iceberg/pull/6371/files#r1056852402) in Iceberg [StructInternalRow](https://github.com/apache/iceberg/blob/master/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/StructInternalRow.java#L362) should implements equals and hashCode. see . Actually it is not necessary. ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing tests Closes apache#39687 from yabola/InternalRow_hashcode. Authored-by: chenliang.lu <[email protected]> Signed-off-by: Chao Sun <[email protected]>
1 parent d009d26 commit a93d914

File tree

9 files changed

+149
-179
lines changed

9 files changed

+149
-179
lines changed

sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/InternalRowSet.scala

Lines changed: 0 additions & 65 deletions
This file was deleted.

sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/InternalRowSet.scala

Lines changed: 0 additions & 69 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable
2222

2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
2526
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types.{DataType, IntegerType}
2728

@@ -677,9 +678,6 @@ case class KeyGroupedShuffleSpec(
677678
}
678679
}
679680

680-
lazy val ordering: Ordering[InternalRow] =
681-
RowOrdering.createNaturalAscendingOrdering(partitioning.expressions.map(_.dataType))
682-
683681
override def numPartitions: Int = partitioning.numPartitions
684682

685683
override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
@@ -697,7 +695,9 @@ case class KeyGroupedShuffleSpec(
697695
distribution.clustering.length == otherDistribution.clustering.length &&
698696
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
699697
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
700-
case (left, right) => ordering.compare(left, right) == 0
698+
case (left, right) =>
699+
InternalRowComparableWrapper(left, partitioning.expressions)
700+
.equals(InternalRowComparableWrapper(right, partitioning.expressions))
701701
}
702702
case ShuffleSpecCollection(specs) =>
703703
specs.exists(isCompatibleWith)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering}
24+
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
25+
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition}
26+
import org.apache.spark.sql.types.{DataType, StructField, StructType}
27+
28+
/**
29+
* Wraps the [[InternalRow]] with the corresponding [[DataType]] to make it comparable with
30+
* the values in [[InternalRow]].
31+
* It uses Spark's internal murmur hash to compute hash code from an row, and uses [[RowOrdering]]
32+
* to perform equality checks.
33+
*
34+
* @param dataTypes the data types for the row
35+
*/
36+
class InternalRowComparableWrapper(val row: InternalRow, val dataTypes: Seq[DataType]) {
37+
38+
private val structType = StructType(dataTypes.map(t => StructField("f", t)))
39+
private val ordering = RowOrdering.createNaturalAscendingOrdering(dataTypes)
40+
41+
override def hashCode(): Int = Murmur3HashFunction.hash(row, structType, 42L).toInt
42+
43+
override def equals(other: Any): Boolean = {
44+
if (!other.isInstanceOf[InternalRowComparableWrapper]) {
45+
return false
46+
}
47+
val otherWrapper = other.asInstanceOf[InternalRowComparableWrapper]
48+
if (!otherWrapper.dataTypes.equals(this.dataTypes)) {
49+
return false
50+
}
51+
ordering.compare(row, otherWrapper.row) == 0
52+
}
53+
}
54+
55+
object InternalRowComparableWrapper {
56+
57+
def apply(
58+
partition: InputPartition with HasPartitionKey,
59+
partitionExpression: Seq[Expression]): InternalRowComparableWrapper = {
60+
new InternalRowComparableWrapper(
61+
partition.asInstanceOf[HasPartitionKey].partitionKey(), partitionExpression.map(_.dataType))
62+
}
63+
64+
def apply(
65+
partitionRow: InternalRow,
66+
partitionExpression: Seq[Expression]): InternalRowComparableWrapper = {
67+
new InternalRowComparableWrapper(partitionRow, partitionExpression.map(_.dataType))
68+
}
69+
70+
def mergePartitions(
71+
leftPartitioning: KeyGroupedPartitioning,
72+
rightPartitioning: KeyGroupedPartitioning,
73+
partitionExpression: Seq[Expression]): Seq[InternalRow] = {
74+
val partitionDataTypes = partitionExpression.map(_.dataType)
75+
val partitionsSet = new mutable.HashSet[InternalRowComparableWrapper]
76+
leftPartitioning.partitionValues
77+
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
78+
.foreach(partition => partitionsSet.add(partition))
79+
rightPartitioning.partitionValues
80+
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
81+
.foreach(partition => partitionsSet.add(partition))
82+
partitionsSet.map(_.row).toSeq
83+
}
84+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import java.util.OptionalLong
2424

2525
import scala.collection.mutable
2626

27+
import com.google.common.base.Objects
2728
import org.scalatest.Assertions._
2829

2930
import org.apache.spark.sql.catalyst.InternalRow
@@ -541,13 +542,31 @@ class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage
541542

542543
def keyString(): String = key.toArray.mkString("/")
543544

544-
override def partitionKey(): InternalRow = {
545-
InternalRow.fromSeq(key)
546-
}
545+
override def partitionKey(): InternalRow = PartitionInternalRow(key.toArray)
547546

548547
def clear(): Unit = rows.clear()
549548
}
550549

550+
/**
551+
* Theoretically, [[InternalRow]] returned by [[HasPartitionKey#partitionKey()]]
552+
* does not need to implement equal and hashcode methods.
553+
* But [[GenericInternalRow]] implements equals and hashcode methods already. Here we override it
554+
* to simulate that it has not been implemented to verify codes correctness.
555+
*/
556+
case class PartitionInternalRow(keys: Array[Any])
557+
extends GenericInternalRow(keys) {
558+
override def equals(other: Any): Boolean = {
559+
if (!other.isInstanceOf[PartitionInternalRow]) {
560+
return false
561+
}
562+
// Just compare by reference, not by value
563+
this.keys == other.asInstanceOf[PartitionInternalRow].keys
564+
}
565+
override def hashCode: Int = {
566+
Objects.hashCode(keys)
567+
}
568+
}
569+
551570
private class BufferedRowsReaderFactory(
552571
metadataColumnNames: Seq[String],
553572
nonMetaDataColumns: Seq[StructField],

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.QueryPlan
2727
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
28-
import org.apache.spark.sql.catalyst.util.InternalRowSet
29-
import org.apache.spark.sql.catalyst.util.truncatedString
28+
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
3029
import org.apache.spark.sql.connector.catalog.Table
31-
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeV2Filtering}
30+
import org.apache.spark.sql.connector.read._
3231

3332
/**
3433
* Physical plan node for scanning a batch of data from a data source v2.
@@ -80,24 +79,24 @@ case class BatchScanExec(
8079
"during runtime filtering: not all partitions implement HasPartitionKey after " +
8180
"filtering")
8281
}
83-
84-
val newRows = new InternalRowSet(p.expressions.map(_.dataType))
85-
newRows ++= newPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey())
86-
87-
val oldRows = p.partitionValues.toSet
88-
// We require the new number of partition keys to be equal or less than the old number
89-
// of partition keys here. In the case of less than, empty partitions will be added for
90-
// those missing keys that are not present in the new input partitions.
91-
if (oldRows.size < newRows.size) {
82+
val newPartitionValues = newPartitions.map(partition =>
83+
InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions))
84+
.toSet
85+
val oldPartitionValues = p.partitionValues
86+
.map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet
87+
// We require the new number of partition values to be equal or less than the old number
88+
// of partition values here. In the case of less than, empty partitions will be added for
89+
// those missing values that are not present in the new input partitions.
90+
if (oldPartitionValues.size < newPartitionValues.size) {
9291
throw new SparkException("During runtime filtering, data source must either report " +
93-
"the same number of partition keys, or a subset of partition keys from the " +
94-
s"original. Before: ${oldRows.size} partition keys. After: ${newRows.size} " +
95-
"partition keys")
92+
"the same number of partition values, or a subset of partition values from the " +
93+
s"original. Before: ${oldPartitionValues.size} partition values. " +
94+
s"After: ${newPartitionValues.size} partition values")
9695
}
9796

98-
if (!newRows.forall(oldRows.contains)) {
97+
if (!newPartitionValues.forall(oldPartitionValues.contains)) {
9998
throw new SparkException("During runtime filtering, data source must not report new " +
100-
"partition keys that are not present in the original partitioning.")
99+
"partition values that are not present in the original partitioning.")
101100
}
102101

103102
groupPartitions(newPartitions).get.map(_._2)
@@ -132,11 +131,13 @@ case class BatchScanExec(
132131

133132
outputPartitioning match {
134133
case p: KeyGroupedPartitioning =>
135-
val partitionMapping = finalPartitions.map(s =>
136-
s.head.asInstanceOf[HasPartitionKey].partitionKey() -> s).toMap
134+
val partitionMapping = finalPartitions.map(s => InternalRowComparableWrapper(
135+
s.head.asInstanceOf[HasPartitionKey], p.expressions) -> s)
136+
.toMap
137137
finalPartitions = p.partitionValues.map { partValue =>
138138
// Use empty partition for those partition values that are not present
139-
partitionMapping.getOrElse(partValue, Seq.empty)
139+
partitionMapping.getOrElse(
140+
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
140141
}
141142
case _ =>
142143
}

0 commit comments

Comments
 (0)