diff --git a/src/icepool/function.py b/src/icepool/function.py index 40bb15a7..88360385 100644 --- a/src/icepool/function.py +++ b/src/icepool/function.py @@ -235,7 +235,15 @@ def _iter_outcomes( else: yield arg -def pointwise_highest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]': +@overload +def pointwise_highest(arg0: 'Iterable[icepool.Die[T]]', /,) -> 'icepool.Die[T]': + ... + +@overload +def pointwise_highest(arg0: 'icepool.Die[T]', arg1: 'icepool.Die[T]', /, *args: 'icepool.Die[T]') -> 'icepool.Die[T]': + ... + +def pointwise_highest(arg0, /, *more_args: 'icepool.Die[T]') -> 'icepool.Die[T]': """Selects the highest chance of rolling >= each outcome among the arguments. Specifically, for each outcome, the chance of the result rolling >= to that @@ -250,14 +258,27 @@ def pointwise_highest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]': situation. Args: - dice: Any number of dice. + dice: Either an iterable of dice, or two or more dice as separate + arguments. """ - dice = commonize_denominator(*dice) - outcomes = sorted_union(*dice) - cumulative = [min(die.quantity('<=', outcome) for die in dice) for outcome in outcomes] + if len(more_args) == 0: + args = arg0 + else: + args = (arg0, ) + more_args + args = commonize_denominator(*args) + outcomes = sorted_union(*args) + cumulative = [min(die.quantity('<=', outcome) for die in args) for outcome in outcomes] return from_cumulative(outcomes, cumulative) -def pointwise_lowest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]': +@overload +def pointwise_lowest(arg0: 'Iterable[icepool.Die[T]]', /,) -> 'icepool.Die[T]': + ... + +@overload +def pointwise_lowest(arg0: 'icepool.Die[T]', arg1: 'icepool.Die[T]', /, *args: 'icepool.Die[T]') -> 'icepool.Die[T]': + ... + +def pointwise_lowest(arg0, /, *more_args: 'icepool.Die[T]') -> 'icepool.Die[T]': """Selects the highest chance of rolling <= each outcome among the arguments. Specifically, for each outcome, the chance of the result rolling <= to that @@ -272,11 +293,16 @@ def pointwise_lowest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]': situation. Args: - dice: Any number of dice. + dice: Either an iterable of dice, or two or more dice as separate + arguments. """ - dice = commonize_denominator(*dice) - outcomes = sorted_union(*dice) - cumulative = [max(die.quantity('<=', outcome) for die in dice) for outcome in outcomes] + if len(more_args) == 0: + args = arg0 + else: + args = (arg0, ) + more_args + args = commonize_denominator(*args) + outcomes = sorted_union(*args) + cumulative = [max(die.quantity('<=', outcome) for die in args) for outcome in outcomes] return from_cumulative(outcomes, cumulative) @overload diff --git a/tests/function_test.py b/tests/function_test.py index 4594604a..e482c203 100644 --- a/tests/function_test.py +++ b/tests/function_test.py @@ -33,6 +33,13 @@ def test_pointwise_highest(): assert result.probability('>=', outcome) == max( (3 @ d6).probability('>=', outcome), d20.probability('>=', outcome)) + +def test_pointwise_highest_single_argument(): + result = pointwise_highest([3 @ d6, d20]) + for outcome in range(1, 21): + assert result.probability('>=', outcome) == max( + (3 @ d6).probability('>=', outcome), + d20.probability('>=', outcome)) def test_pointwise_lowest(): result = pointwise_lowest(3 @ d6, d20) @@ -40,3 +47,10 @@ def test_pointwise_lowest(): assert result.probability('<=', outcome) == max( (3 @ d6).probability('<=', outcome), d20.probability('<=', outcome)) + +def test_pointwise_lowest_single_argument(): + result = pointwise_lowest([3 @ d6, d20]) + for outcome in range(1, 21): + assert result.probability('<=', outcome) == max( + (3 @ d6).probability('<=', outcome), + d20.probability('<=', outcome))