From 0f0485281ebc47f9835d5720c18f806f0f7b9cb4 Mon Sep 17 00:00:00 2001 From: Andre Merzky Date: Sun, 24 Nov 2024 22:58:48 +0100 Subject: [PATCH] get better port numbers to simplify tunneling --- src/radical/utils/heartbeat.py | 2 +- src/radical/utils/misc.py | 30 +++++++++++++++++++++++++----- src/radical/utils/zmq/pipe.py | 23 ++++++++++++----------- src/radical/utils/zmq/pubsub.py | 22 +++++----------------- src/radical/utils/zmq/queue.py | 32 ++++++++++---------------------- src/radical/utils/zmq/utils.py | 23 +++++++++++++++++++---- tests/unittests/test_misc.py | 28 ++++++++++++++++++++++++++++ tests/unittests/test_zmq_pipe.py | 2 +- 8 files changed, 101 insertions(+), 61 deletions(-) diff --git a/src/radical/utils/heartbeat.py b/src/radical/utils/heartbeat.py index 0496fbb74..34c3cf901 100644 --- a/src/radical/utils/heartbeat.py +++ b/src/radical/utils/heartbeat.py @@ -171,7 +171,7 @@ def _watch(self): self._log.warn('hb %s fail %s: fatal (%d)', self._uid, uid, self._pid) os.kill(self._pid, signal.SIGTERM) - time.sleep(1) + time.sleep(0.1) os.kill(self._pid, signal.SIGKILL) else: diff --git a/src/radical/utils/misc.py b/src/radical/utils/misc.py index 23b88034f..10a1705bb 100644 --- a/src/radical/utils/misc.py +++ b/src/radical/utils/misc.py @@ -3,6 +3,7 @@ import sys import time import errno +import socket import tarfile import datetime import tempfile @@ -381,12 +382,12 @@ def get_env_ns(key, ns, default=None): ''' get an environment setting within a namespace. For example. - get_env_ns('verbose', 'radical.pilot.umgr'), + get_env_ns('verbose', 'radical.pilot.tmgr'), will return the value of the first found env variable from the following sequence: - RADICAL_PILOT_UMGR_LOG_LVL + RADICAL_PILOT_TMGR_LOG_LVL RADICAL_PILOT_LOG_LVL RADICAL_LOG_LVL @@ -705,7 +706,7 @@ def script_2_func(fpath): This method accepts a single parameter `fpath` which is expected to point to a file containing a self-sufficient Python script. The script will be read and stored, and a function handle will be returned which, upon calling, will - run that script in the currect Python interpreter`. It will be ensured that + run that script in the current Python interpreter. It will be ensured that `__name__` is set to `__main__`, and that any arguments passed to the callable are passed on as `sys.argv`. A single list argument is also allowed which is interpreted as argument list. @@ -714,7 +715,6 @@ def script_2_func(fpath): my_func = ru.script_2_func('/tmp/my_script.py') my_func('-f', 'foo', '-b', 'bar') - my_func('-f foo -b bar'.split()) # equivalent NOTE: calling the returned function handle will change `sys.argv` for the current Python interpreter. @@ -723,7 +723,6 @@ def script_2_func(fpath): prefix = [] postfix = [] - with ru_open(fpath, 'r') as fin: code_lines = fin.readlines() @@ -829,5 +828,26 @@ def ru_open(*args, **kwargs): return open(*args, **kwargs) +# ------------------------------------------------------------------------------ +# +def find_port(port_min=10000, port_max=65535): + ''' + Find a free port in the given range. The range defaults to 10000-65535. + Returns `None` if no free port could be found. + ''' + + for port in range(port_min, port_max): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.bind(('', port)) + return port + + except socket.error: + pass + + finally: + sock.close() + + # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/zmq/pipe.py b/src/radical/utils/zmq/pipe.py index ec747fe33..7c36dcf1b 100644 --- a/src/radical/utils/zmq/pipe.py +++ b/src/radical/utils/zmq/pipe.py @@ -3,6 +3,8 @@ from ..serialize import to_msgpack, from_msgpack +from .utils import zmq_bind + MODE_PUSH = 'push' MODE_PULL = 'pull' @@ -74,16 +76,13 @@ def _connect_push(self, url): if self._sock: raise RuntimeError('already connected at %s' % self._url) - if url: - bind = False - else: - bind = True - url = 'tcp://*:*' - self._sock = self._context.socket(zmq.PUSH) - if bind: self._sock.bind(url) - else : self._sock.connect(url) + if url: + self._sock.connect(url) + self._url = url + else: + self._url = zmq_bind(self._sock) self._url = self._sock.getsockopt(zmq.LAST_ENDPOINT) @@ -106,10 +105,12 @@ def _connect_pull(self, url): self._sock = self._context.socket(zmq.PULL) - if bind: self._sock.bind(url) - else : self._sock.connect(url) + if url: + self._sock.connect(url) + self._url = url + else: + self._url = zmq_bind(self._sock) - self._url = self._sock.getsockopt(zmq.LAST_ENDPOINT) self._poller.register(self._sock, zmq.POLLIN) diff --git a/src/radical/utils/zmq/pubsub.py b/src/radical/utils/zmq/pubsub.py index b625761d9..606a935fa 100644 --- a/src/radical/utils/zmq/pubsub.py +++ b/src/radical/utils/zmq/pubsub.py @@ -18,7 +18,7 @@ from ..serialize import to_msgpack, from_msgpack from .bridge import Bridge -from .utils import no_intr, log_bulk, LOG_ENABLED +from .utils import zmq_bind, no_intr, log_bulk, LOG_ENABLED # ------------------------------------------------------------------------------ @@ -94,30 +94,18 @@ def _bridge_initialize(self): self._log.info('initialize bridge %s', self._uid) - self._lock = mt.Lock() + self._lock = mt.Lock() - self._ctx = zmq.Context.instance() # rely on GC for destruction + self._ctx = zmq.Context.instance() # rely on GC for destruction self._xpub = self._ctx.socket(zmq.XSUB) self._xpub.linger = _LINGER_TIMEOUT self._xpub.hwm = _HIGH_WATER_MARK - self._xpub.bind('tcp://*:%s' % (self._cfg.get('port_pub') or '10000+')) + self._addr_pub = zmq_bind(self._xpub) self._xsub = self._ctx.socket(zmq.XPUB) self._xsub.linger = _LINGER_TIMEOUT self._xsub.hwm = _HIGH_WATER_MARK - self._xsub.bind('tcp://*:%s' % (self._cfg.get('port_sub') or '10000+')) - - # communicate the bridge ports to the parent process - _addr_pub = as_string(self._xpub.getsockopt(zmq.LAST_ENDPOINT)) - _addr_sub = as_string(self._xsub.getsockopt(zmq.LAST_ENDPOINT)) - - # store addresses - self._addr_pub = Url(_addr_pub) - self._addr_sub = Url(_addr_sub) - - # use the local hostip for bridge addresses - self._addr_pub.host = get_hostip() - self._addr_sub.host = get_hostip() + self._addr_sub = zmq_bind(self._xsub) self._log.info('bridge pub on %s: %s', self._uid, self._addr_pub) self._log.info(' sub on %s: %s', self._uid, self._addr_sub) diff --git a/src/radical/utils/zmq/queue.py b/src/radical/utils/zmq/queue.py index 302b28857..c519050c7 100644 --- a/src/radical/utils/zmq/queue.py +++ b/src/radical/utils/zmq/queue.py @@ -12,7 +12,7 @@ from ..config import Config from ..ids import generate_id, ID_CUSTOM from ..url import Url -from ..misc import as_string, as_bytes, as_list, noop +from ..misc import as_string, as_bytes, as_list, noop, find_port from ..host import get_hostip from ..logger import Logger from ..profile import Profiler @@ -20,7 +20,7 @@ from ..serialize import to_msgpack, from_msgpack from .bridge import Bridge -from .utils import no_intr +from .utils import zmq_bind, no_intr from .utils import log_bulk, LOG_ENABLED # from .utils import prof_bulk @@ -165,30 +165,18 @@ def _bridge_initialize(self): self._log.info('start bridge %s', self._uid) - self._lock = mt.Lock() + self._lock = mt.Lock() - self._ctx = zmq.Context() # rely on GC for destruction - self._put = self._ctx.socket(zmq.PULL) - self._put.linger = _LINGER_TIMEOUT - self._put.hwm = _HIGH_WATER_MARK - self._get.bind('tcp://*:%s' % (self._cfg.get('port_put') or find_port()) + self._ctx = zmq.Context() # rely on GC for destruction + self._put = self._ctx.socket(zmq.PULL) + self._put.linger = _LINGER_TIMEOUT + self._put.hwm = _HIGH_WATER_MARK + self._addr_put = zmq_bind(self._put) self._get = self._ctx.socket(zmq.REP) self._get.linger = _LINGER_TIMEOUT self._get.hwm = _HIGH_WATER_MARK - self._get.bind('tcp://*:%s' % (self._cfg.get('port_get') or find_port()) - - # communicate the bridge ports to the parent process - _addr_put = as_string(self._put.getsockopt(zmq.LAST_ENDPOINT)) - _addr_get = as_string(self._get.getsockopt(zmq.LAST_ENDPOINT)) - - # store addresses - self._addr_put = Url(_addr_put) - self._addr_get = Url(_addr_get) - - # use the local hostip for bridge addresses - self._addr_put.host = get_hostip() - self._addr_get.host = get_hostip() + self._addr_get = zmq_bind(self._get) self._log.info('bridge in %s: %s', self._uid, self._addr_put) self._log.info('bridge out %s: %s', self._uid, self._addr_get) @@ -450,7 +438,7 @@ def _listener(url, qname=None, uid=None): qname = 'default' assert url in Getter._callbacks - time.sleep(1) + time.sleep(0.1) try: term = Getter._callbacks.get(url, {}).get('term') diff --git a/src/radical/utils/zmq/utils.py b/src/radical/utils/zmq/utils.py index 19ec67ec8..75752b885 100644 --- a/src/radical/utils/zmq/utils.py +++ b/src/radical/utils/zmq/utils.py @@ -4,8 +4,8 @@ import errno from ..url import Url -from ..misc import as_list -from ..misc import ru_open +from ..host import get_hostip +from ..misc import as_list, as_string, find_port, ru_open # NOTE: this is ignoring `RADICAL_LOG_LVL` on purpose @@ -154,13 +154,28 @@ def sock_connect(sock, url, hop=None): if hop: from zmq import ssh - print('connect to %s via %s' % (url, hop)) ssh.tunnel_connection(sock, url, hop) - print('connected to %s via %s' % (url, hop)) else: sock.connect(url) +# ------------------------------------------------------------------------------ +# +def zmq_bind(sock): + + while True: + port = find_port() + try: + sock.bind('tcp://*:%s' % port) + addr = Url(as_string(sock.getsockopt(zmq.LAST_ENDPOINT))) + addr.host = get_hostip() + return addr + except: + pass + + raise RuntimeError('could not bind to any port') + + # ------------------------------------------------------------------------------ diff --git a/tests/unittests/test_misc.py b/tests/unittests/test_misc.py index 300317319..9e1b0a6ed 100755 --- a/tests/unittests/test_misc.py +++ b/tests/unittests/test_misc.py @@ -8,6 +8,7 @@ import os import copy import pytest +import socket import tempfile import radical.utils as ru @@ -295,6 +296,32 @@ def test_ru_open(): except: pass +# ------------------------------------------------------------------------------ +# +def test_find_port(): + + + s1 = None + s2 = None + try: + p1 = ru.find_port() + assert p1 > 0 + + s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s1.bind(('', p1)) + + p2 = ru.find_port() + assert p2 > p1 + + with pytest.raises(socket.error): + s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s2.bind(('', p2 - 1)) + + finally: + if s1: s1.close() + if s2: s2.close() + + # ------------------------------------------------------------------------------ # run tests if called directly if __name__ == "__main__": @@ -309,6 +336,7 @@ def test_ru_open(): test_script_2_func() test_base() test_ru_open() + test_find_port() # ------------------------------------------------------------------------------ diff --git a/tests/unittests/test_zmq_pipe.py b/tests/unittests/test_zmq_pipe.py index 863651cc6..a4f185204 100755 --- a/tests/unittests/test_zmq_pipe.py +++ b/tests/unittests/test_zmq_pipe.py @@ -21,7 +21,7 @@ def test_zmq_pipe(): pipe_3 = ru.zmq.Pipe(ru.zmq.MODE_PULL, url) # let ZMQ settle - time.sleep(1) + time.sleep(0.1) for i in range(1000): pipe_1.put('foo %d' % i)