Skip to content

Commit

Permalink
remove dep for Utils.bytesToString
Browse files Browse the repository at this point in the history
  • Loading branch information
JeynmannZ committed Dec 20, 2023
1 parent bb4eeb5 commit 295bc2c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque}

import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory}
import org.openucx.jucx.ucs.UcsConstants
import org.apache.spark.shuffle.utils.UcxLogging
import org.apache.spark.shuffle.utils.{UcxUtils, UcxLogging}
import org.apache.spark.shuffle.ucx.MemoryBlock
import org.apache.spark.util.Utils

class UcxBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger,
private[ucx] val memPool: MemoryPool,
Expand Down Expand Up @@ -85,7 +84,7 @@ case class MemoryPool(ucxContext: UcpContext, memoryType: Int)

private[memory] def preallocate(numBuffers: Int): Unit = {
logTrace(s"PreAllocating $numBuffers of size $length, " +
s"totalSize: ${Utils.bytesToString(length * numBuffers)}.")
s"totalSize: ${UcxUtils.bytesToString(length * numBuffers)}.")
val memory = ucxContext.memoryMap(
new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length * numBuffers))
val refCount = new AtomicInteger(numBuffers)
Expand All @@ -106,7 +105,7 @@ case class MemoryPool(ucxContext: UcpContext, memoryType: Int)
numBuffers += 1
})
logInfo(s"Closing $numBuffers buffers of size $length." +
s"Number of allocations: ${numAllocs.get()}. Total size: ${Utils.bytesToString(length * numBuffers)}")
s"Number of allocations: ${numAllocs.get()}. Total size: ${UcxUtils.bytesToString(length * numBuffers)}")
stack.clear()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,13 @@
*/
package org.apache.spark.shuffle.ucx.utils

import java.io.{IOException, EOFException, ObjectInputStream, ObjectOutputStream}
import java.io.{EOFException, ObjectInputStream, ObjectOutputStream}
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
import scala.util.control.{ControlThrowable, NonFatal}

import org.apache.spark.shuffle.utils.UcxLogging

object UcxUtils extends UcxLogging {
def tryOrIOException[T](block: => T): T = {
try {
block
} catch {
case e: IOException =>
logError("Exception encountered", e)
throw e
case NonFatal(e) =>
logError("Exception encountered", e)
throw new IOException(e)
}
}
}
import org.apache.spark.shuffle.utils.{UcxUtils, UcxLogging}

/**
* A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make
Expand Down
60 changes: 60 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/utils/UcxUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.utils

import java.io.IOException
import java.math.{MathContext, RoundingMode}
import java.util.Locale
import scala.util.control.NonFatal

object UcxUtils extends UcxLogging {
def tryOrIOException[T](block: => T): T = {
try {
block
} catch {
case e: IOException =>
logError("Exception encountered", e)
throw e
case NonFatal(e) =>
logError("Exception encountered", e)
throw new IOException(e)
}
}

def bytesToString(size: Long): String = bytesToString(BigInt(size))

def bytesToString(size: BigInt): String = {
val EB = 1L << 60
val PB = 1L << 50
val TB = 1L << 40
val GB = 1L << 30
val MB = 1L << 20
val KB = 1L << 10

if (size >= BigInt(1L << 11) * EB) {
// The number is too large, show it in scientific notation.
BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B"
} else {
val (value, unit) = {
if (size >= 2 * EB) {
(BigDecimal(size) / EB, "EB")
} else if (size >= 2 * PB) {
(BigDecimal(size) / PB, "PB")
} else if (size >= 2 * TB) {
(BigDecimal(size) / TB, "TB")
} else if (size >= 2 * GB) {
(BigDecimal(size) / GB, "GB")
} else if (size >= 2 * MB) {
(BigDecimal(size) / MB, "MB")
} else if (size >= 2 * KB) {
(BigDecimal(size) / KB, "KB")
} else {
(BigDecimal(size), "B")
}
}
"%.1f %s".formatLocal(Locale.US, value, unit)
}
}
}

0 comments on commit 295bc2c

Please sign in to comment.