diff --git a/bridge/src/main/scala/protocbridge/frontend/SocketBasedPluginFrontend.scala b/bridge/src/main/scala/protocbridge/frontend/SocketBasedPluginFrontend.scala index 6d1dd59..0cc73ca 100644 --- a/bridge/src/main/scala/protocbridge/frontend/SocketBasedPluginFrontend.scala +++ b/bridge/src/main/scala/protocbridge/frontend/SocketBasedPluginFrontend.scala @@ -1,22 +1,28 @@ package protocbridge.frontend +import protocbridge.frontend.SocketBasedPluginFrontend.getFreeSocket import protocbridge.{ExtraEnv, ProtocCodeGenerator} +import java.lang.management.ManagementFactory import java.net.ServerSocket import java.nio.file.{Files, Path} +import scala.annotation.tailrec import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.{Future, blocking} +import scala.util.{Failure, Success, Try} + /** PluginFrontend for Windows and macOS where a server socket is used. - */ + */ abstract class SocketBasedPluginFrontend extends PluginFrontend { + case class InternalState(serverSocket: ServerSocket, shellScript: Path) override def prepare( - plugin: ProtocCodeGenerator, - env: ExtraEnv - ): (Path, InternalState) = { - val ss = new ServerSocket(0) // Bind to any available port. + plugin: ProtocCodeGenerator, + env: ExtraEnv + ): (Path, InternalState) = { + val ss = getFreeSocket() val sh = createShellScript(ss.getLocalPort) Future { @@ -49,3 +55,57 @@ abstract class SocketBasedPluginFrontend extends PluginFrontend { protected def createShellScript(port: Int): Path } + + +object SocketBasedPluginFrontend { + + private lazy val currentPid: Int = { + val jvmName = ManagementFactory.getRuntimeMXBean.getName + val pid = jvmName.split("@")(0) + pid.toInt + } + + private def isSocketConflict(currentPid: Int, port: Int): Boolean = { + import scala.sys.process._ + Try { + s"/usr/sbin/lsof -i :$port -t".!!.trim + } match { + case Success(output) => + if (output.nonEmpty) { + val otherPids = output.split("\n").filterNot(_ == currentPid.toString) + otherPids.nonEmpty + } else { + true + } + case Failure(e) => + System.err.println(s"Failure checking if port is busy: $e") + false + } + } + + @tailrec + private def getFreeSocketForMac(currentPid: Int, attemptsLeft: Int): ServerSocket = { + val sock = new ServerSocket(0) + if (isSocketConflict(currentPid, sock.getLocalPort)) { + if (attemptsLeft > 0) { + System.out.println(s"Socket conflict on port ${sock.getLocalPort}, retry, $attemptsLeft attempts left") + getFreeSocketForMac(currentPid, attemptsLeft - 1) + } else { + System.out.println(s"Socket conflict on port ${sock.getLocalPort}, no retries left, you're gonna get an error") + sock + } + } else { + sock + } + } + + def getFreeSocket(maxAttempts: Int = 5): ServerSocket = { + if (!PluginFrontend.isMac) { + // Bind to any available port. + new ServerSocket(0) + } else { + //new ServerSocket(0) on mac might return a socket already in use, so here's a hack + getFreeSocketForMac(currentPid, maxAttempts) + } + } +} diff --git a/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala b/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala new file mode 100644 index 0000000..55de593 --- /dev/null +++ b/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala @@ -0,0 +1,49 @@ +package protocbridge.frontend +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers + +import java.lang.management.ManagementFactory +import scala.collection.mutable +import scala.sys.process._ +import scala.util.{Failure, Success, Try} + +class SocketAllocationSpec extends AnyFlatSpec with Matchers { + it must "allocate an unused port" in { + val repeatCount = 100000 + + val currentPid = getCurrentPid + val portConflictCount = mutable.Map[Int, Int]() + + for (i <- 1 to repeatCount) { + if (i % 100 == 1) println(s"Running iteration $i of $repeatCount") + + val serverSocket = SocketBasedPluginFrontend.getFreeSocket() + try { + val port = serverSocket.getLocalPort + Try { + s"lsof -i :$port -t".!!.trim + } match { + case Success(output) => + if (output.nonEmpty) { + val pids = output.split("\n").filterNot(_ == currentPid.toString) + if (pids.nonEmpty) { + System.err.println("Port conflict detected on port " + port + " with PIDs: " + pids.mkString(", ")) + portConflictCount(port) = portConflictCount.getOrElse(port, 0) + 1 + } + } + case Failure(_) => // Ignore failure and continue + } + } finally { + serverSocket.close() + } + } + + assert(portConflictCount.isEmpty, s"Found the following ports in use out of $repeatCount: $portConflictCount") + } + + private def getCurrentPid: Int = { + val jvmName = ManagementFactory.getRuntimeMXBean.getName + val pid = jvmName.split("@")(0) + pid.toInt + } +}