Skip to content

Commit

Permalink
first draft of new match expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Aug 20, 2024
1 parent 775a2c0 commit 4a3f9b5
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 100 deletions.
5 changes: 3 additions & 2 deletions src/icepool/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
UniqueExpression)
from icepool.expression.filter_outcomes import FilterOutcomesExpression, FilterOutcomesBinaryExpression
from icepool.expression.keep import KeepExpression
from icepool.expression.pair import PairKeepExpression
from icepool.expression.match import SortMatchExpression, MaximumMatchExpression

from icepool.expression.multiset_function import multiset_function

Expand All @@ -23,5 +23,6 @@
'AdjustCountsExpression', 'MultiplyCountsExpression',
'FloorDivCountsExpression', 'ModuloCountsExpression',
'KeepCountsExpression', 'UniqueExpression', 'FilterOutcomesExpression',
'FilterOutcomesBinaryExpression', 'KeepExpression'
'FilterOutcomesBinaryExpression', 'KeepExpression', 'SortMatchExpression',
'MaximumMatchExpression'
]
150 changes: 150 additions & 0 deletions src/icepool/expression/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
__docformat__ = 'google'

import icepool

from icepool.expression.multiset_expression import MultisetExpression

from functools import cached_property

from typing import Hashable
from icepool.typing import Order, T_contra


class SortMatchExpression(MultisetExpression[T_contra]):

def __init__(self, left: MultisetExpression[T_contra],
right: MultisetExpression[T_contra], *, order: Order,
tie: int, left_first: int, right_first: int):
if order == Order.Any:
order = Order.Descending
self._left = left
self._right = right
self._order = order
self._tie = tie
self._left_first = left_first
self._right_first = right_first

def _next_state(self, state, outcome: T_contra, *counts:
int) -> tuple[Hashable, int]:
left_state, right_state, left_lead = state or (None, None, 0)
left_state, left_count = self._left._next_state(
state, outcome, *counts)
right_state, right_count = self._right._next_state(
state, outcome, *counts)

if left_count < 0 or right_count < 0:
raise RuntimeError(
'SortMatchedExpression does not support negative counts.')

count = 0

if left_lead > 0:
count += max(min(right_count - left_lead, left_count),
0) * self._tie
elif left_lead < 0:
count += max(min(left_count + left_lead, right_count),
0) * self._tie
count += min(-left_lead, left_count) * self._right_first
else:
count += min(left_count, right_count) * self._tie

left_lead += left_count - right_count

if left_lead > 0:
count += min(left_lead, left_count) * self._left_first

return (left_state, right_state, left_lead), count

def order(self) -> Order:
return Order.merge(self._order, self._left.order(),
self._right.order())

@cached_property
def _cached_bound_generators(
self) -> 'tuple[icepool.MultisetGenerator, ...]':
return self._left._bound_generators() + self._right._bound_generators()

def _bound_generators(self) -> 'tuple[icepool.MultisetGenerator, ...]':
return self._cached_bound_generators

def _unbind(self, prefix_start: int,
free_start: int) -> 'tuple[MultisetExpression, int]':
unbound_left, prefix_start = self._left._unbind(
prefix_start, free_start)
unbound_right, prefix_start = self._right._unbind(
prefix_start, free_start)
unbound_expression: MultisetExpression = SortMatchExpression(
unbound_left,
unbound_right,
order=self._order,
tie=self._tie,
left_first=self._left_first,
right_first=self._right_first)
return unbound_expression, prefix_start

def _free_arity(self) -> int:
return max(self._left._free_arity(), self._right._free_arity())


class MaximumMatchExpression(MultisetExpression[T_contra]):

def __init__(self, left: MultisetExpression[T_contra],
right: MultisetExpression[T_contra], *, order: Order,
match_equal: bool, keep: bool):
self._left = left
self._right = right
self._order = order
self._match_equal = match_equal
self._keep = keep

