Skip to content

Commit

Permalink
add Die.mean_time_to_sum()
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Dec 11, 2023
1 parent 9d69079 commit 0e2a574
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 43 deletions.
125 changes: 83 additions & 42 deletions src/icepool/population/die.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import bisect
from collections import defaultdict
from fractions import Fraction
from functools import cached_property
import itertools
import math
Expand Down Expand Up @@ -352,12 +353,14 @@ def reroll(self,

# Need TypeVarTuple to check this.
outcome_set = {
outcome for outcome in self.outcomes()
outcome
for outcome in self.outcomes()
if which(*outcome) # type: ignore
}
else:
outcome_set = {
outcome for outcome in self.outcomes() if which(outcome)
outcome
for outcome in self.outcomes() if which(outcome)
}
else:
# Collection.
Expand All @@ -372,17 +375,18 @@ def reroll(self,
elif depth < 0:
raise ValueError('reroll depth cannot be negative.')
else:
total_reroll_quantity = sum(
quantity for outcome, quantity in self.items()
if outcome in outcome_set)
total_reroll_quantity = sum(quantity
for outcome, quantity in self.items()
if outcome in outcome_set)
total_stop_quantity = self.denominator() - total_reroll_quantity
rerollable_factor = total_reroll_quantity**depth
stop_factor = (self.denominator()**(depth + 1) - rerollable_factor *
total_reroll_quantity) // total_stop_quantity
stop_factor = (self.denominator()**(depth + 1) - rerollable_factor
* total_reroll_quantity) // total_stop_quantity
data = {
outcome: (rerollable_factor *
quantity if outcome in outcome_set else stop_factor *
quantity) for outcome, quantity in self.items()
quantity)
for outcome, quantity in self.items()
}
return icepool.Die(data)

Expand Down Expand Up @@ -419,17 +423,19 @@ def filter(self,
if star:

not_outcomes = {
outcome for outcome in self.outcomes()
outcome
for outcome in self.outcomes()
if not which(*outcome) # type: ignore
}
else:
not_outcomes = {
outcome for outcome in self.outcomes() if not which(outcome)
outcome
for outcome in self.outcomes() if not which(outcome)
}
else:
not_outcomes = {
not_outcome for not_outcome in self.outcomes()
if not_outcome not in which
not_outcome
for not_outcome in self.outcomes() if not_outcome not in which
}
return self.reroll(not_outcomes, depth=depth)

Expand Down Expand Up @@ -540,18 +546,18 @@ def _pop_max(self) -> tuple['Die[T_co]', int]:
"""
return self._popped_max

# Mixtures.
# Processes.

def map(
self,
repl:
self,
repl:
'Callable[..., U | Die[U] | icepool.RerollType | icepool.AgainExpression] | Mapping[T_co, U | Die[U] | icepool.RerollType | icepool.AgainExpression]',
/,
*extra_args,
star: bool | None = None,
repeat: int | None = 1,
again_depth: int = 1,
again_end: 'U | Die[U] | icepool.RerollType | None' = None
/,
*extra_args,
star: bool | None = None,
repeat: int | None = 1,
again_depth: int = 1,
again_end: 'U | Die[U] | icepool.RerollType | None' = None
) -> 'Die[U]':
"""Maps outcomes of the `Die` to other outcomes.
Expand Down Expand Up @@ -588,6 +594,38 @@ def map_and_time(
star=star,
repeat=repeat)

@cached_property
def _mean_time_to_sum_cache(self) -> list[Fraction]:
return [Fraction(0)]

def mean_time_to_sum(self: 'Die[int]', target: int, /) -> Fraction:
"""The mean number of rolls until the cumulative sum is greater or equal to the target.
Args:
target: The target sum.
Raises:
ValueError: If `target < 0` or if `self` has negative outcomes.
ZeroDivisionError: If `self.mean() == 0`.
"""
target = max(target, 0)

if target < len(self._mean_time_to_sum_cache):
return self._mean_time_to_sum_cache[target]

if self.min_outcome() < 0:
raise ValueError(
'mean_time_to_sum does not handle negative outcomes.')
zero_scale = Fraction(self.denominator(),
self.denominator() - self.quantity(0))

for i in range(len(self._mean_time_to_sum_cache), target + 1):
result = zero_scale * 1 + (self.reroll(
[0]).map(lambda x: self.mean_time_to_sum(i - x)).mean())
self._mean_time_to_sum_cache.append(result)

return result

def explode(self,
which: Collection[T_co] | Callable[..., bool] | None = None,
*,
Expand Down Expand Up @@ -621,12 +659,14 @@ def explode(self,
if star:
# Need TypeVarTuple to type-check this.
outcome_set = {
outcome for outcome in self.outcomes()
outcome
for outcome in self.outcomes()
if which(*outcome) # type: ignore
}
else:
outcome_set = {
outcome for outcome in self.outcomes() if which(outcome)
outcome
for outcome in self.outcomes() if which(outcome)
}
else:
if not which:
Expand All @@ -647,12 +687,12 @@ def map_final(outcome):
return self.map(map_final, again_depth=depth, again_end=end)

def if_else(
self,
outcome_if_true: U | 'Die[U]',
outcome_if_false: U | 'Die[U]',
*,
again_depth: int = 1,
again_end: 'U | Die[U] | icepool.RerollType | None' = None
self,
outcome_if_true: U | 'Die[U]',
outcome_if_false: U | 'Die[U]',
*,
again_depth: int = 1,
again_end: 'U | Die[U] | icepool.RerollType | None' = None
) -> 'Die[U]':
"""Ternary conditional operator.
Expand Down Expand Up @@ -780,9 +820,9 @@ def _lowest_single(self, rolls: int, /) -> 'Die':
"""Roll this die several times and keep the lowest."""
if rolls == 0:
return self.zero().simplify()
return icepool.from_cumulative(self.outcomes(),
[x**rolls for x in self.quantities_ge()],
reverse=True)
return icepool.from_cumulative(
self.outcomes(), [x**rolls for x in self.quantities_ge()],
reverse=True)

def highest(self,
rolls: int,
Expand Down Expand Up @@ -814,15 +854,16 @@ def _highest_single(self, rolls: int, /) -> 'Die[T_co]':
"""Roll this die several times and keep the highest."""
if rolls == 0:
return self.zero().simplify()
return icepool.from_cumulative(self.outcomes(),
[x**rolls for x in self.quantities_le()])
return icepool.from_cumulative(
self.outcomes(), [x**rolls for x in self.quantities_le()])

def middle(self,
rolls: int,
/,
keep: int = 1,
*,
tie: Literal['error', 'high', 'low'] = 'error') -> 'icepool.Die':
def middle(
self,
rolls: int,
/,
keep: int = 1,
*,
tie: Literal['error', 'high', 'low'] = 'error') -> 'icepool.Die':
"""Roll several of this `Die` and sum the sorted results in the middle.
The outcomes should support addition and multiplication if `keep != 1`.
Expand Down Expand Up @@ -1191,6 +1232,6 @@ def equals(self, other, *, simplify: bool = False) -> bool:
# Strings.

def __repr__(self) -> str:
inner = ', '.join(
f'{outcome}: {weight}' for outcome, weight in self.items())
inner = ', '.join(f'{outcome}: {weight}'
for outcome, weight in self.items())
return type(self).__qualname__ + '({' + inner + '})'
23 changes: 22 additions & 1 deletion tests/map_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import icepool
import pytest

from icepool import d6, Die, coin
from icepool import d, d6, Die, coin
from fractions import Fraction

expected_d6x1 = icepool.Die(range(1, 13),
times=[6, 6, 6, 6, 6, 0, 1, 1, 1, 1, 1, 1]).trim()
Expand Down Expand Up @@ -147,3 +148,23 @@ def test_deck_map_size_increase():
result = icepool.Deck(range(13)).map({12: icepool.Deck(range(12))})
expected = icepool.Deck(range(12), times=2)
assert result == expected


def test_mean_time_to_sum_d6():
cdf = []
for i in range(11):
cdf.append(i @ d6 >= 10)
expected = icepool.from_cumulative(range(11), cdf).mean()
assert d6.mean_time_to_sum(10) == expected


def test_mean_time_to_sum_z6():
cdf = []
for i in range(11):
cdf.append(i @ d(5) >= 10)
expected = icepool.from_cumulative(range(11), cdf).mean() * Fraction(6, 5)
assert (d6 - 1).mean_time_to_sum(10) == expected


def test_mean_time_to_sum_coin():
assert icepool.coin(1, 2).mean_time_to_sum(10) == 20

0 comments on commit 0e2a574

Please sign in to comment.