Skip to content

Commit 04bda24

Browse files
authored
Support multiple files to perf benchmark (#19)
Signed-off-by: Huaxiang Fan <[email protected]>
1 parent deeb9c4 commit 04bda24

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import java.io.{File, RandomAccessFile}
88
import java.net.InetSocketAddress
99
import java.nio.ByteBuffer
1010
import java.nio.charset.StandardCharsets
11+
import java.nio.channels.FileChannel
1112
import java.util.concurrent.atomic.AtomicInteger
1213

1314
import org.apache.commons.cli.{GnuParser, HelpFormatter, Options}
@@ -21,7 +22,7 @@ import org.apache.spark.util.ShutdownHookManager
2122
object UcxPerfBenchmark extends App with Logging {
2223

2324
case class PerfOptions(remoteAddress: InetSocketAddress, numBlocks: Int, blockSize: Long,
24-
numIterations: Int, file: File, numOutstanding: Int)
25+
numIterations: Int, files: Array[File], numOutstanding: Int)
2526

2627
private val HELP_OPTION = "h"
2728
private val ADDRESS_OPTION = "a"
@@ -47,7 +48,7 @@ object UcxPerfBenchmark extends App with Logging {
4748
"number of iterations. Default: 1")
4849
options.addOption(OUTSTANDING_OPTION, "num-outstanding", true,
4950
"number of outstanding requests. Default: 1")
50-
options.addOption(FILE_OPTION, "file", true, "File to transfer")
51+
options.addOption(FILE_OPTION, "files", true, "Files to transfer")
5152
options
5253
}
5354

@@ -68,17 +69,27 @@ object UcxPerfBenchmark extends App with Logging {
6869
null
6970
}
7071

71-
val file = if (cmd.hasOption(FILE_OPTION)) {
72-
new File(cmd.getOptionValue(FILE_OPTION))
73-
} else {
74-
null
72+
var files = Array[File]()
73+
if (cmd.hasOption(FILE_OPTION)) {
74+
for (name <- cmd.getOptionValue(FILE_OPTION).split(",")) {
75+
val f = new File(name)
76+
if (!f.exists()) {
77+
System.err.println(s"File $name does not exist.")
78+
System.exit(-1)
79+
}
80+
files +:= f
81+
}
82+
}
83+
if (files.size == 0) {
84+
System.err.println(s"No file.")
85+
System.exit(-1)
7586
}
7687

7788
PerfOptions(inetAddress,
7889
Integer.parseInt(cmd.getOptionValue(NUM_BLOCKS_OPTION, "1")),
7990
JavaUtils.byteStringAsBytes(cmd.getOptionValue(SIZE_OPTION, "1m")),
8091
Integer.parseInt(cmd.getOptionValue(ITER_OPTION, "1")),
81-
file,
92+
files,
8293
Integer.parseInt(cmd.getOptionValue(OUTSTANDING_OPTION, "1")))
8394
}
8495

@@ -101,7 +112,9 @@ object UcxPerfBenchmark extends App with Logging {
101112
for (b <- 0 until options.numBlocks by options.numOutstanding) {
102113
requestInFlight.set(options.numOutstanding)
103114
for (o <- 0 until options.numOutstanding) {
104-
blocks(o) = UcxShuffleBockId(0, 0, (b+o) % options.numBlocks)
115+
val fileIdx = (b+o) % options.files.size
116+
val blockIdx = (b+o) / options.files.size
117+
blocks(o) = UcxShuffleBockId(0, fileIdx, blockIdx)
105118
callbacks(o) = (result: OperationResult) => {
106119
result.getData.close()
107120
val stats = result.getStats.get
@@ -126,17 +139,21 @@ object UcxPerfBenchmark extends App with Logging {
126139
ucxTransport.init()
127140
val currentThread = Thread.currentThread()
128141

129-
val channel = new RandomAccessFile(options.file, "rw").getChannel
142+
var channels = Array[FileChannel]()
143+
options.files.foreach(channels +:= new RandomAccessFile(_, "r").getChannel)
130144

131145
ShutdownHookManager.addShutdownHook(()=>{
132146
currentThread.interrupt()
133147
ucxTransport.close()
134148
})
135149

136150
for (b <- 0 until options.numBlocks) {
137-
val blockId = UcxShuffleBockId(0, 0, b)
151+
val fileIdx = b % options.files.size
152+
val blockIdx = b / options.files.size
153+
val blockId = UcxShuffleBockId(0, fileIdx, blockIdx)
138154
val block = new Block {
139-
private val fileOffset = b * options.blockSize
155+
private val channel = channels(fileIdx)
156+
private val fileOffset = blockIdx * options.blockSize
140157

141158
override def getMemoryBlock: MemoryBlock = {
142159
val startTime = System.nanoTime()

0 commit comments

Comments
 (0)