Skip to content

Commit

Permalink
cache results for forward algorithm
Browse files Browse the repository at this point in the history
add extra variants to match_test.py to improve forward algorithm coverage
  • Loading branch information
HighDiceRoller committed Aug 24, 2024
1 parent 2778d6b commit a41ed19
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 51 deletions.
95 changes: 56 additions & 39 deletions src/icepool/evaluator/multiset_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

if TYPE_CHECKING:
from icepool.generator.alignment import Alignment
from icepool.expression import MultisetExpression
from icepool import MultisetExpression, MultisetGenerator

PREFERRED_ORDER_COST_FACTOR = 10
"""The preferred order will be favored this times as much."""
Expand Down Expand Up @@ -208,8 +208,16 @@ def validate_arity(self, arity: int) -> None:
"""

@cached_property
def _cache(self) -> MutableMapping[Any, Mapping[Any, int]]:
"""A cache of (order, generators) -> weight distribution over states. """
def _cache(
self
) -> 'MutableMapping[tuple[Order, Alignment, tuple[MultisetGenerator, ...], Hashable], Mapping[Any, int]]':
"""Cached results.
The key is `(order, alignment, generators, state)`.
The value is another mapping `final_state: quantity` representing the
state distribution produced by `order, alignment, generators` when
starting at state `state`.
"""
return {}

@overload
Expand Down Expand Up @@ -359,14 +367,13 @@ def _select_algorithm(
return self._eval_internal, eval_order
else:
# Use the less-preferred algorithm.
return self._eval_internal_iterative, eval_order
return self._eval_internal_forward, eval_order

def _eval_internal(
self, order: Order, alignment: 'Alignment[T_contra]',
generators: 'tuple[icepool.MultisetGenerator[T_contra, Any], ...]'
) -> Mapping[Any, int]:
"""Internal algorithm for iterating in the more-preferred order,
i.e. giving outcomes to `next_state()` from wide to narrow.
"""Internal algorithm for iterating in the more-preferred order.
All intermediate return values are cached in the instance.
Expand All @@ -381,7 +388,7 @@ def _eval_internal(
A dict `{ state : weight }` describing the probability distribution
over states.
"""
cache_key = (order, alignment, generators)
cache_key = (order, alignment, generators, None)
if cache_key in self._cache:
return self._cache[cache_key]

Expand All @@ -407,42 +414,52 @@ def _eval_internal(
self._cache[cache_key] = result
return result

def _eval_internal_iterative(
self, order: int, alignment: 'Alignment[T_contra]',
generators: 'tuple[icepool.MultisetGenerator[T_contra, Any], ...]'
) -> Mapping[Any, int]:
"""Internal algorithm for iterating in the less-preferred order,
i.e. giving outcomes to `next_state()` from narrow to wide.
def _eval_internal_forward(
self,
order: Order,
alignment: 'Alignment[T_contra]',
generators: 'tuple[icepool.MultisetGenerator[T_contra, Any], ...]',
state: Hashable = None) -> Mapping[Any, int]:
"""Internal algorithm for iterating in the less-preferred order.
All intermediate return values are cached in the instance.
This algorithm does not perform persistent memoization.
Arguments:
order: The order in which to send outcomes to `next_state()`.
alignment: As `alignment()`. Elements will be popped off this
during recursion.
generators: One or more `MultisetGenerators`s to evaluate. Elements
will be popped off this during recursion.
Returns:
A dict `{ state : weight }` describing the probability distribution
over states.
"""
cache_key = (order, alignment, generators, state)
if cache_key in self._cache:
return self._cache[cache_key]

result: MutableMapping[Any, int] = defaultdict(int)

if all(not generator.outcomes()
for generator in generators) and not alignment.outcomes():
return {None: 1}
dist: MutableMapping[Any, int] = defaultdict(int)
dist[None, alignment, generators] = 1
final_dist: MutableMapping[Any, int] = defaultdict(int)
while dist:
next_dist: MutableMapping[Any, int] = defaultdict(int)
for (prev_state, prev_alignment,
prev_generators), weight in dist.items():
# The order flip here is the only purpose of this algorithm.
outcome, alignment, iterators = MultisetEvaluator._pop_generators(
-order, prev_alignment, prev_generators)
for p in itertools.product(*iterators):
generators, counts, weights = zip(*p)
counts = tuple(itertools.chain.from_iterable(counts))
prod_weight = math.prod(weights)
state = self.next_state(prev_state, outcome, *counts)
if state is not icepool.Reroll:
if all(not generator.outcomes()
for generator in generators):
final_dist[state] += weight * prod_weight
else:
next_dist[state, alignment,
generators] += weight * prod_weight
dist = next_dist
return final_dist
result = {state: 1}
else:
outcome, next_alignment, iterators = MultisetEvaluator._pop_generators(
-order, alignment, generators)
for p in itertools.product(*iterators):
next_generators, counts, weights = zip(*p)
counts = tuple(itertools.chain.from_iterable(counts))
prod_weight = math.prod(weights)
next_state = self.next_state(state, outcome, *counts)
if next_state is not icepool.Reroll:
final_dist = self._eval_internal_forward(
order, next_alignment, next_generators, next_state)
for final_state, weight in final_dist.items():
result[final_state] += weight * prod_weight

self._cache[cache_key] = result
return result

@staticmethod
def _initialize_generators(
Expand Down
24 changes: 12 additions & 12 deletions tests/match_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import operator
import pytest

from icepool import d6, Die, Order, map_function, Pool
from icepool import d4, d6, d8, Die, Order, map_function, Pool


def test_sort_match_example():
Expand Down Expand Up @@ -75,8 +75,10 @@ def compute_expected(left, right):


@pytest.mark.parametrize('op', sort_ops)
def test_sort_match_operators_expand(op):
result = d6.pool(3).highest(2).sort_match(op, d6.pool(2)).expand()
@pytest.mark.parametrize('left', [d6.pool(3), Pool([d4, d6, d8])])
@pytest.mark.parametrize('right', [d6.pool(2), Pool([d4, d6])])
def test_sort_match_operators_expand(op, left, right):
result = left.highest(2).sort_match(op, right).expand()

@map_function
def compute_expected(left, right):
Expand All @@ -86,7 +88,7 @@ def compute_expected(left, right):
result.append(l)
return tuple(sorted(result))

expected = compute_expected(d6.pool(3), d6.pool(2))
expected = compute_expected(left, right)
assert result == expected


Expand Down Expand Up @@ -139,15 +141,13 @@ def compute_expected(left, right):


@pytest.mark.parametrize('op', maximum_ops)
def test_maximum_match_expand(op):
@pytest.mark.parametrize('left', [d6.pool(3), Pool([d4, d6, d8])])
@pytest.mark.parametrize('right', [d6.pool(2), Pool([d4, d6])])
def test_maximum_match_expand(op, left, right):
if op in ['<=', '<']:
result = d6.pool(3).maximum_match_highest(op,
d6.pool(2),
keep='matched').expand()
result = left.maximum_match_highest(op, right, keep='matched').expand()
else:
result = d6.pool(3).maximum_match_lowest(op,
d6.pool(2),
keep='matched').expand()
result = left.maximum_match_lowest(op, right, keep='matched').expand()

@map_function
def compute_expected(left, right):
Expand All @@ -166,5 +166,5 @@ def compute_expected(left, right):
left.pop(0)
return tuple(sorted(result))

expected = compute_expected(d6.pool(3), d6.pool(2))
expected = compute_expected(left, right)
assert result == expected

0 comments on commit a41ed19

Please sign in to comment.