diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala index d1d6ade1806..7aa6dfecb5e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala @@ -291,10 +291,22 @@ private class GpuColumnarBatchSerializerInstance(metrics: Map[String, GpuMetric] */ class SerializedTableColumn( val header: SerializedTableHeader, - val hostBuffer: HostMemoryBuffer) extends GpuColumnVectorBase(NullType) with SizeProvider { + val shb: SpillableHostBuffer) extends GpuColumnVectorBase(NullType) with SizeProvider { override def close(): Unit = { - if (hostBuffer != null) { - hostBuffer.close() + if (shb != null) { + shb.close() + } + } + + /** + * Get the host buffer containing the serialized table data. + * @return a host buffer needs caller to close + */ + def getHostBuffer: HostMemoryBuffer = { + if (shb == null) { + null + } else { + shb.getHostBuffer() } } @@ -302,10 +314,10 @@ class SerializedTableColumn( override def numNulls(): Int = throw new IllegalStateException("should not be called") - override def sizeInBytes: Long = if (hostBuffer == null) { + override def sizeInBytes: Long = if (shb == null) { 0L } else { - hostBuffer.getLength + shb.length } } @@ -321,7 +333,14 @@ object SerializedTableColumn { def from( header: SerializedTableHeader, hostBuffer: HostMemoryBuffer = null): ColumnarBatch = { - val column = new SerializedTableColumn(header, hostBuffer) + val column = + if (hostBuffer == null) { + new SerializedTableColumn(header, null) + } else { + new SerializedTableColumn(header, + SpillableHostBuffer.apply(hostBuffer, hostBuffer.getLength, + SpillPriorities.ACTIVE_BATCHING_PRIORITY)) + } new ColumnarBatch(Array(column), header.getNumRows) } @@ -331,7 +350,7 @@ object SerializedTableColumn { val cv = batch.column(0) cv match { case serializedTableColumn: SerializedTableColumn => - sum += Option(serializedTableColumn.hostBuffer).map(_.getLength).getOrElse(0L) + sum += Option(serializedTableColumn.shb).map(_.length).getOrElse(0L) case kudo: KudoSerializedTableColumn => sum += Option(kudo.spillableKudoTable).map(_.length).getOrElse(0L) case _ => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala index 5169bc8c08f..98a166793a9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala @@ -18,10 +18,11 @@ package com.nvidia.spark.rapids import java.util +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange} -import ai.rapids.cudf.JCudfSerialization.HostConcatResult +import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange} +import ai.rapids.cudf.JCudfSerialization.{HostConcatResult, SerializedTableHeader} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.FileUtils.createTempFile import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -156,8 +157,8 @@ object GpuShuffleCoalesceUtils { cb.column(0) match { case col: KudoSerializedTableColumn => col.spillableKudoTable.length case serCol: SerializedTableColumn => { - val hmb = serCol.hostBuffer - if (hmb != null) hmb.getLength else 0L + val hmb = serCol.shb + if (hmb != null) hmb.length else 0L } case o => throw new IllegalStateException(s"unsupported type: ${o.getClass}") } @@ -211,8 +212,14 @@ class JCudfTableOperator extends SerializedTableOperator[SerializedTableColumn] val totalRowsNum = tables.map(getNumRows).sum cudf_utils.HostConcatResultUtil.rowsOnlyHostConcatResult(totalRowsNum) } else { - val (headers, buffers) = tables.map(t => (t.header, t.hostBuffer)).unzip - JCudfSerialization.concatToHostBuffer(headers, buffers) + withResource(new ArrayBuffer[HostMemoryBuffer]) { buffers => + val headers = new ArrayBuffer[SerializedTableHeader] + tables.foreach(t => { + headers.append(t.header) + buffers.append(t.getHostBuffer) + }) + JCudfSerialization.concatToHostBuffer(headers.toArray, buffers.toArray) + } } new JCudfCoalescedHostResult(ret) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala index 9041c480c66..00196cc951a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala @@ -1168,8 +1168,7 @@ object SpillableHostConcatResult { new KudoSpillableHostConcatResult(oldKudoTable.header, buffer) } case col: SerializedTableColumn => - val buffer = col.hostBuffer - buffer.incRefCount() + val buffer = col.getHostBuffer new CudfSpillableHostConcatResult(col.header, buffer) case c => throw new IllegalStateException(s"Expected SerializedTableColumn, got ${c.getClass}")