Skip to content

Commit

Permalink
Correct implementation from recent changes to ophy-async
Browse files Browse the repository at this point in the history
  • Loading branch information
tomtrafford committed Oct 8, 2024
1 parent 390b120 commit 7d4244e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 40 deletions.
58 changes: 24 additions & 34 deletions src/ophyd_async/fastcs/panda/_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ophyd_async.epics import motor

from ._block import PcompBlock, PcompDirectionOptions, SeqBlock, TimeUnits
from ._table import SeqTable, SeqTableRow, SeqTrigger, seq_table_from_rows
from ._table import SeqTable, SeqTrigger


class SeqTableInfo(BaseModel):
Expand All @@ -18,7 +18,7 @@ class SeqTableInfo(BaseModel):


class ScanSpecInfo(BaseModel):
spec: Spec[motor.Motor] = Field()
spec: Spec = Field(default=None)
deadtime: float = Field()


Expand Down Expand Up @@ -49,12 +49,11 @@ async def stop(self):
await wait_for_value(self.seq.active, False, timeout=1)


class ScanSpecSeqTableTriggerLogic(TriggerLogic[ScanSpecInfo]):
class ScanSpecSeqTableTriggerLogic(FlyerController[ScanSpecInfo]):
def __init__(self, seq: SeqBlock, name="") -> None:
self.seq = seq
self.name = name

@AsyncStatus.wrap
async def prepare(self, value: ScanSpecInfo):
await asyncio.gather(
self.seq.prescale_units.set(TimeUnits.us),
Expand All @@ -70,16 +69,11 @@ async def prepare(self, value: ScanSpecInfo):
fast_axis = chunk.axes()[len(chunk.axes()) - 2]
gaps = np.append(gaps, scan_size)
start = 0
rows: SeqTableRow = ()
# Wait for GPIO to go low
rows = SeqTable.row(trigger=SeqTrigger.BITA_0)
for gap in gaps:
# Wait for GPIO to go low
rows += (SeqTableRow(trigger=SeqTrigger.BITA_0),)
# Wait for GPIO to go high
rows += (
SeqTableRow(
trigger=SeqTrigger.BITA_1,
),
)
rows += SeqTable.row(trigger=SeqTrigger.BITA_1)
# Wait for position
if chunk.midpoints[fast_axis][gap - 1] > chunk.midpoints[fast_axis][start]:
trig = SeqTrigger.POSA_GT
Expand All @@ -88,42 +82,38 @@ async def prepare(self, value: ScanSpecInfo):
else:
trig = SeqTrigger.POSA_LT
dir = True
rows += (
SeqTableRow(
trigger=trig,
position=chunk.lower[fast_axis][start]
/ await fast_axis.encoder_res.get_value(),
),
rows += SeqTable.row(
trigger=trig,
position=chunk.lower[fast_axis][start]
/ await fast_axis.encoder_res.get_value(),
)

# Time based triggers
rows += (
SeqTableRow(
repeats=gap - start,
trigger=SeqTrigger.IMMEDIATE,
time1=(chunk.midpoints["DURATION"][0] - value.deadtime) * 10**6,
time2=value.deadtime * 10**6,
outa1=True,
outb1=dir,
outa2=False,
outb2=False,
),
rows += SeqTable.row(
repeats=gap - start,
trigger=SeqTrigger.IMMEDIATE,
time1=(chunk.midpoints["DURATION"][0] - value.deadtime) * 10**6,
time2=int(value.deadtime * 10**6),
outa1=True,
outb1=dir,
outa2=False,
outb2=False,
)

# Wait for GPIO to go low
rows += SeqTable.row(trigger=SeqTrigger.BITA_0)

start = gap
table: SeqTable = seq_table_from_rows(rows)
await asyncio.gather(
self.seq.prescale.set(0),
self.seq.repeats.set(1),
self.seq.table.set(table),
self.seq.table.set(rows),
)

@AsyncStatus.wrap
async def kickoff(self) -> None:
await self.seq.enable.set("ONE")
await wait_for_value(self.seq.active, True, timeout=1)

@WatchableAsyncStatus.wrap
async def complete(self) -> None:
await wait_for_value(self.seq.active, False, timeout=None)

Expand All @@ -134,7 +124,7 @@ async def stop(self):
def _calculate_gaps(self, chunk: Frames[motor.Motor]):
inds = np.argwhere(chunk.gap)
if len(inds) == 0:
return len(chunk)
return [len(chunk)]
else:
return inds

Expand Down
12 changes: 6 additions & 6 deletions tests/fastcs/panda/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
StaticPcompTriggerLogic,
StaticSeqTableTriggerLogic,
)
from ophyd_async.fastcs.panda._table import SeqTrigger


@pytest.fixture
Expand Down Expand Up @@ -106,8 +105,8 @@ async def test_seq_scanspec_trigger_logic(mock_panda, sim_x_motor, sim_y_motor)
trigger_logic = ScanSpecSeqTableTriggerLogic(mock_panda.seq[1])
await trigger_logic.prepare(info)
out = await trigger_logic.seq.table.get_value()
assert (out["repeats"] == [1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1, 5]).all()
assert out["trigger"] == [
assert (out.repeats == [1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1, 5, 1]).all()
assert out.trigger == [
SeqTrigger.BITA_0,
SeqTrigger.BITA_1,
SeqTrigger.POSA_GT,
Expand All @@ -120,10 +119,11 @@ async def test_seq_scanspec_trigger_logic(mock_panda, sim_x_motor, sim_y_motor)
SeqTrigger.BITA_1,
SeqTrigger.POSA_GT,
SeqTrigger.IMMEDIATE,
SeqTrigger.BITA_0,
]
assert (out["position"] == [0, 0, 2, 0, 0, 0, 27, 0, 0, 0, 2, 0]).all()
assert (out["time1"] == [0, 0, 0, 900000, 0, 0, 0, 900000, 0, 0, 0, 900000]).all()
assert (out["time2"] == [0, 0, 0, 100000, 0, 0, 0, 100000, 0, 0, 0, 100000]).all()
assert (out.position == [0, 0, 2, 0, 0, 0, 27, 0, 0, 0, 2, 0, 0]).all()
assert (out.time1 == [0, 0, 0, 900000, 0, 0, 0, 900000, 0, 0, 0, 900000, 0]).all()
assert (out.time2 == [0, 0, 0, 100000, 0, 0, 0, 100000, 0, 0, 0, 100000, 0]).all()


@pytest.mark.parametrize(
Expand Down

0 comments on commit 7d4244e

Please sign in to comment.