Skip to content

Retry getting a free port on macos #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
package protocbridge.frontend

import protocbridge.frontend.SocketBasedPluginFrontend.getFreeSocket
import protocbridge.{ExtraEnv, ProtocCodeGenerator}

import java.lang.management.ManagementFactory
import java.net.ServerSocket
import java.nio.file.{Files, Path}
import scala.annotation.tailrec
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.{Future, blocking}
import scala.util.{Failure, Success, Try}


/** PluginFrontend for Windows and macOS where a server socket is used.
*/
*/
abstract class SocketBasedPluginFrontend extends PluginFrontend {

case class InternalState(serverSocket: ServerSocket, shellScript: Path)

override def prepare(
plugin: ProtocCodeGenerator,
env: ExtraEnv
): (Path, InternalState) = {
val ss = new ServerSocket(0) // Bind to any available port.
plugin: ProtocCodeGenerator,
env: ExtraEnv
): (Path, InternalState) = {
val ss = getFreeSocket()
val sh = createShellScript(ss.getLocalPort)

Future {
Expand Down Expand Up @@ -49,3 +55,57 @@ abstract class SocketBasedPluginFrontend extends PluginFrontend {

protected def createShellScript(port: Int): Path
}


object SocketBasedPluginFrontend {

private lazy val currentPid: Int = {
val jvmName = ManagementFactory.getRuntimeMXBean.getName
val pid = jvmName.split("@")(0)
pid.toInt
}

private def isSocketConflict(currentPid: Int, port: Int): Boolean = {
import scala.sys.process._
Try {
s"/usr/sbin/lsof -i :$port -t".!!.trim
} match {
case Success(output) =>
if (output.nonEmpty) {
val otherPids = output.split("\n").filterNot(_ == currentPid.toString)
otherPids.nonEmpty
} else {
true
}
case Failure(e) =>
System.err.println(s"Failure checking if port is busy: $e")
false
}
}

@tailrec
private def getFreeSocketForMac(currentPid: Int, attemptsLeft: Int): ServerSocket = {
val sock = new ServerSocket(0)
if (isSocketConflict(currentPid, sock.getLocalPort)) {
if (attemptsLeft > 0) {
System.out.println(s"Socket conflict on port ${sock.getLocalPort}, retry, $attemptsLeft attempts left")
getFreeSocketForMac(currentPid, attemptsLeft - 1)
} else {
System.out.println(s"Socket conflict on port ${sock.getLocalPort}, no retries left, you're gonna get an error")
sock
}
} else {
sock
}
}

def getFreeSocket(maxAttempts: Int = 5): ServerSocket = {
if (!PluginFrontend.isMac) {
// Bind to any available port.
new ServerSocket(0)
} else {
//new ServerSocket(0) on mac might return a socket already in use, so here's a hack
getFreeSocketForMac(currentPid, maxAttempts)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package protocbridge.frontend
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

import java.lang.management.ManagementFactory
import scala.collection.mutable
import scala.sys.process._
import scala.util.{Failure, Success, Try}

class SocketAllocationSpec extends AnyFlatSpec with Matchers {
it must "allocate an unused port" in {
val repeatCount = 100000

val currentPid = getCurrentPid
val portConflictCount = mutable.Map[Int, Int]()

for (i <- 1 to repeatCount) {
if (i % 100 == 1) println(s"Running iteration $i of $repeatCount")

val serverSocket = SocketBasedPluginFrontend.getFreeSocket()
try {
val port = serverSocket.getLocalPort
Try {
s"lsof -i :$port -t".!!.trim
} match {
case Success(output) =>
if (output.nonEmpty) {
val pids = output.split("\n").filterNot(_ == currentPid.toString)
if (pids.nonEmpty) {
System.err.println("Port conflict detected on port " + port + " with PIDs: " + pids.mkString(", "))
portConflictCount(port) = portConflictCount.getOrElse(port, 0) + 1
}
}
case Failure(_) => // Ignore failure and continue
}
} finally {
serverSocket.close()
}
}

assert(portConflictCount.isEmpty, s"Found the following ports in use out of $repeatCount: $portConflictCount")
}

private def getCurrentPid: Int = {
val jvmName = ManagementFactory.getRuntimeMXBean.getName
val pid = jvmName.split("@")(0)
pid.toInt
}
}