def _next_state(self, state, outcome: T_contra, *counts:
int) -> tuple[Hashable, int]:
left_state, right_state, pairable = state or (None, None, 0)
left_state, left_count = self._left._next_state(
state, outcome, *counts)
right_state, right_count = self._right._next_state(
state, outcome, *counts)

if left_count < 0 or right_count < 0:
raise RuntimeError(
'MaximumMatchedExpression does not support negative counts.')

if self._match_equal:
new_pairs = min(pairable + right_count, left_count)
else:
new_pairs = min(pairable, left_count)
pairable += right_count - new_pairs
if self._keep:
count = new_pairs
else:
count = left_count - new_pairs
return (left_state, right_state, pairable), count

def order(self) -> Order:
return Order.merge(self._order, self._left.order(),
self._right.order())

@cached_property
def _cached_bound_generators(
self) -> 'tuple[icepool.MultisetGenerator, ...]':
return self._left._bound_generators() + self._right._bound_generators()

def _bound_generators(self) -> 'tuple[icepool.MultisetGenerator, ...]':
return self._cached_bound_generators

def _unbind(self, prefix_start: int,
free_start: int) -> 'tuple[MultisetExpression, int]':
unbound_left, prefix_start = self._left._unbind(
prefix_start, free_start)
unbound_right, prefix_start = self._right._unbind(
prefix_start, free_start)
unbound_expression: MultisetExpression = MaximumMatchExpression(
unbound_left,
unbound_right,
order=self._order,
match_equal=self._match_equal,
keep=self._keep)
return unbound_expression, prefix_start

