diff --git a/bridge/src/main/scala/protocbridge/frontend/MacPluginFrontend.scala b/bridge/src/main/scala/protocbridge/frontend/MacPluginFrontend.scala new file mode 100644 index 0000000..6f33cf8 --- /dev/null +++ b/bridge/src/main/scala/protocbridge/frontend/MacPluginFrontend.scala @@ -0,0 +1,36 @@ +package protocbridge.frontend + +import java.nio.file.attribute.PosixFilePermission +import java.nio.file.{Files, Path} +import java.{util => ju} + +/** PluginFrontend for macOS. + * + * Creates a server socket and uses `nc` to communicate with the socket. We use + * a server socket instead of named pipes because named pipes are unreliable on + * macOS: https://github.com/scalapb/protoc-bridge/issues/366. Since `nc` is + * widely available on macOS, this is the simplest and most reliable solution + * for macOS. + */ +object MacPluginFrontend extends SocketBasedPluginFrontend { + + protected def createShellScript(port: Int): Path = { + val shell = sys.env.getOrElse("PROTOCBRIDGE_SHELL", "/bin/sh") + // We use 127.0.0.1 instead of localhost for the (very unlikely) case that localhost is missing from /etc/hosts. + val scriptName = PluginFrontend.createTempFile( + "", + s"""|#!$shell + |set -e + |nc 127.0.0.1 $port + """.stripMargin + ) + val perms = new ju.HashSet[PosixFilePermission] + perms.add(PosixFilePermission.OWNER_EXECUTE) + perms.add(PosixFilePermission.OWNER_READ) + Files.setPosixFilePermissions( + scriptName, + perms + ) + scriptName + } +} diff --git a/bridge/src/main/scala/protocbridge/frontend/PluginFrontend.scala b/bridge/src/main/scala/protocbridge/frontend/PluginFrontend.scala index 7415f06..3b83cfa 100644 --- a/bridge/src/main/scala/protocbridge/frontend/PluginFrontend.scala +++ b/bridge/src/main/scala/protocbridge/frontend/PluginFrontend.scala @@ -5,8 +5,6 @@ import java.nio.file.{Files, Path} import protocbridge.{ProtocCodeGenerator, ExtraEnv} -import scala.util.Try - /** A PluginFrontend instance provides a platform-dependent way for protoc to * communicate with a JVM based ProtocCodeGenerator. * @@ -47,13 +45,7 @@ object PluginFrontend { gen: ProtocCodeGenerator, request: Array[Byte] ): Array[Byte] = { - Try { - gen.run(request) - }.recover { case throwable => - createCodeGeneratorResponseWithError( - throwable.toString + "\n" + getStackTrace(throwable) - ) - }.get + gen.run(request) } def createCodeGeneratorResponseWithError(error: String): Array[Byte] = { @@ -116,9 +108,17 @@ object PluginFrontend { gen: ProtocCodeGenerator, fsin: InputStream, env: ExtraEnv - ): Array[Byte] = { + ): Array[Byte] = try { val bytes = readInputStreamToByteArrayWithEnv(fsin, env) runWithBytes(gen, bytes) + } catch { + // This covers all Throwable including OutOfMemoryError, StackOverflowError, etc. + // We need to make a best effort to return a response to protoc, + // otherwise protoc can hang indefinitely. + case throwable: Throwable => + createCodeGeneratorResponseWithError( + throwable.toString + "\n" + getStackTrace(throwable) + ) } def createTempFile(extension: String, content: String): Path = { @@ -131,8 +131,13 @@ object PluginFrontend { def isWindows: Boolean = sys.props("os.name").startsWith("Windows") + def isMac: Boolean = sys.props("os.name").startsWith("Mac") || sys + .props("os.name") + .startsWith("Darwin") + def newInstance: PluginFrontend = { if (isWindows) WindowsPluginFrontend + else if (isMac) MacPluginFrontend else PosixPluginFrontend } } diff --git a/bridge/src/main/scala/protocbridge/frontend/PosixPluginFrontend.scala b/bridge/src/main/scala/protocbridge/frontend/PosixPluginFrontend.scala index 5f70120..65935e0 100644 --- a/bridge/src/main/scala/protocbridge/frontend/PosixPluginFrontend.scala +++ b/bridge/src/main/scala/protocbridge/frontend/PosixPluginFrontend.scala @@ -12,10 +12,13 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.sys.process._ import java.{util => ju} -/** PluginFrontend for Unix-like systems (Linux, Mac, etc) +/** PluginFrontend for Unix-like systems except macOS (Linux, FreeBSD, + * etc) * * Creates a pair of named pipes for input/output and a shell script that - * communicates with them. + * communicates with them. Compared with `SocketBasedPluginFrontend`, this + * frontend doesn't rely on `nc` that might not be available in some + * distributions. */ object PosixPluginFrontend extends PluginFrontend { case class InternalState( @@ -40,6 +43,11 @@ object PosixPluginFrontend extends PluginFrontend { val response = PluginFrontend.runWithInputStream(plugin, fsin, env) fsin.close() + // Note that the output pipe must be opened after the input pipe is consumed. + // Otherwise, there might be a deadlock that + // - The shell script is stuck writing to the input pipe (which has a full buffer), + // and doesn't open the write end of the output pipe. + // - This thread is stuck waiting for the write end of the output pipe to be opened. val fsout = Files.newOutputStream(outputPipe) fsout.write(response) fsout.close() diff --git a/bridge/src/main/scala/protocbridge/frontend/SocketBasedPluginFrontend.scala b/bridge/src/main/scala/protocbridge/frontend/SocketBasedPluginFrontend.scala new file mode 100644 index 0000000..6d1dd59 --- /dev/null +++ b/bridge/src/main/scala/protocbridge/frontend/SocketBasedPluginFrontend.scala @@ -0,0 +1,51 @@ +package protocbridge.frontend + +import protocbridge.{ExtraEnv, ProtocCodeGenerator} + +import java.net.ServerSocket +import java.nio.file.{Files, Path} +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{Future, blocking} + +/** PluginFrontend for Windows and macOS where a server socket is used. + */ +abstract class SocketBasedPluginFrontend extends PluginFrontend { + case class InternalState(serverSocket: ServerSocket, shellScript: Path) + + override def prepare( + plugin: ProtocCodeGenerator, + env: ExtraEnv + ): (Path, InternalState) = { + val ss = new ServerSocket(0) // Bind to any available port. + val sh = createShellScript(ss.getLocalPort) + + Future { + blocking { + // Accept a single client connection from the shell script. + val client = ss.accept() + try { + val response = + PluginFrontend.runWithInputStream( + plugin, + client.getInputStream, + env + ) + client.getOutputStream.write(response) + } finally { + client.close() + } + } + } + + (sh, InternalState(ss, sh)) + } + + override def cleanup(state: InternalState): Unit = { + state.serverSocket.close() + if (sys.props.get("protocbridge.debug") != Some("1")) { + Files.delete(state.shellScript) + } + } + + protected def createShellScript(port: Int): Path +} diff --git a/bridge/src/main/scala/protocbridge/frontend/WindowsPluginFrontend.scala b/bridge/src/main/scala/protocbridge/frontend/WindowsPluginFrontend.scala index 490211d..adf9486 100644 --- a/bridge/src/main/scala/protocbridge/frontend/WindowsPluginFrontend.scala +++ b/bridge/src/main/scala/protocbridge/frontend/WindowsPluginFrontend.scala @@ -1,53 +1,15 @@ package protocbridge.frontend -import java.net.ServerSocket -import java.nio.file.{Files, Path, Paths} - -import protocbridge.ExtraEnv -import protocbridge.ProtocCodeGenerator - -import scala.concurrent.blocking - -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future +import java.nio.file.{Path, Paths} /** A PluginFrontend that binds a server socket to a local interface. The plugin * is a batch script that invokes BridgeApp.main() method, in a new JVM with * the same parameters as the currently running JVM. The plugin will * communicate its stdin and stdout to this socket. */ -object WindowsPluginFrontend extends PluginFrontend { - - case class InternalState(batFile: Path) - - override def prepare( - plugin: ProtocCodeGenerator, - env: ExtraEnv - ): (Path, InternalState) = { - val ss = new ServerSocket(0) - val state = createWindowsScript(ss.getLocalPort) - - Future { - blocking { - val client = ss.accept() - val response = - PluginFrontend.runWithInputStream(plugin, client.getInputStream, env) - client.getOutputStream.write(response) - client.close() - ss.close() - } - } - - (state.batFile, state) - } - - override def cleanup(state: InternalState): Unit = { - if (sys.props.get("protocbridge.debug") != Some("1")) { - Files.delete(state.batFile) - } - } +object WindowsPluginFrontend extends SocketBasedPluginFrontend { - private def createWindowsScript(port: Int): InternalState = { + protected def createShellScript(port: Int): Path = { val classPath = Paths.get(getClass.getProtectionDomain.getCodeSource.getLocation.toURI) val classPathBatchString = classPath.toString.replace("%", "%%") @@ -62,6 +24,6 @@ object WindowsPluginFrontend extends PluginFrontend { ].getName} $port """.stripMargin ) - InternalState(batchFile) + batchFile } } diff --git a/bridge/src/test/scala/protocbridge/frontend/MacPluginFrontendSpec.scala b/bridge/src/test/scala/protocbridge/frontend/MacPluginFrontendSpec.scala new file mode 100644 index 0000000..6e8b972 --- /dev/null +++ b/bridge/src/test/scala/protocbridge/frontend/MacPluginFrontendSpec.scala @@ -0,0 +1,15 @@ +package protocbridge.frontend + +class MacPluginFrontendSpec extends OsSpecificFrontendSpec { + if (PluginFrontend.isMac) { + it must "execute a program that forwards input and output to given stream" in { + val state = testSuccess(MacPluginFrontend) + state.serverSocket.isClosed mustBe true + } + + it must "not hang if there is an error in generator" in { + val state = testFailure(MacPluginFrontend) + state.serverSocket.isClosed mustBe true + } + } +} diff --git a/bridge/src/test/scala/protocbridge/frontend/OsSpecificFrontendSpec.scala b/bridge/src/test/scala/protocbridge/frontend/OsSpecificFrontendSpec.scala new file mode 100644 index 0000000..65f84ce --- /dev/null +++ b/bridge/src/test/scala/protocbridge/frontend/OsSpecificFrontendSpec.scala @@ -0,0 +1,113 @@ +package protocbridge.frontend + +import org.apache.commons.io.IOUtils +import org.scalatest.exceptions.TestFailedException +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers +import protocbridge.{ExtraEnv, ProtocCodeGenerator} + +import java.io.ByteArrayOutputStream +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.DurationInt +import scala.concurrent.{Await, Future, TimeoutException} +import scala.sys.process.ProcessIO +import scala.util.Random + +class OsSpecificFrontendSpec extends AnyFlatSpec with Matchers { + + protected def testPluginFrontend( + frontend: PluginFrontend, + generator: ProtocCodeGenerator, + env: ExtraEnv, + request: Array[Byte] + ): (frontend.InternalState, Array[Byte]) = { + val (path, state) = frontend.prepare( + generator, + env + ) + val actualOutput = new ByteArrayOutputStream() + val process = sys.process + .Process(path.toAbsolutePath.toString) + .run( + new ProcessIO( + writeInput => { + writeInput.write(request) + writeInput.close() + }, + processOutput => { + IOUtils.copy(processOutput, actualOutput) + processOutput.close() + }, + processError => { + IOUtils.copy(processError, System.err) + processError.close() + } + ) + ) + try { + Await.result(Future { process.exitValue() }, 5.seconds) + } catch { + case _: TimeoutException => + System.err.println(s"Timeout") + process.destroy() + } + frontend.cleanup(state) + (state, actualOutput.toByteArray) + } + + protected def testSuccess( + frontend: PluginFrontend + ): frontend.InternalState = { + val random = new Random() + val toSend = Array.fill(123)(random.nextInt(256).toByte) + val toReceive = Array.fill(456)(random.nextInt(256).toByte) + val env = new ExtraEnv(secondaryOutputDir = "tmp") + + val fakeGenerator = new ProtocCodeGenerator { + override def run(request: Array[Byte]): Array[Byte] = { + request mustBe (toSend ++ env.toByteArrayAsField) + toReceive + } + } + // Repeat 100,000 times since named pipes on macOS are flaky. + val repeatCount = 100000 + for (i <- 1 until repeatCount) { + if (i % 100 == 1) println(s"Running iteration $i of $repeatCount") + val (state, response) = + testPluginFrontend(frontend, fakeGenerator, env, toSend) + try { + response mustBe toReceive + } catch { + case e: TestFailedException => + System.err.println(s"""Failed on iteration $i of $repeatCount: ${e.getMessage}""") + } + } + val (state, response) = + testPluginFrontend(frontend, fakeGenerator, env, toSend) + try { + response mustBe toReceive + } catch { + case e: TestFailedException => + System.err.println(s"""Failed on iteration $repeatCount of $repeatCount: ${e.getMessage}""") + } + state + } + + protected def testFailure( + frontend: PluginFrontend + ): frontend.InternalState = { + val random = new Random() + val toSend = Array.fill(123)(random.nextInt(256).toByte) + val env = new ExtraEnv(secondaryOutputDir = "tmp") + + val fakeGenerator = new ProtocCodeGenerator { + override def run(request: Array[Byte]): Array[Byte] = { + throw new OutOfMemoryError("test error") + } + } + val (state, response) = + testPluginFrontend(frontend, fakeGenerator, env, toSend) + response.length must be > 0 + state + } +} diff --git a/bridge/src/test/scala/protocbridge/frontend/PosixPluginFrontendSpec.scala b/bridge/src/test/scala/protocbridge/frontend/PosixPluginFrontendSpec.scala new file mode 100644 index 0000000..1c615d2 --- /dev/null +++ b/bridge/src/test/scala/protocbridge/frontend/PosixPluginFrontendSpec.scala @@ -0,0 +1,13 @@ +package protocbridge.frontend + +class PosixPluginFrontendSpec extends OsSpecificFrontendSpec { + if (!PluginFrontend.isWindows && !PluginFrontend.isMac) { + it must "execute a program that forwards input and output to given stream" in { + testSuccess(MacPluginFrontend) + } + + it must "not hang if there is an OOM in generator" in { + testFailure(MacPluginFrontend) + } + } +} diff --git a/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala b/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala new file mode 100644 index 0000000..531b44b --- /dev/null +++ b/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala @@ -0,0 +1,50 @@ +package protocbridge.frontend +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers + +import java.lang.management.ManagementFactory +import java.net.ServerSocket +import scala.collection.mutable +import scala.sys.process._ +import scala.util.{Failure, Success, Try} + +class SocketAllocationSpec extends AnyFlatSpec with Matchers { + it must "allocate an unused port" in { + val repeatCount = 100000 + + val currentPid = getCurrentPid + val portConflictCount = mutable.Map[Int, Int]() + + for (i <- 1 to repeatCount) { + if (i % 100 == 1) println(s"Running iteration $i of $repeatCount") + + val serverSocket = new ServerSocket(0) // Bind to any available port. + try { + val port = serverSocket.getLocalPort + Try { + s"lsof -i :$port -t".!!.trim + } match { + case Success(output) => + if (output.nonEmpty) { + val pids = output.split("\n").filterNot(_ == currentPid.toString) + if (pids.nonEmpty) { + System.err.println("Port conflict detected on port " + port + " with PIDs: " + pids.mkString(", ")) + portConflictCount(port) = portConflictCount.getOrElse(port, 0) + 1 + } + } + case Failure(_) => // Ignore failure and continue + } + } finally { + serverSocket.close() + } + } + + assert(portConflictCount.isEmpty, s"Found the following ports in use out of $repeatCount: $portConflictCount") + } + + private def getCurrentPid: Int = { + val jvmName = ManagementFactory.getRuntimeMXBean.getName + val pid = jvmName.split("@")(0) + pid.toInt + } +} diff --git a/bridge/src/test/scala/protocbridge/frontend/WindowsPluginFrontendSpec.scala b/bridge/src/test/scala/protocbridge/frontend/WindowsPluginFrontendSpec.scala index 6385ad7..db0bc65 100644 --- a/bridge/src/test/scala/protocbridge/frontend/WindowsPluginFrontendSpec.scala +++ b/bridge/src/test/scala/protocbridge/frontend/WindowsPluginFrontendSpec.scala @@ -1,38 +1,15 @@ package protocbridge.frontend -import java.io.ByteArrayInputStream - -import protocbridge.{ProtocCodeGenerator, ExtraEnv} - -import scala.sys.process.ProcessLogger -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.must.Matchers - -class WindowsPluginFrontendSpec extends AnyFlatSpec with Matchers { +class WindowsPluginFrontendSpec extends OsSpecificFrontendSpec { if (PluginFrontend.isWindows) { it must "execute a program that forwards input and output to given stream" in { - val toSend = "ping" - val toReceive = "pong" - val env = new ExtraEnv(secondaryOutputDir = "tmp") + val state = testSuccess(WindowsPluginFrontend) + state.serverSocket.isClosed mustBe true + } - val fakeGenerator = new ProtocCodeGenerator { - override def run(request: Array[Byte]): Array[Byte] = { - request mustBe (toSend.getBytes ++ env.toByteArrayAsField) - toReceive.getBytes - } - } - val (path, state) = WindowsPluginFrontend.prepare( - fakeGenerator, - env - ) - val actualOutput = scala.collection.mutable.Buffer.empty[String] - val process = sys.process - .Process(path.toAbsolutePath.toString) - .#<(new ByteArrayInputStream(toSend.getBytes)) - .run(ProcessLogger(o => actualOutput.append(o))) - process.exitValue() - actualOutput.mkString mustBe toReceive - WindowsPluginFrontend.cleanup(state) + it must "not hang if there is an OOM in generator" in { + val state = testFailure(WindowsPluginFrontend) + state.serverSocket.isClosed mustBe true } } } diff --git a/port_conflict.py b/port_conflict.py new file mode 100644 index 0000000..d7033c0 --- /dev/null +++ b/port_conflict.py @@ -0,0 +1,68 @@ +import os +import socket +import subprocess +import sys + +def is_port_in_use(port, current_pid): + """ + Check if the given port is in use by other processes, excluding the current process. + + :param port: Port number to check + :param pid: Current process ID to exclude from the result + :return: True if the port is in use by another process, False otherwise + """ + try: + # Run lsof command to check if any process is using the port + result = subprocess.run( + ['lsof', '-i', f':{port}', '-t'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + output = result.stdout.strip() + + if output: + # Check if the output contains lines with processes other than the current one + return [ + line + for line in output.split('\n') + if line != str(current_pid) + ] + return [] + except subprocess.CalledProcessError as e: + print(f"Error checking port: {e}", file=sys.stderr) + return [] + +def main(): + repeat_count = 10000 + + current_pid = os.getpid() # Get the current process ID + port_conflict_count = {} + + for i in range(1, repeat_count + 1): + if i % 100 == 1: + print(f"Running iteration {i} of {repeat_count}") + + # Bind to an available port (port 0) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('', 0)) # Bind to port 0 to get an available port + port = sock.getsockname()[1] # Get the actual port number assigned + + # Check if the port is in use by any other process + pids = is_port_in_use(port, current_pid) + if pids: + print(f"Port conflict detected on port {port} with PIDs: {', '.join(pids)}", file=sys.stderr) + port_conflict_count[port] = port_conflict_count.get(port, 0) + 1 + + # Close the socket after checking + sock.close() + + if port_conflict_count: + print("Ports that were found to be in use and their collision counts:") + for port, count in port_conflict_count.items(): + print(f"Port {port} was found in use {count} times") + else: + print("No ports were found to be in use.") + +if __name__ == '__main__': + main()