Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Oct 24, 2024
1 parent 136e69a commit efb351a
Showing 1 changed file with 13 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.spark.sql.rapids.aggregate
import scala.collection.immutable.Seq

import ai.rapids.cudf
import ai.rapids.cudf.{ColumnVector, DType, GroupByAggregation, ReductionAggregation}
import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation}
import com.nvidia.spark.rapids._
//import com.nvidia.spark.rapids.Arm.withResourceIfAllowed
//import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression
//import com.nvidia.spark.rapids.jni.HLL
import com.nvidia.spark.rapids.Arm.withResourceIfAllowed
import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression
import com.nvidia.spark.rapids.jni.HLL
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
Expand Down Expand Up @@ -57,7 +57,7 @@ case class CudfMergeHLLPP(override val dataType: DataType,
* Perform the final evaluation step to compute approximate count distinct from sketches.
* Input is long columns, first construct struct of long then feed to cuDF
*/
case class GpuHyperLogLogPlusPlusEvaluation(children: Seq[Expression],
case class GpuHyperLogLogPlusPlusEvaluation(childExpr: Expression,
numRegistersPerSketch: Int)
extends GpuExpression with ShimExpression {
override def dataType: DataType = LongType
Expand All @@ -66,19 +66,14 @@ case class GpuHyperLogLogPlusPlusEvaluation(children: Seq[Expression],

override def nullable: Boolean = false

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
override def children: scala.Seq[Expression] = Seq(childExpr)

// TODO
// withResourceIfAllowed(childExpr.columnarEval(batch)) { sketches =>
// val distinctValues = HLL.estimateDistinctValueFromSketches(
// sketches.getBase, numRegistersPerSketch)
// GpuColumnVector.from(distinctValues, LongType)
// }

val numRows = batch.numRows()
val zeros = Array.fill(numRows)(0L)
val longCv = ColumnVector.fromLongs(zeros: _*)
GpuColumnVector.from(longCv, LongType)
override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
withResourceIfAllowed(childExpr.columnarEval(batch)) { sketches =>
val distinctValues = HLL.estimateDistinctValueFromSketches(
sketches.getBase, numRegistersPerSketch)
GpuColumnVector.from(distinctValues, LongType)
}
}
}

Expand Down Expand Up @@ -160,7 +155,7 @@ case class GpuHyperLogLogPlusPlus(childExpr: Expression, relativeSD: Double)
}

override lazy val evaluateExpression: Expression =
GpuHyperLogLogPlusPlusEvaluation(aggBufferAttributes, numRegistersPerSketch)
GpuHyperLogLogPlusPlusEvaluation(genStruct.head, numRegistersPerSketch)

override def dataType: DataType = LongType

Expand Down

0 comments on commit efb351a

Please sign in to comment.