Skip to content

Commit

Permalink
Reimplement DeviceCollector as auto_init_devices
Browse files Browse the repository at this point in the history
Functionality is the same, but the public interface is now a function rather than a class.
Also makes sure that Devices that exist on entry, and are redeclared in the context manager are found.
  • Loading branch information
coretl committed Dec 11, 2024
1 parent e27b6d7 commit 33f70d8
Show file tree
Hide file tree
Showing 29 changed files with 175 additions and 158 deletions.
4 changes: 2 additions & 2 deletions docs/examples/epics_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bluesky.utils import ProgressBarManager, register_transform
from ophyd import Component, Device, EpicsSignal, EpicsSignalRO

from ophyd_async.core import DeviceCollector
from ophyd_async.core import auto_init_devices
from ophyd_async.epics import demo

# Create a run engine, with plotting, progressbar and transform
Expand All @@ -31,7 +31,7 @@ class OldSensor(Device):
det_old = OldSensor(pv_prefix, name="det_old")

# Create ophyd-async devices
with DeviceCollector():
with auto_init_devices():
det = demo.Sensor(pv_prefix)
det_group = demo.SensorGroup(pv_prefix)
samp = demo.SampleStage(pv_prefix)
4 changes: 2 additions & 2 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
StandardDetector,
TriggerInfo,
)
from ._device import Device, DeviceCollector, DeviceConnector, DeviceVector
from ._device import Device, DeviceConnector, DeviceVector, auto_init_devices
from ._device_filler import DeviceFiller
from ._flyer import FlyerController, StandardFlyer
from ._hdf_dataset import HDFDataset, HDFFile
Expand Down Expand Up @@ -87,7 +87,7 @@
"TriggerInfo",
"Device",
"DeviceConnector",
"DeviceCollector",
"auto_init_devices",
"DeviceVector",
"DeviceFiller",
"StandardFlyer",
Expand Down
143 changes: 75 additions & 68 deletions src/ophyd_async/core/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import sys
from collections.abc import Coroutine, Iterator, Mapping, MutableMapping
from collections.abc import Awaitable, Callable, Iterator, Mapping, MutableMapping
from functools import cached_property
from logging import LoggerAdapter, getLogger
from typing import Any, TypeVar
Expand Down Expand Up @@ -254,54 +254,18 @@ def __hash__(self): # to allow DeviceVector to be used as dict keys and in sets
return hash(id(self))


class DeviceCollector:
"""Collector of top level Device instances to be used as a context manager
Parameters
----------
set_name:
If True, call ``device.set_name(variable_name)`` on all collected
Devices
child_name_separator:
Use this as a separator if we call ``set_name``.
connect:
If True, call ``device.connect(mock)`` in parallel on all
collected Devices
mock:
If True, connect Signals in simulation mode
timeout:
How long to wait for connect before logging an exception
Notes
-----
Example usage::
[async] with DeviceCollector():
t1x = motor.Motor("BLxxI-MO-TABLE-01:X")
t1y = motor.Motor("pva://BLxxI-MO-TABLE-01:Y")
# Names and connects devices here
assert t1x.comm.velocity.source
assert t1x.name == "t1x"
class DeviceContextManager:
"""Sync/Async Context Manager that finds all the Devices declared within it.
Used in `auto_init`
"""

def __init__(
self,
set_name=True,
child_name_separator: str = "-",
connect=True,
mock=False,
timeout: float = 10.0,
):
self._set_name = set_name
self._child_name_separator = child_name_separator
self._connect = connect
self._mock = mock
self._timeout = timeout
self._names_on_enter: set[str] = set()
self._objects_on_exit: dict[str, Any] = {}

def _caller_locals(self):
def __init__(self, process_devices: Callable[[dict[str, Device]], Awaitable[None]]):
self._process_devices = process_devices
self._locals_on_enter: dict[str, Any] = {}
self._locals_on_exit: dict[str, Any] = {}

def _caller_locals(self) -> dict[str, Any]:
"""Walk up until we find a stack frame that doesn't have us as self"""
try:
raise ValueError
Expand All @@ -314,34 +278,18 @@ def _caller_locals(self):
assert (
caller_frame
), "No previous frame to the one with self in it, this shouldn't happen"
return caller_frame.f_locals
return caller_frame.f_locals.copy()

def __enter__(self) -> DeviceCollector:
def __enter__(self) -> DeviceContextManager:
# Stash the names that were defined before we were called
self._names_on_enter = set(self._caller_locals())
self._locals_on_enter = self._caller_locals()
return self

