1616 */
1717package io.bazel.kotlin.builder.tasks
1818
19- import com.google.devtools.build.lib.worker.WorkerProtocol
19+ import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest
20+ import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse
2021import io.bazel.kotlin.builder.utils.WorkingDirectoryContext
2122import io.bazel.kotlin.builder.utils.wasInterrupted
23+ import java.io.BufferedInputStream
2224import java.io.ByteArrayInputStream
2325import java.io.ByteArrayOutputStream
26+ import java.io.Closeable
2427import java.io.IOException
28+ import java.io.InputStream
2529import java.io.PrintStream
2630import java.nio.charset.StandardCharsets.UTF_8
2731import java.nio.file.Files
2832import java.nio.file.Path
2933import java.nio.file.Paths
34+ import java.util.logging.Level.SEVERE
35+ import java.util.logging.Logger
3036
3137/* *
3238 * Interface for command line programs.
3339 *
3440 * This is the same thing as a main function, except not static.
3541 */
42+ @FunctionalInterface
3643interface CommandLineProgram {
3744 /* *
3845 * Runs blocking program start to finish.
@@ -56,107 +63,164 @@ interface CommandLineProgram {
5663 * @param <T> delegate program type
5764 */
5865class BazelWorker (
59- private val delegate : CommandLineProgram ,
66+ private val commandLineProgram : CommandLineProgram ,
6067 private val output : PrintStream ,
6168 private val mnemonic : String
6269) {
63- companion object {
64- const val OK = 0
65- const val INTERRUPTED_STATUS = 1
66- const val ERROR_STATUS = 2
67- }
6870
6971 fun apply (args : List <String >): Int {
70- return if (args.contains(" --persistent_worker" ))
71- runAsPersistentWorker()
72- else WorkingDirectoryContext .newContext().use { ctx ->
73- delegate.apply (ctx.dir, loadArguments(args, false ))
72+ if (args.contains(" --persistent_worker" )) {
73+ return WorkerIO .open().use { io ->
74+ PersistentWorker (io, commandLineProgram).run (args)
75+ }
76+ } else {
77+ output.println (
78+ " HINT: $mnemonic will compile faster if you run: echo \" build --strategy=$mnemonic =worker\" >>~/.bazelrc"
79+ )
80+ return WorkerIO .noop().use { io ->
81+ InvocationWorker (io, commandLineProgram).run (args)
82+ }
7483 }
7584 }
85+ }
7686
77- private fun runAsPersistentWorker (): Int {
78- val realStdIn = System .`in `
79- val realStdOut = System .out
80- val realStdErr = System .err
81- try {
82- ByteArrayInputStream (ByteArray (0 )).use { emptyIn ->
83- ByteArrayOutputStream ().use { buffer ->
84- PrintStream (buffer).use { ps ->
85- System .setIn(emptyIn)
86- System .setOut(ps)
87- System .setErr(ps)
88- val invocationWorker = InvocationWorker (delegate, buffer)
89- while (true ) {
90- val status =
91- WorkerProtocol .WorkRequest .parseDelimitedFrom(realStdIn)?.let { request ->
92- invocationWorker.invoke(loadArguments(request.argumentsList, true ))
93- }?.also { (status, log) ->
94- with (WorkerProtocol .WorkResponse .newBuilder()) {
95- exitCode = status
96- output = log
97- build().writeDelimitedTo(realStdOut)
98- }
99- }?.let { (status, _) -> status } ? : OK
100-
101- if (status != OK ) {
102- return status
103- }
104- System .gc()
105- }
106- }
107- }
87+ private fun maybeExpand (args : List <String >): List <String > {
88+ if (args.isNotEmpty()) {
89+ val lastArg = args[args.size - 1 ]
90+ if (lastArg.startsWith(" @" )) {
91+ val pathElement = lastArg.substring(1 )
92+ val flagFile = Paths .get(pathElement)
93+ try {
94+ return Files .readAllLines(flagFile, UTF_8 )
95+ } catch (e: IOException ) {
96+ throw RuntimeException (e)
10897 }
109- } finally {
110- System .setIn(realStdIn)
111- System .setOut(realStdOut)
112- System .setErr(realStdErr)
11398 }
114- return OK
11599 }
100+ return args
101+ }
116102
117- private fun loadArguments (args : List <String >, isWorker : Boolean ): List <String > {
118- if (args.isNotEmpty()) {
119- val lastArg = args[args.size - 1 ]
120-
121- if (lastArg.startsWith(" @" )) {
122- val pathElement = lastArg.substring(1 )
123- val flagFile = Paths .get(pathElement)
124- if (isWorker && lastArg.startsWith(" @@" ) || Files .exists(flagFile)) {
125- if (! isWorker && mnemonic.isNotEmpty()) {
126- output.printf(
127- " HINT: %s will compile faster if you run: " + " echo \" build --strategy=%s=worker\" >>~/.bazelrc\n " ,
128- mnemonic, mnemonic
129- )
130- }
131- try {
132- return Files .readAllLines(flagFile, UTF_8 )
133- } catch (e: IOException ) {
134- throw RuntimeException (e)
135- }
136- }
103+ /* * Defines the common worker interface. */
104+ interface Worker {
105+ fun run (args : List <String >): Int
106+ }
107+
108+ class WorkerIO (
109+ val input : InputStream ,
110+ val output : PrintStream ,
111+ val execution : ByteArrayOutputStream ,
112+ private val restore : () -> Unit
113+ ) : Closeable {
114+ companion object {
115+ fun open (): WorkerIO {
116+ val stdErr = System .err
117+ val stdIn = BufferedInputStream (System .`in `)
118+ val stdOut = System .out
119+ val inputBuffer = ByteArrayInputStream (ByteArray (0 ))
120+ val execution = ByteArrayOutputStream ()
121+ val outputBuffer = PrintStream (execution)
122+
123+ // delegate the system defaults to capture execution information
124+ System .setErr(outputBuffer)
125+ System .setOut(outputBuffer)
126+ System .setIn(inputBuffer)
127+
128+ return WorkerIO (stdIn, stdOut, execution) {
129+ System .setOut(stdOut)
130+ System .setIn(stdIn)
131+ System .setErr(stdErr)
137132 }
138133 }
139- return args
134+
135+ fun noop (): WorkerIO {
136+ val inputBuffer = ByteArrayInputStream (ByteArray (0 ))
137+ val execution = ByteArrayOutputStream ()
138+ val outputBuffer = PrintStream (execution)
139+ return WorkerIO (inputBuffer, outputBuffer, execution) {}
140+ }
141+ }
142+
143+ override fun close () {
144+ restore.invoke()
140145 }
141146}
142147
143- class InvocationWorker (
144- private val delegate : CommandLineProgram ,
145- private val buffer : ByteArrayOutputStream
146- ) {
148+ /* * PersistentWorker follows the Bazel worker protocol and executes a CommandLineProgram. */
149+ class PersistentWorker (
150+ private val io : WorkerIO ,
151+ private val program : CommandLineProgram
152+ ) : Worker {
153+ private val logger = Logger .getLogger(PersistentWorker ::class .java.canonicalName)
154+
155+ enum class Status {
156+ OK , INTERRUPTED , ERROR
157+ }
158+
159+ override fun run (args : List <String >): Int {
160+ while (true ) {
161+ val request = WorkRequest .parseDelimitedFrom(io.input) ? : continue
147162
148- fun invoke (args : List <String >): Pair <Int , String > = WorkingDirectoryContext .newContext()
149- .use { wdCtx ->
150- return try {
151- delegate.apply (wdCtx.dir, args)
152- } catch (e: RuntimeException ) {
153- if (e.wasInterrupted()) BazelWorker .INTERRUPTED_STATUS
154- else BazelWorker .ERROR_STATUS .also {
155- System .err.println (
156- " ERROR: Worker threw uncaught exception with args: ${args} "
157- )
158- e.printStackTrace(System .err)
163+ val (status, exit) = WorkingDirectoryContext .newContext()
164+ .runCatching {
165+ request.argumentsList
166+ ?.let { maybeExpand(it) }
167+ .run {
168+ Status .OK to program.apply (dir, maybeExpand(request.argumentsList))
169+ }
170+ }
171+ .recover { e: Throwable ->
172+ io.execution.write((e.message ? : e.toString()).toByteArray(UTF_8 ))
173+ if (! e.wasInterrupted()) {
174+ logger.log(SEVERE ,
175+ " ERROR: Worker threw uncaught exception" ,
176+ e)
177+ Status .ERROR to 1
178+ } else {
179+ Status .INTERRUPTED to 1
159180 }
160- } to buffer.toString()
181+ }
182+ .getOrThrow()
183+
184+ val response = WorkResponse .newBuilder().apply {
185+ output = String (io.execution.toByteArray(), UTF_8 )
186+ exitCode = exit
187+ requestId = request.requestId
188+ }.build()
189+
190+ // return the response
191+ response.writeDelimitedTo(io.output)
192+ io.output.flush()
193+
194+ // clear execution logs
195+ io.execution.reset()
196+
197+ if (status == Status .INTERRUPTED ) {
198+ break
161199 }
200+ }
201+ logger.info(" Shutting down worker." )
202+ return 0
203+ }
204+ }
205+
206+ class InvocationWorker (
207+ private val io : WorkerIO ,
208+ private val program : CommandLineProgram
209+ ) : Worker {
210+ private val logger: Logger = Logger .getLogger(InvocationWorker ::class .java.canonicalName)
211+ override fun run (args : List <String >): Int = WorkingDirectoryContext .newContext()
212+ .runCatching { program.apply (dir, maybeExpand(args)) }
213+ .recover { e ->
214+ logger.log(SEVERE ,
215+ " ERROR: Worker threw uncaught exception with args: ${maybeExpand(args)} " ,
216+ e)
217+ return @recover 1 // return non-0 exitcode
218+ }
219+ .also {
220+ // print execution log
221+ println (String (io.execution.toByteArray(), UTF_8 ))
222+ }
223+ .getOrDefault(0 )
162224}
225+
226+
0 commit comments