Skip to content

Commit

Permalink
minimized the test
Browse files Browse the repository at this point in the history
  • Loading branch information
himadripal committed Jul 9, 2024
1 parent de05de8 commit 64598f3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 289 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1491,11 +1491,11 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "SPJ should be triggered")

val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
val partions = collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
partitions.length)
val expectedBuckets = Math.min(table1buckets1, table2buckets1) *
Math.min(table1buckets2, table2buckets2)
assert(scans == Seq(expectedBuckets, expectedBuckets))
assert(partions == Seq(expectedBuckets, expectedBuckets))

checkAnswer(df, Seq(
Row(0, 0, "aa", "aa"),
Expand Down Expand Up @@ -1721,237 +1721,30 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}


test("SPARK-47094: SPJ: Support compatible buckets common divisor is one of the numbers") {
val table1 = "tab1e1"
val table2 = "table2"

Seq(
((4, 8), (8, 4)),
((3, 9), (9, 3)),
).foreach {
case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) =>
catalog.clearTables()

val partition1 = Array(bucket(table1buckets1, "store_id"),
bucket(table1buckets2, "dept_id"))
val partition2 = Array(bucket(table2buckets1, "store_id"),
bucket(table2buckets2, "dept_id"))

Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) =>
createTable(tab, columns2, part)
val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
"(0, 0, 'aa'), " +
"(0, 0, 'ab'), " + // duplicate partition key
"(0, 1, 'ac'), " +
"(0, 2, 'ad'), " +
"(0, 3, 'ae'), " +
"(0, 4, 'af'), " +
"(0, 5, 'ag'), " +
"(1, 0, 'ah'), " +
"(1, 0, 'ai'), " + // duplicate partition key
"(1, 1, 'aj'), " +
"(1, 2, 'ak'), " +
"(1, 3, 'al'), " +
"(1, 4, 'am'), " +
"(1, 5, 'an'), " +
"(2, 0, 'ao'), " +
"(2, 0, 'ap'), " + // duplicate partition key
"(2, 1, 'aq'), " +
"(2, 2, 'ar'), " +
"(2, 3, 'as'), " +
"(2, 4, 'at'), " +
"(2, 5, 'au'), " +
"(3, 0, 'av'), " +
"(3, 0, 'aw'), " + // duplicate partition key
"(3, 1, 'ax'), " +
"(3, 2, 'ay'), " +
"(3, 3, 'az'), " +
"(3, 4, 'ba'), " +
"(3, 5, 'bb'), " +
"(4, 0, 'bc'), " +
"(4, 0, 'bd'), " + // duplicate partition key
"(4, 1, 'be'), " +
"(4, 2, 'bf'), " +
"(4, 3, 'bg'), " +
"(4, 4, 'bh'), " +
"(4, 5, 'bi'), " +
"(5, 0, 'bj'), " +
"(5, 0, 'bk'), " + // duplicate partition key
"(5, 1, 'bl'), " +
"(5, 2, 'bm'), " +
"(5, 3, 'bn'), " +
"(5, 4, 'bo'), " +
"(5, 5, 'bp')"

// additional unmatched partitions to test push down
val finalStr = if (tab == table1) {
insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
} else {
insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
}

sql(finalStr)
}

Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
allowJoinKeysSubsetOfPartitionKeys.toString,
SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
val df = sql(
s"""
|${selectWithMergeJoinHint("t1", "t2")}
|t1.store_id, t1.dept_id, t1.data, t2.data
|FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
|ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id
|ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data
|""".stripMargin)

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "SPJ should be triggered")

val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
partitions.length)

def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt
val expectedPartitions = gcd(table1buckets1, table2buckets1) *
gcd(table1buckets2, table2buckets2)
assert(scans == Seq(expectedPartitions, expectedPartitions))

