From 8efbc241c0d1636079e4dddbe1652adc5df78659 Mon Sep 17 00:00:00 2001 From: zizhao Date: Thu, 4 Jan 2024 15:34:30 +0200 Subject: [PATCH] fix invalid size for getByteBufferView of 2G block --- .../apache/spark/shuffle/ucx/UcxWorkerWrapper.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 66cefd76..f117003e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -123,8 +123,8 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i override def onSuccess(r: UcpRequest): Unit = { request.completed = true stats.endTime = System.nanoTime() - logDebug(s"Received rndv data of size: ${mem.size} for tag $i in " + - s"${stats.getElapsedTimeNs} ns " + + logDebug(s"Received rndv data of size: ${ucpAmData.getLength}" + + s" for tag $i in ${stats.getElapsedTimeNs} ns " + s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns") for (b <- 0 until numBlocks) { val blockSize = headerBuffer.getInt @@ -280,10 +280,11 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length - val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) + val msgSize = tagAndSizes + blocks.map(_.getSize).sum + val resultMemory = transport.hostBounceBufferMemoryPool.get(msgSize) .asInstanceOf[UcxBounceBufferMemoryBlock] val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, - resultMemory.size) + msgSize) resultBuffer.putInt(replyTag) var offset = 0 @@ -311,9 +312,9 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i val startTime = System.nanoTime() val req = connections(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, - resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { + resultMemory.address + tagAndSizes, msgSize - tagAndSizes, 0, new UcxCallback { override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + + logTrace(s"Sent ${blocks.length} blocks of size: ${msgSize} " + s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") transport.hostBounceBufferMemoryPool.put(resultMemory) }