diff --git a/widget_state/__init__.py b/widget_state/__init__.py index f468bf6..c6f7c49 100644 --- a/widget_state/__init__.py +++ b/widget_state/__init__.py @@ -18,7 +18,7 @@ from .list_state import ListState from .state import State from .types import Serializable, Primitive -from .util import computed_state +from .util import computed_state, compute __all__ = [ "BASIC_STATE_DICT", @@ -36,4 +36,5 @@ "Primitive", "computed_state", "computed", + "compute", ] diff --git a/widget_state/util.py b/widget_state/util.py index 4dd7b08..26cd7ef 100644 --- a/widget_state/util.py +++ b/widget_state/util.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Any, Callable, ParamSpec, TypeVar +from typing import Any, Callable, Iterable, ParamSpec, TypeVar from .basic_state import BasicState from .state import State @@ -15,6 +15,8 @@ T = TypeVar("T", bound=BasicState[Any]) P = ParamSpec("P") +S = TypeVar("S", bound=State) + def computed_state( func: Callable[P, T], @@ -64,3 +66,17 @@ def _on_change(_: Any) -> None: return computed_value return wrapped + + +def compute( + states: Iterable[State], + compute_value: Callable[[], S], + kwargs: dict[State, dict[str, Any]] = {}, +) -> S: + res = compute_value() + + for state in states: + _kwargs = {} if state not in kwargs else kwargs[state] + state.on_change(lambda _: res.copy_from(compute_value()), **_kwargs) + + return res