Skip to content

Commit

Permalink
add maximum_match example and some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Aug 20, 2024
1 parent 2de725e commit faa434b
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 15 deletions.
14 changes: 7 additions & 7 deletions src/icepool/expression/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, left: MultisetExpression[T_contra],

def _next_state(self, state, outcome: T_contra, *counts:
int) -> tuple[Hashable, int]:
left_state, right_state, pairable = state or (None, None, 0)
left_state, right_state, prev_matchable = state or (None, None, 0)
left_state, left_count = self._left._next_state(
state, outcome, *counts)
right_state, right_count = self._right._next_state(
Expand All @@ -108,15 +108,15 @@ def _next_state(self, state, outcome: T_contra, *counts:
'MaximumMatchedExpression does not support negative counts.')

if self._match_equal:
new_pairs = min(pairable + right_count, left_count)
new_matches = min(prev_matchable + right_count, left_count)
else:
new_pairs = min(pairable, left_count)
pairable += right_count - new_pairs
new_matches = min(prev_matchable, left_count)
prev_matchable += right_count - new_matches
if self._keep:
count = new_pairs
count = new_matches
else:
count = left_count - new_pairs
return (left_state, right_state, pairable), count
count = left_count - new_matches
return (left_state, right_state, prev_matchable), count

def order(self) -> Order:
return Order.merge(self._order, self._left.order(),
Expand Down
15 changes: 10 additions & 5 deletions src/icepool/expression/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,16 +740,21 @@ def maximum_match(
`comparison`. Contrast `sort_matched()`, which first creates pairs in
sorted order and then filters them by `comparison`.
Example:
An attacker rolls a pool of 4d6 and a defender rolls a pool of 3d6.
Defender dice can be used to block attacker dice of equal or lesser
Example: An attacker rolls a pool of 4d6 and a defender rolls a pool of
3d6. Defender dice can be used to block attacker dice of equal or lesser
value, and the defender prefers to block the highest attacker dice
possible. What is the sum of the attacker dice that were not blocked?
possible. Which attacker dice were not blocked?
```python
d6.pool(4).maximum_match('<=', d6.pool(3), keep='unmatched').sum()
```
Suppose the attacker rolls 6, 5, 3, 1 and the defender rolls 6, 4.
Then the result should be [5, 1].
```python
d6.pool([6, 5, 3, 1]).maximum_match('<=', [6, 4], keep='unmatched')
-> [5, 1]
```
Args:
comparison: One of the following:
* '==': The same as `intersection(other)` if `keep='matched'`,
Expand Down
58 changes: 55 additions & 3 deletions tests/match_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import operator
import pytest

from icepool import d6, Die, Order, map_function
from icepool import d6, Die, Order, map_function, Pool


def test_sort_match_example():
result = Pool([5, 3, 2]).highest(2).sort_match('>', [6, 1]).expand()
assert result.simplify() == Die([(3, )])


def test_risk():
Expand Down Expand Up @@ -65,6 +70,28 @@ def compute_expected(left, right):
assert result == expected


@pytest.mark.parametrize('op', sort_ops)
def test_sort_match_operators_sum(op):
result = d6.pool(3).highest(2).sort_match(op, d6.pool(2)).sum()

@map_function
def compute_expected(left, right):
result = 0
for l, r in zip(reversed(left), reversed(right)):
if operators[op](l, r):
result += l
return result

expected = compute_expected(d6.pool(3), d6.pool(2))
assert result == expected


def test_maximum_match_example():
result = Pool([6, 5, 3, 1]).maximum_match('<=', [6, 4],
keep='unmatched').expand()
assert result.simplify() == Die([(1, 5)])


maximum_ops = ['<=', '<', '>=', '>']


Expand All @@ -76,7 +103,7 @@ def test_maximum_match(op):

@map_function
def compute_expected(left, right):
if op in ['>=', '>']:
if op in ['<=', '<']:
left = reversed(left)
right = reversed(right)
left = list(left)
Expand All @@ -88,9 +115,34 @@ def compute_expected(left, right):
left.pop(0)
right.pop(0)
else:
right.pop(0)
left.pop(0)
return result

expected = compute_expected(d6.pool(3), d6.pool(2))
assert result == expected
assert 3 - result == complement


@pytest.mark.parametrize('op', maximum_ops)
def test_maximum_match_sum(op):
result = d6.pool(3).maximum_match(op, d6.pool(2), keep='matched').sum()

@map_function
def compute_expected(left, right):
if op in ['<=', '<']:
left = reversed(left)
right = reversed(right)
left = list(left)
right = list(right)
result = 0
while left and right:
if operators[op](left[0], right[0]):
result += left[0]
left.pop(0)
right.pop(0)
else:
left.pop(0)
return result

expected = compute_expected(d6.pool(3), d6.pool(2))
assert result == expected

0 comments on commit faa434b

Please sign in to comment.