From e5bc45bf2ed0a1d4e907e19ef7e86b6698549f34 Mon Sep 17 00:00:00 2001 From: James Souter Date: Wed, 23 Oct 2024 08:34:19 +0100 Subject: [PATCH 1/2] add run arg to signal subscribe methods to drop callbacks in Signal._notify --- src/ophyd_async/core/_signal.py | 18 +++++++++++++----- tests/core/test_signal.py | 27 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index d4e4d7bbb9..7d408a834b 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -123,6 +123,7 @@ def __init__(self, backend: SignalBackend[T], signal: Signal): self._valid = asyncio.Event() self._reading: Reading | None = None self._value: T | None = None + self._drop_next_notification: set[Callback] = set() self.backend = backend signal.log.debug(f"Making subscription on source {signal.source}") @@ -154,13 +155,18 @@ def _callback(self, reading: Reading, value: T): self._notify(function, want_value) def _notify(self, function: Callback, want_value: bool): + if function in self._drop_next_notification: + self._drop_next_notification.discard(function) + return if want_value: function(self._value) else: function({self._signal.name: self._reading}) - def subscribe(self, function: Callback, want_value: bool) -> None: + def subscribe(self, function: Callback, run: bool, want_value: bool) -> None: self._listeners[function] = want_value + if not run: + self._drop_next_notification.add(function) if self._valid.is_set(): self._notify(function, want_value) @@ -215,13 +221,15 @@ async def get_value(self, cached: bool | None = None) -> T: self.log.debug(f"get_value() on source {self.source} returned {value}") return value - def subscribe_value(self, function: Callback[T]): + def subscribe_value(self, function: Callback[T], run: bool = True): """Subscribe to updates in value of a device""" - self._get_cache().subscribe(function, want_value=True) + self._get_cache().subscribe(function, run=run, want_value=True) - def subscribe(self, function: Callback[dict[str, Reading]]) -> None: + def subscribe( + self, function: Callback[dict[str, Reading]], run: bool = True + ) -> None: """Subscribe to updates in the reading""" - self._get_cache().subscribe(function, want_value=False) + self._get_cache().subscribe(function, run=run, want_value=False) def clear_sub(self, function: Callback) -> None: """Remove a subscription.""" diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 133cfd5ee4..242d630b72 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -407,6 +407,33 @@ async def test_subscription_logs(caplog): assert "Closing subscription on source" in caplog.text +async def test_signal_subscription_run_false(): + mock_signal_no_notify = epics_signal_rw( + int, "mock_signal_no_notify", name="mock_signal_no_notify" + ) + mock_signal_notify = epics_signal_rw( + int, "mock_signal_notify", name="mock_signal_notify" + ) + await mock_signal_no_notify.connect(mock=True) + await mock_signal_notify.connect(mock=True) + + callbacks_called = set() + + def __my_callback(reading): + for name in reading.keys(): + callbacks_called.add(name) + + mock_signal_no_notify.subscribe(__my_callback, run=False) + mock_signal_notify.subscribe(__my_callback, run=True) + mock_signal_no_notify.set(1) + mock_signal_notify.set(1) + assert "mock_signal_no_notify" not in callbacks_called + assert "mock_signal_notify" in callbacks_called + # run=False callback gets delayed until next update of the value by the backend + set_mock_value(mock_signal_no_notify, 2) + assert "mock_signal_no_notify" in callbacks_called + + async def test_signal_unknown_datatype(): class SomeClass: def __init__(self): From 6b3b1e40e9e4b2a4a6c9a9d2416db0d19f2675ea Mon Sep 17 00:00:00 2001 From: James Souter Date: Wed, 23 Oct 2024 07:40:48 +0000 Subject: [PATCH 2/2] Add test for bluesky suspenders --- tests/core/test_signal.py | 66 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 242d630b72..1502d58a4d 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -6,7 +6,18 @@ import numpy import pytest +from bluesky import Msg from bluesky.protocols import Reading +from bluesky.run_engine import call_in_bluesky_event_loop +from bluesky.suspenders import ( + SuspendBoolHigh, + SuspendBoolLow, + SuspendCeil, + SuspendFloor, + SuspendInBand, + SuspendOutBand, + SuspendWhenOutsideBand, +) from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -457,3 +468,58 @@ def some_function(self): assert isinstance((await signal.get_value()), SomeClass) await signal.set(1) assert (await signal.get_value()) == 1 + + +@pytest.mark.parametrize( + "klass,sc_args,start_val,fail_val,resume_val,wait_time", + [ + (SuspendBoolHigh, (), 0, 1, 0, 0.2), + (SuspendBoolLow, (), 1, 0, 1, 0.2), + (SuspendFloor, (0.5,), 1, 0, 1, 0.2), + (SuspendCeil, (0.5,), 0, 1, 0, 0.2), + (SuspendWhenOutsideBand, (0.5, 1.5), 1, 0, 1, 0.2), + ((SuspendInBand, True), (0.5, 1.5), 1, 0, 1, 0.2), # renamed to WhenOutsideBand + ((SuspendOutBand, True), (0.5, 1.5), 0, 1, 0, 0.2), + ], +) # deprecated +async def test_bluesky_suspenders( + klass, sc_args, start_val, fail_val, resume_val, wait_time, RE +): + sleep_time = 0.2 + fail_time = 0.1 + resume_time = 0.5 + signal = epics_signal_rw(int, "mock_signal") + await signal.connect(mock=True) + try: + klass, deprecated = klass + except TypeError: + deprecated = False + if deprecated: + with pytest.warns(UserWarning): + suspender = klass(signal, *sc_args, sleep=wait_time) + else: + suspender = klass(signal, *sc_args, sleep=wait_time) + + RE.install_suspender(suspender) + + await signal.set(start_val) + + async def _set_after_time(): + await asyncio.sleep(fail_time) + await signal.set(fail_val) + await asyncio.sleep(resume_time - fail_time) + await signal.set(resume_val) + + start = time.time() + + # loop = RE.loop + + call_in_bluesky_event_loop(_set_after_time()) + # task = RE.loop.create_task(_set_after_time()) + + RE([Msg("checkpoint"), Msg("sleep", None, sleep_time)]) + + stop = time.time() + delta = stop - start + # await task + assert delta >= resume_time + sleep_time + wait_time