diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 3801245f5c..8cf282f463 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -467,6 +467,14 @@ async def observe_value( yield value +def _get_iteration_timeout( + timeout: float | None, overall_deadline: float | None +) -> float | None: + # Test this works if overall timeout <=0 + overall_deadline = overall_deadline - time.monotonic() if overall_deadline else None + return min([x for x in [overall_deadline, timeout] if x is not None], default=None) + + async def observe_signals_value( *signals: SignalR[SignalDatatypeT], timeout: float | None = None, @@ -504,12 +512,11 @@ async def observe_signals_value( q: asyncio.Queue[tuple[SignalR[SignalDatatypeT], SignalDatatypeT] | Status] = ( asyncio.Queue() ) - if timeout is None: - get_value = q.get - else: - async def get_value(): - return await asyncio.wait_for(q.get(), timeout) + async def get_value(timeout: float | None = None): + if timeout is None: + return await q.get() + return await asyncio.wait_for(q.get(), timeout) cbs: dict[SignalR, Callback] = {} for signal in signals: @@ -522,17 +529,19 @@ def queue_value(value: SignalDatatypeT, signal=signal): if done_status is not None: done_status.add_callback(q.put_nowait) - start_time = time.time() + overall_deadline = ( + time.monotonic() + except_after_time if except_after_time else None + ) try: while True: - if except_after_time and time.time() - start_time > except_after_time: + if overall_deadline and time.monotonic() >= overall_deadline: raise asyncio.TimeoutError( f"observe_value was still observing signals " f"{[signal.source for signal in signals]} after " f"timeout {except_after_time}s" ) - item = await asyncio.wait_for(q.get(), timeout) - item = await get_value() + + item = await get_value(_get_iteration_timeout(timeout, overall_deadline)) if done_status and item is done_status: if exc := done_status.exception(): raise exc diff --git a/tests/core/test_observe.py b/tests/core/test_observe.py index 14b9443ac2..8cd576034c 100644 --- a/tests/core/test_observe.py +++ b/tests/core/test_observe.py @@ -105,13 +105,13 @@ async def test_observe_value_times_out_with_no_external_task(): recv = [] - async def watch(): - async for val in observe_value(sig): + async def watch(except_after_time): + async for val in observe_value(sig, except_after_time=except_after_time): recv.append(val) setter(val + 1) start = time.time() with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(watch(), timeout=0.1) + await watch(except_after_time=0.1) assert recv assert time.time() - start == pytest.approx(0.1, abs=0.05) diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index df65384772..ac3ac4721e 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -940,21 +940,20 @@ def test_signal_module_emits_deprecation_warning(): @PARAMETERISE_PROTOCOLS async def test_observe_ticking_signal_with_busy_loop(ioc, protocol): sig = epics_signal_rw(int, f"{protocol}://{get_prefix(ioc, protocol)}ticking") - sig.set_name("hello") await sig.connect() recv = [] async def watch(): - async for val in observe_value(sig, except_after_time=0.35): - time.sleep(0.15) + async for val in observe_value(sig, except_after_time=0.4): + time.sleep(0.3) recv.append(val) start = time.time() with pytest.raises(asyncio.TimeoutError): await watch() - assert time.time() - start == pytest.approx(0.35, abs=0.15) + assert time.time() - start == pytest.approx(0.6, abs=0.1) assert len(recv) == 2 # Don't check values as CA and PVA have different algorithms for # dropping updates for slow callbacks