Skip to content

Commit

Permalink
Use overall timeout for get value too
Browse files Browse the repository at this point in the history
  • Loading branch information
olliesilvester committed Nov 29, 2024
1 parent 2d96cca commit 76773b4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
27 changes: 18 additions & 9 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,14 @@ async def observe_value(
yield value


def _get_iteration_timeout(
timeout: float | None, overall_deadline: float | None
) -> float | None:
# Test this works if overall timeout <=0
overall_deadline = overall_deadline - time.monotonic() if overall_deadline else None
return min([x for x in [overall_deadline, timeout] if x is not None], default=None)


async def observe_signals_value(
*signals: SignalR[SignalDatatypeT],
timeout: float | None = None,
Expand Down Expand Up @@ -504,12 +512,11 @@ async def observe_signals_value(
q: asyncio.Queue[tuple[SignalR[SignalDatatypeT], SignalDatatypeT] | Status] = (
asyncio.Queue()
)
if timeout is None:
get_value = q.get
else:

async def get_value():
return await asyncio.wait_for(q.get(), timeout)
async def get_value(timeout: float | None = None):
if timeout is None:
return await q.get()
return await asyncio.wait_for(q.get(), timeout)

cbs: dict[SignalR, Callback] = {}
for signal in signals:
Expand All @@ -522,17 +529,19 @@ def queue_value(value: SignalDatatypeT, signal=signal):

if done_status is not None:
done_status.add_callback(q.put_nowait)
start_time = time.time()
overall_deadline = (
time.monotonic() + except_after_time if except_after_time else None
)
try:
while True:
if except_after_time and time.time() - start_time > except_after_time:
if overall_deadline and time.monotonic() >= overall_deadline:
raise asyncio.TimeoutError(
f"observe_value was still observing signals "
f"{[signal.source for signal in signals]} after "
f"timeout {except_after_time}s"
)
item = await asyncio.wait_for(q.get(), timeout)
item = await get_value()

item = await get_value(_get_iteration_timeout(timeout, overall_deadline))
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ async def test_observe_value_times_out_with_no_external_task():

recv = []

async def watch():
async for val in observe_value(sig):
async def watch(except_after_time):
async for val in observe_value(sig, except_after_time=except_after_time):
recv.append(val)
setter(val + 1)

start = time.time()
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(watch(), timeout=0.1)
await watch(except_after_time=0.1)
assert recv
assert time.time() - start == pytest.approx(0.1, abs=0.05)
7 changes: 3 additions & 4 deletions tests/epics/signal/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,21 +940,20 @@ def test_signal_module_emits_deprecation_warning():
@PARAMETERISE_PROTOCOLS
async def test_observe_ticking_signal_with_busy_loop(ioc, protocol):
sig = epics_signal_rw(int, f"{protocol}://{get_prefix(ioc, protocol)}ticking")
sig.set_name("hello")
await sig.connect()

recv = []

async def watch():
async for val in observe_value(sig, except_after_time=0.35):
time.sleep(0.15)
async for val in observe_value(sig, except_after_time=0.4):
time.sleep(0.3)
recv.append(val)

start = time.time()

with pytest.raises(asyncio.TimeoutError):
await watch()
assert time.time() - start == pytest.approx(0.35, abs=0.15)
assert time.time() - start == pytest.approx(0.6, abs=0.1)
assert len(recv) == 2
# Don't check values as CA and PVA have different algorithms for
# dropping updates for slow callbacks

0 comments on commit 76773b4

Please sign in to comment.