Skip to content

Commit

Permalink
add generics to typting
Browse files Browse the repository at this point in the history
  • Loading branch information
pLeminoq committed Dec 5, 2024
1 parent 93a43db commit 4e9ddd4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 33 deletions.
30 changes: 16 additions & 14 deletions widget_state/basic_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Callable, Optional
from typing import Any, Callable, Generic, Optional, TypeVar

from .state import State
from .list_state import ListState
from .types import Serializable

T = TypeVar("T")
R = TypeVar("R")

class BasicState(State):
class BasicState(State, Generic[T]):
"""
A basic state contains a single value.
Expand All @@ -22,7 +24,7 @@ class BasicState(State):
if the value changed on reassignment.
"""

def __init__(self, value: Any, verify_change: bool = True) -> None:
def __init__(self, value: T, verify_change: bool = True) -> None:
"""
Initialize a basic state:
Expand All @@ -39,7 +41,7 @@ def __init__(self, value: Any, verify_change: bool = True) -> None:

self.value = value

def __setattr__(self, name: str, new_value: Any) -> None:
def __setattr__(self, name: str, new_value: T) -> None:
# ignore private attributes (begin with an underscore)
if name[0] == "_":
super().__setattr__(name, new_value)
Expand All @@ -63,7 +65,7 @@ def __setattr__(self, name: str, new_value: Any) -> None:
# notify that the value changed
self.notify_change()

def set(self, value: Any) -> None:
def set(self, value: T) -> None:
"""
Simple function for the assignment of the value.
Expand All @@ -79,7 +81,7 @@ def set(self, value: Any) -> None:
def depends_on(
self,
states: Iterable[State],
compute_value: Callable[[], Any],
compute_value: Callable[[], T],
element_wise: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -109,8 +111,8 @@ def depends_on(
self.set(compute_value())

def transform(
self, self_to_other: Callable[[BasicState], BasicState]
) -> BasicState:
self, self_to_other: Callable[[BasicState[T]], BasicState[R]]
) -> BasicState[R]:
"""
Transform this state into another state.
Expand Down Expand Up @@ -141,11 +143,11 @@ def __repr__(self) -> str:
def serialize(self) -> Serializable:
raise NotImplementedError("Unable to serialize abstract basic state")

def deserialize(self, _dict: Serializable):
def deserialize(self, _dict: Serializable) -> None:
raise NotImplementedError("Unable to deserialize abstract basic state")


class IntState(BasicState):
class IntState(BasicState[int]):
"""
Implementation of the `BasicState` for an int.
"""
Expand All @@ -158,7 +160,7 @@ def serialize(self) -> int:
return self.value


class FloatState(BasicState):
class FloatState(BasicState[float]):
"""
Implementation of the `BasicState` for a float.
Expand All @@ -182,7 +184,7 @@ def serialize(self) -> float:
return self.value


class StringState(BasicState):
class StringState(BasicState[str]):
"""
Implementation of the `BasicState` for a string.
"""
Expand All @@ -198,7 +200,7 @@ def __repr__(self) -> str:
return f'{type(self).__name__}[value="{self.value}"]'


class BoolState(BasicState):
class BoolState(BasicState[bool]):
"""
Implementation of the `BasicState` for a bool.
"""
Expand All @@ -211,7 +213,7 @@ def serialize(self) -> bool:
return self.value


class ObjectState(BasicState):
class ObjectState(BasicState[Any]):
"""
Implementation of the `BasicState` for objects.
Expand Down
10 changes: 5 additions & 5 deletions widget_state/dict_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self) -> None:
super().__init__()
self._labels: list[str] = []

def __setattr__(self, name: str, new_value: Any | BasicState) -> None:
def __setattr__(self, name: str, new_value: Any | BasicState[Any]) -> None:
super().__setattr__(name, new_value)

if name[0] == "_":
Expand All @@ -36,12 +36,12 @@ def __setattr__(self, name: str, new_value: Any | BasicState) -> None:
if name not in self._labels:
self._labels.append(name)

def __getitem__(self, i: int) -> BasicState:
def __getitem__(self, i: int) -> BasicState[Any]:
item = self.__getattribute__(self._labels[i])
assert isinstance(item, BasicState)
return item

def __iter__(self) -> Iterator[BasicState]:
def __iter__(self) -> Iterator[BasicState[Any]]:
return iter(map(self.__getattribute__, self._labels))

def __len__(self) -> int:
Expand All @@ -53,7 +53,7 @@ def values(self) -> list[Any]:
"""
return [attr.value for attr in self]

def set(self, *args: BasicState | Primitive) -> None:
def set(self, *args: BasicState[Any] | Primitive) -> None:
"""
Reassign all internal basic state values and only
trigger a notification afterwards.
Expand All @@ -62,4 +62,4 @@ def set(self, *args: BasicState | Primitive) -> None:

with self:
for i, arg in enumerate(args):
self[i].value = arg
self[i].value = arg.value if isinstance(arg, BasicState) else arg
27 changes: 14 additions & 13 deletions widget_state/list_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@

from collections.abc import Iterator
import typing
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Generic, Optional, TypeVar, Union

from .state import State
from .types import Serializable

if typing.TYPE_CHECKING:
from _typeshed import SupportsDunderLT, SupportsDunderGT

T = TypeVar("T", bound=State)

class _ElementObserver:
"""
Utility class that keeps track of all callbacks observing element-wise changes of a list state.
"""

def __init__(self, list_state: ListState) -> None:
def __init__(self, list_state: ListState[T]) -> None:
"""
Initialize an `_ElementObserver`.
Expand All @@ -38,12 +39,12 @@ def __call__(self, state: State) -> None:
cb(self._list_state)


class ListState(State):
class ListState(State, Generic[T]):
"""
A list of states.
"""

def __init__(self, _list: Optional[list[State]] = None) -> None:
def __init__(self, _list: Optional[list[T]] = None) -> None:
"""
Initial a `ListState`.
Expand All @@ -56,7 +57,7 @@ def __init__(self, _list: Optional[list[State]] = None) -> None:

self._elem_obs = _ElementObserver(self)

self._list: list[State] = []
self._list: list[T] = []
self.extend(_list if _list is not None else [])

def on_change(
Expand All @@ -82,7 +83,7 @@ def remove_callback(
if cb in self._elem_obs._callbacks:
self._elem_obs._callbacks.remove(cb)

def append(self, elem: State) -> None:
def append(self, elem: T) -> None:
"""
Append a `State` to the list and notify.
Expand Down Expand Up @@ -110,7 +111,7 @@ def clear(self) -> None:

self.notify_change()

def extend(self, _list: list[State]) -> None:
def extend(self, _list: list[T]) -> None:
"""
Extend the list and notify.
Expand All @@ -124,7 +125,7 @@ def extend(self, _list: list[State]) -> None:
for elem in _list:
self.append(elem)

def insert(self, index: int, elem: State) -> None:
def insert(self, index: int, elem: T) -> None:
"""
Insert an element at `index` into the list and notify.
Expand Down Expand Up @@ -165,7 +166,7 @@ def pop(self, index: int = -1) -> State:

return elem

def remove(self, elem: State) -> None:
def remove(self, elem: T) -> None:
"""
Remove an element from the list and notify.
Expand All @@ -189,7 +190,7 @@ def reverse(self) -> None:
self.notify_change()

def sort(
self, key: Callable[[State], SupportsDunderLT[Any] | SupportsDunderGT[Any]]
self, key: Callable[[T], SupportsDunderLT[Any] | SupportsDunderGT[Any]]
) -> None:
"""
Wrapper to the sort method of the internal list.
Expand All @@ -205,10 +206,10 @@ def sort(
self._list.sort(key=key)
self.notify_change()

def __getitem__(self, i: int) -> State:
def __getitem__(self, i: int) -> T:
return self._list[i]

def index(self, elem: State) -> int:
def index(self, elem: T) -> int:
"""
Wrapper to the index method of the internal list.
Expand All @@ -223,7 +224,7 @@ def index(self, elem: State) -> int:
"""
return self._list.index(elem)

def __iter__(self) -> Iterator[State]:
def __iter__(self) -> Iterator[T]:
return iter(self._list)

def __len__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion widget_state/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .basic_state import BasicState
from .state import State

T = TypeVar("T", bound=BasicState)
T = TypeVar("T", bound=BasicState[Any])
P = ParamSpec("P")


Expand Down

0 comments on commit 4e9ddd4

Please sign in to comment.