async def __aenter__(self) -> DeviceCollector:
async def __aenter__(self) -> DeviceContextManager:
return self.__enter__()

async def _on_exit(self) -> None:
# Name and kick off connect for devices
connect_coroutines: dict[str, Coroutine] = {}
for name, obj in self._objects_on_exit.items():
if name not in self._names_on_enter and isinstance(obj, Device):
if self._set_name and not obj.name:
obj.set_name(name, child_name_separator=self._child_name_separator)
if self._connect:
connect_coroutines[name] = obj.connect(
self._mock, timeout=self._timeout
)

# Connect to all the devices
if connect_coroutines:
await wait_for_connection(**connect_coroutines)

async def __aexit__(self, type, value, traceback):
self._objects_on_exit = self._caller_locals()
self._locals_on_exit = self._caller_locals()
await self._on_exit()

def __exit__(self, type_, value, traceback):
Expand All @@ -350,7 +298,7 @@ def __exit__(self, type_, value, traceback):
"Cannot use DeviceConnector inside a plan, instead use "
"`yield from ophyd_async.plan_stubs.ensure_connected(device)`"
)
self._objects_on_exit = self._caller_locals()
self._locals_on_exit = self._caller_locals()
try:
fut = call_in_bluesky_event_loop(self._on_exit())
except RuntimeError as e:
Expand All @@ -360,3 +308,62 @@ def __exit__(self, type_, value, traceback):
"user/explanations/event-loop-choice.html for more info."
) from e
return fut

async def _on_exit(self) -> None:
# Find all the devices
devices = {
name: obj
for name, obj in self._locals_on_exit.items()
if isinstance(obj, Device) and self._locals_on_enter.get(name) is not obj
}
# Call the provided process function on them
await self._process_devices(devices)


def auto_init_devices(
set_name=True,
child_name_separator: str = "-",
connect=True,
mock=False,
timeout: float = 10.0,
) -> DeviceContextManager:
"""Auto initialise top level Device instances to be used as a context manager
Parameters
----------
set_name:
If True, call ``device.set_name(variable_name)`` on all Devices
created within the context manager that have an empty ``name``
child_name_separator:
Use this as a separator if we call ``set_name``.
connect:
If True, call ``device.connect(mock, timeout)`` in parallel on all
Devices created within the context manager
mock:
If True, connect Signals in mock mode
timeout:
How long to wait for connect before logging an exception
Notes
-----
Example usage::
[async] with auto_init_devices():
t1x = motor.Motor("BLxxI-MO-TABLE-01:X")
t1y = motor.Motor("pva://BLxxI-MO-TABLE-01:Y")
# Names and connects devices here
assert t1x.name == "t1x"
"""

async def process_devices(devices: dict[str, Device]):
if set_name:
for name, device in devices.items():
if not device.name:
device.set_name(name, child_name_separator=child_name_separator)
if connect:
coros = {
name: device.connect(mock, timeout) for name, device in devices.items()
}
await wait_for_connection(**coros)

return DeviceContextManager(process_devices)
6 changes: 3 additions & 3 deletions system_tests/epics/eiger/test_eiger_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from ophyd_async.core import (
DetectorTrigger,
Device,
DeviceCollector,
StaticPathProvider,
auto_init_devices,
)
from ophyd_async.epics.core import epics_signal_rw
from ophyd_async.epics.eiger import EigerDetector, EigerTriggerInfo
Expand Down Expand Up @@ -47,7 +47,7 @@ def RE():