def _free_arity(self) -> int:
return max(self._left._free_arity(), self._right._free_arity())
169 changes: 150 additions & 19 deletions src/icepool/expression/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,31 +655,162 @@ def highest(self,
index = highest_slice(keep, drop)
return self.keep(index)

# Pairing.
# Matching.

def pair_le(self,
other: 'MultisetExpression[T_contra]',
*,
keep: bool = True):
"""EXPERIMENTAL: Make pairs of elements such that `self <= other`, then keep or drop the paired elements from `self`.
def sort_match(
self,
comparison: Literal['==', '!=', '<=', '<', '>=', '>'],
other: 'MultisetExpression[T_contra]',
/,
order: Order = Order.Descending) -> 'MultisetExpression[T_contra]':
"""EXPERIMENTAL: Matches elements of `self` with elements of `other` in sorted order, then keeps elements from `self` that fit `comparison` with their partner.
Contrast `maximum_matched()`, which first creates the maximum number of
pairs that fit the comparison, not necessarily in sorted order.
Example: An attacker rolls 3d6 versus a defender's 2d6 in the game of
*RISK*. Which pairs did the attacker win?
```python
d6.pool(3).highest(2).sort_matched('>', d6.pool(2))
```
Suppose the attacker rolled 5, 3, 2 and the defender 6, 1.
In this case the attacker's 2 would be dropped by `highest`,
and then the 5 would be dropped since the attacker lost that pair,
leaving the attacker's 3.
```python
Pool([5, 3, 2]).highest(2).sort_matched('>', [6, 1]) -> [3]
```
Extra elements: If `self` has more elements than `other`, whether the
extra elements are kept depends on the `order` and `comparison`:
* Descending: `>=`, `>`
* Ascending: `<=`, `<`
First, make as many pairs of one element of `self` and one element of
`other` such that:
* In each pair, the element from `self` <= the element from `other`.
* The element from `self` is as great as possible otherwise.
Keep the elements from `self` that were paired and drop the rest.
Args:
comparison: The comparison to filter by. If you want to drop rather
than keep, use the complementary comparison:
* '==' vs. '!='
* '<=' vs. '>'
* '>=' vs. '<'
other: The other multiset to match elements with.
order: The order in which to sort before forming matches.
Default is descending.
"""
other = implicit_convert_to_expression(other)

if comparison == '==':
lesser, tie, greater = 0, 1, 0
elif comparison == '!=':
lesser, tie, greater = 1, 0, 1
elif comparison == '<=':
lesser, tie, greater = 1, 1, 0
elif comparison == '<':
lesser, tie, greater = 1, 0, 0
elif comparison == '>=':
lesser, tie, greater = 0, 1, 1
elif comparison == '>':
lesser, tie, greater = 0, 0, 1
else:
raise ValueError(f'Invalid comparison {comparison}')

if order > 0:
left_first = lesser
right_first = greater
else:
left_first = greater
right_first = lesser

return icepool.expression.SortMatchExpression(self,
other,
order=order,
tie=tie,
left_first=left_first,
right_first=right_first)

def maximum_match(
self, comparison: Literal['==', '<=', '<', '>=',
'>'], other: 'MultisetExpression[T_contra]',
/, *, keep: Literal['matched',
'unmatched']) -> 'MultisetExpression[T_contra]':
"""EXPERIMENTAL: Match elements of `self` with elements of `other` fitting the comparison, then keeps the matched or unmatched elements from `self`.
As many pairs of elements will be matched as possible that fit the
`comparison`. Contrast `sort_matched()`, which first creates pairs in
sorted order and then filters them by `comparison`.
Example:
An attacker rolls a pool of 4d6 and a defender rolls a pool of 3d6.
Defender dice can be used to block attacker dice of equal or lesser
value, and the defender prefers to block the highest attacker dice
possible. What is the sum of the attacker dice that were not blocked?
```python
d6.pool(4).maximum_matched('<=', d6.pool(3), keep='unmatched').sum()
```
Args:
other: The other multiset to pair elements with.
keep: If `True` (default), the paired elements from `self` will be
kept. Otherwise, the unpaired elements will be kept.
comparison: One of the following:
* '==': The same as `intersection(other)` if `keep='matched'`,
or `difference(other)` if `keep=unmatched`.
* '<=': Elements of `self` will be matched with elements of
`other` such that the element from `self` is <= the element
from `other`, but is otherwise as high as possible.
This requires that outcomes be evaluated in descending
order.
* `<`: Elements of `self` will be matched with elements of
`other` such that the element from `self` is < the element
from `other`, but is otherwise as high as possible.
This requires that outcomes be evaluated in descending
order.
* '>=': Elements of `self` will be matched with elements of
`other` such that the element from `self` is >= the element
from `other`, but is otherwise as low as possible.
This requires that outcomes be evaluated in ascending
order.
* `>`: Elements of `self` will be matched with elements of
`other` such that the element from `self` is > the element
from `other`, but is otherwise as low as possible.
This requires that outcomes be evaluated in ascending
order.
other: The other multiset to match elements with.
keep: Whether 'matched' or 'unmatched' elements are to be kept.
"""
if keep == 'matched':
keep_boolean = True
elif keep == 'unmatched':
keep_boolean = False
else:
raise ValueError(f"keep must be either 'matched' or 'unmatched'")

if comparison == '==':
if keep_boolean:
return self.intersection(other)
else:
return self.difference(other)

other = implicit_convert_to_expression(other)
return icepool.expression.PairKeepExpression(self,
other,
order=Order.Descending,
allow_equal=True,
keep=keep)
if comparison == '<=':
order = Order.Descending
match_equal = True
elif comparison == '<':
order = Order.Descending
match_equal = False
elif comparison == '>=':
order = Order.Ascending
match_equal = True
elif comparison == '>':
order = Order.Ascending
match_equal = False
else:
raise ValueError(f'Invalid comparison {comparison}')

return icepool.expression.MaximumMatchExpression(
self,
other,
order=order,
match_equal=match_equal,
keep=keep_boolean)

# Evaluations.

Expand Down
Loading

0 comments on commit 4a3f9b5

Please sign in to comment.