Skip to content

Commit 91f2a25

Browse files
Matt Hawesmaropu
Matt Hawes
andcommitted
[SPARK-28818][SQL][2.4] Respect source column nullability in the arrays created by freqItems()
### What changes were proposed in this pull request? This PR replaces the hard-coded non-nullability of the array elements returned by `freqItems()` with a nullability that reflects the original schema. Essentially [the functional change](https://github.com/apache/spark/pull/25575/files#diff-bf59bb9f3dc351f5bf6624e5edd2dcf4R122) to the schema generation is: ``` StructField(name + "_freqItems", ArrayType(dataType, false)) ``` Becomes: ``` StructField(name + "_freqItems", ArrayType(dataType, originalField.nullable)) ``` Respecting the original nullability prevents issues when Spark depends on `ArrayType`'s `containsNull` being accurate. The example that uncovered this is calling `collect()` on the dataframe (see [ticket](https://issues.apache.org/jira/browse/SPARK-28818) for full repro). Though it's likely that there a several places where this could cause a problem. I've also refactored a small amount of the surrounding code to remove some unnecessary steps and group together related operations. Note: This is the backport PR of #25575 and the credit should be MGHawes. ### Why are the changes needed? I think it's pretty clear why this change is needed. It fixes a bug that currently prevents users from calling `df.freqItems.collect()` along with potentially causing other, as yet unknown, issues. ### Does this PR introduce any user-facing change? Nullability of columns when calling freqItems on them is now respected after the change. ### How was this patch tested? I added a test that specifically tests the carry-through of the nullability as well as explicitly calling `collect()` to catch the exact regression that was observed. I also ran the test against the old version of the code and it fails as expected. Closes #29327 from maropu/SPARK-28818-2.4. Lead-authored-by: Matt Hawes <[email protected]> Co-authored-by: Takeshi Yamamuro <[email protected]> Signed-off-by: Takeshi Yamamuro <[email protected]>
1 parent 4a8f692 commit 91f2a25

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala

+10-9
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,6 @@ object FrequentItems extends Logging {
8989
// number of max items to keep counts for
9090
val sizeOfMap = (1 / support).toInt
9191
val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
92-
val originalSchema = df.schema
93-
val colInfo: Array[(String, DataType)] = cols.map { name =>
94-
val index = originalSchema.fieldIndex(name)
95-
(name, originalSchema.fields(index).dataType)
96-
}.toArray
9792

9893
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
9994
seqOp = (counts, row) => {
@@ -117,10 +112,16 @@ object FrequentItems extends Logging {
117112
)
118113
val justItems = freqItems.map(m => m.baseMap.keys.toArray)
119114
val resultRow = Row(justItems : _*)
120-
// append frequent Items to the column name for easy debugging
121-
val outputCols = colInfo.map { v =>
122-
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
123-
}
115+
116+
val originalSchema = df.schema
117+
val outputCols = cols.map { name =>
118+
val index = originalSchema.fieldIndex(name)
119+
val originalField = originalSchema.fields(index)
120+
121+
// append frequent Items to the column name for easy debugging
122+
StructField(name + "_freqItems", ArrayType(originalField.dataType, originalField.nullable))
123+
}.toArray
124+
124125
val schema = StructType(outputCols).toAttributes
125126
Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
126127
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

+25-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
2626
import org.apache.spark.sql.functions.col
2727
import org.apache.spark.sql.internal.SQLConf
2828
import org.apache.spark.sql.test.SharedSQLContext
29-
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
29+
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
3030

3131
class DataFrameStatSuite extends QueryTest with SharedSQLContext {
3232
import testImplicits._
@@ -366,6 +366,30 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
366366
}
367367
}
368368

369+
test("SPARK-28818: Respect original column nullability in `freqItems`") {
370+
val rows = spark.sparkContext.parallelize(
371+
Seq(Row("1", "a"), Row("2", null), Row("3", "b"))
372+
)
373+
val schema = StructType(Seq(
374+
StructField("non_null", StringType, false),
375+
StructField("nullable", StringType, true)
376+
))
377+
val df = spark.createDataFrame(rows, schema)
378+
379+
val result = df.stat.freqItems(df.columns)
380+
381+
val nonNullableDataType = result.schema("non_null_freqItems").dataType.asInstanceOf[ArrayType]
382+
val nullableDataType = result.schema("nullable_freqItems").dataType.asInstanceOf[ArrayType]
383+
384+
assert(nonNullableDataType.containsNull == false)
385+
assert(nullableDataType.containsNull == true)
386+
// Original bug was a NullPointerException exception caused by calling collect(), test for this
387+
val resultRow = result.collect()(0)
388+
389+
assert(resultRow.get(0).asInstanceOf[Seq[String]].toSet == Set("1", "2", "3"))
390+
assert(resultRow.get(1).asInstanceOf[Seq[String]].toSet == Set("a", "b", null))
391+
}
392+
369393
test("sampleBy") {
370394
val df = spark.range(0, 100).select((col("id") % 3).as("key"))
371395
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)

0 commit comments

Comments
 (0)