Skip to content

Commit

Permalink
Merge pull request #2 from pLeminoq/feature/improve_computed_states
Browse files Browse the repository at this point in the history
Feature/improve computed states
  • Loading branch information
pLeminoq authored Dec 13, 2024
2 parents c2a36cd + 0b6e177 commit e32ccef
Show file tree
Hide file tree
Showing 16 changed files with 237 additions and 65 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
- name: Validate formatting with black
run: |
black --check widget_state tests
- name: Validate types with mypy
run: |
mypy widget_state tests
# - name: Validate types with mypy
# run: |
# mypy widget_state tests
- name: Lint with flake8
run: |
flake8 --max-line-length 127 widget_state tests
Expand Down
6 changes: 3 additions & 3 deletions check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
echo "Black:" &&
black --check widget_state tests &&
echo "" &&
echo "MyPy:" &&
mypy widget_state tests &&
echo "" &&
# echo "MyPy:" &&
# mypy widget_state tests &&
# echo "" &&
echo "Flake8:" &&
flake8 --max-line-length 127 widget_state tests &&
echo "" &&
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "widget_state"
version = "0.0.3"
version = "0.0.4"
authors = [
{ name="Tamino Huxohl", email="[email protected]" },
]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_basic_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def test_depends_on(callback: MockCallback) -> None:
list_state = ListState([IntState(1), IntState(2)])
float_state = FloatState(3.5)

def compute_sum() -> float:
def compute_sum() -> FloatState:
_sum = sum(map(lambda _state: _state.value, [float_state, *list_state]))
assert isinstance(_sum, float)
return _sum
return FloatState(_sum)

res_state.depends_on(
[list_state, float_state],
compute_value=compute_sum,
element_wise=True,
kwargs={list_state: {"element_wise": True}},
)
assert res_state.value == (1 + 2 + 3.5)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_dict_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class VectorState(DictState):

def __init__(self, x: int, y: int, z: int):
super().__init__()

Expand Down Expand Up @@ -36,3 +35,7 @@ def test_set(vector_state: VectorState) -> None:
assert isinstance(vector_state.x, IntState) and vector_state.x.value == 1
assert isinstance(vector_state.y, IntState) and vector_state.y.value == 2
assert isinstance(vector_state.z, IntState) and vector_state.z.value == 3

vector_state.set(*[], y=IntState(10), z=5)
assert isinstance(vector_state.y, IntState) and vector_state.y.value == 10
assert isinstance(vector_state.z, IntState) and vector_state.z.value == 5
33 changes: 31 additions & 2 deletions tests/test_higher_order_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
StringState,
ObjectState,
HigherOrderState,
computed,
)

from .util import MockCallback
Expand All @@ -22,14 +23,12 @@ def callback() -> MockCallback:


class NestedState(HigherOrderState):

def __init__(self) -> None:
super().__init__()
self.length = FloatState(3.141)


class SuperState(HigherOrderState):

def __init__(self) -> None:
super().__init__()
self.name = StringState("Higher")
Expand Down Expand Up @@ -99,3 +98,33 @@ def test_deserialize(super_state: SuperState, callback: MockCallback) -> None:
def test_to_str(super_state: SuperState) -> None:
assert super_state.to_str() == _str
assert str(super_state) == _str


def test_copy_from(super_state: SuperState) -> None:
new_state = SuperState()
new_state.name.value = "Test"
new_state.count.value = 2
new_state.nested.length.value = 2.71

super_state.copy_from(new_state)
assert super_state.name.value == "Test"
assert super_state.count.value == 2
assert super_state.nested.length.value == 2.71


def test_computed() -> None:
class ExampleState(HigherOrderState):
def __init__(self):
super().__init__()

self.a = IntState(0)
self.b = IntState(1)

@computed
def sum(self, a: IntState, b: IntState) -> IntState:
return IntState(a.value + b.value)

