Skip to content

Commit

Permalink
update to bps.collect_while_completeing
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Dec 4, 2024
1 parent e4b18e5 commit c90e403
Show file tree
Hide file tree
Showing 83 changed files with 1,551 additions and 1,217 deletions.
2 changes: 2 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ Fixes #ISSUE

### Checks for reviewer
- [ ] Would the PR title make sense to a user on a set of release notes
- [ ] If the change requires a bump in an IOC version, is that specified in a `##Changes` section in the body of the PR
- [ ] If the change requires a bump in the PandABlocks-ioc version, is the `ophyd_async.fastcs.panda._hdf_panda.MINIMUM_PANDA_IOC` variable updated to match
1 change: 1 addition & 0 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
with:
python-version: ${{ inputs.python-version }}
pip-install: ".[dev]"

- name: Run tests
run: tox -e tests

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
matrix:
runs-on: ["ubuntu-latest", "windows-latest"] # can add macos-latest
python-version: ["3.10","3.11"] # 3.12 should be added when p4p is updated
python-version: ["3.10", "3.11"] # 3.12 should be added when p4p is updated
include:
# Include one that runs in the dev environment
- runs-on: "ubuntu-latest"
Expand Down
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ repos:
entry: ruff format --force-exclude
types: [python]
require_serial: true

