diff --git a/pyproject.toml b/pyproject.toml index 6d0e0324bd..fb0c6eb53d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dev = [ "pytest-rerunfailures", "pytest-timeout", "ruff", + "scanspec==0.7.2", "sphinx<7.4.0", # https://github.com/bluesky/ophyd-async/issues/459 "sphinx-autobuild", "autodoc-pydantic", diff --git a/src/ophyd_async/epics/motor.py b/src/ophyd_async/epics/motor.py index 645f6e224b..13191df88f 100644 --- a/src/ophyd_async/epics/motor.py +++ b/src/ophyd_async/epics/motor.py @@ -79,6 +79,7 @@ def __init__(self, prefix: str, name="") -> None: self.high_limit_travel = epics_signal_rw(float, prefix + ".HLM") self.motor_stop = epics_signal_x(prefix + ".STOP") + self.encoder_res = epics_signal_rw(float, prefix + ".ERES") # Whether set() should complete successfully or not self._set_success = True diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 9d1c1d429f..5bad33640c 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -22,6 +22,8 @@ ) from ._trigger import ( PcompInfo, + ScanSpecInfo, + ScanSpecSeqTableTriggerLogic, SeqTableInfo, StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, @@ -37,6 +39,8 @@ "PcompBlock", "PcompDirectionOptions", "PulseBlock", + "ScanSpecInfo", + "ScanSpecSeqTableTriggerLogic", "SeqBlock", "TimeUnits", "HDFPanda", diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index ec2c1a5b8b..bda51905e8 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -79,6 +79,8 @@ def seq_table_from_rows(*rows: SeqTableRow): """ Constructs a sequence table from a series of rows. """ + if type(rows[0]) is tuple: + rows = rows[0] return seq_table_from_arrays( repeats=np.array([row.repeats for row in rows], dtype=np.uint16), trigger=[row.trigger for row in rows], diff --git a/src/ophyd_async/fastcs/panda/_trigger.py b/src/ophyd_async/fastcs/panda/_trigger.py index c79988a381..2429f3fb9b 100644 --- a/src/ophyd_async/fastcs/panda/_trigger.py +++ b/src/ophyd_async/fastcs/panda/_trigger.py @@ -1,12 +1,20 @@ import asyncio from typing import Optional +import numpy as np from pydantic import BaseModel, Field +from scanspec.specs import Frames, Path, Spec -from ophyd_async.core import TriggerLogic, wait_for_value +from ophyd_async.core import ( + AsyncStatus, + TriggerLogic, + WatchableAsyncStatus, + wait_for_value, +) +from ophyd_async.epics import motor from ._block import PcompBlock, PcompDirectionOptions, SeqBlock, TimeUnits -from ._table import SeqTable +from ._table import SeqTable, SeqTableRow, SeqTrigger, seq_table_from_rows class SeqTableInfo(BaseModel): @@ -15,6 +23,11 @@ class SeqTableInfo(BaseModel): prescale_as_us: float = Field(default=1, ge=0) # microseconds +class ScanSpecInfo(BaseModel): + spec: Spec[motor.Motor] = Field() + deadtime: float = Field() + + class StaticSeqTableTriggerLogic(TriggerLogic[SeqTableInfo]): def __init__(self, seq: SeqBlock) -> None: self.seq = seq @@ -42,6 +55,96 @@ async def stop(self): await wait_for_value(self.seq.active, False, timeout=1) +class ScanSpecSeqTableTriggerLogic(TriggerLogic[ScanSpecInfo]): + def __init__(self, seq: SeqBlock, name="") -> None: + self.seq = seq + self.name = name + + @AsyncStatus.wrap + async def prepare(self, value: ScanSpecInfo): + await asyncio.gather( + self.seq.prescale_units.set(TimeUnits.us), + self.seq.enable.set("ZERO"), + ) + path = Path(value.spec.calculate()) + chunk = path.consume() + gaps = self._calculate_gaps(chunk) + if gaps[0] == 0: + gaps = np.delete(gaps, 0) + scan_size = len(chunk) + + fast_axis = chunk.axes()[len(chunk.axes()) - 2] + gaps = np.append(gaps, scan_size) + start = 0 + rows: SeqTableRow = () + for gap in gaps: + # Wait for GPIO to go low + rows += (SeqTableRow(trigger=SeqTrigger.BITA_0),) + # Wait for GPIO to go high + rows += ( + SeqTableRow( + trigger=SeqTrigger.BITA_1, + ), + ) + # Wait for position + if chunk.midpoints[fast_axis][gap - 1] > chunk.midpoints[fast_axis][start]: + trig = SeqTrigger.POSA_GT + dir = False + + else: + trig = SeqTrigger.POSA_LT + dir = True + rows += ( + SeqTableRow( + trigger=trig, + position=chunk.lower[fast_axis][start] + / await fast_axis.encoder_res.get_value(), + ), + ) + + # Time based triggers + rows += ( + SeqTableRow( + repeats=gap - start, + trigger=SeqTrigger.IMMEDIATE, + time1=(chunk.midpoints["DURATION"][0] - value.deadtime) * 10**6, + time2=value.deadtime * 10**6, + outa1=True, + outb1=dir, + outa2=False, + outb2=False, + ), + ) + + start = gap + table: SeqTable = seq_table_from_rows(rows) + await asyncio.gather( + self.seq.prescale.set(0), + self.seq.repeats.set(1), + self.seq.table.set(table), + ) + + @AsyncStatus.wrap + async def kickoff(self) -> None: + await self.seq.enable.set("ONE") + await wait_for_value(self.seq.active, True, timeout=1) + + @WatchableAsyncStatus.wrap + async def complete(self) -> None: + await wait_for_value(self.seq.active, False, timeout=None) + + async def stop(self): + await self.seq.enable.set("ZERO") + await wait_for_value(self.seq.active, False, timeout=1) + + def _calculate_gaps(self, chunk: Frames[motor.Motor]): + inds = np.argwhere(chunk.gap) + if len(inds) == 0: + return len(chunk) + else: + return inds + + class PcompInfo(BaseModel): start_postion: int = Field(description="start position in counts") pulse_width: int = Field(description="width of a single pulse in counts", gt=0) diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index 1a76614afa..7fd4701e40 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -3,14 +3,19 @@ import numpy as np import pytest from pydantic import ValidationError +from scanspec.specs import Line, fly from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, set_mock_value +from ophyd_async.epics import motor from ophyd_async.epics.pvi import fill_pvi_entries from ophyd_async.fastcs.panda import ( CommonPandaBlocks, PcompInfo, + ScanSpecInfo, + ScanSpecSeqTableTriggerLogic, SeqTable, SeqTableInfo, + SeqTrigger, StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, ) @@ -71,6 +76,52 @@ async def set_active(value: bool): await asyncio.gather(trigger_logic.complete(), set_active(False)) +@pytest.fixture +async def sim_x_motor(): + async with DeviceCollector(mock=True): + sim_motor = motor.Motor("BLxxI-MO-STAGE-01:X", name="sim_x_motor") + + set_mock_value(sim_motor.encoder_res, 0.2) + + yield sim_motor + + +@pytest.fixture +async def sim_y_motor(): + async with DeviceCollector(mock=True): + sim_motor = motor.Motor("BLxxI-MO-STAGE-01:Y", name="sim_x_motor") + + set_mock_value(sim_motor.encoder_res, 0.2) + + yield sim_motor + + +async def test_seq_scanspec_trigger_logic(mock_panda, sim_x_motor, sim_y_motor) -> None: + spec = fly(Line(sim_y_motor, 1, 2, 3) * ~Line(sim_x_motor, 1, 5, 5), 1) + info = ScanSpecInfo(spec=spec, deadtime=0.1) + trigger_logic = ScanSpecSeqTableTriggerLogic(mock_panda.seq[1]) + await trigger_logic.prepare(info) + out = await trigger_logic.seq.table.get_value() + assert (out["repeats"] == [1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1, 5]).all() + assert out["trigger"] == [ + SeqTrigger.BITA_0, + SeqTrigger.BITA_1, + SeqTrigger.POSA_GT, + SeqTrigger.IMMEDIATE, + SeqTrigger.BITA_0, + SeqTrigger.BITA_1, + SeqTrigger.POSA_LT, + SeqTrigger.IMMEDIATE, + SeqTrigger.BITA_0, + SeqTrigger.BITA_1, + SeqTrigger.POSA_GT, + SeqTrigger.IMMEDIATE, + ] + assert (out["position"] == [0, 0, 2, 0, 0, 0, 27, 0, 0, 0, 2, 0]).all() + assert (out["time1"] == [0, 0, 0, 900000, 0, 0, 0, 900000, 0, 0, 0, 900000]).all() + assert (out["time2"] == [0, 0, 0, 100000, 0, 0, 0, 100000, 0, 0, 0, 100000]).all() + + @pytest.mark.parametrize( ["kwargs", "error_msg"], [