diff --git a/widget_state/basic_state.py b/widget_state/basic_state.py index ff41bdc..9d6143c 100644 --- a/widget_state/basic_state.py +++ b/widget_state/basic_state.py @@ -6,14 +6,16 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Callable, Optional +from typing import Any, Callable, Generic, Optional, TypeVar from .state import State from .list_state import ListState from .types import Serializable +T = TypeVar("T") +R = TypeVar("R") -class BasicState(State): +class BasicState(State, Generic[T]): """ A basic state contains a single value. @@ -22,7 +24,7 @@ class BasicState(State): if the value changed on reassignment. """ - def __init__(self, value: Any, verify_change: bool = True) -> None: + def __init__(self, value: T, verify_change: bool = True) -> None: """ Initialize a basic state: @@ -39,7 +41,7 @@ def __init__(self, value: Any, verify_change: bool = True) -> None: self.value = value - def __setattr__(self, name: str, new_value: Any) -> None: + def __setattr__(self, name: str, new_value: T) -> None: # ignore private attributes (begin with an underscore) if name[0] == "_": super().__setattr__(name, new_value) @@ -63,7 +65,7 @@ def __setattr__(self, name: str, new_value: Any) -> None: # notify that the value changed self.notify_change() - def set(self, value: Any) -> None: + def set(self, value: T) -> None: """ Simple function for the assignment of the value. @@ -79,7 +81,7 @@ def set(self, value: Any) -> None: def depends_on( self, states: Iterable[State], - compute_value: Callable[[], Any], + compute_value: Callable[[], T], element_wise: bool = False, ) -> None: """ @@ -109,8 +111,8 @@ def depends_on( self.set(compute_value()) def transform( - self, self_to_other: Callable[[BasicState], BasicState] - ) -> BasicState: + self, self_to_other: Callable[[BasicState[T]], BasicState[R]] + ) -> BasicState[R]: """ Transform this state into another state. @@ -141,11 +143,11 @@ def __repr__(self) -> str: def serialize(self) -> Serializable: raise NotImplementedError("Unable to serialize abstract basic state") - def deserialize(self, _dict: Serializable): + def deserialize(self, _dict: Serializable) -> None: raise NotImplementedError("Unable to deserialize abstract basic state") -class IntState(BasicState): +class IntState(BasicState[int]): """ Implementation of the `BasicState` for an int. """ @@ -158,7 +160,7 @@ def serialize(self) -> int: return self.value -class FloatState(BasicState): +class FloatState(BasicState[float]): """ Implementation of the `BasicState` for a float. @@ -182,7 +184,7 @@ def serialize(self) -> float: return self.value -class StringState(BasicState): +class StringState(BasicState[str]): """ Implementation of the `BasicState` for a string. """ @@ -198,7 +200,7 @@ def __repr__(self) -> str: return f'{type(self).__name__}[value="{self.value}"]' -class BoolState(BasicState): +class BoolState(BasicState[bool]): """ Implementation of the `BasicState` for a bool. """ @@ -211,7 +213,7 @@ def serialize(self) -> bool: return self.value -class ObjectState(BasicState): +class ObjectState(BasicState[Any]): """ Implementation of the `BasicState` for objects. diff --git a/widget_state/dict_state.py b/widget_state/dict_state.py index e11bdc2..855bf8f 100644 --- a/widget_state/dict_state.py +++ b/widget_state/dict_state.py @@ -27,7 +27,7 @@ def __init__(self) -> None: super().__init__() self._labels: list[str] = [] - def __setattr__(self, name: str, new_value: Any | BasicState) -> None: + def __setattr__(self, name: str, new_value: Any | BasicState[Any]) -> None: super().__setattr__(name, new_value) if name[0] == "_": @@ -36,12 +36,12 @@ def __setattr__(self, name: str, new_value: Any | BasicState) -> None: if name not in self._labels: self._labels.append(name) - def __getitem__(self, i: int) -> BasicState: + def __getitem__(self, i: int) -> BasicState[Any]: item = self.__getattribute__(self._labels[i]) assert isinstance(item, BasicState) return item - def __iter__(self) -> Iterator[BasicState]: + def __iter__(self) -> Iterator[BasicState[Any]]: return iter(map(self.__getattribute__, self._labels)) def __len__(self) -> int: @@ -53,7 +53,7 @@ def values(self) -> list[Any]: """ return [attr.value for attr in self] - def set(self, *args: BasicState | Primitive) -> None: + def set(self, *args: BasicState[Any] | Primitive) -> None: """ Reassign all internal basic state values and only trigger a notification afterwards. @@ -62,4 +62,4 @@ def set(self, *args: BasicState | Primitive) -> None: with self: for i, arg in enumerate(args): - self[i].value = arg + self[i].value = arg.value if isinstance(arg, BasicState) else arg diff --git a/widget_state/list_state.py b/widget_state/list_state.py index d785501..f5f950c 100644 --- a/widget_state/list_state.py +++ b/widget_state/list_state.py @@ -7,7 +7,7 @@ from collections.abc import Iterator import typing -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Generic, Optional, TypeVar, Union from .state import State from .types import Serializable @@ -15,13 +15,14 @@ if typing.TYPE_CHECKING: from _typeshed import SupportsDunderLT, SupportsDunderGT +T = TypeVar("T", bound=State) class _ElementObserver: """ Utility class that keeps track of all callbacks observing element-wise changes of a list state. """ - def __init__(self, list_state: ListState) -> None: + def __init__(self, list_state: ListState[T]) -> None: """ Initialize an `_ElementObserver`. @@ -38,12 +39,12 @@ def __call__(self, state: State) -> None: cb(self._list_state) -class ListState(State): +class ListState(State, Generic[T]): """ A list of states. """ - def __init__(self, _list: Optional[list[State]] = None) -> None: + def __init__(self, _list: Optional[list[T]] = None) -> None: """ Initial a `ListState`. @@ -56,7 +57,7 @@ def __init__(self, _list: Optional[list[State]] = None) -> None: self._elem_obs = _ElementObserver(self) - self._list: list[State] = [] + self._list: list[T] = [] self.extend(_list if _list is not None else []) def on_change( @@ -82,7 +83,7 @@ def remove_callback( if cb in self._elem_obs._callbacks: self._elem_obs._callbacks.remove(cb) - def append(self, elem: State) -> None: + def append(self, elem: T) -> None: """ Append a `State` to the list and notify. @@ -110,7 +111,7 @@ def clear(self) -> None: self.notify_change() - def extend(self, _list: list[State]) -> None: + def extend(self, _list: list[T]) -> None: """ Extend the list and notify. @@ -124,7 +125,7 @@ def extend(self, _list: list[State]) -> None: for elem in _list: self.append(elem) - def insert(self, index: int, elem: State) -> None: + def insert(self, index: int, elem: T) -> None: """ Insert an element at `index` into the list and notify. @@ -165,7 +166,7 @@ def pop(self, index: int = -1) -> State: return elem - def remove(self, elem: State) -> None: + def remove(self, elem: T) -> None: """ Remove an element from the list and notify. @@ -189,7 +190,7 @@ def reverse(self) -> None: self.notify_change() def sort( - self, key: Callable[[State], SupportsDunderLT[Any] | SupportsDunderGT[Any]] + self, key: Callable[[T], SupportsDunderLT[Any] | SupportsDunderGT[Any]] ) -> None: """ Wrapper to the sort method of the internal list. @@ -205,10 +206,10 @@ def sort( self._list.sort(key=key) self.notify_change() - def __getitem__(self, i: int) -> State: + def __getitem__(self, i: int) -> T: return self._list[i] - def index(self, elem: State) -> int: + def index(self, elem: T) -> int: """ Wrapper to the index method of the internal list. @@ -223,7 +224,7 @@ def index(self, elem: State) -> int: """ return self._list.index(elem) - def __iter__(self) -> Iterator[State]: + def __iter__(self) -> Iterator[T]: return iter(self._list) def __len__(self) -> int: diff --git a/widget_state/util.py b/widget_state/util.py index 7f26660..4dd7b08 100644 --- a/widget_state/util.py +++ b/widget_state/util.py @@ -12,7 +12,7 @@ from .basic_state import BasicState from .state import State -T = TypeVar("T", bound=BasicState) +T = TypeVar("T", bound=BasicState[Any]) P = ParamSpec("P")