diff --git a/src/ophyd_async/core/_derived_signal.py b/src/ophyd_async/core/_derived_signal.py index bb1f953139..044c30f90f 100644 --- a/src/ophyd_async/core/_derived_signal.py +++ b/src/ophyd_async/core/_derived_signal.py @@ -1,7 +1,6 @@ import asyncio -import dataclasses from abc import abstractmethod -from typing import Generic, Self, TypeVar, get_args +from typing import Generic, TypedDict, TypeVar, get_args from ._device import Device from ._protocol import AsyncMovable @@ -9,25 +8,26 @@ from ._signal_backend import SignalBackend, SignalDatatypeT -@dataclasses.dataclass -class TransformArgument(Generic[SignalDatatypeT]): - @classmethod - async def get_dataclass_from_signals(cls, device: Device) -> Self: - coros = {} - for field in dataclasses.fields(cls): - sig = getattr(device, field.name) - assert isinstance( - sig, SignalR - ), f"{device.name}.{field.name} is {sig}, not a Signal" - coros[field.name] = sig.get_value() - results = await asyncio.gather(*coros.values()) - kwargs = dict(zip(coros, results, strict=True)) - return cls(**kwargs) +class TransformArgument(TypedDict, Generic[SignalDatatypeT]): ... + + +T = TypeVar("T", bound=TransformArgument) + + +async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T: + coros = {} + for name in cls.__annotations__: + sig = getattr(device, name) + assert isinstance(sig, SignalR), f"{device.name}.{name} is {sig}, not a Signal" + coros[name] = sig.get_value() + results = await asyncio.gather(*coros.values()) + kwargs = dict(zip(coros, results, strict=True)) + return cls(**kwargs) RawT = TypeVar("RawT", bound=TransformArgument) DerivedT = TypeVar("DerivedT", bound=TransformArgument) -ParametersT = TypeVar("ParametersT", bound=TransformArgument) +ParametersT = TypeVar("ParametersT") class TransformMeta(type): @@ -67,12 +67,12 @@ def __init__( self._transform = transform async def get_parameters(self) -> ParametersT: - return await self._transform.parameters_cls.get_dataclass_from_signals( - self._device + return await _get_dataclass_from_signals( + self._transform.parameters_cls, self._device ) async def get_raw_values(self) -> RawT: - return await self._transform.raw_cls.get_dataclass_from_signals(self._device) + return await _get_dataclass_from_signals(self._transform.raw_cls, self._device) async def get_derived_values(self) -> DerivedT: raw, parameters = await asyncio.gather( diff --git a/tests/core/test_derived_signal.py b/tests/core/test_derived_signal.py index cf90bc4437..fd1c68faea 100644 --- a/tests/core/test_derived_signal.py +++ b/tests/core/test_derived_signal.py @@ -1,5 +1,4 @@ import asyncio -from dataclasses import dataclass from typing import TypeVar import numpy as np @@ -36,7 +35,7 @@ class SomeTransform1(Transform[Raw, Derived]): ... # type: ignore TypeError, match=( "Transform classes must be defined with Raw, Derived, " - "and Parameter args." + "and Parameter `TransformArgument`s." ), ): @@ -54,20 +53,17 @@ class SomeTransform(Transform[Raw, Derived, Parameters]): ... F = TypeVar("F", float, Array1D[np.float64]) -@dataclass class SlitsRaw(TransformArgument[F]): top: F bottom: F -@dataclass class SlitsDerived(TransformArgument[F]): gap: F centre: F -@dataclass -class SlitsParameters(TransformArgument[float]): +class SlitsParameters(TransformArgument): gap_offset: float @@ -75,18 +71,18 @@ class SlitsTransform(Transform[SlitsRaw[F], SlitsDerived[F], SlitsParameters]): @classmethod def forward(cls, raw: SlitsRaw[F], parameters: SlitsParameters) -> SlitsDerived[F]: return SlitsDerived( - gap=raw.top - raw.bottom + parameters.gap_offset, - centre=(raw.top + raw.bottom) / 2, + gap=raw["top"] - raw["bottom"] + parameters["gap_offset"], + centre=(raw["top"] + raw["bottom"]) / 2, ) @classmethod def inverse( cls, derived: SlitsDerived[F], parameters: SlitsParameters ) -> SlitsRaw[F]: - half_gap = (derived.gap - parameters.gap_offset) / 2 + half_gap = (derived["gap"] - parameters["gap_offset"]) / 2 return SlitsRaw( - top=derived.centre + half_gap, - bottom=derived.centre - half_gap, + top=derived["centre"] + half_gap, + bottom=derived["centre"] - half_gap, ) @@ -106,7 +102,7 @@ def __init__(self, name=""): @AsyncStatus.wrap async def set(self, derived: SlitsDerived[float]) -> None: raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived) - await asyncio.gather(self.top.set(raw.top), self.bottom.set(raw.bottom)) + await asyncio.gather(self.top.set(raw["top"]), self.bottom.set(raw["bottom"])) async def test_derived_signals():