ex = ExampleState()
assert ex.sum.value == 1
ex.a.value = 5
assert ex.sum.value == 6
9 changes: 9 additions & 0 deletions tests/test_list_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,12 @@ def test_serialize(list_state: ListState) -> None:
def test_deserialize(list_state: ListState) -> None:
with pytest.raises(NotImplementedError):
list_state.deserialize([0, 1, 2])


def test_copy_from(list_state: ListState) -> None:
new_list: ListState[IntState] = ListState()
new_list.copy_from(list_state)

assert len(new_list) == len(list_state)
for i in range(len(new_list)):
assert new_list[i].value == list_state[i].value
5 changes: 5 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ def test_serialize(state: State) -> None:
def test_deserialize(state: State) -> None:
with pytest.raises(NotImplementedError):
state.deserialize(0)


def test_copy_from(state: State) -> None:
with pytest.raises(NotImplementedError):
state.copy_from(state)
1 change: 0 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class Sum(HigherOrderState):

def __init__(self) -> None:
super().__init__()

Expand Down
1 change: 0 additions & 1 deletion tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class MockCallback:

def __init__(self) -> None:
self.n_calls = 0
self.arg: Optional[State] = None
Expand Down
3 changes: 2 additions & 1 deletion widget_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
ObjectState,
)
from .dict_state import DictState
from .higher_order_state import HigherOrderState
from .higher_order_state import HigherOrderState, computed
from .list_state import ListState
from .state import State
from .types import Serializable, Primitive
Expand All @@ -35,4 +35,5 @@
"Serializable",
"Primitive",
"computed_state",
"computed",
]
41 changes: 7 additions & 34 deletions widget_state/basic_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Callable, Generic, Optional, TypeVar
from typing_extensions import Self

from .state import State
from .list_state import ListState
from .types import Serializable

T = TypeVar("T")
Expand Down Expand Up @@ -79,38 +78,6 @@ def set(self, value: T) -> None:
"""
self.value = value

def depends_on(
self,
states: Iterable[State],
compute_value: Callable[[], T],
element_wise: bool = False,
) -> None:
"""
Declare that this state depends on other states.
This state is updated by the `compute_value` callable whenever one of the
states it depends on changes.
Parameters
----------
states: iterator of states
the states self depends on
compute_value: callable
function which computes the value of this state
element_wise: bool
trigger on element-wise changes of `ListState`
"""
for state in states:
if isinstance(state, ListState):
state.on_change(
lambda _: self.set(compute_value()), element_wise=element_wise
)
continue

state.on_change(lambda _: self.set(compute_value()))

self.set(compute_value())

def transform(
self, self_to_other: Callable[[BasicState[T]], BasicState[R]]
) -> BasicState[R]:
Expand Down Expand Up @@ -147,6 +114,12 @@ def serialize(self) -> Serializable:
def deserialize(self, _dict: Serializable) -> None:
raise NotImplementedError("Unable to deserialize abstract basic state")

def copy_from(self, other: Self) -> None:
assert type(self) is type(
other
), "`copy_from` needs other[type(self)] to be same type as self[{type(self)}]"
self.value = other.value


class IntState(BasicState[int]):
"""
Expand Down
16 changes: 13 additions & 3 deletions widget_state/dict_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,23 @@ def values(self) -> list[Any]:
"""
return [attr.value for attr in self]

def set(self, *args: BasicState[Any] | Primitive) -> None:
def set(
self,
*args: BasicState[Any] | Primitive,
**kwargs: BasicState[Any] | Primitive,
) -> None:
"""
Reassign all internal basic state values and only
Reassign internal basic state values and only
trigger a notification afterwards.
"""
assert len(args) == len(self)
assert len(args) <= len(self)

with self:
for i, arg in enumerate(args):
self[i].value = arg.value if isinstance(arg, BasicState) else arg

_dict = self.dict()
for name, kwarg in kwargs.items():
attr = _dict[name]
assert isinstance(attr, BasicState)
attr.value = kwarg.value if isinstance(kwarg, BasicState) else kwarg
Loading

0 comments on commit e32ccef

Please sign in to comment.