checkAnswer(df, Seq(
Row(0, 0, "aa", "aa"),
Row(0, 0, "aa", "ab"),
Row(0, 0, "ab", "aa"),
Row(0, 0, "ab", "ab"),
Row(0, 1, "ac", "ac"),
Row(0, 2, "ad", "ad"),
Row(0, 3, "ae", "ae"),
Row(0, 4, "af", "af"),
Row(0, 5, "ag", "ag"),
Row(1, 0, "ah", "ah"),
Row(1, 0, "ah", "ai"),
Row(1, 0, "ai", "ah"),
Row(1, 0, "ai", "ai"),
Row(1, 1, "aj", "aj"),
Row(1, 2, "ak", "ak"),
Row(1, 3, "al", "al"),
Row(1, 4, "am", "am"),
Row(1, 5, "an", "an"),
Row(2, 0, "ao", "ao"),
Row(2, 0, "ao", "ap"),
Row(2, 0, "ap", "ao"),
Row(2, 0, "ap", "ap"),
Row(2, 1, "aq", "aq"),
Row(2, 2, "ar", "ar"),
Row(2, 3, "as", "as"),
Row(2, 4, "at", "at"),
Row(2, 5, "au", "au"),
Row(3, 0, "av", "av"),
Row(3, 0, "av", "aw"),
Row(3, 0, "aw", "av"),
Row(3, 0, "aw", "aw"),
Row(3, 1, "ax", "ax"),
Row(3, 2, "ay", "ay"),
Row(3, 3, "az", "az"),
Row(3, 4, "ba", "ba"),
Row(3, 5, "bb", "bb"),
Row(4, 0, "bc", "bc"),
Row(4, 0, "bc", "bd"),
Row(4, 0, "bd", "bc"),
Row(4, 0, "bd", "bd"),
Row(4, 1, "be", "be"),
Row(4, 2, "bf", "bf"),
Row(4, 3, "bg", "bg"),
Row(4, 4, "bh", "bh"),
Row(4, 5, "bi", "bi"),
Row(5, 0, "bj", "bj"),
Row(5, 0, "bj", "bk"),
Row(5, 0, "bk", "bj"),
Row(5, 0, "bk", "bk"),
Row(5, 1, "bl", "bl"),
Row(5, 2, "bm", "bm"),
Row(5, 3, "bn", "bn"),
Row(5, 4, "bo", "bo"),
Row(5, 5, "bp", "bp")
))
}
}
}
}

test("SPARK-47094: SPJ: Does not trigger when incompatible number of buckets on both side") {
val table1 = "tab1e1"
val table2 = "table2"

Seq(
((3, 5), (5, 3)),
((5, 7), (7, 5)),
(2, 3),
(3, 4)
).foreach {
case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) =>
case (table1buckets1, table2buckets1) =>
catalog.clearTables()

val partition1 = Array(bucket(table1buckets1, "store_id"),
bucket(table1buckets2, "dept_id"))
val partition2 = Array(bucket(table2buckets1, "store_id"),
bucket(table2buckets2, "dept_id"))
val partition1 = Array(bucket(table1buckets1, "store_id"))
val partition2 = Array(bucket(table2buckets1, "store_id"))

Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) =>
createTable(tab, columns2, part)
val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " +
"(0, 0, 'aa'), " +
"(0, 0, 'ab'), " + // duplicate partition key
"(0, 1, 'ac'), " +
"(0, 2, 'ad'), " +
"(0, 3, 'ae'), " +
"(0, 4, 'af'), " +
"(0, 5, 'ag'), " +
"(1, 0, 'ah'), " +
"(1, 0, 'ai'), " + // duplicate partition key
"(1, 1, 'aj'), " +
"(1, 2, 'ak'), " +
"(1, 3, 'al'), " +
"(1, 4, 'am'), " +
"(1, 5, 'an'), " +
"(2, 0, 'ao'), " +
"(2, 0, 'ap'), " + // duplicate partition key
"(2, 1, 'aq'), " +
"(2, 2, 'ar'), " +
"(2, 3, 'as'), " +
"(2, 4, 'at'), " +
"(2, 5, 'au'), " +
"(3, 0, 'av'), " +
"(3, 0, 'aw'), " + // duplicate partition key
"(3, 1, 'ax'), " +
"(3, 2, 'ay'), " +
"(3, 3, 'az'), " +
"(3, 4, 'ba'), " +
"(3, 5, 'bb'), " +
"(4, 0, 'bc'), " +
"(4, 0, 'bd'), " + // duplicate partition key
"(4, 1, 'be'), " +
"(4, 2, 'bf'), " +
"(4, 3, 'bg'), " +
"(4, 4, 'bh'), " +
"(4, 5, 'bi'), " +
"(5, 0, 'bj'), " +
"(5, 0, 'bk'), " + // duplicate partition key
"(5, 1, 'bl'), " +
"(5, 2, 'bm'), " +
"(5, 3, 'bn'), " +
"(5, 4, 'bo'), " +
"(5, 5, 'bp')"
"(1, 0, 'ab'), " + // duplicate partition key
"(2, 2, 'ac'), " +
"(3, 3, 'ad'), " +
"(4, 2, 'bc') "