@pytest.fixture
async def setup_device(RE, ioc_prefixes):
async with DeviceCollector():
async with auto_init_devices():
device = SetupDevice(ioc_prefixes[0], ioc_prefixes[1] + "FP:")
await asyncio.gather(
device.header_detail.set("all"),
Expand All @@ -62,7 +62,7 @@ async def setup_device(RE, ioc_prefixes):
@pytest.fixture
async def test_eiger(RE, ioc_prefixes) -> EigerDetector:
provider = StaticPathProvider(lambda: "test_eiger", Path(SAVE_PATH))
async with DeviceCollector():
async with auto_init_devices():
test_eiger = EigerDetector("", provider, ioc_prefixes[0], ioc_prefixes[1])

return test_eiger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ophyd_async.core import (
DEFAULT_TIMEOUT,
Device,
DeviceCollector,
NotConnected,
auto_init_devices,
)
from ophyd_async.epics import motor
from ophyd_async.testing import set_mock_value
Expand All @@ -33,10 +33,10 @@ async def connect(
async def set(self, new_position: float): ...


async def test_device_collector_handles_top_level_errors(caplog):
async def test_auto_init_handles_top_level_errors(caplog):
caplog.set_level(10)
with pytest.raises(NotConnected) as exc:
async with DeviceCollector():
async with auto_init_devices():
_ = FailingDevice("somename")

assert not exc.value.__cause__
Expand All @@ -52,9 +52,9 @@ async def test_device_collector_handles_top_level_errors(caplog):
assert device_log[0].levelname == "ERROR"


def test_sync_device_connector_no_run_engine_raises_error():
def test_sync_auto_init_no_run_engine_raises_error():
with pytest.raises(NotConnected) as e:
with DeviceCollector():
with auto_init_devices():
working_device = WorkingDevice("somename")
assert e.value._errors == (
"Could not connect devices. Is the bluesky event loop running? See "
Expand All @@ -64,26 +64,36 @@ def test_sync_device_connector_no_run_engine_raises_error():
assert not working_device.connected


def test_sync_device_connector_run_engine_created_connects(RE):
with DeviceCollector():
def test_sync_auto_init_run_engine_created_connects(RE):
with auto_init_devices():
working_device = WorkingDevice("somename")

assert working_device.connected


async def test_auto_init_detects_redeclared_devices():
original_working_device = working_device = WorkingDevice()

async with auto_init_devices():
working_device = WorkingDevice()
assert original_working_device is not working_device
assert working_device.connected and working_device.name == "working_device"
assert not original_working_device.connected and original_working_device.name == ""


def test_connecting_in_plan_raises(RE):
def bad_plan():
yield from bps.null()
with DeviceCollector():
with auto_init_devices():
working_device = WorkingDevice("somename") # noqa: F841

with pytest.raises(RuntimeError, match="Cannot use DeviceConnector inside a plan"):
RE(bad_plan())


def test_async_device_connector_run_engine_same_event_loop():
def test_async_auto_init_run_engine_same_event_loop():
async def set_up_device():
async with DeviceCollector(mock=True):
async with auto_init_devices(mock=True):
mock_motor = motor.Motor("BLxxI-MO-TABLE-01:X")
set_mock_value(mock_motor.velocity, 1)
return mock_motor
Expand Down Expand Up @@ -126,17 +136,17 @@ def my_plan():
"loop to set the value, unlike real signals."
)
)
def test_async_device_connector_run_engine_different_event_loop():
def test_async_auto_init_run_engine_different_event_loop():
async def set_up_device():
async with DeviceCollector(mock=True):
async with auto_init_devices(mock=True):
mock_motor = motor.Motor("BLxxI-MO-TABLE-01:X")
return mock_motor

device_connector_loop = asyncio.new_event_loop()
auto_init_loop = asyncio.new_event_loop()
run_engine_loop = asyncio.new_event_loop()
assert run_engine_loop is not device_connector_loop
assert run_engine_loop is not auto_init_loop

mock_motor = device_connector_loop.run_until_complete(set_up_device())
mock_motor = auto_init_loop.run_until_complete(set_up_device())

RE = RunEngine(loop=run_engine_loop)

Expand All @@ -147,7 +157,7 @@ def my_plan():

# The set should fail since the run engine is on a different event loop
assert (
device_connector_loop.run_until_complete(mock_motor.user_setpoint.read())[
auto_init_loop.run_until_complete(mock_motor.user_setpoint.read())[
"mock_motor-user_setpoint"
]["value"]
!= 3.14
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from ophyd_async.core import (
DEFAULT_TIMEOUT,
Device,
DeviceCollector,
DeviceVector,
NotConnected,
Reference,
SignalRW,
auto_init_devices,
soft_signal_rw,
wait_for_connection,
)
Expand Down Expand Up @@ -126,8 +126,8 @@ async def test_children_of_device_with_different_separator(
assert parent.dict_with_children[123].name == "parent_dict_with_children_123"


async def test_device_with_device_collector():
async with DeviceCollector(mock=True):
async def test_device_with_auto_init():
async with auto_init_devices(mock=True):
parent = DummyDeviceGroup("parent")

assert parent.name == "parent"
Expand Down
Loading

0 comments on commit 33f70d8

Please sign in to comment.