Skip to content

Commit

Permalink
select worker round robin
Browse files Browse the repository at this point in the history
  • Loading branch information
JeynmannZ committed Jul 17, 2023
1 parent c55a642 commit 0a3706a
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.openucx.jucx.ucs.UcsConstants

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.concurrent.TrieMap
import scala.collection.mutable

Expand Down Expand Up @@ -82,7 +83,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
val executorAddresses = new TrieMap[ExecutorId, ByteBuffer]

private var allocatedClientThreads: Array[UcxWorkerThread] = _
private val clientThreadId = new AtomicInteger()

private var allocatedServerThreads: Array[UcxWorkerThread] = _
private val serverThreadId = new AtomicInteger()

private val registeredBlocks = new TrieMap[BlockId, Block]
private var progressThread: Thread = _
Expand Down Expand Up @@ -313,11 +317,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo

@inline
def selectClientThread(): UcxWorkerThread = allocatedClientThreads(
(Thread.currentThread().getId % allocatedClientThreads.length).toInt)
(clientThreadId.incrementAndGet() % allocatedClientThreads.length).abs)

@inline
def selectServerThread(): UcxWorkerThread = allocatedServerThreads(
(Thread.currentThread().getId % allocatedServerThreads.length).toInt)
(serverThreadId.incrementAndGet() % allocatedServerThreads.length).abs)

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
Expand Down

0 comments on commit 0a3706a

Please sign in to comment.