diff --git a/src/icepool/population/die.py b/src/icepool/population/die.py index a4888b55..f86bd9c7 100644 --- a/src/icepool/population/die.py +++ b/src/icepool/population/die.py @@ -560,6 +560,26 @@ def map_and_time( star=star, repeat=repeat) + def time_to_sum(self: 'Die[int]', target: int, /, + max_time: int) -> 'Die[int]': + """The number of rolls until the cumulative sum is greater or equal to the target. + + Args: + target: The number to stop at once reached. + max_time: The maximum number of rolls to run. + If the sum is not reached, the outcome is equal to max_time. + """ + if self.min_outcome() < 0: + raise ValueError('time_to_sum does not handle negative outcomes.') + + def step(total, roll): + return min(total + roll, target) + + result: 'Die[tuple[int, int]]' = Die([0]).map_and_time(step, + self, + repeat=max_time) + return result.marginals[1] + @cached_property def _mean_time_to_sum_cache(self) -> list[Fraction]: return [Fraction(0)] diff --git a/tests/map_test.py b/tests/map_test.py index 7590f7fd..87295b70 100644 --- a/tests/map_test.py +++ b/tests/map_test.py @@ -157,6 +157,10 @@ def test_mean_time_to_sum_d6(): assert d6.mean_time_to_sum(10) == expected +def test_time_to_sum_d6(): + assert d6.mean_time_to_sum(10) == d6.time_to_sum(10, 11).mean() + + def test_mean_time_to_sum_z6(): cdf = [] for i in range(11): @@ -179,7 +183,9 @@ def test_stochastic_round(): def test_map_and_time_extra_args(): + def test_function(current, roll): return min(current + roll, 10) + result = Die([0]).map_and_time(test_function, d6, repeat=10) assert result.marginals[1].mean() == d6.mean_time_to_sum(10)