From 2a57d0ca447a79a0fdf53834a6d54514f9680dc8 Mon Sep 17 00:00:00 2001 From: Bell Le Date: Fri, 13 Sep 2024 13:23:32 -0700 Subject: [PATCH] Add SocketAllocationSpec --- .../frontend/SocketAllocationSpec.scala | 50 ++++++++++++++ port_conflict.py | 68 +++++++++++++++++++ 2 files changed, 118 insertions(+) create mode 100644 bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala create mode 100644 port_conflict.py diff --git a/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala b/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala new file mode 100644 index 0000000..531b44b --- /dev/null +++ b/bridge/src/test/scala/protocbridge/frontend/SocketAllocationSpec.scala @@ -0,0 +1,50 @@ +package protocbridge.frontend +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers + +import java.lang.management.ManagementFactory +import java.net.ServerSocket +import scala.collection.mutable +import scala.sys.process._ +import scala.util.{Failure, Success, Try} + +class SocketAllocationSpec extends AnyFlatSpec with Matchers { + it must "allocate an unused port" in { + val repeatCount = 100000 + + val currentPid = getCurrentPid + val portConflictCount = mutable.Map[Int, Int]() + + for (i <- 1 to repeatCount) { + if (i % 100 == 1) println(s"Running iteration $i of $repeatCount") + + val serverSocket = new ServerSocket(0) // Bind to any available port. + try { + val port = serverSocket.getLocalPort + Try { + s"lsof -i :$port -t".!!.trim + } match { + case Success(output) => + if (output.nonEmpty) { + val pids = output.split("\n").filterNot(_ == currentPid.toString) + if (pids.nonEmpty) { + System.err.println("Port conflict detected on port " + port + " with PIDs: " + pids.mkString(", ")) + portConflictCount(port) = portConflictCount.getOrElse(port, 0) + 1 + } + } + case Failure(_) => // Ignore failure and continue + } + } finally { + serverSocket.close() + } + } + + assert(portConflictCount.isEmpty, s"Found the following ports in use out of $repeatCount: $portConflictCount") + } + + private def getCurrentPid: Int = { + val jvmName = ManagementFactory.getRuntimeMXBean.getName + val pid = jvmName.split("@")(0) + pid.toInt + } +} diff --git a/port_conflict.py b/port_conflict.py new file mode 100644 index 0000000..d7033c0 --- /dev/null +++ b/port_conflict.py @@ -0,0 +1,68 @@ +import os +import socket +import subprocess +import sys + +def is_port_in_use(port, current_pid): + """ + Check if the given port is in use by other processes, excluding the current process. + + :param port: Port number to check + :param pid: Current process ID to exclude from the result + :return: True if the port is in use by another process, False otherwise + """ + try: + # Run lsof command to check if any process is using the port + result = subprocess.run( + ['lsof', '-i', f':{port}', '-t'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + output = result.stdout.strip() + + if output: + # Check if the output contains lines with processes other than the current one + return [ + line + for line in output.split('\n') + if line != str(current_pid) + ] + return [] + except subprocess.CalledProcessError as e: + print(f"Error checking port: {e}", file=sys.stderr) + return [] + +def main(): + repeat_count = 10000 + + current_pid = os.getpid() # Get the current process ID + port_conflict_count = {} + + for i in range(1, repeat_count + 1): + if i % 100 == 1: + print(f"Running iteration {i} of {repeat_count}") + + # Bind to an available port (port 0) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('', 0)) # Bind to port 0 to get an available port + port = sock.getsockname()[1] # Get the actual port number assigned + + # Check if the port is in use by any other process + pids = is_port_in_use(port, current_pid) + if pids: + print(f"Port conflict detected on port {port} with PIDs: {', '.join(pids)}", file=sys.stderr) + port_conflict_count[port] = port_conflict_count.get(port, 0) + 1 + + # Close the socket after checking + sock.close() + + if port_conflict_count: + print("Ports that were found to be in use and their collision counts:") + for port, count in port_conflict_count.items(): + print(f"Port {port} was found in use {count} times") + else: + print("No ports were found to be in use.") + +if __name__ == '__main__': + main()