Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client/Server use wakeup feature and their own thread #27

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
084fd6d
bind worker to thread
JeynmannZ Apr 25, 2023
761792d
use wakeup/yield in server/client worker
JeynmannZ Apr 25, 2023
87b19e0
use asyn
JeynmannZ Apr 25, 2023
c075505
use loop
JeynmannZ Apr 25, 2023
760a276
Merge branch 'NVIDIA:master' into bind_worker_wake_asyn
jeynmann Apr 26, 2023
fb0bb34
connect executors in setup thread
JeynmannZ May 4, 2023
2f44bc0
Merge branch 'con_in_setup' into bind_worker_wake_asyn_conn
JeynmannZ May 8, 2023
6b4d43e
yield in loop
JeynmannZ May 10, 2023
d6c1978
reverse client bind
JeynmannZ May 23, 2023
731d5d8
replace while with if in client progress
JeynmannZ May 31, 2023
8e49765
Merge branch 'bind_worker_wake_asyn_conl' into bind_worker_wake_asyn
JeynmannZ May 31, 2023
1279bc8
remove useless modifications
JeynmannZ May 31, 2023
a1e09ec
use sync send
JeynmannZ Jun 6, 2023
9267414
bind worker to thread
JeynmannZ Jun 12, 2023
f94359f
bind client worker
JeynmannZ Jun 13, 2023
4c3ba6e
use rr instead of blocked queue to select client
JeynmannZ Jun 13, 2023
d8c12f1
add progressBlocked for flush
JeynmannZ Jun 13, 2023
893ee5c
add asyn
JeynmannZ Jun 14, 2023
1546d73
add pre connection
JeynmannZ Jun 14, 2023
3f14725
submit req after fetch
JeynmannZ Jun 13, 2023
931fba1
bind client worker to fixed thread
JeynmannZ Jun 20, 2023
6a49fd6
temp
JeynmannZ Jun 25, 2023
e393aa3
split context
JeynmannZ Jun 26, 2023
7c47170
client use monitor
JeynmannZ Jun 26, 2023
cdc9992
rm monitor
JeynmannZ Jun 27, 2023
b04c225
add sync in worker thread
JeynmannZ Jun 30, 2023
8267cf7
connection monitor
JeynmannZ Jul 3, 2023
1b5ffd5
reduce synchronized region
JeynmannZ Jul 5, 2023
4b95547
rm use code
JeynmannZ Jul 17, 2023
a608bb1
reset RR code
JeynmannZ Jul 17, 2023
c55a642
add latch for perf bench
JeynmannZ Jul 17, 2023
bb17180
add thread pool + progress thread
JeynmannZ May 29, 2024
8418905
shuffle: simplify
JeynmannZ Jun 3, 2024
d9d315e
Merge branch 'master' into client_server_use_own_thread
JeynmannZ Jun 3, 2024
9a931e9
use awaitUcxTransport instead of sleep in UcxLocalDiskShuffleExecutor…
JeynmannZ Jun 3, 2024
2d1c464
simplify code
JeynmannZ Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,8 @@ class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

// Ucx shuffle logic
// Java reflection to get access to private results queue
val queueField = wrappedStreams.getClass.getDeclaredField(
"org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
queueField.setAccessible(true)
val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]]

// Do progress if queue is empty before calling next on ShuffleIterator
val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
override def next(): (BlockId, InputStream) = {
val startTime = System.currentTimeMillis()
while (resultQueue.isEmpty) {
transport.progress()
}
shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
wrappedStreams.next()
}

override def hasNext: Boolean = {
val result = wrappedStreams.hasNext
if (!result) {
shuffleClient.close()
}
result
}
}
// End of ucx shuffle logic

val serializerInstance = dep.serializer.newInstance()
val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf)

override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = {
val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
while (ucxShuffleManager.ucxTransport == null) {
Thread.sleep(5)
}
ucxShuffleManager.awaitUcxTransport()
blockResolver = ucxShuffleManager.shuffleBlockResolver
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma
}
val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size)
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks)
transport.progress()
}

override def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,40 +104,10 @@ private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],

val wrappedStreams = shuffleIterator.toCompletionIterator


// Ucx shuffle logic
// Java reflection to get access to private results queue
val queueField = shuffleIterator.getClass.getDeclaredField(
"org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
queueField.setAccessible(true)
val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]]

// Do progress if queue is empty before calling next on ShuffleIterator
val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
override def next(): (BlockId, InputStream) = {
val startTime = System.nanoTime()
while (resultQueue.isEmpty) {
transport.progress()
}
val fetchWaitTime = System.nanoTime() - startTime
readMetrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(fetchWaitTime))
wrappedStreams.next()
}

override def hasNext: Boolean = {
val result = wrappedStreams.hasNext
if (!result) {
shuffleClient.close()
}
result
}
}
// End of ucx shuffle logic

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,6 @@ trait ShuffleTransport {
*/
def unregister(blockId: BlockId): Unit

/**
* Batch version of [[ fetchBlocksByBlockIds ]].
*/
def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId],
resultBufferAllocator: BufferAllocator,
callbacks: Seq[OperationCallback]): Seq[Request]

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
* It's required to call this routine within same thread that submitted [[ fetchBlocksByBlockIds ]].
Expand All @@ -167,3 +160,7 @@ trait ShuffleTransport {
def progress(): Unit

}

