Skip to content

Commit

Permalink
swapped to generic TypedDict
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Nov 25, 2024
1 parent 1f5366d commit 435e56f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 32 deletions.
43 changes: 23 additions & 20 deletions src/ophyd_async/core/_derived_signal.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
import asyncio
import dataclasses
from abc import abstractmethod
from typing import Generic, Self, TypeVar, get_args
from typing import Generic, TypedDict, TypeVar, get_args

from ._device import Device
from ._protocol import AsyncMovable
from ._signal import SignalR, SignalRW
from ._signal_backend import SignalBackend, SignalDatatypeT


@dataclasses.dataclass
class TransformArgument(Generic[SignalDatatypeT]):
@classmethod
async def get_dataclass_from_signals(cls, device: Device) -> Self:
coros = {}
for field in dataclasses.fields(cls):
sig = getattr(device, field.name)
assert isinstance(
sig, SignalR
), f"{device.name}.{field.name} is {sig}, not a Signal"
coros[field.name] = sig.get_value()
results = await asyncio.gather(*coros.values())
kwargs = dict(zip(coros, results, strict=True))
return cls(**kwargs)
class TransformArgument(TypedDict, Generic[SignalDatatypeT]):
pass


T = TypeVar("T", bound=TransformArgument)


async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T:
coros = {}
for name in cls.__annotations__:
signal = getattr(device, name)
assert isinstance(
signal, SignalR
), f"{device.name}.{name} is {signal}, not a Signal"
coros[name] = signal.get_value()
results = await asyncio.gather(*coros.values())
kwargs = dict(zip(coros, results, strict=True))
return cls(**kwargs)


RawT = TypeVar("RawT", bound=TransformArgument)
DerivedT = TypeVar("DerivedT", bound=TransformArgument)
ParametersT = TypeVar("ParametersT", bound=TransformArgument)
ParametersT = TypeVar("ParametersT")


class TransformMeta(type):
Expand Down Expand Up @@ -67,12 +70,12 @@ def __init__(
self._transform = transform

async def get_parameters(self) -> ParametersT:
return await self._transform.parameters_cls.get_dataclass_from_signals(
self._device
return await _get_dataclass_from_signals(
self._transform.parameters_cls, self._device
)

async def get_raw_values(self) -> RawT:
return await self._transform.raw_cls.get_dataclass_from_signals(self._device)
return await _get_dataclass_from_signals(self._transform.raw_cls, self._device)

async def get_derived_values(self) -> DerivedT:
raw, parameters = await asyncio.gather(
Expand Down
20 changes: 8 additions & 12 deletions tests/core/test_derived_signal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from dataclasses import dataclass
from typing import TypeVar

import numpy as np
Expand Down Expand Up @@ -36,7 +35,7 @@ class SomeTransform1(Transform[Raw, Derived]): ... # type: ignore
TypeError,
match=(
"Transform classes must be defined with Raw, Derived, "
"and Parameter args."
"and Parameter `TransformArgument`s."
),
):

Expand All @@ -54,39 +53,36 @@ class SomeTransform(Transform[Raw, Derived, Parameters]): ...
F = TypeVar("F", float, Array1D[np.float64])


@dataclass
class SlitsRaw(TransformArgument[F]):
top: F
bottom: F


@dataclass
class SlitsDerived(TransformArgument[F]):
gap: F
centre: F


@dataclass
class SlitsParameters(TransformArgument[float]):
class SlitsParameters(TransformArgument):
gap_offset: float


class SlitsTransform(Transform[SlitsRaw[F], SlitsDerived[F], SlitsParameters]):
@classmethod
def forward(cls, raw: SlitsRaw[F], parameters: SlitsParameters) -> SlitsDerived[F]:
return SlitsDerived(
gap=raw.top - raw.bottom + parameters.gap_offset,
centre=(raw.top + raw.bottom) / 2,
gap=raw["top"] - raw["bottom"] + parameters["gap_offset"],
centre=(raw["top"] + raw["bottom"]) / 2,
)

@classmethod
def inverse(
cls, derived: SlitsDerived[F], parameters: SlitsParameters
) -> SlitsRaw[F]:
half_gap = (derived.gap - parameters.gap_offset) / 2
half_gap = (derived["gap"] - parameters["gap_offset"]) / 2
return SlitsRaw(
top=derived.centre + half_gap,
bottom=derived.centre - half_gap,
top=derived["centre"] + half_gap,
bottom=derived["centre"] - half_gap,
)


Expand All @@ -106,7 +102,7 @@ def __init__(self, name=""):
@AsyncStatus.wrap
async def set(self, derived: SlitsDerived[float]) -> None:
raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived)
await asyncio.gather(self.top.set(raw.top), self.bottom.set(raw.bottom))
await asyncio.gather(self.top.set(raw["top"]), self.bottom.set(raw["bottom"]))


async def test_derived_signals():
Expand Down

0 comments on commit 435e56f

Please sign in to comment.