- id: import-contracts
name: Ensure import directionality
pass_filenames: false
language: system
entry: lint-imports
types: [python]
require_serial: false
51 changes: 51 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dev = [
"inflection",
"ipython",
"ipywidgets",
"import-linter",
"matplotlib",
"myst-parser",
"numpydoc",
Expand Down Expand Up @@ -164,3 +165,53 @@ lint.preview = true # so that preview mode PLC2701 is enabled
# See https://github.com/DiamondLightSource/python-copier-template/issues/154
# Remove this line to forbid private member access in tests
"tests/**/*" = ["SLF001"]


[tool.importlinter]
root_package = "ophyd_async"

[[tool.importlinter.contracts]]
name = "Core is independent"
type = "independence"
modules = "ophyd_async.core"

[[tool.importlinter.contracts]]
name = "Epics depends only on core"
type = "forbidden"
source_modules = "ophyd_async.epics"
forbidden_modules = [
"ophyd_async.fastcs",
"ophyd_async.plan_stubs",
"ophyd_async.sim",
"ophyd_async.tango",
]

[[tool.importlinter.contracts]]
name = "tango depends only on core"
type = "forbidden"
source_modules = "ophyd_async.tango"
forbidden_modules = [
"ophyd_async.epics",
"ophyd_async.fastcs",
"ophyd_async.plan_stubs",
"ophyd_async.sim",
]


[[tool.importlinter.contracts]]
name = "sim depends only on core"
type = "forbidden"
source_modules = "ophyd_async.sim"
forbidden_modules = [
"ophyd_async.epics",
"ophyd_async.fastcs",
"ophyd_async.plan_stubs",
"ophyd_async.tango",
]


[[tool.importlinter.contracts]]
name = "Fastcs depends only on core, epics, tango"
type = "forbidden"
source_modules = "ophyd_async.fastcs"
forbidden_modules = ["ophyd_async.plan_stubs", "ophyd_async.sim"]
18 changes: 9 additions & 9 deletions src/ophyd_async/core/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class DetectorTrigger(StrictEnum):
"""Type of mechanism for triggering a detector to take frames"""

#: Detector generates internal trigger for given rate
internal = "internal"
INTERNAL = "internal"
#: Expect a series of arbitrary length trigger signals
edge_trigger = "edge_trigger"
EDGE_TRIGGER = "edge_trigger"
#: Expect a series of constant width external gate signals
constant_gate = "constant_gate"
CONSTANT_GATE = "constant_gate"
#: Expect a series of variable width external gate signals
variable_gate = "variable_gate"
VARIABLE_GATE = "variable_gate"


class TriggerInfo(BaseModel):
Expand All @@ -53,7 +53,7 @@ class TriggerInfo(BaseModel):
#: - 3 times for final flat field images
number_of_triggers: NonNegativeInt | list[NonNegativeInt]
#: Sort of triggers that will be sent
trigger: DetectorTrigger = Field(default=DetectorTrigger.internal)
trigger: DetectorTrigger = Field(default=DetectorTrigger.INTERNAL)
#: What is the minimum deadtime between triggers
deadtime: float | None = Field(default=None, ge=0)
#: What is the maximum high time of the triggers
Expand Down Expand Up @@ -265,14 +265,14 @@ async def trigger(self) -> None:
await self.prepare(
TriggerInfo(
number_of_triggers=1,
trigger=DetectorTrigger.internal,
trigger=DetectorTrigger.INTERNAL,
deadtime=None,
livetime=None,
frame_timeout=None,
)
)
assert self._trigger_info
assert self._trigger_info.trigger is DetectorTrigger.internal
assert self._trigger_info.trigger is DetectorTrigger.INTERNAL
# Arm the detector and wait for it to finish.
indices_written = await self.writer.get_indices_written()
await self.controller.arm()
Expand Down Expand Up @@ -303,7 +303,7 @@ async def prepare(self, value: TriggerInfo) -> None:
Args:
value: TriggerInfo describing how to trigger the detector
"""
if value.trigger != DetectorTrigger.internal:
if value.trigger != DetectorTrigger.INTERNAL:
assert (
value.deadtime
), "Deadtime must be supplied when in externally triggered mode"
Expand All @@ -323,7 +323,7 @@ async def prepare(self, value: TriggerInfo) -> None:
self._describe, _ = await asyncio.gather(
self.writer.open(value.multiplier), self.controller.prepare(value)
)
if value.trigger != DetectorTrigger.internal:
if value.trigger != DetectorTrigger.INTERNAL:
await self.controller.arm()
self._fly_start = time.monotonic()

Expand Down
32 changes: 26 additions & 6 deletions src/ophyd_async/core/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ class Device(HasName, Connectable):
_connect_task: asyncio.Task | None = None
# The mock if we have connected in mock mode
_mock: LazyMock | None = None
# The separator to use when making child names
_child_name_separator: str = "-"

def __init__(
self, name: str = "", connector: DeviceConnector | None = None
) -> None:
self._connector = connector or DeviceConnector()
self._connector.create_children_from_annotations(self)
self.set_name(name)
if name:
self.set_name(name)

@property
def name(self) -> str:
Expand All @@ -97,21 +100,30 @@ def log(self) -> LoggerAdapter:
getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name}
)

def set_name(self, name: str):
def set_name(self, name: str, *, child_name_separator: str | None = None) -> None:
"""Set ``self.name=name`` and each ``self.child.name=name+"-child"``.
Parameters
----------
name:
New name to set
child_name_separator:
Use this as a separator instead of "-". Use "_" instead to make the same
names as the equivalent ophyd sync device.
"""
self._name = name
if child_name_separator:
self._child_name_separator = child_name_separator
# Ensure logger is recreated after a name change
if "log" in self.__dict__:
del self.log
for child_name, child in self.children():
child_name = f"{self.name}-{child_name.strip('_')}" if self.name else ""
child.set_name(child_name)
for attr_name, child in self.children():
child_name = (
f"{self.name}{self._child_name_separator}{attr_name}"
if self.name
else ""
)
child.set_name(child_name, child_name_separator=self._child_name_separator)

def __setattr__(self, name: str, value: Any) -> None:
# Bear in mind that this function is called *a lot*, so
Expand Down Expand Up @@ -147,6 +159,10 @@ async def connect(
timeout:
Time to wait before failing with a TimeoutError.
"""
assert hasattr(self, "_connector"), (
f"{self}: doesn't have attribute `_connector`,"
" did you call `super().__init__` in your `__init__` method?"
)
if mock:
# Always connect in mock mode serially
if isinstance(mock, LazyMock):
Expand Down Expand Up @@ -247,6 +263,8 @@ class DeviceCollector:
set_name:
If True, call ``device.set_name(variable_name)`` on all collected
Devices
child_name_separator:
Use this as a separator if we call ``set_name``.
connect:
If True, call ``device.connect(mock)`` in parallel on all
collected Devices
Expand All @@ -271,11 +289,13 @@ class DeviceCollector:
def __init__(
self,
set_name=True,
child_name_separator: str = "-",
connect=True,
mock=False,
timeout: float = 10.0,
):
self._set_name = set_name
self._child_name_separator = child_name_separator
self._connect = connect
self._mock = mock
self._timeout = timeout
Expand Down Expand Up @@ -311,7 +331,7 @@ async def _on_exit(self) -> None:
for name, obj in self._objects_on_exit.items():
if name not in self._names_on_enter and isinstance(obj, Device):
if self._set_name and not obj.name:
obj.set_name(name)
obj.set_name(name, child_name_separator=self._child_name_separator)
if self._connect:
connect_coroutines[name] = obj.connect(
self._mock, timeout=self._timeout
Expand Down
6 changes: 3 additions & 3 deletions src/ophyd_async/core/_mock_signal_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Awaitable, Callable, Iterable
from contextlib import asynccontextmanager, contextmanager
from contextlib import contextmanager
from unittest.mock import AsyncMock, Mock

from ._device import Device
Expand Down Expand Up @@ -40,8 +40,8 @@ def set_mock_put_proceeds(signal: Signal, proceeds: bool):
backend.put_proceeds.clear()


@asynccontextmanager
async def mock_puts_blocked(*signals: Signal):
@contextmanager
def mock_puts_blocked(*signals: Signal):
for signal in signals:
set_mock_put_proceeds(signal, False)
yield
Expand Down
49 changes: 36 additions & 13 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import functools
import time
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from typing import Any, Generic, cast

Expand Down Expand Up @@ -122,7 +123,7 @@ async def get_value(self) -> SignalDatatypeT:

def _callback(self, reading: Reading[SignalDatatypeT]):
self._signal.log.debug(
f"Updated subscription: reading of source {self._signal.source} changed"
f"Updated subscription: reading of source {self._signal.source} changed "
f"from {self._reading} to {reading}"
)
self._reading = reading
Expand Down Expand Up @@ -425,6 +426,7 @@ async def observe_value(
signal: SignalR[SignalDatatypeT],
timeout: float | None = None,
done_status: Status | None = None,
done_timeout: float | None = None,
) -> AsyncGenerator[SignalDatatypeT, None]:
"""Subscribe to the value of a signal so it can be iterated from.
Expand All @@ -439,25 +441,44 @@ async def observe_value(
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.
done_timeout:
If given, the maximum time to watch a signal, in seconds. If the loop is still
being watched after this length, raise asyncio.TimeoutError. This should be used
instead of on an 'asyncio.wait_for' timeout
Notes
-----
Due to a rare condition with busy signals, it is not recommended to use this
function with asyncio.timeout, including in an 'asyncio.wait_for' loop. Instead,
this timeout should be given to the done_timeout parameter.
Example usage::
async for value in observe_value(sig):
do_something_with(value)
"""

async for _, value in observe_signals_value(
signal, timeout=timeout, done_status=done_status
signal,
timeout=timeout,
done_status=done_status,
done_timeout=done_timeout,
):
yield value


def _get_iteration_timeout(
timeout: float | None, overall_deadline: float | None
) -> float | None:
overall_deadline = overall_deadline - time.monotonic() if overall_deadline else None
return min([x for x in [overall_deadline, timeout] if x is not None], default=None)


async def observe_signals_value(
*signals: SignalR[SignalDatatypeT],
timeout: float | None = None,
done_status: Status | None = None,
done_timeout: float | None = None,
) -> AsyncGenerator[tuple[SignalR[SignalDatatypeT], SignalDatatypeT], None]:
"""Subscribe to the value of a signal so it can be iterated from.
Expand All @@ -472,6 +493,10 @@ async def observe_signals_value(
done_status:
If this status is complete, stop observing and make the iterator return.
If it raises an exception then this exception will be raised by the iterator.
done_timeout:
If given, the maximum time to watch a signal, in seconds. If the loop is still
being watched after this length, raise asyncio.TimeoutError. This should be used
instead of on an 'asyncio.wait_for' timeout
Notes
-----
Expand All @@ -486,12 +511,6 @@ async def observe_signals_value(
q: asyncio.Queue[tuple[SignalR[SignalDatatypeT], SignalDatatypeT] | Status] = (
asyncio.Queue()
)
if timeout is None:
get_value = q.get
else:

async def get_value():
return await asyncio.wait_for(q.get(), timeout)

cbs: dict[SignalR, Callback] = {}
for signal in signals:
Expand All @@ -504,13 +523,17 @@ def queue_value(value: SignalDatatypeT, signal=signal):

if done_status is not None:
done_status.add_callback(q.put_nowait)

overall_deadline = time.monotonic() + done_timeout if done_timeout else None
try:
while True:
# yield here in case something else is filling the queue
# like in test_observe_value_times_out_with_no_external_task()
await asyncio.sleep(0)
item = await get_value()
if overall_deadline and time.monotonic() >= overall_deadline:
raise asyncio.TimeoutError(
f"observe_value was still observing signals "
f"{[signal.source for signal in signals]} after "
f"timeout {done_timeout}s"
)
iteration_timeout = _get_iteration_timeout(timeout, overall_deadline)
item = await asyncio.wait_for(q.get(), iteration_timeout)
if done_status and item is done_status:
if exc := done_status.exception():
raise exc
Expand Down
Loading

0 comments on commit c90e403

Please sign in to comment.