class UcxFetchState(val callbacks: Seq[OperationCallback],
val request: UcxRequest,
val timestamp: Long) {}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool
import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread
import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.openucx.jucx.UcxException
import org.openucx.jucx.ucp._
import org.openucx.jucx.ucs.UcsConstants

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.concurrent.TrieMap
import scala.collection.mutable

Expand Down Expand Up @@ -53,7 +55,7 @@ class UcxStats extends OperationStats {
}

case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
override def serializedSize: Int = 12
override def serializedSize: Int = UcxShuffleBockId.serializedSize

override def serialize(byteBuffer: ByteBuffer): Unit = {
byteBuffer.putInt(shuffleId)
Expand All @@ -63,6 +65,8 @@ case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends B
}

object UcxShuffleBockId {
val serializedSize: Int = 12

def deserialize(byteBuffer: ByteBuffer): UcxShuffleBockId = {
val shuffleId = byteBuffer.getInt
val mapId = byteBuffer.getInt
Expand All @@ -88,6 +92,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
private var progressThread: Thread = _
var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _

private[spark] lazy val replyThreadPool = ThreadUtils.newForkJoinPool(
"UcxListenerThread", ucxShuffleConf.numListenerThreads)

private val errorHandler = new UcpEndpointErrorHandler {
override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = {
if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) {
Expand Down Expand Up @@ -126,7 +133,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers")
for (i <- 0 until ucxShuffleConf.numListenerThreads) {
val worker = ucxContext.newWorker(ucpWorkerParams)
allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false)
allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong)
}

val Array(host, port) = ucxShuffleConf.listenerAddress.split(":")
Expand All @@ -150,6 +157,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId)
}

allocatedServerWorkers.foreach(_.progressStart())
allocatedClientWorkers.foreach(_.progressStart())
initialized = true
logInfo(s"Started listener on ${listener.getAddress}")
SerializationUtils.serializeInetAddress(listener.getAddress)
Expand All @@ -166,7 +175,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
hostBounceBufferMemoryPool.close()

allocatedClientWorkers.foreach(_.close())
allocatedServerWorkers.foreach(_.close())

if (listener != null) {
listener.close()
Expand All @@ -183,6 +191,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
globalWorker = null
}

allocatedServerWorkers.foreach(_.close())

if (ucxContext != null) {
ucxContext.close()
ucxContext = null
Expand All @@ -196,10 +206,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
*/
override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = {
executorAddresses.put(executorId, workerAddress)
allocatedClientWorkers.foreach(w => {
w.getConnection(executorId)
w.progressConnect()
})
allocatedClientWorkers.foreach(_.getConnection(executorId))
}

def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = {
Expand Down Expand Up @@ -265,36 +272,35 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
/**
* Batch version of [[ fetchBlocksByBlockIds ]].
*/
override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId],
def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId],
resultBufferAllocator: BufferAllocator,
callbacks: Seq[OperationCallback]): Seq[Request] = {
allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt)
.fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks)
callbacks: Seq[OperationCallback]): Unit = {
selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds,
resultBufferAllocator, callbacks)
}

def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = {
allocatedServerWorkers.foreach(w => w.connectByWorkerAddress(executorId, workerAddress))
}

def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = {
val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt)
val blockIds = mutable.ArrayBuffer.empty[BlockId]

// 1. Deserialize blockIds from header
while (buffer.remaining() > 0) {
val blockId = UcxShuffleBockId.deserialize(buffer)
if (!registeredBlocks.contains(blockId)) {
throw new UcxException(s"$blockId is not registered")
def handleFetchBlockRequest(replyTag: Int, blockIds: Seq[BlockId],
replyExecutor: Long): Unit = {
replyThreadPool.submit(new Runnable {
override def run(): Unit = {
val blocks = blockIds.map(bid => registeredBlocks(bid))
selectServerWorker.handleFetchBlockRequest(blocks, replyTag,
replyExecutor)
}
blockIds += blockId
}

val blocks = blockIds.map(bid => registeredBlocks(bid))
amData.close()
allocatedServerWorkers((Thread.currentThread().getId % allocatedServerWorkers.length).toInt)
.handleFetchBlockRequest(blocks, replyTag, replyExecutor)
})
}

@inline
def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers(
(Thread.currentThread().getId % allocatedClientWorkers.length).toInt)

@inline
def selectServerWorker(): UcxWorkerWrapper = allocatedServerWorkers(
(Thread.currentThread().getId % allocatedServerWorkers.length).toInt)

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
Expand All @@ -304,10 +310,5 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
* But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed!
*/
override def progress(): Unit = {
allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress()
}

def progressConnect(): Unit = {
allocatedClientWorkers.par.foreach(_.progressConnect())
}
}
}
Loading
Loading