Skip to content

Commit

Permalink
allow single iterable argument for pointwise_highest
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Oct 12, 2024
1 parent 2b8a1ab commit ee28e13
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
46 changes: 36 additions & 10 deletions src/icepool/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,15 @@ def _iter_outcomes(
else:
yield arg

def pointwise_highest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]':
@overload
def pointwise_highest(arg0: 'Iterable[icepool.Die[T]]', /,) -> 'icepool.Die[T]':
...

@overload
def pointwise_highest(arg0: 'icepool.Die[T]', arg1: 'icepool.Die[T]', /, *args: 'icepool.Die[T]') -> 'icepool.Die[T]':
...

def pointwise_highest(arg0, /, *more_args: 'icepool.Die[T]') -> 'icepool.Die[T]':
"""Selects the highest chance of rolling >= each outcome among the arguments.
Specifically, for each outcome, the chance of the result rolling >= to that
Expand All @@ -250,14 +258,27 @@ def pointwise_highest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]':
situation.
Args:
dice: Any number of dice.
dice: Either an iterable of dice, or two or more dice as separate
arguments.
"""
dice = commonize_denominator(*dice)
outcomes = sorted_union(*dice)
cumulative = [min(die.quantity('<=', outcome) for die in dice) for outcome in outcomes]
if len(more_args) == 0:
args = arg0
else:
args = (arg0, ) + more_args
args = commonize_denominator(*args)
outcomes = sorted_union(*args)
cumulative = [min(die.quantity('<=', outcome) for die in args) for outcome in outcomes]
return from_cumulative(outcomes, cumulative)

def pointwise_lowest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]':
@overload
def pointwise_lowest(arg0: 'Iterable[icepool.Die[T]]', /,) -> 'icepool.Die[T]':
...

@overload
def pointwise_lowest(arg0: 'icepool.Die[T]', arg1: 'icepool.Die[T]', /, *args: 'icepool.Die[T]') -> 'icepool.Die[T]':
...

def pointwise_lowest(arg0, /, *more_args: 'icepool.Die[T]') -> 'icepool.Die[T]':
"""Selects the highest chance of rolling <= each outcome among the arguments.
Specifically, for each outcome, the chance of the result rolling <= to that
Expand All @@ -272,11 +293,16 @@ def pointwise_lowest(*dice: 'icepool.Die[T]') -> 'icepool.Die[T]':
situation.
Args:
dice: Any number of dice.
dice: Either an iterable of dice, or two or more dice as separate
arguments.
"""
dice = commonize_denominator(*dice)
outcomes = sorted_union(*dice)
cumulative = [max(die.quantity('<=', outcome) for die in dice) for outcome in outcomes]
if len(more_args) == 0:
args = arg0
else:
args = (arg0, ) + more_args
args = commonize_denominator(*args)
outcomes = sorted_union(*args)
cumulative = [max(die.quantity('<=', outcome) for die in args) for outcome in outcomes]
return from_cumulative(outcomes, cumulative)

@overload
Expand Down
14 changes: 14 additions & 0 deletions tests/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,24 @@ def test_pointwise_highest():
assert result.probability('>=', outcome) == max(
(3 @ d6).probability('>=', outcome),
d20.probability('>=', outcome))

def test_pointwise_highest_single_argument():
result = pointwise_highest([3 @ d6, d20])
for outcome in range(1, 21):
assert result.probability('>=', outcome) == max(
(3 @ d6).probability('>=', outcome),
d20.probability('>=', outcome))

def test_pointwise_lowest():
result = pointwise_lowest(3 @ d6, d20)
for outcome in range(1, 21):
assert result.probability('<=', outcome) == max(
(3 @ d6).probability('<=', outcome),
d20.probability('<=', outcome))

def test_pointwise_lowest_single_argument():
result = pointwise_lowest([3 @ d6, d20])
for outcome in range(1, 21):
assert result.probability('<=', outcome) == max(
(3 @ d6).probability('<=', outcome),
d20.probability('<=', outcome))

0 comments on commit ee28e13

Please sign in to comment.