diff --git a/check.sh b/check.sh index 9bac561..e24f01e 100755 --- a/check.sh +++ b/check.sh @@ -3,9 +3,9 @@ echo "Black:" && black --check widget_state tests && echo "" && -echo "MyPy:" && -mypy widget_state tests && -echo "" && +# echo "MyPy:" && +# mypy widget_state tests && +# echo "" && echo "Flake8:" && flake8 --max-line-length 127 widget_state tests && echo "" && diff --git a/widget_state/__init__.py b/widget_state/__init__.py index e31f64b..f468bf6 100644 --- a/widget_state/__init__.py +++ b/widget_state/__init__.py @@ -14,7 +14,7 @@ ObjectState, ) from .dict_state import DictState -from .higher_order_state import HigherOrderState +from .higher_order_state import HigherOrderState, computed from .list_state import ListState from .state import State from .types import Serializable, Primitive @@ -35,4 +35,5 @@ "Serializable", "Primitive", "computed_state", + "computed", ] diff --git a/widget_state/higher_order_state.py b/widget_state/higher_order_state.py index 860033b..ff34eb2 100644 --- a/widget_state/higher_order_state.py +++ b/widget_state/higher_order_state.py @@ -67,17 +67,28 @@ class HigherOrderState(State): def __init__(self): super().__init__() - self._computed_states = {} + self._computed_states = dict( + filter( + lambda member: inspect.ismethod(member[1]) + and hasattr(member[1], "is_computed_state"), + inspect.getmembers(self), + ) + ) + + def _update_computed_state(self, name: str) -> None: + func = self._computed_states[name] + params = list(map(lambda param_name: getattr(self, param_name), func.params)) + self.__dict__[name].copy_from(func(*params)) - def __setattr__(self, name: str, new_value: Union[Any, State]) -> None: + def __setattr__(self, name: str, value: Union[Any, State]) -> None: # ignore private attributes (begin with an underscore) if name[0] == "_": - super().__setattr__(name, new_value) + super().__setattr__(name, value) return # wrap non-state values into states - if not isinstance(new_value, State): - new_value = BASIC_STATE_DICT.get(type(new_value), ObjectState)(new_value) + if not isinstance(value, State): + value = BASIC_STATE_DICT.get(type(value), ObjectState)(value) # assert that states are not reassigned as only their values should change assert not hasattr(self, name) or callable( @@ -85,17 +96,44 @@ def __setattr__(self, name: str, new_value: Union[Any, State]) -> None: ), f"Reassignment of value {name} in state {self}" # assert that all attributes are states assert isinstance( - new_value, State - ), f"Values of higher states must be states not {type(new_value)}" + value, State + ), f"Values of higher states must be states not {type(value)}" # update the attribute - super().__setattr__(name, new_value) + super().__setattr__(name, value) # set self as parent of added state to build a state hierarchy - new_value._parent = self + value._parent = self # register notification to the internal state - new_value.on_change(lambda _: self.notify_change()) + value.on_change(lambda _: self.notify_change()) + + # update computed states + for computed_state_name, func in self._computed_states.items(): + # skip computed states that are already initialized + if computed_state_name in self.__dict__: + continue + + # skip computed states whose params are not yet available + if not all(map(lambda param_name: hasattr(self, param_name), func.params)): + continue + + # initialize computed state + params = list( + map(lambda param_name: getattr(self, param_name), func.params) + ) + self.__dict__[computed_state_name] = func(*params) + + # re-compute every time a parameter changes + for param in params: + # print( + # f" - Regster callback with {computed_state_name=} for param {param}" + # ) + param.on_change( + lambda _, _name=computed_state_name: self._update_computed_state( + _name + ) + ) def dict(self) -> dict[str, State]: """