From 7d4244eb1f02d52ab8a7370a20ea871c607d7ae0 Mon Sep 17 00:00:00 2001 From: Tom Trafford Date: Tue, 8 Oct 2024 09:39:22 +0000 Subject: [PATCH] Correct implementation from recent changes to ophy-async --- src/ophyd_async/fastcs/panda/_trigger.py | 58 ++++++++++-------------- tests/fastcs/panda/test_trigger.py | 12 ++--- 2 files changed, 30 insertions(+), 40 deletions(-) diff --git a/src/ophyd_async/fastcs/panda/_trigger.py b/src/ophyd_async/fastcs/panda/_trigger.py index 345500a237..a369d461c2 100644 --- a/src/ophyd_async/fastcs/panda/_trigger.py +++ b/src/ophyd_async/fastcs/panda/_trigger.py @@ -8,7 +8,7 @@ from ophyd_async.epics import motor from ._block import PcompBlock, PcompDirectionOptions, SeqBlock, TimeUnits -from ._table import SeqTable, SeqTableRow, SeqTrigger, seq_table_from_rows +from ._table import SeqTable, SeqTrigger class SeqTableInfo(BaseModel): @@ -18,7 +18,7 @@ class SeqTableInfo(BaseModel): class ScanSpecInfo(BaseModel): - spec: Spec[motor.Motor] = Field() + spec: Spec = Field(default=None) deadtime: float = Field() @@ -49,12 +49,11 @@ async def stop(self): await wait_for_value(self.seq.active, False, timeout=1) -class ScanSpecSeqTableTriggerLogic(TriggerLogic[ScanSpecInfo]): +class ScanSpecSeqTableTriggerLogic(FlyerController[ScanSpecInfo]): def __init__(self, seq: SeqBlock, name="") -> None: self.seq = seq self.name = name - @AsyncStatus.wrap async def prepare(self, value: ScanSpecInfo): await asyncio.gather( self.seq.prescale_units.set(TimeUnits.us), @@ -70,16 +69,11 @@ async def prepare(self, value: ScanSpecInfo): fast_axis = chunk.axes()[len(chunk.axes()) - 2] gaps = np.append(gaps, scan_size) start = 0 - rows: SeqTableRow = () + # Wait for GPIO to go low + rows = SeqTable.row(trigger=SeqTrigger.BITA_0) for gap in gaps: - # Wait for GPIO to go low - rows += (SeqTableRow(trigger=SeqTrigger.BITA_0),) # Wait for GPIO to go high - rows += ( - SeqTableRow( - trigger=SeqTrigger.BITA_1, - ), - ) + rows += SeqTable.row(trigger=SeqTrigger.BITA_1) # Wait for position if chunk.midpoints[fast_axis][gap - 1] > chunk.midpoints[fast_axis][start]: trig = SeqTrigger.POSA_GT @@ -88,42 +82,38 @@ async def prepare(self, value: ScanSpecInfo): else: trig = SeqTrigger.POSA_LT dir = True - rows += ( - SeqTableRow( - trigger=trig, - position=chunk.lower[fast_axis][start] - / await fast_axis.encoder_res.get_value(), - ), + rows += SeqTable.row( + trigger=trig, + position=chunk.lower[fast_axis][start] + / await fast_axis.encoder_res.get_value(), ) # Time based triggers - rows += ( - SeqTableRow( - repeats=gap - start, - trigger=SeqTrigger.IMMEDIATE, - time1=(chunk.midpoints["DURATION"][0] - value.deadtime) * 10**6, - time2=value.deadtime * 10**6, - outa1=True, - outb1=dir, - outa2=False, - outb2=False, - ), + rows += SeqTable.row( + repeats=gap - start, + trigger=SeqTrigger.IMMEDIATE, + time1=(chunk.midpoints["DURATION"][0] - value.deadtime) * 10**6, + time2=int(value.deadtime * 10**6), + outa1=True, + outb1=dir, + outa2=False, + outb2=False, ) + # Wait for GPIO to go low + rows += SeqTable.row(trigger=SeqTrigger.BITA_0) + start = gap - table: SeqTable = seq_table_from_rows(rows) await asyncio.gather( self.seq.prescale.set(0), self.seq.repeats.set(1), - self.seq.table.set(table), + self.seq.table.set(rows), ) - @AsyncStatus.wrap async def kickoff(self) -> None: await self.seq.enable.set("ONE") await wait_for_value(self.seq.active, True, timeout=1) - @WatchableAsyncStatus.wrap async def complete(self) -> None: await wait_for_value(self.seq.active, False, timeout=None) @@ -134,7 +124,7 @@ async def stop(self): def _calculate_gaps(self, chunk: Frames[motor.Motor]): inds = np.argwhere(chunk.gap) if len(inds) == 0: - return len(chunk) + return [len(chunk)] else: return inds diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index 3615b05204..cc05cdd188 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -19,7 +19,6 @@ StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, ) -from ophyd_async.fastcs.panda._table import SeqTrigger @pytest.fixture @@ -106,8 +105,8 @@ async def test_seq_scanspec_trigger_logic(mock_panda, sim_x_motor, sim_y_motor) trigger_logic = ScanSpecSeqTableTriggerLogic(mock_panda.seq[1]) await trigger_logic.prepare(info) out = await trigger_logic.seq.table.get_value() - assert (out["repeats"] == [1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1, 5]).all() - assert out["trigger"] == [ + assert (out.repeats == [1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1, 5, 1]).all() + assert out.trigger == [ SeqTrigger.BITA_0, SeqTrigger.BITA_1, SeqTrigger.POSA_GT, @@ -120,10 +119,11 @@ async def test_seq_scanspec_trigger_logic(mock_panda, sim_x_motor, sim_y_motor) SeqTrigger.BITA_1, SeqTrigger.POSA_GT, SeqTrigger.IMMEDIATE, + SeqTrigger.BITA_0, ] - assert (out["position"] == [0, 0, 2, 0, 0, 0, 27, 0, 0, 0, 2, 0]).all() - assert (out["time1"] == [0, 0, 0, 900000, 0, 0, 0, 900000, 0, 0, 0, 900000]).all() - assert (out["time2"] == [0, 0, 0, 100000, 0, 0, 0, 100000, 0, 0, 0, 100000]).all() + assert (out.position == [0, 0, 2, 0, 0, 0, 27, 0, 0, 0, 2, 0, 0]).all() + assert (out.time1 == [0, 0, 0, 900000, 0, 0, 0, 900000, 0, 0, 0, 900000, 0]).all() + assert (out.time2 == [0, 0, 0, 100000, 0, 0, 0, 100000, 0, 0, 0, 100000, 0]).all() @pytest.mark.parametrize(