Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add callbacks to pulling end of pipe #428

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 61 additions & 12 deletions src/radical/utils/zmq/pipe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

import zmq
import threading as mt

from ..serialize import to_msgpack, from_msgpack
from ..logger import Logger

from .utils import zmq_bind

Expand Down Expand Up @@ -32,7 +34,7 @@ class Pipe(object):

# --------------------------------------------------------------------------
#
def __init__(self, mode, url=None) -> None:
def __init__(self, mode, url=None, log=None) -> None:
'''
Create a `Pipe` instance which can be used for either sending (`put()`)
or receiving (`get()` / `get_nowait()`) data. according to the specified
Expand All @@ -43,11 +45,14 @@ def __init__(self, mode, url=None) -> None:
URL provided by the listening end (`Pipe.url`).
'''

self._context = zmq.Context.instance()
self._mode = mode
self._url = None
self._sock = None
self._poller = zmq.Poller()
self._context = zmq.Context.instance()
self._mode = mode
self._url = None
self._log = log
self._sock = None
self._poller = zmq.Poller()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the poller: For the PUSH mode, we don't need a Poller since we are only sending messages (and the socket doesn't need to wait for incoming messages), correct?

So maybe something like this:

    if mode == MODE_PUSH:
        self._connect_push(url)
        self._poller = None 

    elif mode == MODE_PULL:
        self._connect_pull(url)
        self._poller = zmq.Poller() 

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andre-merzky, did you see this comment? I am unsure if you agree on this (it is a minor change). Besides this, the PR looks good to me.

self._cbs = list()
self._listener = None

if mode == MODE_PUSH:
self._connect_push(url)
Expand All @@ -58,6 +63,10 @@ def __init__(self, mode, url=None) -> None:
else:
raise ValueError('unsupported pipe mode [%s]' % mode)

if not self._log:

self._log = Logger('radical.utils.pipe')


# --------------------------------------------------------------------------
#
Expand Down Expand Up @@ -97,12 +106,6 @@ def _connect_pull(self, url):
if self._sock:
raise RuntimeError('already connected at %s' % self._url)

if url:
bind = False
else:
bind = True
url = 'tcp://*:*'

self._sock = self._context.socket(zmq.PULL)

if url:
Expand All @@ -114,6 +117,49 @@ def _connect_pull(self, url):
self._poller.register(self._sock, zmq.POLLIN)


# --------------------------------------------------------------------------
#
def register_cb(self, cb):
'''
Register a callback for incoming messages. The callback will be called
with the message as argument.

Only a pipe in pull mode can have callbacks registered. Note that once
a callback is registered, the `get()` and `get_nowait()` methods must
not be used anymore.
'''

assert self._mode == MODE_PULL

self._cbs.append(cb)

if not self._listener:
self._listener = mt.Thread(target=self._listen)
self._listener.daemon = True
self._listener.start()


# --------------------------------------------------------------------------
#
def _listen(self):
'''
Listen for incoming messages, and call registered callbacks.
'''

while True:

socks = dict(self._poller.poll(timeout=10))

if self._sock in socks:
msg = from_msgpack(self._sock.recv())

for cb in self._cbs:
try:
cb(msg)
except:
self._log.exception('callback failed')


# --------------------------------------------------------------------------
#
def put(self, msg):
Expand All @@ -134,6 +180,8 @@ def get(self):
'''

assert self._mode == MODE_PULL
assert not self._cbs

return from_msgpack(self._sock.recv())


Expand All @@ -147,6 +195,7 @@ def get_nowait(self, timeout: float = 0):
'''

assert self._mode == MODE_PULL
assert not self._cbs

# zmq timeouts are in milliseconds
socks = dict(self._poller.poll(timeout=int(timeout * 1000)))
Expand Down
43 changes: 32 additions & 11 deletions tests/unittests/test_zmq_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@
def test_zmq_pipe():

pipe_1 = ru.zmq.Pipe(ru.zmq.MODE_PUSH)
pipe_2 = ru.zmq.Pipe(ru.zmq.MODE_PULL, pipe_1.url)
pipe_3 = ru.zmq.Pipe(ru.zmq.MODE_PULL, pipe_1.url)

url = pipe_1.url

pipe_2 = ru.zmq.Pipe(ru.zmq.MODE_PULL, url)
pipe_3 = ru.zmq.Pipe(ru.zmq.MODE_PULL, url)

# let ZMQ settle
time.sleep(0.1)
time.sleep(0.01)

for i in range(1000):
pipe_1.put('foo %d' % i)
Expand All @@ -34,24 +30,49 @@ def test_zmq_pipe():
result_3.append(pipe_3.get())

for i in range(100):
result_2.append(pipe_2.get_nowait(timeout=1.0))
result_3.append(pipe_3.get_nowait(timeout=1.0))
result_2.append(pipe_2.get_nowait(timeout=0.01))
result_3.append(pipe_3.get_nowait(timeout=0.01))

assert len(result_2) == 500
assert len(result_3) == 500

test_2 = result_2.append(pipe_2.get_nowait(timeout=1.0))
test_3 = result_3.append(pipe_3.get_nowait(timeout=1.0))
test_2 = result_2.append(pipe_2.get_nowait(timeout=0.01))
test_3 = result_3.append(pipe_3.get_nowait(timeout=0.01))

assert test_2 is None
assert test_3 is None


# ------------------------------------------------------------------------------
#
def test_zmq_pipe_cb():

pipe_1 = ru.zmq.Pipe(ru.zmq.MODE_PUSH)
pipe_2 = ru.zmq.Pipe(ru.zmq.MODE_PULL, pipe_1.url)
results = list()

time.sleep(0.01)

def cb(msg):
results.append(msg)

pipe_2.register_cb(cb)

n = 1000
for i in range(n):
pipe_1.put('foo %d' % i)

time.sleep(0.01)

assert len(results) == n, results


# ------------------------------------------------------------------------------
# run tests if called directly
if __name__ == '__main__':

test_zmq_pipe()
test_zmq_pipe_cb()


# ------------------------------------------------------------------------------
Expand Down