Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add callbacks to pulling end of pipe #428

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/radical/utils/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def attach_pudb(log=None):

if log:
log.info('debugger open: telnet %s %d', host, port)
else:
print('debugger open: telnet %s %d' % (host, port))

print('debugger open: telnet %s %d' % (host, port))

try:
import pudb # pylint: disable=E0401
Expand All @@ -352,7 +352,9 @@ def attach_pudb(log=None):

except Exception as e:
if log:
log.warning('failed to attach pudb (%s)', e)
log.exception('failed to attach pudb')
else:
print('failed to attach pudb (%s)' % repr(e))


# ------------------------------------------------------------------------------
Expand Down
124 changes: 105 additions & 19 deletions src/radical/utils/zmq/pipe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

import zmq
import threading as mt

from ..serialize import to_msgpack, from_msgpack
from ..logger import Logger

from .utils import zmq_bind

Expand Down Expand Up @@ -32,7 +34,7 @@ class Pipe(object):

# --------------------------------------------------------------------------
#
def __init__(self, mode, url=None) -> None:
def __init__(self, mode, url=None, log=None) -> None:
'''
Create a `Pipe` instance which can be used for either sending (`put()`)
or receiving (`get()` / `get_nowait()`) data. according to the specified
Expand All @@ -46,8 +48,12 @@ def __init__(self, mode, url=None) -> None:
self._context = zmq.Context.instance()
self._mode = mode
self._url = None
self._log = log
self._sock = None
self._poller = zmq.Poller()
self._cbs = list()
self._thread = None
self._term = mt.Event()

if mode == MODE_PUSH:
self._connect_push(url)
Expand All @@ -58,6 +64,10 @@ def __init__(self, mode, url=None) -> None:
else:
raise ValueError('unsupported pipe mode [%s]' % mode)

if not self._log:

self._log = Logger('radical.utils.pipe')


# --------------------------------------------------------------------------
#
Expand Down Expand Up @@ -97,12 +107,6 @@ def _connect_pull(self, url):
if self._sock:
raise RuntimeError('already connected at %s' % self._url)

if url:
bind = False
else:
bind = True
url = 'tcp://*:*'

self._sock = self._context.socket(zmq.PULL)

if url:
Expand All @@ -114,18 +118,6 @@ def _connect_pull(self, url):
self._poller.register(self._sock, zmq.POLLIN)


# --------------------------------------------------------------------------
#
def put(self, msg):
'''
Send a message - if receiving endpoints are connected, exactly one of
them will be able to receive that message.
'''

assert self._mode == MODE_PUSH
self._sock.send(to_msgpack(msg))


# --------------------------------------------------------------------------
#
def get(self):
Expand All @@ -134,6 +126,8 @@ def get(self):
'''

assert self._mode == MODE_PULL
assert not self._cbs

return from_msgpack(self._sock.recv())


Expand All @@ -147,6 +141,7 @@ def get_nowait(self, timeout: float = 0):
'''

assert self._mode == MODE_PULL
assert not self._cbs

# zmq timeouts are in milliseconds
socks = dict(self._poller.poll(timeout=int(timeout * 1000)))
Expand All @@ -155,5 +150,96 @@ def get_nowait(self, timeout: float = 0):
return from_msgpack(self._sock.recv())


# --------------------------------------------------------------------------
#
def _listener(self):
'''
Listen for incoming messages, and call registered callbacks.
'''

while not self._term.is_set():

socks = dict(self._poller.poll(timeout=10))

if self._sock in socks:
msg = from_msgpack(self._sock.recv())

for cb in self._cbs:
try:
cb(msg)
except:
self._log.exception('callback failed')


# --------------------------------------------------------------------------
#
def register_cb(self, cb):
'''
Register a callback for incoming messages. The callback will be called
with the message as argument.