// additional unmatched partitions to test push down
val finalStr = if (tab == table1) {
insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')"
} else {
insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')"
}

sql(finalStr)
sql(insertStr)
}

Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys =>
Expand All @@ -1973,75 +1766,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.nonEmpty, "SPJ should not be triggered")

val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD.
partitions.length)

def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt
val gcdOfNumBuckets = gcd(table1buckets1, table2buckets1) *
gcd(table1buckets2, table2buckets2)

scans.foreach { partition =>
assert(partition != gcdOfNumBuckets
&& partition != table1buckets1 && partition != table2buckets1)
}

checkAnswer(df, Seq(
Row(0, 0, "aa", "aa"),
Row(0, 0, "aa", "ab"),
Row(0, 0, "ab", "aa"),
Row(0, 0, "ab", "ab"),
Row(0, 1, "ac", "ac"),
Row(0, 2, "ad", "ad"),
Row(0, 3, "ae", "ae"),
Row(0, 4, "af", "af"),
Row(0, 5, "ag", "ag"),
Row(1, 0, "ah", "ah"),
Row(1, 0, "ah", "ai"),
Row(1, 0, "ai", "ah"),
Row(1, 0, "ai", "ai"),
Row(1, 1, "aj", "aj"),
Row(1, 2, "ak", "ak"),
Row(1, 3, "al", "al"),
Row(1, 4, "am", "am"),
Row(1, 5, "an", "an"),
Row(2, 0, "ao", "ao"),
Row(2, 0, "ao", "ap"),
Row(2, 0, "ap", "ao"),
Row(2, 0, "ap", "ap"),
Row(2, 1, "aq", "aq"),
Row(2, 2, "ar", "ar"),
Row(2, 3, "as", "as"),
Row(2, 4, "at", "at"),
Row(2, 5, "au", "au"),
Row(3, 0, "av", "av"),
Row(3, 0, "av", "aw"),
Row(3, 0, "aw", "av"),
Row(3, 0, "aw", "aw"),
Row(3, 1, "ax", "ax"),
Row(3, 2, "ay", "ay"),
Row(3, 3, "az", "az"),
Row(3, 4, "ba", "ba"),
Row(3, 5, "bb", "bb"),
Row(4, 0, "bc", "bc"),
Row(4, 0, "bc", "bd"),
Row(4, 0, "bd", "bc"),
Row(4, 0, "bd", "bd"),
Row(4, 1, "be", "be"),
Row(4, 2, "bf", "bf"),
Row(4, 3, "bg", "bg"),
Row(4, 4, "bh", "bh"),
Row(4, 5, "bi", "bi"),
Row(5, 0, "bj", "bj"),
Row(5, 0, "bj", "bk"),
Row(5, 0, "bk", "bj"),
Row(5, 0, "bk", "bk"),
Row(5, 1, "bl", "bl"),
Row(5, 2, "bm", "bm"),
Row(5, 3, "bn", "bn"),
Row(5, 4, "bo", "bo"),
Row(5, 5, "bp", "bp")
))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In

if (otherFunc == BucketFunction) {
val gcd = this.gcd(thisNumBuckets, otherNumBuckets)
if (gcd>1 && gcd!=thisNumBuckets) {
if (gcd > 1 && gcd != thisNumBuckets) {
return BucketReducer(gcd)
}
}
Expand Down

0 comments on commit 64598f3

Please sign in to comment.