@@ -41,12 +41,7 @@ import org.apache.spark.memory.MemoryMode
4141import org .apache .spark .sql .auron .memory .OnHeapSpillManager
4242import org .apache .spark .sql .auron .util .Using
4343import org .apache .spark .sql .catalyst .InternalRow
44- import org .apache .spark .sql .catalyst .expressions .Attribute
45- import org .apache .spark .sql .catalyst .expressions .AttributeReference
46- import org .apache .spark .sql .catalyst .expressions .JoinedRow
47- import org .apache .spark .sql .catalyst .expressions .Nondeterministic
48- import org .apache .spark .sql .catalyst .expressions .UnsafeProjection
49- import org .apache .spark .sql .catalyst .expressions .UnsafeRow
44+ import org .apache .spark .sql .catalyst .expressions .{AttributeReference , BoundReference , JoinedRow , Nondeterministic , UnsafeProjection , UnsafeRow }
5045import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateFunction
5146import org .apache .spark .sql .catalyst .expressions .aggregate .DeclarativeAggregate
5247import org .apache .spark .sql .catalyst .expressions .aggregate .TypedImperativeAggregate
@@ -68,10 +63,6 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging {
6863 bytes
6964 })
7065
71- val inputAttributes : Seq [Attribute ] = javaParamsSchema.fields.map { field =>
72- AttributeReference (field.name, field.dataType, field.nullable)()
73- }
74-
7566 private val outputSchema = {
7667 val schema = StructType (Seq (StructField (" " , expr.dataType, expr.nullable)))
7768 ArrowUtils .toArrowSchema(schema)
@@ -93,7 +84,7 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging {
9384 override def initialValue : AggregateEvaluator [B , BufferRowsColumn [B ]] = {
9485 val evaluator = expr match {
9586 case declarative : DeclarativeAggregate =>
96- new DeclarativeEvaluator (declarative, inputAttributes )
87+ new DeclarativeEvaluator (declarative)
9788 case imperative : TypedImperativeAggregate [B ] =>
9889 new TypedImperativeEvaluator (imperative)
9990 }
@@ -243,7 +234,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] extends Logging {
243234 }
244235}
245236
246- class DeclarativeEvaluator (val agg : DeclarativeAggregate , inputAttributes : Seq [ Attribute ] )
237+ class DeclarativeEvaluator (val agg : DeclarativeAggregate )
247238 extends AggregateEvaluator [UnsafeRow , DeclarativeAggRowsColumn ] {
248239
249240 val initializedRow : UnsafeRow = {
@@ -252,8 +243,20 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: Seq[A
252243 }
253244 val releasedRow : UnsafeRow = null
254245
255- val updater : UnsafeProjection =
256- UnsafeProjection .create(agg.updateExpressions, agg.aggBufferAttributes ++ inputAttributes)
246+ val updater : UnsafeProjection = {
247+ val updateExpressions = agg.updateExpressions.map { expr =>
248+ expr.transform {
249+ case BoundReference (odin, dt, nullable) =>
250+ BoundReference (odin + agg.aggBufferAttributes.length, dt, nullable)
251+ case expr => expr
252+ }
253+ }
254+ UnsafeProjection .create(
255+ updateExpressions,
256+ agg.aggBufferAttributes ++ agg.children.map(c => {
257+ AttributeReference (" " , c.dataType, c.nullable)()
258+ }))
259+ }
257260
258261 val merger : UnsafeProjection = UnsafeProjection .create(
259262 agg.mergeExpressions,
@@ -322,7 +325,7 @@ case class DeclarativeAggRowsColumn(
322325
323326 override def resize (len : Int ): Unit = {
324327 rows.appendAll((rows.length until len).map(_ => {
325- val newRow = evaluator.initializedRow.copy()
328+ val newRow = evaluator.initializedRow
326329 rowsMemUsed += newRow.getSizeInBytes
327330 newRow
328331 }))
@@ -335,28 +338,25 @@ case class DeclarativeAggRowsColumn(
335338
336339 override def updateRow (i : Int , inputRow : InternalRow ): Unit = {
337340 if (i == rows.length) {
338- val newRow = evaluator.updater(evaluator.joiner(evaluator.initializedRow.copy(), inputRow))
339- rowsMemUsed += newRow.getSizeInBytes
340- rows.append(newRow)
341+ rows.append(evaluator.updater(evaluator.joiner(evaluator.initializedRow, inputRow)).copy())
341342 } else {
342343 rowsMemUsed -= rows(i).getSizeInBytes
343- rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow))
344- rowsMemUsed += rows(i).getSizeInBytes
344+ rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow)).copy()
345345 }
346+ rowsMemUsed += rows(i).getSizeInBytes
346347 }
347348
348349 override def mergeRow (i : Int , mergeRows : BufferRowsColumn [UnsafeRow ], mergeIdx : Int ): Unit = {
349350 mergeRows match {
350351 case mergeRows : DeclarativeAggRowsColumn =>
352+ val mergeRow = mergeRows.rows(mergeIdx)
351353 if (i == rows.length) {
352- val newRow = mergeRows.rows(mergeIdx)
353- rowsMemUsed += newRow.getSizeInBytes
354- rows.append(newRow)
354+ rows.append(mergeRow)
355355 } else {
356356 rowsMemUsed -= rows(i).getSizeInBytes
357- rows(i) = evaluator.merger(evaluator.joiner(rows(i), mergeRows.rows(mergeIdx)))
358- rowsMemUsed += rows(i).getSizeInBytes
357+ rows(i) = evaluator.merger(evaluator.joiner(rows(i), mergeRow)).copy()
359358 }
359+ rowsMemUsed += rows(i).getSizeInBytes
360360 mergeRows.rows(mergeIdx) = evaluator.releasedRow
361361 }
362362 }
0 commit comments