Only a pipe in pull mode can have callbacks registered. Note that once
a callback is registered, the `get()` and `get_nowait()` methods must
not be used anymore.
'''

assert self._mode == MODE_PULL

self._cbs.append(cb)

if not self._thread:
self._thread = mt.Thread(target=self._listener)
self._thread.daemon = True
self._thread.start()


# --------------------------------------------------------------------------
#
def unregister_cb(self, cb):
'''
Unregister a callback. If no callback remains registered, the listener
thread will be stopped.
'''

assert self._mode == MODE_PULL
assert cb in self._cbs

self._cbs.remove(cb)

if not self._cbs:
self._stop_listener()


# --------------------------------------------------------------------------
#
def put(self, msg):
'''
Send a message - if receiving endpoints are connected, exactly one of
them will be able to receive that message.
'''

assert self._mode == MODE_PUSH
self._sock.send(to_msgpack(msg))


# --------------------------------------------------------------------------
#
def _stop_listener(self):

if self._thread:
self._term.set()
self._thread.join()
self._term.clear()
self._thread = None


# --------------------------------------------------------------------------
#
def stop(self):

self._stop_listener()



# ------------------------------------------------------------------------------

43 changes: 21 additions & 22 deletions src/radical/utils/zmq/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from ..atfork import atfork
from ..config import Config
from ..ids import generate_id, ID_CUSTOM
from ..url import Url
from ..misc import as_string, as_bytes, as_list, noop
from ..host import get_hostip
from ..logger import Logger
from ..profile import Profiler
from ..serialize import to_msgpack, from_msgpack
Expand Down Expand Up @@ -230,8 +228,8 @@ def put(self, topic, msg):

assert isinstance(topic, str), 'invalid topic type'

self._log.debug_9('=== put %s : %s: %s', topic, self.channel, msg)
# self._log.debug_9('=== put %s: %s', msg, get_stacktrace())
self._log.debug_9('put %s : %s: %s', topic, self.channel, msg)
# self._log.debug_9('put %s: %s', msg, get_stacktrace())
# self._prof.prof('put', uid=self._uid, msg=msg)
log_bulk(self._log, '-> %s' % topic, [msg])

Expand Down Expand Up @@ -397,22 +395,22 @@ def channel(self):
def _start_listener(self):

# only start if needed
if self._thread:
return
with self._lock:

lock = self._lock
term = self._term
callbacks = self._callbacks
if self._thread:
return

self._log.info('start listener for %s', self._channel)
lock = self._lock
term = self._term
callbacks = self._callbacks

t = mt.Thread(target=Subscriber._listener,
args=[self._sock, lock, term, callbacks,
self._log, self._prof])
t.daemon = True
t.start()
self._log.info('start listener for %s', self._channel)

self._thread = t
self._thread = mt.Thread(target=Subscriber._listener,
args=[self._sock, lock, term, callbacks,
self._log, self._prof])
self._thread.daemon = True
self._thread.start()


# --------------------------------------------------------------------------
Expand All @@ -421,11 +419,12 @@ def _stop_listener(self, force=False):

# only stop listener if no callbacks remain registered (unless forced)
if force or not self._callbacks:
if self._thread:
self._term.set()
self._thread.join()
self._term.clear()
self._thread = None
with self._lock:
if self._thread:
self._term.set()
self._thread.join()
self._term.clear()
self._thread = None


# --------------------------------------------------------------------------
Expand All @@ -451,7 +450,7 @@ def subscribe(self, topic, cb=None, lock=None):
log_bulk(self._log, '~~2 %s' % topic, [topic])

with self._lock:
self._log.debug_9('==== subscribe for %s', topic)
self._log.debug_9('subscribe for %s', topic)
no_intr(self._sock.setsockopt, zmq.SUBSCRIBE, as_bytes(topic))

if topic not in self._topics:
Expand Down
10 changes: 4 additions & 6 deletions src/radical/utils/zmq/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from ..atfork import atfork
from ..config import Config
from ..ids import generate_id, ID_CUSTOM
from ..url import Url
from ..misc import as_string, as_bytes, as_list, noop, find_port
from ..host import get_hostip
from ..misc import as_string, as_bytes, as_list, noop
from ..logger import Logger
from ..profile import Profiler
from ..debug import print_exception_trace
Expand Down Expand Up @@ -405,7 +403,7 @@ def _get_nowait(url, qname=None, timeout=None, uid=None): # timeout in ms

# send the request *once* per recieval (got lock above)
# FIXME: why is this sent repeatedly?
logger.debug_9('=== => from %s[%s]', uid, qname)
logger.debug_9('=> from %s[%s]', uid, qname)
no_intr(info['socket'].send, as_bytes(qname))
info['requested'] = True

Expand Down Expand Up @@ -674,7 +672,7 @@ def get(self, qname=None):
if not self._requested:
with self._lock:
if not self._requested:
self._log.debug_9('=== => from %s[%s]', self._channel, qname)
self._log.debug_9('=> from %s[%s]', self._channel, qname)
no_intr(self._q.send, as_bytes(qname))
self._requested = True

Expand Down Expand Up @@ -710,7 +708,7 @@ def get_nowait(self, qname=None, timeout=None): # timeout in ms
if not self._requested:
with self._lock: # need to protect self._requested
if not self._requested:
self._log.debug_9('=== => from %s[%s]', self._channel, qname)
self._log.debug_9('=> from %s[%s]', self._channel, qname)
no_intr(self._q.send_multipart, [as_bytes(qname)])
self._requested = True

Expand Down
2 changes: 1 addition & 1 deletion src/radical/utils/zmq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_channel_url(ep_type, channel=None, url=None):
#
def log_bulk(log, token, msgs):

if log._num_level > 1:
if log.num_level > 1:
# log level `debug_9` disabled
return

Expand Down
Loading
Loading