From 820735be4ea09c7a4404dfb045623bf483c3f496 Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Thu, 28 Nov 2024 12:51:13 +0000 Subject: [PATCH] Switch from intelligent save to intelligent load Use the concept of Settings that can be: - Created from a Device - Stored to YAML - Retrieved from YAML - Added to other Settings - Applied to a Device --- .gitignore | 4 +- pyproject.toml | 50 +-- src/ophyd_async/core/__init__.py | 27 +- src/ophyd_async/core/_device_save_loader.py | 274 --------------- src/ophyd_async/core/_settings.py | 70 ++++ src/ophyd_async/core/_signal.py | 34 +- src/ophyd_async/core/_table.py | 3 - src/ophyd_async/core/_yaml_settings.py | 64 ++++ src/ophyd_async/fastcs/panda/__init__.py | 2 - src/ophyd_async/fastcs/panda/_utils.py | 16 - src/ophyd_async/plan_stubs/__init__.py | 14 + .../plan_stubs/_ensure_connected.py | 28 +- src/ophyd_async/plan_stubs/_panda.py | 12 + src/ophyd_async/plan_stubs/_settings.py | 103 ++++++ src/ophyd_async/plan_stubs/_wait_for_one.py | 12 + src/ophyd_async/sim/testing/__init__.py | 13 + .../sim/testing/_one_of_everything.py | 115 +++++++ tests/core/test_device_save_loader.py | 322 ------------------ tests/core/test_table.py | 1 + tests/epics/signal/test_signals.py | 17 +- tests/fastcs/panda/test_panda_utils.py | 28 +- tests/plan_stubs/test_settings.py | 96 ++++++ tests/test_data/test_yaml_save.yaml | 119 +++++++ tests/test_data/test_yaml_save.yml | 42 --- 24 files changed, 711 insertions(+), 755 deletions(-) delete mode 100644 src/ophyd_async/core/_device_save_loader.py create mode 100644 src/ophyd_async/core/_settings.py create mode 100644 src/ophyd_async/core/_yaml_settings.py delete mode 100644 src/ophyd_async/fastcs/panda/_utils.py create mode 100644 src/ophyd_async/plan_stubs/_panda.py create mode 100644 src/ophyd_async/plan_stubs/_settings.py create mode 100644 src/ophyd_async/plan_stubs/_wait_for_one.py create mode 100644 src/ophyd_async/sim/testing/_one_of_everything.py delete mode 100644 tests/core/test_device_save_loader.py create mode 100644 tests/plan_stubs/test_settings.py create mode 100644 tests/test_data/test_yaml_save.yaml delete mode 100644 tests/test_data/test_yaml_save.yml diff --git a/.gitignore b/.gitignore index ef6d127ea9..8d2b6b388a 100644 --- a/.gitignore +++ b/.gitignore @@ -87,6 +87,8 @@ docs/savefig # generated version number ophyd_async/_version.py - # ruff cache .ruff_cache/ + +# import linter cache +.import_linter_cache/ diff --git a/pyproject.toml b/pyproject.toml index 9246ec9ba7..f07aef8504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,6 +159,9 @@ lint.select = [ "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self "PLC2701", # private import - https://docs.astral.sh/ruff/rules/import-private-name/ ] +lint.ignore = [ + "B901", # Return in a generator is needed for plans +] lint.preview = true # so that preview mode PLC2701 is enabled [tool.ruff.lint.per-file-ignores] @@ -172,47 +175,26 @@ lint.preview = true # so that preview mode PLC2701 is enabled root_package = "ophyd_async" [[tool.importlinter.contracts]] -name = "Core is independent" -type = "independence" -modules = "ophyd_async.core" +name = "All runtime modules are in layers" +type = "layers" +containers = ["ophyd_async"] +layers = ["plan_stubs", "fastcs", "epics | tango | sim", "core"] +exhaustive = true +exhaustive_ignores = ["testing", "_version", "__main__"] [[tool.importlinter.contracts]] -name = "Epics depends only on core" -type = "forbidden" -source_modules = "ophyd_async.epics" -forbidden_modules = [ - "ophyd_async.fastcs", - "ophyd_async.plan_stubs", - "ophyd_async.sim", - "ophyd_async.tango", -] +name = "Testing modules" +type = "layers" +containers = ["ophyd_async"] +layers = ["testing", "core"] [[tool.importlinter.contracts]] -name = "tango depends only on core" +name = "Testing modules are not used at runtime" type = "forbidden" -source_modules = "ophyd_async.tango" +source_modules = "ophyd_async.testing" forbidden_modules = [ - "ophyd_async.epics", - "ophyd_async.fastcs", "ophyd_async.plan_stubs", - "ophyd_async.sim", -] - - -[[tool.importlinter.contracts]] -name = "sim depends only on core" -type = "forbidden" -source_modules = "ophyd_async.sim" -forbidden_modules = [ - "ophyd_async.epics", "ophyd_async.fastcs", - "ophyd_async.plan_stubs", + "ophyd_async.epics", "ophyd_async.tango", ] - - -[[tool.importlinter.contracts]] -name = "Fastcs depends only on core, epics, tango" -type = "forbidden" -source_modules = "ophyd_async.fastcs" -forbidden_modules = ["ophyd_async.plan_stubs", "ophyd_async.sim"] diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 8b3be801db..3154572417 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -7,16 +7,6 @@ ) from ._device import Device, DeviceCollector, DeviceConnector, DeviceVector from ._device_filler import DeviceFiller -from ._device_save_loader import ( - all_at_once, - get_signal_values, - load_device, - load_from_yaml, - save_device, - save_to_yaml, - set_signal_values, - walk_rw_signals, -) from ._flyer import FlyerController, StandardFlyer from ._hdf_dataset import HDFDataset, HDFFile from ._log import config_ophyd_async_logging @@ -41,6 +31,7 @@ StandardReadable, StandardReadableFormat, ) +from ._settings import Settings, SettingsProvider from ._signal import ( Signal, SignalConnector, @@ -55,9 +46,11 @@ soft_signal_r_and_setter, soft_signal_rw, wait_for_value, + walk_rw_signals, ) from ._signal_backend import ( Array1D, + DTypeScalar_co, SignalBackend, SignalDatatype, SignalDatatypeT, @@ -84,6 +77,7 @@ in_micros, wait_for_connection, ) +from ._yaml_settings import YamlSettingsProvider __all__ = [ "DetectorController", @@ -96,14 +90,6 @@ "DeviceCollector", "DeviceVector", "DeviceFiller", - "all_at_once", - "get_signal_values", - "load_device", - "load_from_yaml", - "save_device", - "save_to_yaml", - "set_signal_values", - "walk_rw_signals", "StandardFlyer", "FlyerController", "HDFDataset", @@ -128,6 +114,8 @@ "HintedSignal", "StandardReadable", "StandardReadableFormat", + "Settings", + "SettingsProvider", "Signal", "SignalConnector", "SignalR", @@ -141,7 +129,9 @@ "soft_signal_r_and_setter", "soft_signal_rw", "wait_for_value", + "walk_rw_signals", "Array1D", + "DTypeScalar_co", "SignalBackend", "make_datakey", "StrictEnum", @@ -168,4 +158,5 @@ "in_micros", "wait_for_connection", "completed_status", + "YamlSettingsProvider", ] diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py deleted file mode 100644 index 86e479d136..0000000000 --- a/src/ophyd_async/core/_device_save_loader.py +++ /dev/null @@ -1,274 +0,0 @@ -from collections.abc import Callable, Generator, Sequence -from enum import Enum -from pathlib import Path -from typing import Any - -import numpy as np -import numpy.typing as npt -import yaml -from bluesky.plan_stubs import abs_set, wait -from bluesky.protocols import Location -from bluesky.utils import Msg -from pydantic import BaseModel - -from ._device import Device -from ._signal import SignalRW - - -def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node: - return dumper.represent_sequence( - "tag:yaml.org,2002:seq", array.tolist(), flow_style=True - ) - - -def pydantic_model_abstraction_representer( - dumper: yaml.Dumper, model: BaseModel -) -> yaml.Node: - return dumper.represent_data(model.model_dump(mode="python")) - - -def enum_representer(dumper: yaml.Dumper, enum: Enum) -> yaml.Node: - return dumper.represent_data(enum.value) - - -def get_signal_values( - signals: dict[str, SignalRW[Any]], ignore: list[str] | None = None -) -> Generator[Msg, Sequence[Location[Any]], dict[str, Any]]: - """Get signal values in bulk. - - Used as part of saving the signals of a device to a yaml file. - - Parameters - ---------- - signals : Dict[str, SignalRW] - Dictionary with pv names and matching SignalRW values. Often the direct result - of :func:`walk_rw_signals`. - - ignore : Optional[List[str]] - Optional list of PVs that should be ignored. - - Returns - ------- - Dict[str, Any] - A dictionary containing pv names and their associated values. Ignored pvs are - set to None. - - See Also - -------- - :func:`ophyd_async.core.walk_rw_signals` - :func:`ophyd_async.core.save_to_yaml` - """ - - ignore = ignore or [] - selected_signals = { - key: signal for key, signal in signals.items() if key not in ignore - } - selected_values = yield Msg("locate", *selected_signals.values()) - - assert selected_values is not None, "No signalRW's were able to be located" - named_values = { - key: value["setpoint"] - for key, value in zip(selected_signals, selected_values, strict=False) - } - # Ignored values place in with value None so we know which ones were ignored - named_values.update(dict.fromkeys(ignore)) - return named_values - - -def walk_rw_signals( - device: Device, path_prefix: str | None = "" -) -> dict[str, SignalRW[Any]]: - """Retrieve all SignalRWs from a device. - - Stores retrieved signals with their dotted attribute paths in a dictionary. Used as - part of saving and loading a device. - - Parameters - ---------- - device : Device - Ophyd device to retrieve read-write signals from. - - path_prefix : str - For internal use, leave blank when calling the method. - - Returns - ------- - SignalRWs : dict - A dictionary matching the string attribute path of a SignalRW with the - signal itself. - - See Also - -------- - :func:`ophyd_async.core.get_signal_values` - :func:`ophyd_async.core.save_to_yaml` - - """ - - if not path_prefix: - path_prefix = "" - - signals: dict[str, SignalRW[Any]] = {} - - for attr_name, attr in device.children(): - dot_path = f"{path_prefix}{attr_name}" - if type(attr) is SignalRW: - signals[dot_path] = attr - attr_signals = walk_rw_signals(attr, path_prefix=dot_path + ".") - signals.update(attr_signals) - return signals - - -def save_to_yaml(phases: Sequence[dict[str, Any]], save_path: str | Path) -> None: - """Plan which serialises a phase or set of phases of SignalRWs to a yaml file. - - Parameters - ---------- - phases : dict or list of dicts - The values to save. Each item in the list is a seperate phase used when loading - a device. In general this variable be the return value of `get_signal_values`. - - save_path : str - Path of the yaml file to write to - - See Also - -------- - :func:`ophyd_async.core.walk_rw_signals` - :func:`ophyd_async.core.get_signal_values` - :func:`ophyd_async.core.load_from_yaml` - """ - - yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) - yaml.add_multi_representer( - BaseModel, - pydantic_model_abstraction_representer, - Dumper=yaml.Dumper, - ) - yaml.add_multi_representer(Enum, enum_representer, Dumper=yaml.Dumper) - - with open(save_path, "w") as file: - yaml.dump(phases, file) - - -def load_from_yaml(save_path: str) -> Sequence[dict[str, Any]]: - """Plan that returns a list of dicts with saved signal values from a yaml file. - - Parameters - ---------- - save_path : str - Path of the yaml file to load from - - See Also - -------- - :func:`ophyd_async.core.save_to_yaml` - :func:`ophyd_async.core.set_signal_values` - """ - with open(save_path) as file: - return yaml.full_load(file) - - -def set_signal_values( - signals: dict[str, SignalRW[Any]], values: Sequence[dict[str, Any]] -) -> Generator[Msg, None, None]: - """Maps signals from a yaml file into device signals. - - ``values`` contains signal values in phases, which are loaded in sequentially - into the provided signals, to ensure signals are set in the correct order. - - Parameters - ---------- - signals : Dict[str, SignalRW[Any]] - Dictionary of named signals to be updated if value found in values argument. - Can be the output of :func:`walk_rw_signals()` for a device. - - values : Sequence[Dict[str, Any]] - List of dictionaries of signal name and value pairs, if a signal matches - the name of one in the signals argument, sets the signal to that value. - The groups of signals are loaded in their list order. - Can be the output of :func:`load_from_yaml()` for a yaml file. - - See Also - -------- - :func:`ophyd_async.core.load_from_yaml` - :func:`ophyd_async.core.walk_rw_signals` - """ - # For each phase, set all the signals, - # load them to the correct value and wait for the load to complete - for phase_number, phase in enumerate(values): - # Key is signal name - for key, value in phase.items(): - # Skip ignored values - if value is None: - continue - - if key in signals: - yield from abs_set( - signals[key], value, group=f"load-phase{phase_number}" - ) - - yield from wait(f"load-phase{phase_number}") - - -def load_device(device: Device, path: str): - """Plan which loads PVs from a yaml file into a device. - - Parameters - ---------- - device: Device - The device to load PVs into - path: str - Path of the yaml file to load - - See Also - -------- - :func:`ophyd_async.core.save_device` - """ - values = load_from_yaml(path) - signals_to_set = walk_rw_signals(device) - yield from set_signal_values(signals_to_set, values) - - -def all_at_once(values: dict[str, Any]) -> Sequence[dict[str, Any]]: - """Sort all the values into a single phase so they are set all at once""" - return [values] - - -def save_device( - device: Device, - path: str, - sorter: Callable[[dict[str, Any]], Sequence[dict[str, Any]]] = all_at_once, - ignore: list[str] | None = None, -): - """Plan that saves the state of all PV's on a device using a sorter. - - The default sorter assumes all saved PVs can be loaded at once, and therefore - can be saved at one time, i.e. all PVs will appear on one list in the - resulting yaml file. - - This can be a problem, because when the yaml is ingested with - :func:`ophyd_async.core.load_device`, it will set all of those PVs at once. - However, some PV's need to be set before others - this is device specific. - - Therefore, users should consider the order of device loading and write their - own sorter algorithms accordingly. - - See :func:`ophyd_async.fastcs.panda.phase_sorter` for a valid implementation of the - sorter. - - Parameters - ---------- - device : Device - The device whose PVs should be saved. - - path : str - The path where the resulting yaml should be saved to - - sorter : Callable[[Dict[str, Any]], Sequence[Dict[str, Any]]] - - ignore : Optional[List[str]] - - See Also - -------- - :func:`ophyd_async.core.load_device` - """ - values = yield from get_signal_values(walk_rw_signals(device), ignore=ignore) - save_to_yaml(sorter(values), path) diff --git a/src/ophyd_async/core/_settings.py b/src/ophyd_async/core/_settings.py new file mode 100644 index 0000000000..bac37a1071 --- /dev/null +++ b/src/ophyd_async/core/_settings.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Callable, Iterator, MutableMapping +from typing import Any + +from ._device import Device +from ._signal import SignalRW +from ._signal_backend import SignalDatatypeT + + +class Settings(MutableMapping[SignalRW[Any], Any]): + def __init__( + self, device: Device, settings: MutableMapping[SignalRW, Any] | None = None + ): + self.device = device + self._settings = {} + self.update(settings or {}) + + def __getitem__(self, key: SignalRW[SignalDatatypeT]) -> SignalDatatypeT: + return self._settings[key] + + def _is_in_device(self, device: Device) -> bool: + while device.parent and device.parent is not self.device: + # While we have a parent that is not the right device + # continue searching up the tree + device = device.parent + return device.parent is self.device + + def __setitem__( + self, key: SignalRW[SignalDatatypeT], value: SignalDatatypeT | None + ) -> None: + # Check the types on entry to dict to make sure we can't accidentally + # add a non-signal type + if not isinstance(key, SignalRW): + raise TypeError(f"Expected SignalRW, got {key}") + if not self._is_in_device(key): + raise KeyError(f"Signal {key} is not a child of {self.device}") + self._settings[key] = value + + def __delitem__(self, key: SignalRW) -> None: + del self._settings[key] + + def __iter__(self) -> Iterator[SignalRW]: + yield from iter(self._settings) + + def __len__(self) -> int: + return len(self._settings) + + def __or__(self, other: MutableMapping[SignalRW, Any]) -> Settings: + if isinstance(other, Settings) and not self._is_in_device(other.device): + raise ValueError(f"{other.device} is not a child of {self.device}") + return Settings(self.device, self._settings | dict(other)) + + def partition( + self, predicate: Callable[[SignalRW], bool] + ) -> tuple[Settings, Settings]: + where_true, where_false = Settings(self.device), Settings(self.device) + for signal, value in self.items(): + dest = where_true if predicate(signal) else where_false + dest[signal] = value + return where_true, where_false + + +class SettingsProvider: + @abstractmethod + async def store(self, name: str, data: dict[str, Any]): ... + + @abstractmethod + async def retrieve(self, name: str) -> dict[str, Any]: ... diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index b8c5a258d4..a8c2faba49 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -4,7 +4,7 @@ import functools import time from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import Generic, cast +from typing import Any, Generic, cast from bluesky.protocols import ( Locatable, @@ -598,3 +598,35 @@ async def set_and_wait_for_value( status_timeout, wait_for_set_completion, ) + + +def walk_rw_signals(device: Device, path_prefix: str = "") -> dict[str, SignalRW[Any]]: + """Retrieve all SignalRWs from a device. + + Stores retrieved signals with their dotted attribute paths in a dictionary. Used as + part of saving and loading a device. + + Parameters + ---------- + device : Device + Ophyd device to retrieve read-write signals from. + + path_prefix : str + For internal use, leave blank when calling the method. + + Returns + ------- + SignalRWs : dict + A dictionary matching the string attribute path of a SignalRW with the + signal itself. + + """ + signals: dict[str, SignalRW[Any]] = {} + + for attr_name, attr in device.children(): + dot_path = f"{path_prefix}{attr_name}" + if type(attr) is SignalRW: + signals[dot_path] = attr + attr_signals = walk_rw_signals(attr, path_prefix=dot_path + ".") + signals.update(attr_signals) + return signals diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index bf912fea22..62b7b1233a 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -78,9 +78,6 @@ def __add__(self, right: TableSubclass) -> TableSubclass: } ) - def __eq__(self, value: object) -> bool: - return super().__eq__(value) - def numpy_dtype(self) -> np.dtype: dtype = [] for k, v in self: diff --git a/src/ophyd_async/core/_yaml_settings.py b/src/ophyd_async/core/_yaml_settings.py new file mode 100644 index 0000000000..2cba3660b0 --- /dev/null +++ b/src/ophyd_async/core/_yaml_settings.py @@ -0,0 +1,64 @@ +import warnings +from enum import Enum +from pathlib import Path +from typing import Any + +import numpy as np +import numpy.typing as npt +import yaml +from pydantic import BaseModel + +from ._settings import SettingsProvider + + +def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node: + return dumper.represent_sequence( + "tag:yaml.org,2002:seq", array.tolist(), flow_style=True + ) + + +def pydantic_model_abstraction_representer( + dumper: yaml.Dumper, model: BaseModel +) -> yaml.Node: + return dumper.represent_data(model.model_dump(mode="python")) + + +def enum_representer(dumper: yaml.Dumper, enum: Enum) -> yaml.Node: + return dumper.represent_data(enum.value) + + +class YamlSettingsProvider(SettingsProvider): + def __init__(self, directory: Path | str): + self._directory = Path(directory) + + def _file_path(self, name: str) -> Path: + return self._directory / (name + ".yaml") + + async def store(self, name: str, data: dict[str, Any]): + yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) + yaml.add_multi_representer( + BaseModel, + pydantic_model_abstraction_representer, + Dumper=yaml.Dumper, + ) + yaml.add_multi_representer(Enum, enum_representer, Dumper=yaml.Dumper) + with open(self._file_path(name), "w") as file: + yaml.dump(data, file) + + async def retrieve(self, name: str) -> dict[str, Any]: + with open(self._file_path(name)) as file: + data = yaml.full_load(file) + if isinstance(data, list): + warnings.warn( + DeprecationWarning( + "Found old save file. Re-save your yaml settings file " + f"{self._file_path(name)} using " + "ophyd_async.plan_stubs.store_settings" + ), + stacklevel=2, + ) + merge = {} + for d in data: + merge.update(d) + return merge + return data diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 29b27d557b..acb16e5179 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -23,7 +23,6 @@ StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, ) -from ._utils import phase_sorter from ._writer import PandaHDFWriter __all__ = [ @@ -47,5 +46,4 @@ "SeqTableInfo", "StaticPcompTriggerLogic", "StaticSeqTableTriggerLogic", - "phase_sorter", ] diff --git a/src/ophyd_async/fastcs/panda/_utils.py b/src/ophyd_async/fastcs/panda/_utils.py deleted file mode 100644 index e960b5c7dd..0000000000 --- a/src/ophyd_async/fastcs/panda/_utils.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence -from typing import Any - - -def phase_sorter(panda_signal_values: dict[str, Any]) -> Sequence[dict[str, Any]]: - # Panda has two load phases. If the signal name ends in the string "UNITS", - # it needs to be loaded first so put in first phase - phase_1, phase_2 = {}, {} - - for key, value in panda_signal_values.items(): - if key.endswith("units"): - phase_1[key] = value - else: - phase_2[key] = value - - return [phase_1, phase_2] diff --git a/src/ophyd_async/plan_stubs/__init__.py b/src/ophyd_async/plan_stubs/__init__.py index 170ee7b518..f549dd27f9 100644 --- a/src/ophyd_async/plan_stubs/__init__.py +++ b/src/ophyd_async/plan_stubs/__init__.py @@ -5,6 +5,14 @@ time_resolved_fly_and_collect_with_static_seq_table, ) from ._nd_attributes import setup_ndattributes, setup_ndstats_sum +from ._panda import apply_panda_settings +from ._settings import ( + apply_settings, + apply_settings_if_different, + get_current_settings, + retrieve_settings, + store_settings, +) __all__ = [ "fly_and_collect", @@ -13,4 +21,10 @@ "ensure_connected", "setup_ndattributes", "setup_ndstats_sum", + "apply_panda_settings", + "apply_settings", + "apply_settings_if_different", + "get_current_settings", + "retrieve_settings", + "store_settings", ] diff --git a/src/ophyd_async/plan_stubs/_ensure_connected.py b/src/ophyd_async/plan_stubs/_ensure_connected.py index 2d9a8cc85a..9199ec893a 100644 --- a/src/ophyd_async/plan_stubs/_ensure_connected.py +++ b/src/ophyd_async/plan_stubs/_ensure_connected.py @@ -1,10 +1,11 @@ -from collections.abc import Awaitable - -import bluesky.plan_stubs as bps +from bluesky.utils import plan from ophyd_async.core import DEFAULT_TIMEOUT, Device, LazyMock, wait_for_connection +from ._wait_for_one import wait_for_one + +@plan def ensure_connected( *devices: Device, mock: bool | LazyMock = False, @@ -17,17 +18,10 @@ def ensure_connected( } if non_unique: raise ValueError(f"Devices do not have unique names {non_unique}") - - def connect_devices() -> Awaitable[None]: - coros = { - device.name: device.connect( - mock=mock, timeout=timeout, force_reconnect=force_reconnect - ) - for device in devices - } - return wait_for_connection(**coros) - - (connect_task,) = yield from bps.wait_for([connect_devices]) - - if connect_task and connect_task.exception() is not None: - raise connect_task.exception() + coros = { + device.name: device.connect( + mock=mock, timeout=timeout, force_reconnect=force_reconnect + ) + for device in devices + } + yield from wait_for_one(wait_for_connection(**coros)) diff --git a/src/ophyd_async/plan_stubs/_panda.py b/src/ophyd_async/plan_stubs/_panda.py new file mode 100644 index 0000000000..6d381fdc88 --- /dev/null +++ b/src/ophyd_async/plan_stubs/_panda.py @@ -0,0 +1,12 @@ +from bluesky.utils import MsgGenerator, plan + +from ophyd_async.core import Settings + +from ._settings import apply_settings + + +@plan +def apply_panda_settings(settings: Settings) -> MsgGenerator[None]: + units, others = settings.partition(lambda signal: signal.name.endswith("_units")) + yield from apply_settings(units) + yield from apply_settings(others) diff --git a/src/ophyd_async/plan_stubs/_settings.py b/src/ophyd_async/plan_stubs/_settings.py new file mode 100644 index 0000000000..6bd10edf4b --- /dev/null +++ b/src/ophyd_async/plan_stubs/_settings.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Mapping +from typing import Any + +import bluesky.plan_stubs as bps +import numpy as np +from bluesky.utils import MsgGenerator, plan + +from ophyd_async.core import ( + Device, + Settings, + SettingsProvider, + SignalRW, + T, + walk_rw_signals, +) +from ophyd_async.core._table import Table + +from ._wait_for_one import wait_for_one + + +@plan +def _get_values_of_signals( + signals: Mapping[T, SignalRW], +) -> MsgGenerator[dict[T, Any]]: + coros = [sig.get_value() for sig in signals.values()] + values = yield from wait_for_one(asyncio.gather(*coros)) + named_values = dict(zip(signals, values, strict=True)) + return named_values + + +@plan +def get_current_settings(device: Device) -> MsgGenerator[Settings]: + signals = walk_rw_signals(device) + named_values = yield from _get_values_of_signals(signals) + signal_values = {signals[name]: value for name, value in named_values.items()} + return Settings(device, signal_values) + + +@plan +def store_settings( + provider: SettingsProvider, name: str, device: Device +) -> MsgGenerator[None]: + """Plan to recursively walk a Device to find SignalRWs and write a YAML of their + values. + """ + signals = walk_rw_signals(device) + named_values = yield from _get_values_of_signals(signals) + yield from wait_for_one(provider.store(name, named_values)) + + +@plan +def retrieve_settings( + provider: SettingsProvider, name: str, device: Device +) -> MsgGenerator[Settings]: + named_values = yield from wait_for_one(provider.retrieve(name)) + signals = walk_rw_signals(device) + signal_values = {signals[name]: value for name, value in named_values.items()} + return Settings(device, signal_values) + + +@plan +def apply_settings(settings: Settings) -> MsgGenerator[None]: + signal_values = { + signal: value for signal, value in settings.items() if value is not None + } + if signal_values: + for signal, value in signal_values.items(): + yield from bps.abs_set(signal, value, group="apply_settings") + yield from bps.wait("apply_settings") + + +@plan +def apply_settings_if_different( + settings: Settings, + apply_plan: Callable[[Settings], MsgGenerator[None]], + current_settings: Settings | None = None, +) -> MsgGenerator[None]: + if current_settings is None: + signal_values = yield from _get_values_of_signals( + {sig: sig for sig in settings} + ) + current_settings = Settings(settings.device, signal_values) + + def _is_different(current, required) -> bool: + if isinstance(current, Table): + current = current.model_dump() + if isinstance(required, Table): + required = required.model_dump() + return current.keys() != required.keys() or any( + _is_different(current[k], required[k]) for k in current + ) + elif isinstance(current, np.ndarray): + return not np.array_equal(current, required) + else: + return current != required + + settings_to_change, _ = settings.partition( + lambda sig: _is_different(current_settings[sig], settings[sig]) + ) + yield from apply_plan(settings_to_change) diff --git a/src/ophyd_async/plan_stubs/_wait_for_one.py b/src/ophyd_async/plan_stubs/_wait_for_one.py new file mode 100644 index 0000000000..1d67ada095 --- /dev/null +++ b/src/ophyd_async/plan_stubs/_wait_for_one.py @@ -0,0 +1,12 @@ +from collections.abc import Awaitable + +import bluesky.plan_stubs as bps +from bluesky.utils import MsgGenerator, plan + +from ophyd_async.core import T + + +@plan +def wait_for_one(coro: Awaitable[T]) -> MsgGenerator[T]: + (task,) = yield from bps.wait_for([lambda: coro]) + return task.result() diff --git a/src/ophyd_async/sim/testing/__init__.py b/src/ophyd_async/sim/testing/__init__.py index e69de29bb2..90c6b8f836 100644 --- a/src/ophyd_async/sim/testing/__init__.py +++ b/src/ophyd_async/sim/testing/__init__.py @@ -0,0 +1,13 @@ +from ._one_of_everything import ( + ExampleEnum, + ExampleTable, + OneOfEverythingDevice, + ParentOfEverythingDevice, +) + +__all__ = [ + "ExampleEnum", + "ExampleTable", + "OneOfEverythingDevice", + "ParentOfEverythingDevice", +] diff --git a/src/ophyd_async/sim/testing/_one_of_everything.py b/src/ophyd_async/sim/testing/_one_of_everything.py new file mode 100644 index 0000000000..48e6c587e0 --- /dev/null +++ b/src/ophyd_async/sim/testing/_one_of_everything.py @@ -0,0 +1,115 @@ +from collections.abc import Sequence +from typing import Any + +import numpy as np + +from ophyd_async.core import ( + Array1D, + Device, + DTypeScalar_co, + SignalRW, + StrictEnum, + Table, + soft_signal_r_and_setter, + soft_signal_rw, +) +from ophyd_async.core._device import DeviceVector + + +class ExampleEnum(StrictEnum): + A = "Aaa" + B = "Bbb" + C = "Ccc" + + +class ExampleTable(Table): + bool: Array1D[np.bool_] + int: Array1D[np.int32] + float: Array1D[np.float64] + str: Sequence[str] + enum: Sequence[ExampleEnum] + + +def int_array_signal(dtype: type[DTypeScalar_co]) -> SignalRW[Array1D[DTypeScalar_co]]: + iinfo = np.iinfo(dtype) # type: ignore + value = np.array([iinfo.min, iinfo.max, 0, 1, 2, 3, 4], dtype=dtype) + return soft_signal_rw(Array1D[dtype], value) + + +def float_array_signal( + dtype: type[DTypeScalar_co], +) -> SignalRW[Array1D[DTypeScalar_co]]: + finfo = np.finfo(dtype) # type: ignore + value = np.array( + [ + finfo.min, + finfo.max, + finfo.smallest_normal, + finfo.smallest_subnormal, + 0, + 1.234, + 2.34e5, + 3.45e-6, + ], + dtype=dtype, + ) + return soft_signal_rw(Array1D[dtype], value) + + +class OneOfEverythingDevice(Device): + def __init__(self, name=""): + self.int = soft_signal_rw(int, 1) + self.float = soft_signal_rw(float, 1.234) + self.str = soft_signal_rw(str, "test_string") + self.bool = soft_signal_rw(bool, True) + self.enum = soft_signal_rw(ExampleEnum, ExampleEnum.B) + self.int8a = int_array_signal(np.int8) + self.uint8a = int_array_signal(np.uint8) + self.int16a = int_array_signal(np.int16) + self.uint16a = int_array_signal(np.uint16) + self.int32a = int_array_signal(np.int32) + self.uint32a = int_array_signal(np.uint32) + self.int64a = int_array_signal(np.int64) + self.uint64a = int_array_signal(np.uint64) + self.float32a = float_array_signal(np.float32) + self.float64a = float_array_signal(np.float64) + self.stra = soft_signal_rw(Sequence[str], ["one", "two", "three"]) + self.enuma = soft_signal_rw( + Sequence[ExampleEnum], [ExampleEnum.A, ExampleEnum.C] + ) + self.table = soft_signal_rw( + ExampleTable, + ExampleTable( + bool=np.array([False, False, True, True], np.bool_), + int=np.array([1, 8, -9, 32], np.int32), + float=np.array([1.8, 8.2, -6, 32.9887], np.float64), + str=["Hello", "World", "Foo", "Bar"], + enum=[ExampleEnum.A, ExampleEnum.B, ExampleEnum.A, ExampleEnum.C], + ), + ) + self.ndarray = soft_signal_rw(np.ndarray, np.array(([1, 2, 3], [4, 5, 6]))) + super().__init__(name=name) + + +async def _get_signal_values(child: Device) -> dict[SignalRW, Any]: + if isinstance(child, SignalRW): + return {child: await child.get_value()} + ret = {} + for _, c in child.children(): + ret.update(await _get_signal_values(c)) + return ret + + +class ParentOfEverythingDevice(Device): + def __init__(self, name=""): + self.child = OneOfEverythingDevice() + self.vector = DeviceVector( + {1: OneOfEverythingDevice(), 3: OneOfEverythingDevice()} + ) + self.sig_rw = soft_signal_rw(str, "Top level SignalRW") + self.sig_r, _ = soft_signal_r_and_setter(str, "Top level SignalR") + self._sig_rw = soft_signal_rw(str, "Top level private SignalRW") + super().__init__(name=name) + + async def get_signal_values(self): + return await _get_signal_values(self) diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py deleted file mode 100644 index d83b236b4e..0000000000 --- a/tests/core/test_device_save_loader.py +++ /dev/null @@ -1,322 +0,0 @@ -from os import path -from typing import Any -from unittest.mock import patch - -import numpy as np -import pytest -import yaml -from bluesky.run_engine import RunEngine - -from ophyd_async.core import ( - Array1D, - Device, - SignalRW, - StrictEnum, - Table, - all_at_once, - get_signal_values, - load_device, - load_from_yaml, - save_device, - save_to_yaml, - set_signal_values, - walk_rw_signals, -) -from ophyd_async.epics.core import epics_signal_r, epics_signal_rw -from ophyd_async.epics.testing import ExampleEnum, ExamplePvaDevice, ExampleTable - - -class EnumTest(StrictEnum): - VAL1 = "val1" - VAL2 = "val2" - - -class DummyChildDevice(Device): - def __init__(self) -> None: - self.str_sig = epics_signal_rw(str, "StrSignal") - super().__init__() - - -class DummyDeviceGroup(Device): - def __init__(self, name: str): - self.child1 = DummyChildDevice() - self.child2 = DummyChildDevice() - self.str_sig = epics_signal_rw(str, "ParentValue1") - self.parent_sig2 = epics_signal_r( - int, "ParentValue2" - ) # Ensure only RW are found - self.table_sig = epics_signal_rw(Table, "TableSignal") - self.array_sig = epics_signal_rw(Array1D[np.uint32], "ArraySignal") - self.enum_sig = epics_signal_rw(EnumTest, "EnumSignal") - super().__init__(name) - - -@pytest.fixture -async def device() -> DummyDeviceGroup: - device = DummyDeviceGroup("parent") - await device.connect(mock=True) - return device - - -@pytest.fixture -async def device_all_types() -> ExamplePvaDevice: - device = ExamplePvaDevice("parent") - await device.connect(mock=True) - return device - - -# Dummy function to check different phases save properly -def sort_signal_by_phase(values: dict[str, Any]) -> list[dict[str, Any]]: - phase_1 = {"child1.str_sig": values["child1.str_sig"]} - phase_2 = {"child2.str_sig": values["child2.str_sig"]} - phase_3 = { - key: value - for key, value in values.items() - if key not in phase_1 and key not in phase_2 - } - return [phase_1, phase_2, phase_3] - - -async def test_enum_yaml_formatting(tmp_path): - enums = [EnumTest.VAL1, EnumTest.VAL2] - save_to_yaml(enums, path.join(tmp_path, "test_file.yaml")) - with open(path.join(tmp_path, "test_file.yaml")) as file: - saved_enums = yaml.load(file, yaml.Loader) - # check that save/load reduces from enum to str - assert all(isinstance(value, str) for value in saved_enums) - # check values of enums same - assert saved_enums == enums - - -# Long string more than 40 chars -ls1 = "/here/is/a/long/string/of/more/than/40/chars" - - -async def test_save_device_all_types( - RE: RunEngine, device_all_types: ExamplePvaDevice, tmp_path -): - # Populate fake device with PV's... - await device_all_types.my_int.set(1) - await device_all_types.my_float.set(1.234) - await device_all_types.my_str.set("test_string") - await device_all_types.longstr.set(ls1) - await device_all_types.enum.set(ExampleEnum.B) - await device_all_types.enum2.set("Bbb") - for pv, dtype in { - device_all_types.int8a: np.int8, - device_all_types.uint8a: np.uint8, - device_all_types.int16a: np.int16, - device_all_types.uint16a: np.uint16, - device_all_types.int32a: np.int32, - device_all_types.uint32a: np.uint32, - device_all_types.int64a: np.int64, - device_all_types.uint64a: np.uint64, - }.items(): - await pv.set( - np.array( - [np.iinfo(dtype).min, np.iinfo(dtype).max, 0, 1, 2, 3, 4], dtype=dtype - ) - ) - for pv, dtype in { - device_all_types.float32a: np.float32, - device_all_types.float64a: np.float64, - }.items(): - finfo = np.finfo(dtype) - data = np.array( - [ - finfo.min, - finfo.max, - finfo.smallest_normal, - finfo.smallest_subnormal, - 0, - 1.234, - 2.34e5, - 3.45e-6, - ], - dtype=dtype, - ) - - await pv.set(data) - await device_all_types.stra.set( - ["one", "two", "three"], - ) - await device_all_types.table.set( - ExampleTable( - bool=np.array([False, False, True, True], np.bool_), - int=np.array([1, 8, -9, 32], np.int32), - float=np.array([1.8, 8.2, -6, 32.9887], np.float64), - str=["Hello", "World", "Foo", "Bar"], - enum=[ExampleEnum.A, ExampleEnum.B, ExampleEnum.A, ExampleEnum.C], - ) - ) - - # Create save plan from utility functions - def save_my_device(): - signalRWs = walk_rw_signals(device_all_types) - values = yield from get_signal_values(signalRWs) - - save_to_yaml([values], path.join(tmp_path, "test_file.yaml")) - - RE(save_my_device()) - - actual_file_path = path.join(tmp_path, "test_file.yaml") - with open(actual_file_path) as actual_file: - with open("tests/test_data/test_yaml_save.yml") as expected_file: - assert yaml.safe_load(actual_file) == yaml.safe_load(expected_file) - - -async def test_save_device(RE: RunEngine, device: DummyDeviceGroup, tmp_path): - # Populate fake device with PV's... - await device.child1.str_sig.set("test_string") - # Test tables PVs - table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])} - array_pv = np.array([2, 2, 2, 2, 2]) - await device.array_sig.set(array_pv) - await device.table_sig.set(table_pv) - await device.enum_sig.set(EnumTest.VAL2) - - # Create save plan from utility functions - def save_my_device(): - signalRWs = walk_rw_signals(device) - - assert list(signalRWs.keys()) == [ - "child1.str_sig", - "child2.str_sig", - "str_sig", - "table_sig", - "array_sig", - "enum_sig", - ] - assert all(isinstance(signal, SignalRW) for signal in list(signalRWs.values())) - - values = yield from get_signal_values(signalRWs, ignore=["str_sig"]) - assert np.array_equal(values["array_sig"], array_pv) - assert values["enum_sig"] == "val2" - assert values["table_sig"] == Table(**table_pv) - assert values["str_sig"] is None - assert values["child1.str_sig"] == "test_string" - assert values["child2.str_sig"] == "" - - save_to_yaml([values], path.join(tmp_path, "test_file.yaml")) - - RE(save_my_device()) - - with open(path.join(tmp_path, "test_file.yaml")) as file: - yaml_content = yaml.load(file, yaml.Loader)[0] - assert yaml_content["child1.str_sig"] == "test_string" - assert yaml_content["child2.str_sig"] == "" - assert np.array_equal(yaml_content["table_sig"]["VAL1"], table_pv["VAL1"]) - assert np.array_equal(yaml_content["table_sig"]["VAL2"], table_pv["VAL2"]) - assert np.array_equal(yaml_content["array_sig"], array_pv) - assert yaml_content["enum_sig"] == "val2" - assert yaml_content["str_sig"] is None - - -async def test_yaml_formatting(RE: RunEngine, device: DummyDeviceGroup, tmp_path): - file_path = path.join(tmp_path, "test_file.yaml") - await device.child1.str_sig.set("test_string") - table = {"VAL1": np.array([1, 2, 3, 4, 5]), "VAL2": np.array([6, 7, 8, 9, 10])} - await device.array_sig.set(np.array([11, 12, 13, 14, 15])) - await device.table_sig.set(table) - await device.enum_sig.set(EnumTest.VAL2) - RE(save_device(device, file_path, sorter=sort_signal_by_phase)) - - with open(file_path) as file: - expected = """\ -- child1.str_sig: test_string -- child2.str_sig: '' -- array_sig: [11, 12, 13, 14, 15] - enum_sig: val2 - str_sig: '' - table_sig: - VAL1: [1, 2, 3, 4, 5] - VAL2: [6, 7, 8, 9, 10] -""" - # assert False, file.read() - assert file.read() == expected - - -async def test_load_from_yaml(RE: RunEngine, device: DummyDeviceGroup, tmp_path): - file_path = path.join(tmp_path, "test_file.yaml") - - array = np.array([1, 1, 1, 1, 1]) - table = {"VAL1": np.array([1, 2, 3, 4, 5]), "VAL2": np.array([6, 7, 8, 9, 10])} - await device.child1.str_sig.set("initial_string") - await device.array_sig.set(array) - await device.str_sig.set(None) - await device.enum_sig.set(EnumTest.VAL2) - await device.table_sig.set(table) - RE(save_device(device, file_path, sorter=sort_signal_by_phase)) - - values = load_from_yaml(file_path) - assert values[0]["child1.str_sig"] == "initial_string" - assert values[1]["child2.str_sig"] == "" - assert values[2]["str_sig"] == "" - assert values[2]["enum_sig"] == "val2" - assert np.array_equal(values[2]["array_sig"], array) - assert np.array_equal(values[2]["table_sig"]["VAL1"], table["VAL1"]) - assert np.array_equal(values[2]["table_sig"]["VAL2"], table["VAL2"]) - - -async def test_set_signal_values_restores_value( - RE: RunEngine, device: DummyDeviceGroup, tmp_path -): - file_path = path.join(tmp_path, "test_file.yaml") - - await device.str_sig.set("initial_string") - await device.array_sig.set(np.array([1, 1, 1, 1, 1])) - RE(save_device(device, file_path, sorter=sort_signal_by_phase)) - - await device.str_sig.set("changed_string") - await device.array_sig.set(np.array([2, 2, 2, 2, 2])) - string_value = await device.str_sig.get_value() - array_value = await device.array_sig.get_value() - assert string_value == "changed_string" - assert np.array_equal(array_value, np.array([2, 2, 2, 2, 2])) - - values = load_from_yaml(file_path) - signals_to_set = walk_rw_signals(device) - - RE(set_signal_values(signals_to_set, values)) - - string_value = await device.str_sig.get_value() - array_value = await device.array_sig.get_value() - assert string_value == "initial_string" - assert np.array_equal(array_value, np.array([1, 1, 1, 1, 1])) - - -@patch("ophyd_async.core._device_save_loader.load_from_yaml") -@patch("ophyd_async.core._device_save_loader.walk_rw_signals") -@patch("ophyd_async.core._device_save_loader.set_signal_values") -async def test_load_device( - mock_set_signal_values, - mock_walk_rw_signals, - mock_load_from_yaml, - device: DummyDeviceGroup, -): - RE = RunEngine() - RE(load_device(device, "path")) - mock_load_from_yaml.assert_called_once() - mock_walk_rw_signals.assert_called_once() - mock_set_signal_values.assert_called_once() - - -async def test_set_signal_values_skips_ignored_values(device: DummyDeviceGroup): - RE = RunEngine() - array = np.array([1, 1, 1, 1, 1]) - - await device.child1.str_sig.set("initial_string") - await device.array_sig.set(array) - await device.str_sig.set(None) - - signals_of_device = walk_rw_signals(device) - values_to_set = [{"child1.str_sig": None, "array_sig": np.array([2, 3, 4])}] - - RE(set_signal_values(signals_of_device, values_to_set)) - - assert np.all(await device.array_sig.get_value() == np.array([2, 3, 4])) - assert await device.child1.str_sig.get_value() == "initial_string" - - -def test_all_at_once_sorter(): - assert all_at_once({"child1.str_sig": 0}) == [{"child1.str_sig": 0}] diff --git a/tests/core/test_table.py b/tests/core/test_table.py index fe48629dbe..4818f1976f 100644 --- a/tests/core/test_table.py +++ b/tests/core/test_table.py @@ -55,3 +55,4 @@ def test_table_coerces(kwargs): t = MyTable(**kwargs) for k, v in t: assert v == pytest.approx(kwargs[k]) + assert t == pytest.approx(t) diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 992a822cd9..eb3b0cca8f 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -25,9 +25,8 @@ SubsetEnum, T, Table, - load_from_yaml, + YamlSettingsProvider, observe_value, - save_to_yaml, ) from ophyd_async.epics.core import ( EpicsDevice, @@ -324,11 +323,10 @@ async def assert_backend_get_put_monitor( initial_value, datatype=None, ) - - yaml_path = tmp_path / "test.yaml" - save_to_yaml([{"test": put_value}], yaml_path) - loaded = load_from_yaml(yaml_path) - assert np.all(loaded[0]["test"] == put_value) + provider = YamlSettingsProvider(tmp_path) + await provider.store("test", {"test": put_value}) + loaded = await provider.retrieve("test") + assert np.all(loaded["test"] == put_value) @PARAMETERISE_PROTOCOLS @@ -504,11 +502,6 @@ async def test_bool_conversion_of_enum( bool, ) - yaml_path = tmp_path / "test.yaml" - save_to_yaml([{"test": False}], yaml_path) - loaded = load_from_yaml(yaml_path) - assert np.all(loaded[0]["test"] is False) - @PARAMETERISE_PROTOCOLS async def test_error_raised_on_disconnected_PV(ioc, protocol) -> None: diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index fc662fde5c..9a0855ad33 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -2,15 +2,14 @@ import yaml from bluesky import RunEngine -from ophyd_async.core import DeviceCollector, load_device, save_device +from ophyd_async.core import DeviceCollector, YamlSettingsProvider from ophyd_async.epics.core import epics_signal_rw from ophyd_async.fastcs.core import fastcs_connector -from ophyd_async.fastcs.panda import ( - CommonPandaBlocks, - DataBlock, - SeqTable, - TimeUnits, - phase_sorter, +from ophyd_async.fastcs.panda import CommonPandaBlocks, DataBlock, SeqTable, TimeUnits +from ophyd_async.plan_stubs import ( + apply_panda_settings, + retrieve_settings, + store_settings, ) @@ -30,8 +29,8 @@ def __init__(self, uri: str, name: str = ""): async def test_save_load_panda(tmp_path, RE: RunEngine): mock_panda1 = await get_mock_panda() await mock_panda1.seq[1].table.set(SeqTable.row(repeats=1)) - - RE(save_device(mock_panda1, str(tmp_path / "panda.yaml"), sorter=phase_sorter)) + provider = YamlSettingsProvider(tmp_path) + RE(store_settings(provider, "panda", mock_panda1)) def check_equal_with_seq_tables(actual, expected): assert actual.model_fields_set == expected.model_fields_set @@ -43,7 +42,12 @@ def check_equal_with_seq_tables(actual, expected): check_equal_with_seq_tables( (await mock_panda2.seq[1].table.get_value()), SeqTable() ) - RE(load_device(mock_panda2, str(tmp_path / "panda.yaml"))) + + def load_panda(): + settings = yield from retrieve_settings(provider, "panda", mock_panda2) + yield from apply_panda_settings(settings) + + RE(load_panda()) check_equal_with_seq_tables( await mock_panda2.seq[1].table.get_value(), @@ -57,12 +61,10 @@ def check_equal_with_seq_tables(actual, expected): # Parse the YAML content parsed_yaml = yaml.safe_load(yaml_content) - assert parsed_yaml[0] == { + assert parsed_yaml == { "phase_1_signal_units": 0, "seq.1.prescale_units": TimeUnits("min"), "seq.2.prescale_units": TimeUnits("min"), - } - assert parsed_yaml[1] == { "data.capture": False, "data.capture_mode": "FIRST_N", "data.create_directory": 0, diff --git a/tests/plan_stubs/test_settings.py b/tests/plan_stubs/test_settings.py new file mode 100644 index 0000000000..4f8c7caf4a --- /dev/null +++ b/tests/plan_stubs/test_settings.py @@ -0,0 +1,96 @@ +from pathlib import Path +from unittest.mock import call + +import bluesky.plan_stubs as bps +import pytest +import yaml + +from ophyd_async.core import Settings, YamlSettingsProvider +from ophyd_async.plan_stubs import ( + apply_settings, + apply_settings_if_different, + get_current_settings, + retrieve_settings, + store_settings, +) +from ophyd_async.sim.testing import ExampleTable, ParentOfEverythingDevice +from ophyd_async.testing import get_mock + +TEST_DATA = Path(__file__).absolute().parent.parent / "test_data" + + +@pytest.fixture +async def parent_device() -> ParentOfEverythingDevice: + device = ParentOfEverythingDevice("parent") + await device.connect(mock=True) + return device + + +async def test_get_current_settings(RE, parent_device: ParentOfEverythingDevice): + expected_values = await parent_device.get_signal_values() + + def my_plan(): + current_settings = yield from get_current_settings(parent_device) + assert dict(current_settings) == expected_values + + RE(my_plan()) + + +async def test_store_settings(RE, parent_device: ParentOfEverythingDevice, tmp_path): + provider = YamlSettingsProvider(tmp_path) + + def my_plan(): + yield from store_settings(provider, "test_file", parent_device) + with open(tmp_path / "test_file.yaml") as actual_file: + with open(TEST_DATA / "test_yaml_save.yaml") as expected_file: + assert yaml.safe_load(actual_file) == yaml.safe_load(expected_file) + + RE(my_plan()) + + +async def test_retrieve_and_apply_settings(RE, parent_device: ParentOfEverythingDevice): + provider = YamlSettingsProvider(TEST_DATA) + expected_values = await parent_device.get_signal_values() + serialized_values = {} + # Override the table to be the serialized version so it compares equal + for sig, value in expected_values.items(): + if isinstance(value, ExampleTable): + serialized_values[sig] = { + k: pytest.approx(v) for k, v in value.model_dump().items() + } + else: + serialized_values[sig] = pytest.approx(value) + + def my_plan(): + m = get_mock(parent_device) + assert not m.mock_calls + settings = yield from retrieve_settings( + provider, "test_yaml_save", parent_device + ) + assert dict(settings) == serialized_values + assert not m.mock_calls + yield from apply_settings(settings) + assert len(m.mock_calls) == 59 + m.reset_mock() + assert not m.mock_calls + yield from apply_settings_if_different(settings, apply_settings) + assert not m.mock_calls + yield from bps.abs_set(parent_device.sig_rw, "foo", wait=True) + assert m.mock_calls == [call.sig_rw.put("foo", wait=True)] + m.reset_mock() + yield from apply_settings_if_different(settings, apply_settings) + assert m.mock_calls == [call.sig_rw.put("Top level SignalRW", wait=True)] + + RE(my_plan()) + + +async def test_ignored_settings(RE, parent_device: ParentOfEverythingDevice): + def my_plan(): + m = get_mock(parent_device) + settings = Settings( + parent_device, {parent_device.sig_rw: "foo", parent_device._sig_rw: None} + ) + yield from apply_settings(settings) + assert m.mock_calls == [call.sig_rw.put("foo", wait=True)] + + RE(my_plan()) diff --git a/tests/test_data/test_yaml_save.yaml b/tests/test_data/test_yaml_save.yaml new file mode 100644 index 0000000000..974042cba1 --- /dev/null +++ b/tests/test_data/test_yaml_save.yaml @@ -0,0 +1,119 @@ +_sig_rw: Top level private SignalRW +child.bool: true +child.enum: Bbb +child.enuma: +- Aaa +- Ccc +child.float: 1.234 +child.float32a: [-3.4028234663852886e+38, 3.4028234663852886e+38, 1.1754943508222875e-38, + 1.401298464324817e-45, 0.0, 1.2339999675750732, 234000.0, 3.4499998946557753e-06] +child.float64a: [-1.7976931348623157e+308, 1.7976931348623157e+308, 2.2250738585072014e-308, + 5.0e-324, 0.0, 1.234, 234000.0, 3.45e-06] +child.int: 1 +child.int16a: [-32768, 32767, 0, 1, 2, 3, 4] +child.int32a: [-2147483648, 2147483647, 0, 1, 2, 3, 4] +child.int64a: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4] +child.int8a: [-128, 127, 0, 1, 2, 3, 4] +child.ndarray: [[1, 2, 3], [4, 5, 6]] +child.str: test_string +child.stra: +- one +- two +- three +child.table: + bool: [false, false, true, true] + enum: + - Aaa + - Bbb + - Aaa + - Ccc + float: [1.8, 8.2, -6.0, 32.9887] + int: [1, 8, -9, 32] + str: + - Hello + - World + - Foo + - Bar +child.uint16a: [0, 65535, 0, 1, 2, 3, 4] +child.uint32a: [0, 4294967295, 0, 1, 2, 3, 4] +child.uint64a: [0, 18446744073709551615, 0, 1, 2, 3, 4] +child.uint8a: [0, 255, 0, 1, 2, 3, 4] +sig_rw: Top level SignalRW +vector.1.bool: true +vector.1.enum: Bbb +vector.1.enuma: +- Aaa +- Ccc +vector.1.float: 1.234 +vector.1.float32a: [-3.4028234663852886e+38, 3.4028234663852886e+38, 1.1754943508222875e-38, + 1.401298464324817e-45, 0.0, 1.2339999675750732, 234000.0, 3.4499998946557753e-06] +vector.1.float64a: [-1.7976931348623157e+308, 1.7976931348623157e+308, 2.2250738585072014e-308, + 5.0e-324, 0.0, 1.234, 234000.0, 3.45e-06] +vector.1.int: 1 +vector.1.int16a: [-32768, 32767, 0, 1, 2, 3, 4] +vector.1.int32a: [-2147483648, 2147483647, 0, 1, 2, 3, 4] +vector.1.int64a: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4] +vector.1.int8a: [-128, 127, 0, 1, 2, 3, 4] +vector.1.ndarray: [[1, 2, 3], [4, 5, 6]] +vector.1.str: test_string +vector.1.stra: +- one +- two +- three +vector.1.table: + bool: [false, false, true, true] + enum: + - Aaa + - Bbb + - Aaa + - Ccc + float: [1.8, 8.2, -6.0, 32.9887] + int: [1, 8, -9, 32] + str: + - Hello + - World + - Foo + - Bar +vector.1.uint16a: [0, 65535, 0, 1, 2, 3, 4] +vector.1.uint32a: [0, 4294967295, 0, 1, 2, 3, 4] +vector.1.uint64a: [0, 18446744073709551615, 0, 1, 2, 3, 4] +vector.1.uint8a: [0, 255, 0, 1, 2, 3, 4] +vector.3.bool: true +vector.3.enum: Bbb +vector.3.enuma: +- Aaa +- Ccc +vector.3.float: 1.234 +vector.3.float32a: [-3.4028234663852886e+38, 3.4028234663852886e+38, 1.1754943508222875e-38, + 1.401298464324817e-45, 0.0, 1.2339999675750732, 234000.0, 3.4499998946557753e-06] +vector.3.float64a: [-1.7976931348623157e+308, 1.7976931348623157e+308, 2.2250738585072014e-308, + 5.0e-324, 0.0, 1.234, 234000.0, 3.45e-06] +vector.3.int: 1 +vector.3.int16a: [-32768, 32767, 0, 1, 2, 3, 4] +vector.3.int32a: [-2147483648, 2147483647, 0, 1, 2, 3, 4] +vector.3.int64a: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4] +vector.3.int8a: [-128, 127, 0, 1, 2, 3, 4] +vector.3.ndarray: [[1, 2, 3], [4, 5, 6]] +vector.3.str: test_string +vector.3.stra: +- one +- two +- three +vector.3.table: + bool: [false, false, true, true] + enum: + - Aaa + - Bbb + - Aaa + - Ccc + float: [1.8, 8.2, -6.0, 32.9887] + int: [1, 8, -9, 32] + str: + - Hello + - World + - Foo + - Bar +vector.3.uint16a: [0, 65535, 0, 1, 2, 3, 4] +vector.3.uint32a: [0, 4294967295, 0, 1, 2, 3, 4] +vector.3.uint64a: [0, 18446744073709551615, 0, 1, 2, 3, 4] +vector.3.uint8a: [0, 255, 0, 1, 2, 3, 4] diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml deleted file mode 100644 index 349cbca975..0000000000 --- a/tests/test_data/test_yaml_save.yml +++ /dev/null @@ -1,42 +0,0 @@ -- bool_unnamed: false - enum: Bbb - enum2: Bbb - float32a: [-3.4028234663852886e+38, 3.4028234663852886e+38, 1.1754943508222875e-38, - 1.401298464324817e-45, 0.0, 1.2339999675750732, 234000.0, 3.4499998946557753e-06] - float64a: [-1.7976931348623157e+308, 1.7976931348623157e+308, 2.2250738585072014e-308, - 5.0e-324, 0.0, 1.234, 234000.0, 3.45e-06] - int16a: [-32768, 32767, 0, 1, 2, 3, 4] - int32a: [-2147483648, 2147483647, 0, 1, 2, 3, 4] - int64a: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4] - int8a: [-128, 127, 0, 1, 2, 3, 4] - lessint: 0 - longstr: '/here/is/a/long/string/of/more/than/40/chars' - longstr2: '' - my_bool: false - my_float: 1.234 - my_int: 1 - my_str: test_string - ntndarray_data: [] - partialint: 0 - stra: - - one - - two - - three - table: - bool: [false, false, true, true] - enum: - - Aaa - - Bbb - - Aaa - - Ccc - float: [1.8, 8.2, -6.0, 32.9887] - int: [1, 8, -9, 32] - str: - - Hello - - World - - Foo - - Bar - uint16a: [0, 65535, 0, 1, 2, 3, 4] - uint32a: [0, 4294967295, 0, 1, 2, 3, 4] - uint64a: [0, 18446744073709551615, 0, 1, 2, 3, 4] - uint8a: [0, 255, 0, 1, 2, 3, 4]