From 0e2a5747d68231301b8f0cc1ce93eeaed801517e Mon Sep 17 00:00:00 2001 From: Albert Julius Liu Date: Sun, 10 Dec 2023 18:07:05 -0800 Subject: [PATCH] add `Die.mean_time_to_sum()` --- src/icepool/population/die.py | 125 ++++++++++++++++++++++------------ tests/map_test.py | 23 ++++++- 2 files changed, 105 insertions(+), 43 deletions(-) diff --git a/src/icepool/population/die.py b/src/icepool/population/die.py index 7bca9550..a3595666 100644 --- a/src/icepool/population/die.py +++ b/src/icepool/population/die.py @@ -11,6 +11,7 @@ import bisect from collections import defaultdict +from fractions import Fraction from functools import cached_property import itertools import math @@ -352,12 +353,14 @@ def reroll(self, # Need TypeVarTuple to check this. outcome_set = { - outcome for outcome in self.outcomes() + outcome + for outcome in self.outcomes() if which(*outcome) # type: ignore } else: outcome_set = { - outcome for outcome in self.outcomes() if which(outcome) + outcome + for outcome in self.outcomes() if which(outcome) } else: # Collection. @@ -372,17 +375,18 @@ def reroll(self, elif depth < 0: raise ValueError('reroll depth cannot be negative.') else: - total_reroll_quantity = sum( - quantity for outcome, quantity in self.items() - if outcome in outcome_set) + total_reroll_quantity = sum(quantity + for outcome, quantity in self.items() + if outcome in outcome_set) total_stop_quantity = self.denominator() - total_reroll_quantity rerollable_factor = total_reroll_quantity**depth - stop_factor = (self.denominator()**(depth + 1) - rerollable_factor * - total_reroll_quantity) // total_stop_quantity + stop_factor = (self.denominator()**(depth + 1) - rerollable_factor + * total_reroll_quantity) // total_stop_quantity data = { outcome: (rerollable_factor * quantity if outcome in outcome_set else stop_factor * - quantity) for outcome, quantity in self.items() + quantity) + for outcome, quantity in self.items() } return icepool.Die(data) @@ -419,17 +423,19 @@ def filter(self, if star: not_outcomes = { - outcome for outcome in self.outcomes() + outcome + for outcome in self.outcomes() if not which(*outcome) # type: ignore } else: not_outcomes = { - outcome for outcome in self.outcomes() if not which(outcome) + outcome + for outcome in self.outcomes() if not which(outcome) } else: not_outcomes = { - not_outcome for not_outcome in self.outcomes() - if not_outcome not in which + not_outcome + for not_outcome in self.outcomes() if not_outcome not in which } return self.reroll(not_outcomes, depth=depth) @@ -540,18 +546,18 @@ def _pop_max(self) -> tuple['Die[T_co]', int]: """ return self._popped_max - # Mixtures. + # Processes. def map( - self, - repl: + self, + repl: 'Callable[..., U | Die[U] | icepool.RerollType | icepool.AgainExpression] | Mapping[T_co, U | Die[U] | icepool.RerollType | icepool.AgainExpression]', - /, - *extra_args, - star: bool | None = None, - repeat: int | None = 1, - again_depth: int = 1, - again_end: 'U | Die[U] | icepool.RerollType | None' = None + /, + *extra_args, + star: bool | None = None, + repeat: int | None = 1, + again_depth: int = 1, + again_end: 'U | Die[U] | icepool.RerollType | None' = None ) -> 'Die[U]': """Maps outcomes of the `Die` to other outcomes. @@ -588,6 +594,38 @@ def map_and_time( star=star, repeat=repeat) + @cached_property + def _mean_time_to_sum_cache(self) -> list[Fraction]: + return [Fraction(0)] + + def mean_time_to_sum(self: 'Die[int]', target: int, /) -> Fraction: + """The mean number of rolls until the cumulative sum is greater or equal to the target. + + Args: + target: The target sum. + + Raises: + ValueError: If `target < 0` or if `self` has negative outcomes. + ZeroDivisionError: If `self.mean() == 0`. + """ + target = max(target, 0) + + if target < len(self._mean_time_to_sum_cache): + return self._mean_time_to_sum_cache[target] + + if self.min_outcome() < 0: + raise ValueError( + 'mean_time_to_sum does not handle negative outcomes.') + zero_scale = Fraction(self.denominator(), + self.denominator() - self.quantity(0)) + + for i in range(len(self._mean_time_to_sum_cache), target + 1): + result = zero_scale * 1 + (self.reroll( + [0]).map(lambda x: self.mean_time_to_sum(i - x)).mean()) + self._mean_time_to_sum_cache.append(result) + + return result + def explode(self, which: Collection[T_co] | Callable[..., bool] | None = None, *, @@ -621,12 +659,14 @@ def explode(self, if star: # Need TypeVarTuple to type-check this. outcome_set = { - outcome for outcome in self.outcomes() + outcome + for outcome in self.outcomes() if which(*outcome) # type: ignore } else: outcome_set = { - outcome for outcome in self.outcomes() if which(outcome) + outcome + for outcome in self.outcomes() if which(outcome) } else: if not which: @@ -647,12 +687,12 @@ def map_final(outcome): return self.map(map_final, again_depth=depth, again_end=end) def if_else( - self, - outcome_if_true: U | 'Die[U]', - outcome_if_false: U | 'Die[U]', - *, - again_depth: int = 1, - again_end: 'U | Die[U] | icepool.RerollType | None' = None + self, + outcome_if_true: U | 'Die[U]', + outcome_if_false: U | 'Die[U]', + *, + again_depth: int = 1, + again_end: 'U | Die[U] | icepool.RerollType | None' = None ) -> 'Die[U]': """Ternary conditional operator. @@ -780,9 +820,9 @@ def _lowest_single(self, rolls: int, /) -> 'Die': """Roll this die several times and keep the lowest.""" if rolls == 0: return self.zero().simplify() - return icepool.from_cumulative(self.outcomes(), - [x**rolls for x in self.quantities_ge()], - reverse=True) + return icepool.from_cumulative( + self.outcomes(), [x**rolls for x in self.quantities_ge()], + reverse=True) def highest(self, rolls: int, @@ -814,15 +854,16 @@ def _highest_single(self, rolls: int, /) -> 'Die[T_co]': """Roll this die several times and keep the highest.""" if rolls == 0: return self.zero().simplify() - return icepool.from_cumulative(self.outcomes(), - [x**rolls for x in self.quantities_le()]) + return icepool.from_cumulative( + self.outcomes(), [x**rolls for x in self.quantities_le()]) - def middle(self, - rolls: int, - /, - keep: int = 1, - *, - tie: Literal['error', 'high', 'low'] = 'error') -> 'icepool.Die': + def middle( + self, + rolls: int, + /, + keep: int = 1, + *, + tie: Literal['error', 'high', 'low'] = 'error') -> 'icepool.Die': """Roll several of this `Die` and sum the sorted results in the middle. The outcomes should support addition and multiplication if `keep != 1`. @@ -1191,6 +1232,6 @@ def equals(self, other, *, simplify: bool = False) -> bool: # Strings. def __repr__(self) -> str: - inner = ', '.join( - f'{outcome}: {weight}' for outcome, weight in self.items()) + inner = ', '.join(f'{outcome}: {weight}' + for outcome, weight in self.items()) return type(self).__qualname__ + '({' + inner + '})' diff --git a/tests/map_test.py b/tests/map_test.py index e4ca3639..1a440e52 100644 --- a/tests/map_test.py +++ b/tests/map_test.py @@ -1,7 +1,8 @@ import icepool import pytest -from icepool import d6, Die, coin +from icepool import d, d6, Die, coin +from fractions import Fraction expected_d6x1 = icepool.Die(range(1, 13), times=[6, 6, 6, 6, 6, 0, 1, 1, 1, 1, 1, 1]).trim() @@ -147,3 +148,23 @@ def test_deck_map_size_increase(): result = icepool.Deck(range(13)).map({12: icepool.Deck(range(12))}) expected = icepool.Deck(range(12), times=2) assert result == expected + + +def test_mean_time_to_sum_d6(): + cdf = [] + for i in range(11): + cdf.append(i @ d6 >= 10) + expected = icepool.from_cumulative(range(11), cdf).mean() + assert d6.mean_time_to_sum(10) == expected + + +def test_mean_time_to_sum_z6(): + cdf = [] + for i in range(11): + cdf.append(i @ d(5) >= 10) + expected = icepool.from_cumulative(range(11), cdf).mean() * Fraction(6, 5) + assert (d6 - 1).mean_time_to_sum(10) == expected + + +def test_mean_time_to_sum_coin(): + assert icepool.coin(1, 2).mean_time_to_sum(10) == 20