Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/improve computed states #2

Merged
merged 9 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
- name: Validate formatting with black
run: |
black --check widget_state tests
- name: Validate types with mypy
run: |
mypy widget_state tests
# - name: Validate types with mypy
# run: |
# mypy widget_state tests
- name: Lint with flake8
run: |
flake8 --max-line-length 127 widget_state tests
Expand Down
6 changes: 3 additions & 3 deletions check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
echo "Black:" &&
black --check widget_state tests &&
echo "" &&
echo "MyPy:" &&
mypy widget_state tests &&
echo "" &&
# echo "MyPy:" &&
# mypy widget_state tests &&
# echo "" &&
echo "Flake8:" &&
flake8 --max-line-length 127 widget_state tests &&
echo "" &&
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "widget_state"
version = "0.0.3"
version = "0.0.4"
authors = [
{ name="Tamino Huxohl", email="[email protected]" },
]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_basic_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def test_depends_on(callback: MockCallback) -> None:
list_state = ListState([IntState(1), IntState(2)])
float_state = FloatState(3.5)

def compute_sum() -> float:
def compute_sum() -> FloatState:
_sum = sum(map(lambda _state: _state.value, [float_state, *list_state]))
assert isinstance(_sum, float)
return _sum
return FloatState(_sum)

res_state.depends_on(
[list_state, float_state],
compute_value=compute_sum,
element_wise=True,
kwargs={list_state: {"element_wise": True}},
)
assert res_state.value == (1 + 2 + 3.5)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_dict_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class VectorState(DictState):

def __init__(self, x: int, y: int, z: int):
super().__init__()

Expand Down Expand Up @@ -36,3 +35,7 @@ def test_set(vector_state: VectorState) -> None:
assert isinstance(vector_state.x, IntState) and vector_state.x.value == 1
assert isinstance(vector_state.y, IntState) and vector_state.y.value == 2
assert isinstance(vector_state.z, IntState) and vector_state.z.value == 3

vector_state.set(*[], y=IntState(10), z=5)
assert isinstance(vector_state.y, IntState) and vector_state.y.value == 10
assert isinstance(vector_state.z, IntState) and vector_state.z.value == 5
33 changes: 31 additions & 2 deletions tests/test_higher_order_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
StringState,
ObjectState,
HigherOrderState,
computed,
)

from .util import MockCallback
Expand All @@ -22,14 +23,12 @@ def callback() -> MockCallback:


class NestedState(HigherOrderState):

def __init__(self) -> None:
super().__init__()
self.length = FloatState(3.141)


class SuperState(HigherOrderState):

def __init__(self) -> None:
super().__init__()
self.name = StringState("Higher")
Expand Down Expand Up @@ -99,3 +98,33 @@ def test_deserialize(super_state: SuperState, callback: MockCallback) -> None:
def test_to_str(super_state: SuperState) -> None:
assert super_state.to_str() == _str
assert str(super_state) == _str


def test_copy_from(super_state: SuperState) -> None:
new_state = SuperState()
new_state.name.value = "Test"
new_state.count.value = 2
new_state.nested.length.value = 2.71

super_state.copy_from(new_state)
assert super_state.name.value == "Test"
assert super_state.count.value == 2
assert super_state.nested.length.value == 2.71


def test_computed() -> None:
class ExampleState(HigherOrderState):
def __init__(self):
super().__init__()

self.a = IntState(0)
self.b = IntState(1)

@computed
def sum(self, a: IntState, b: IntState) -> IntState:
return IntState(a.value + b.value)

ex = ExampleState()
assert ex.sum.value == 1
ex.a.value = 5
assert ex.sum.value == 6
9 changes: 9 additions & 0 deletions tests/test_list_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,12 @@ def test_serialize(list_state: ListState) -> None:
def test_deserialize(list_state: ListState) -> None:
with pytest.raises(NotImplementedError):
list_state.deserialize([0, 1, 2])


def test_copy_from(list_state: ListState) -> None:
new_list: ListState[IntState] = ListState()
new_list.copy_from(list_state)

assert len(new_list) == len(list_state)
for i in range(len(new_list)):
assert new_list[i].value == list_state[i].value
5 changes: 5 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ def test_serialize(state: State) -> None:
def test_deserialize(state: State) -> None:
with pytest.raises(NotImplementedError):
state.deserialize(0)


def test_copy_from(state: State) -> None:
with pytest.raises(NotImplementedError):
state.copy_from(state)
1 change: 0 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class Sum(HigherOrderState):

def __init__(self) -> None:
super().__init__()

Expand Down
1 change: 0 additions & 1 deletion tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class MockCallback:

def __init__(self) -> None:
self.n_calls = 0
self.arg: Optional[State] = None
Expand Down
3 changes: 2 additions & 1 deletion widget_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
ObjectState,
)
from .dict_state import DictState
from .higher_order_state import HigherOrderState
from .higher_order_state import HigherOrderState, computed
from .list_state import ListState
from .state import State
from .types import Serializable, Primitive
Expand All @@ -35,4 +35,5 @@
"Serializable",
"Primitive",
"computed_state",
"computed",
]
41 changes: 7 additions & 34 deletions widget_state/basic_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Callable, Generic, Optional, TypeVar
from typing_extensions import Self

from .state import State
from .list_state import ListState
from .types import Serializable

T = TypeVar("T")
Expand Down Expand Up @@ -79,38 +78,6 @@ def set(self, value: T) -> None:
"""
self.value = value

def depends_on(
self,
states: Iterable[State],
compute_value: Callable[[], T],
element_wise: bool = False,
) -> None:
"""
Declare that this state depends on other states.

This state is updated by the `compute_value` callable whenever one of the
states it depends on changes.

Parameters
----------
states: iterator of states
the states self depends on
compute_value: callable
function which computes the value of this state
element_wise: bool
trigger on element-wise changes of `ListState`
"""
for state in states:
if isinstance(state, ListState):
state.on_change(
lambda _: self.set(compute_value()), element_wise=element_wise
)
continue

state.on_change(lambda _: self.set(compute_value()))

self.set(compute_value())

def transform(
self, self_to_other: Callable[[BasicState[T]], BasicState[R]]
) -> BasicState[R]:
Expand Down Expand Up @@ -147,6 +114,12 @@ def serialize(self) -> Serializable:
def deserialize(self, _dict: Serializable) -> None:
raise NotImplementedError("Unable to deserialize abstract basic state")

def copy_from(self, other: Self) -> None:
assert type(self) is type(
other
), "`copy_from` needs other[type(self)] to be same type as self[{type(self)}]"
self.value = other.value


class IntState(BasicState[int]):
"""
Expand Down
16 changes: 13 additions & 3 deletions widget_state/dict_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,23 @@ def values(self) -> list[Any]:
"""
return [attr.value for attr in self]

def set(self, *args: BasicState[Any] | Primitive) -> None:
def set(
self,
*args: BasicState[Any] | Primitive,
**kwargs: BasicState[Any] | Primitive,
) -> None:
"""
Reassign all internal basic state values and only
Reassign internal basic state values and only
trigger a notification afterwards.
"""
assert len(args) == len(self)
assert len(args) <= len(self)

with self:
for i, arg in enumerate(args):
self[i].value = arg.value if isinstance(arg, BasicState) else arg

_dict = self.dict()
for name, kwarg in kwargs.items():
attr = _dict[name]
assert isinstance(attr, BasicState)
attr.value = kwarg.value if isinstance(kwarg, BasicState) else kwarg
Loading
Loading