Skip to content

Commit d0ba458

Browse files
committed
fix UDAF fallback bug when handling DeclarativeAggregator (#1369)
enable UDAF fallback by default
1 parent c4d9d9e commit d0ba458

File tree

3 files changed

+37
-26
lines changed

3 files changed

+37
-26
lines changed

spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,15 @@ class AuronFunctionSuite
100100
checkAnswer(df, Seq(Row("11", "537061726B2053514C")))
101101
}
102102
}
103+
104+
test("stddev_samp function with UDAF fallback") {
105+
withSQLConf("spark.auron.udafFallback.enable" -> "true") {
106+
withTable("t1") {
107+
sql("create table t1(c1 double) using parquet")
108+
sql("insert into t1 values(10.0), (20.0), (30.0), (31.0), (null)")
109+
val df = sql("select stddev_samp(c1) from t1")
110+
checkAnswer(df, Seq(Row(9.844626283748239)))
111+
}
112+
}
113+
}
103114
}

spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public enum AuronConf {
4040
INPUT_BATCH_STATISTICS_ENABLE("spark.auron.enableInputBatchStatistics", true),
4141

4242
/// supports UDAF and other aggregate functions not implemented
43-
UDAF_FALLBACK_ENABLE("spark.auron.udafFallback.enable", false),
43+
UDAF_FALLBACK_ENABLE("spark.auron.udafFallback.enable", true),
4444

4545
// TypedImperativeAggregate one row mem use size
4646
SUGGESTED_UDAF_ROW_MEM_USAGE("spark.auron.suggested.udaf.memUsedSize", 64),

spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ import org.apache.spark.memory.MemoryMode
4141
import org.apache.spark.sql.auron.memory.OnHeapSpillManager
4242
import org.apache.spark.sql.auron.util.Using
4343
import org.apache.spark.sql.catalyst.InternalRow
44-
import org.apache.spark.sql.catalyst.expressions.Attribute
45-
import org.apache.spark.sql.catalyst.expressions.AttributeReference
46-
import org.apache.spark.sql.catalyst.expressions.JoinedRow
47-
import org.apache.spark.sql.catalyst.expressions.Nondeterministic
48-
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
49-
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
44+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, JoinedRow, Nondeterministic, UnsafeProjection, UnsafeRow}
5045
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
5146
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
5247
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
@@ -68,10 +63,6 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging {
6863
bytes
6964
})
7065

71-
val inputAttributes: Seq[Attribute] = javaParamsSchema.fields.map { field =>
72-
AttributeReference(field.name, field.dataType, field.nullable)()
73-
}
74-
7566
private val outputSchema = {
7667
val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable)))
7768
ArrowUtils.toArrowSchema(schema)
@@ -93,7 +84,7 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging {
9384
override def initialValue: AggregateEvaluator[B, BufferRowsColumn[B]] = {
9485
val evaluator = expr match {
9586
case declarative: DeclarativeAggregate =>
96-
new DeclarativeEvaluator(declarative, inputAttributes)
87+
new DeclarativeEvaluator(declarative)
9788
case imperative: TypedImperativeAggregate[B] =>
9889
new TypedImperativeEvaluator(imperative)
9990
}
@@ -243,7 +234,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] extends Logging {
243234
}
244235
}
245236

246-
class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: Seq[Attribute])
237+
class DeclarativeEvaluator(val agg: DeclarativeAggregate)
247238
extends AggregateEvaluator[UnsafeRow, DeclarativeAggRowsColumn] {
248239

249240
val initializedRow: UnsafeRow = {
@@ -252,8 +243,20 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: Seq[A
252243
}
253244
val releasedRow: UnsafeRow = null
254245

255-
val updater: UnsafeProjection =
256-
UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++ inputAttributes)
246+
val updater: UnsafeProjection = {
247+
val updateExpressions = agg.updateExpressions.map { expr =>
248+
expr.transform {
249+
case BoundReference(odin, dt, nullable) =>
250+
BoundReference(odin + agg.aggBufferAttributes.length, dt, nullable)
251+
case expr => expr
252+
}
253+
}
254+
UnsafeProjection.create(
255+
updateExpressions,
256+
agg.aggBufferAttributes ++ agg.children.map(c => {
257+
AttributeReference("", c.dataType, c.nullable)()
258+
}))
259+
}
257260

258261
val merger: UnsafeProjection = UnsafeProjection.create(
259262
agg.mergeExpressions,
@@ -322,7 +325,7 @@ case class DeclarativeAggRowsColumn(
322325

323326
override def resize(len: Int): Unit = {
324327
rows.appendAll((rows.length until len).map(_ => {
325-
val newRow = evaluator.initializedRow.copy()
328+
val newRow = evaluator.initializedRow
326329
rowsMemUsed += newRow.getSizeInBytes
327330
newRow
328331
}))
@@ -335,28 +338,25 @@ case class DeclarativeAggRowsColumn(
335338

336339
override def updateRow(i: Int, inputRow: InternalRow): Unit = {
337340
if (i == rows.length) {
338-
val newRow = evaluator.updater(evaluator.joiner(evaluator.initializedRow.copy(), inputRow))
339-
rowsMemUsed += newRow.getSizeInBytes
340-
rows.append(newRow)
341+
rows.append(evaluator.updater(evaluator.joiner(evaluator.initializedRow, inputRow)).copy())
341342
} else {
342343
rowsMemUsed -= rows(i).getSizeInBytes
343-
rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow))
344-
rowsMemUsed += rows(i).getSizeInBytes
344+
rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow)).copy()
345345
}
346+
rowsMemUsed += rows(i).getSizeInBytes
346347
}
347348

348349
override def mergeRow(i: Int, mergeRows: BufferRowsColumn[UnsafeRow], mergeIdx: Int): Unit = {
349350
mergeRows match {
350351
case mergeRows: DeclarativeAggRowsColumn =>
352+
val mergeRow = mergeRows.rows(mergeIdx)
351353
if (i == rows.length) {
352-
val newRow = mergeRows.rows(mergeIdx)
353-
rowsMemUsed += newRow.getSizeInBytes
354-
rows.append(newRow)
354+
rows.append(mergeRow)
355355
} else {
356356
rowsMemUsed -= rows(i).getSizeInBytes
357-
rows(i) = evaluator.merger(evaluator.joiner(rows(i), mergeRows.rows(mergeIdx)))
358-
rowsMemUsed += rows(i).getSizeInBytes
357+
rows(i) = evaluator.merger(evaluator.joiner(rows(i), mergeRow)).copy()
359358
}
359+
rowsMemUsed += rows(i).getSizeInBytes
360360
mergeRows.rows(mergeIdx) = evaluator.releasedRow
361361
}
362362
}

0 commit comments

Comments
 (0)