From a41ed19f7fa0acba5f556a75e36bb43608a303a6 Mon Sep 17 00:00:00 2001 From: Albert Julius Liu Date: Sat, 24 Aug 2024 12:54:34 -0700 Subject: [PATCH] cache results for forward algorithm add extra variants to match_test.py to improve forward algorithm coverage --- src/icepool/evaluator/multiset_evaluator.py | 95 ++++++++++++--------- tests/match_test.py | 24 +++--- 2 files changed, 68 insertions(+), 51 deletions(-) diff --git a/src/icepool/evaluator/multiset_evaluator.py b/src/icepool/evaluator/multiset_evaluator.py index b4be4437..c6c052ff 100644 --- a/src/icepool/evaluator/multiset_evaluator.py +++ b/src/icepool/evaluator/multiset_evaluator.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from icepool.generator.alignment import Alignment - from icepool.expression import MultisetExpression + from icepool import MultisetExpression, MultisetGenerator PREFERRED_ORDER_COST_FACTOR = 10 """The preferred order will be favored this times as much.""" @@ -208,8 +208,16 @@ def validate_arity(self, arity: int) -> None: """ @cached_property - def _cache(self) -> MutableMapping[Any, Mapping[Any, int]]: - """A cache of (order, generators) -> weight distribution over states. """ + def _cache( + self + ) -> 'MutableMapping[tuple[Order, Alignment, tuple[MultisetGenerator, ...], Hashable], Mapping[Any, int]]': + """Cached results. + + The key is `(order, alignment, generators, state)`. + The value is another mapping `final_state: quantity` representing the + state distribution produced by `order, alignment, generators` when + starting at state `state`. + """ return {} @overload @@ -359,14 +367,13 @@ def _select_algorithm( return self._eval_internal, eval_order else: # Use the less-preferred algorithm. - return self._eval_internal_iterative, eval_order + return self._eval_internal_forward, eval_order def _eval_internal( self, order: Order, alignment: 'Alignment[T_contra]', generators: 'tuple[icepool.MultisetGenerator[T_contra, Any], ...]' ) -> Mapping[Any, int]: - """Internal algorithm for iterating in the more-preferred order, - i.e. giving outcomes to `next_state()` from wide to narrow. + """Internal algorithm for iterating in the more-preferred order. All intermediate return values are cached in the instance. @@ -381,7 +388,7 @@ def _eval_internal( A dict `{ state : weight }` describing the probability distribution over states. """ - cache_key = (order, alignment, generators) + cache_key = (order, alignment, generators, None) if cache_key in self._cache: return self._cache[cache_key] @@ -407,42 +414,52 @@ def _eval_internal( self._cache[cache_key] = result return result - def _eval_internal_iterative( - self, order: int, alignment: 'Alignment[T_contra]', - generators: 'tuple[icepool.MultisetGenerator[T_contra, Any], ...]' - ) -> Mapping[Any, int]: - """Internal algorithm for iterating in the less-preferred order, - i.e. giving outcomes to `next_state()` from narrow to wide. + def _eval_internal_forward( + self, + order: Order, + alignment: 'Alignment[T_contra]', + generators: 'tuple[icepool.MultisetGenerator[T_contra, Any], ...]', + state: Hashable = None) -> Mapping[Any, int]: + """Internal algorithm for iterating in the less-preferred order. + + All intermediate return values are cached in the instance. - This algorithm does not perform persistent memoization. + Arguments: + order: The order in which to send outcomes to `next_state()`. + alignment: As `alignment()`. Elements will be popped off this + during recursion. + generators: One or more `MultisetGenerators`s to evaluate. Elements + will be popped off this during recursion. + + Returns: + A dict `{ state : weight }` describing the probability distribution + over states. """ + cache_key = (order, alignment, generators, state) + if cache_key in self._cache: + return self._cache[cache_key] + + result: MutableMapping[Any, int] = defaultdict(int) + if all(not generator.outcomes() for generator in generators) and not alignment.outcomes(): - return {None: 1} - dist: MutableMapping[Any, int] = defaultdict(int) - dist[None, alignment, generators] = 1 - final_dist: MutableMapping[Any, int] = defaultdict(int) - while dist: - next_dist: MutableMapping[Any, int] = defaultdict(int) - for (prev_state, prev_alignment, - prev_generators), weight in dist.items(): - # The order flip here is the only purpose of this algorithm. - outcome, alignment, iterators = MultisetEvaluator._pop_generators( - -order, prev_alignment, prev_generators) - for p in itertools.product(*iterators): - generators, counts, weights = zip(*p) - counts = tuple(itertools.chain.from_iterable(counts)) - prod_weight = math.prod(weights) - state = self.next_state(prev_state, outcome, *counts) - if state is not icepool.Reroll: - if all(not generator.outcomes() - for generator in generators): - final_dist[state] += weight * prod_weight - else: - next_dist[state, alignment, - generators] += weight * prod_weight - dist = next_dist - return final_dist + result = {state: 1} + else: + outcome, next_alignment, iterators = MultisetEvaluator._pop_generators( + -order, alignment, generators) + for p in itertools.product(*iterators): + next_generators, counts, weights = zip(*p) + counts = tuple(itertools.chain.from_iterable(counts)) + prod_weight = math.prod(weights) + next_state = self.next_state(state, outcome, *counts) + if next_state is not icepool.Reroll: + final_dist = self._eval_internal_forward( + order, next_alignment, next_generators, next_state) + for final_state, weight in final_dist.items(): + result[final_state] += weight * prod_weight + + self._cache[cache_key] = result + return result @staticmethod def _initialize_generators( diff --git a/tests/match_test.py b/tests/match_test.py index f8054b8d..4e60fda3 100644 --- a/tests/match_test.py +++ b/tests/match_test.py @@ -2,7 +2,7 @@ import operator import pytest -from icepool import d6, Die, Order, map_function, Pool +from icepool import d4, d6, d8, Die, Order, map_function, Pool def test_sort_match_example(): @@ -75,8 +75,10 @@ def compute_expected(left, right): @pytest.mark.parametrize('op', sort_ops) -def test_sort_match_operators_expand(op): - result = d6.pool(3).highest(2).sort_match(op, d6.pool(2)).expand() +@pytest.mark.parametrize('left', [d6.pool(3), Pool([d4, d6, d8])]) +@pytest.mark.parametrize('right', [d6.pool(2), Pool([d4, d6])]) +def test_sort_match_operators_expand(op, left, right): + result = left.highest(2).sort_match(op, right).expand() @map_function def compute_expected(left, right): @@ -86,7 +88,7 @@ def compute_expected(left, right): result.append(l) return tuple(sorted(result)) - expected = compute_expected(d6.pool(3), d6.pool(2)) + expected = compute_expected(left, right) assert result == expected @@ -139,15 +141,13 @@ def compute_expected(left, right): @pytest.mark.parametrize('op', maximum_ops) -def test_maximum_match_expand(op): +@pytest.mark.parametrize('left', [d6.pool(3), Pool([d4, d6, d8])]) +@pytest.mark.parametrize('right', [d6.pool(2), Pool([d4, d6])]) +def test_maximum_match_expand(op, left, right): if op in ['<=', '<']: - result = d6.pool(3).maximum_match_highest(op, - d6.pool(2), - keep='matched').expand() + result = left.maximum_match_highest(op, right, keep='matched').expand() else: - result = d6.pool(3).maximum_match_lowest(op, - d6.pool(2), - keep='matched').expand() + result = left.maximum_match_lowest(op, right, keep='matched').expand() @map_function def compute_expected(left, right): @@ -166,5 +166,5 @@ def compute_expected(left, right): left.pop(0) return tuple(sorted(result)) - expected = compute_expected(d6.pool(3), d6.pool(2)) + expected = compute_expected(left, right) assert result == expected