Skip to content

Commit

Permalink
Pool now stores explicit alignment; skips only remove dice, not out…
Browse files Browse the repository at this point in the history
…comes
  • Loading branch information
HighDiceRoller committed Sep 10, 2024
1 parent 4b1e191 commit d759ad5
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 83 deletions.
25 changes: 7 additions & 18 deletions src/icepool/evaluator/multiset_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,11 @@ def next_state(self, state: Hashable, outcome: T_contra, /, *counts:
outcome: The current outcome.
`next_state` will see all rolled outcomes in monotonic order;
either ascending or descending depending on `order()`.
If there are multiple generators, the set of outcomes is the
union of the outcomes of the invididual generators. All outcomes
with nonzero count will be seen. Outcomes with zero count
may or may not be seen. If you need to enforce that certain
outcomes are seen even if they have zero count, see
`alignment()`.
If there are multiple generators, the set of outcomes is at
least the union of the outcomes of the invididual generators.
You can use `alignment()` to add additional outcomes.
*counts: One value (usually an `int`) for each generator output
indicating how many of the current outcome were produced.
All outcomes with nonzero count are guaranteed to be seen.
To guarantee that outcomes are seen even if they have zero
count, override `alignment()`.
Returns:
A hashable object indicating the next state.
Expand Down Expand Up @@ -144,19 +138,14 @@ def order(self) -> Order:
def alignment(self, outcomes: Sequence[T_contra]) -> Collection[T_contra]:
"""Optional method to specify additional outcomes that should be seen by `next_state()`.
These will be seen by `next_state` even if they have zero count or do
not appear in the generator(s) at all.
The default implementation returns `()`; this means outcomes with zero
count may or may not be seen by `next_state`.
These will be seen by `next_state` even if they do not appear in the
generator(s). The default implementation returns `()`, or no additional
outcomes.
If you want `next_state` to see consecutive `int` outcomes, you can set
`alignment = icepool.MultisetEvaluator.range_alignment`.
See `range_alignment()` below.
If you want `next_state` to see all generator outcomes, you can return
`outcomes` as-is.
Args:
outcomes: The outcomes that could be produced by the generators, in
ascending order.
Expand Down Expand Up @@ -291,7 +280,7 @@ def evaluate(

# We use a separate class to guarantee all outcomes are visited.
outcomes = icepool.sorted_union(*(generator.outcomes()
for generator in generators))
for generator in generators))
alignment = Alignment(self.alignment(outcomes))

dist: MutableMapping[Any, int] = defaultdict(int)
Expand Down
13 changes: 8 additions & 5 deletions src/icepool/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,20 +229,23 @@ def max_outcome(*dice: 'T | icepool.Die[T]') -> T:
converted_dice = [icepool.implicit_convert_to_die(die) for die in dice]
return max(die.outcomes()[-1] for die in converted_dice)


def range_union(*args: Iterable[int]) -> Sequence[int]:
"""Produces a sequence of consecutive ints covering the argument sets."""
start = min((x for x in itertools.chain(*args)), default=None)
if start is None:
return ()
stop = max(x for x in itertools.chain(*args))
return tuple(range(start, stop+1))
return tuple(range(start, stop + 1))


def sorted_union(*args: Iterable[T]) -> Sequence[T]:
def sorted_union(*args: Iterable[T]) -> tuple[T, ...]:
"""Merge sets into a sorted sequence."""
if not args:
return ()
return tuple(sorted(set.union(*(set(arg) for arg in args))))


def align(*dice: 'T | icepool.Die[T]') -> tuple['icepool.Die[T]', ...]:
"""DEPRECATED: Pads dice with zero quantities so that all have the same set of outcomes.
Expand Down Expand Up @@ -277,7 +280,7 @@ def align_range(

def commonize_denominator(
*dice: 'T | icepool.Die[T]') -> tuple['icepool.Die[T]', ...]:
"""Scale the weights of the dice so that all of them have the same denominator.
"""Scale the quantities of the dice so that all of them have the same denominator.
Args:
*dice: Any number of dice or single outcomes convertible to dice.
Expand All @@ -292,8 +295,8 @@ def commonize_denominator(
if die.denominator() > 0))
return tuple(
die.multiply_quantities(denominator_lcm //
die.denominator() if die.denominator() > 0 else 1)
for die in converted_dice)
die.denominator() if die.denominator() >
0 else 1) for die in converted_dice)


def reduce(
Expand Down
14 changes: 4 additions & 10 deletions src/icepool/generator/deal.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ def _new_raw(cls, deck: 'icepool.Deck[T]', hand_size: int,
self._keep_tuple = keep_tuple
return self

@classmethod
def _new_empty(cls):
self = super(Deal, cls).__new__(cls)
self._deck = icepool.Deck(())
self._hand_size = 0
self._keep_tuple = ()
return self

def _set_keep_tuple(self, keep_tuple: tuple[int, ...]) -> 'Deal[T]':
return Deal._new_raw(self._deck, self._hand_size, keep_tuple)

Expand Down Expand Up @@ -120,7 +112,8 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator:
yield popped_deal, (result_count, ), weight

if skip_weight is not None:
yield Deal._new_empty(), (sum(self.keep_tuple()), ), skip_weight
popped_deal = Deal._new_raw(popped_deck, 0, ())
yield popped_deal, (sum(self.keep_tuple()), ), skip_weight

def _generate_max(self, max_outcome) -> NextMultisetGenerator:
if not self.outcomes() or max_outcome != self.max_outcome():
Expand All @@ -146,7 +139,8 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator:
yield popped_deal, (result_count, ), weight

if skip_weight is not None:
yield Deal._new_empty(), (sum(self.keep_tuple()), ), skip_weight
popped_deal = Deal._new_raw(popped_deck, 0, ())
yield popped_deal, (sum(self.keep_tuple()), ), skip_weight

def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]:
lo_skip, hi_skip = icepool.generator.pop_order.lo_hi_skip(
Expand Down
77 changes: 37 additions & 40 deletions src/icepool/generator/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Pool(KeepGenerator[T]):
"""

_dice: tuple[tuple['icepool.Die[T]', int]]
_outcomes: tuple[T, ...]

def __new__(
cls,
Expand All @@ -50,8 +51,9 @@ def __new__(
It is permissible to create a `Pool` without providing dice, but not all
evaluators will handle this case, especially if they depend on the
outcome type. In this case you may want to provide a die with zero
quantity.
outcome type. Dice may be in the pool zero times, in which case their
outcomes will be considered but without any count (unless another die
has that outcome).
Args:
dice: The dice to put in the `Pool`. This can be one of the following:
Expand Down Expand Up @@ -87,17 +89,19 @@ def __new__(

dice_counts: MutableMapping['icepool.Die[T]', int] = defaultdict(int)
for die, qty in zip(converted_dice, times):
if die.is_empty() and qty == 0:
# zero empty dice is considered to have no effect
if qty == 0:
continue
dice_counts[die] += qty
keep_tuple = (1, ) * sum(times)
return cls._new_from_mapping(dice_counts, keep_tuple)

# Includes dice with zero qty.
outcomes = icepool.sorted_union(*converted_dice)
return cls._new_from_mapping(dice_counts, outcomes, keep_tuple)

@classmethod
@cache
def _new_raw(cls, dice: tuple[tuple['icepool.Die[T]', int]],
keep_tuple: tuple[int, ...]) -> 'Pool[T]':
outcomes: tuple[T], keep_tuple: tuple[int, ...]) -> 'Pool[T]':
"""All pool creation ends up here. This method is cached.
Args:
Expand All @@ -106,20 +110,18 @@ def _new_raw(cls, dice: tuple[tuple['icepool.Die[T]', int]],
"""
self = super(Pool, cls).__new__(cls)
self._dice = dice
self._outcomes = outcomes
self._keep_tuple = keep_tuple
return self

@classmethod
def _new_empty(cls) -> 'Pool':
return cls._new_raw((), ())

@classmethod
def clear_cache(cls):
"""Clears the global pool cache."""
Pool._new_raw.cache_clear()

@classmethod
def _new_from_mapping(cls, dice_counts: Mapping['icepool.Die[T]', int],
outcomes: tuple[T, ...],
keep_tuple: Sequence[int]) -> 'Pool[T]':
"""Creates a new pool.
Expand All @@ -129,7 +131,7 @@ def _new_from_mapping(cls, dice_counts: Mapping['icepool.Die[T]', int],
"""
dice = tuple(
sorted(dice_counts.items(), key=lambda kv: kv[0]._hash_key))
return Pool._new_raw(dice, keep_tuple)
return Pool._new_raw(dice, outcomes, keep_tuple)

@cached_property
def _raw_size(self) -> int:
Expand Down Expand Up @@ -161,13 +163,6 @@ def unique_dice(self) -> Collection['icepool.Die[T]']:
"""The collection of unique dice in this pool."""
return self._unique_dice

@cached_property
def _outcomes(self) -> Sequence[T]:
outcome_set = set(
itertools.chain.from_iterable(die.outcomes()
for die in self.unique_dice()))
return tuple(sorted(outcome_set))

def outcomes(self) -> Sequence[T]:
"""The union of possible outcomes among all dice in this pool in ascending order."""
return self._outcomes
Expand All @@ -192,21 +187,13 @@ def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]:

return Order.Any, PopOrderReason.NoPreference

@cached_property
def _min_outcome(self) -> T:
return min(die.min_outcome() for die in self.unique_dice())

def min_outcome(self) -> T:
"""The min outcome among all dice in this pool."""
return self._min_outcome

@cached_property
def _max_outcome(self) -> T:
return max(die.max_outcome() for die in self.unique_dice())
return self._outcomes[0]

def max_outcome(self) -> T:
"""The max outcome among all dice in this pool."""
return self._max_outcome
return self._outcomes[-1]

def _generate_initial(self) -> InitialMultisetGenerator:
yield self, 1
Expand All @@ -223,6 +210,9 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator:
if not self.outcomes():
yield self, (0, ), 1
return
if min_outcome != self.min_outcome():
yield self, (0, ), 1
return
generators = [
iter_die_pop_min(die, die_count, min_outcome)
for die, die_count in self._dice
Expand All @@ -233,13 +223,14 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator:
result_weight = 1
next_dice_counts: MutableMapping[Any, int] = defaultdict(int)
for popped_die, misses, hits, weight in pop:
if not popped_die.is_empty():
if not popped_die.is_empty() and misses > 0:
next_dice_counts[popped_die] += misses
total_hits += hits
result_weight *= weight
popped_keep_tuple, result_count = pop_min_from_keep_tuple(
self.keep_tuple(), total_hits)
popped_pool = Pool._new_from_mapping(next_dice_counts,
self._outcomes[1:],
popped_keep_tuple)
if not any(popped_keep_tuple):
# Dump all dice in exchange for the denominator.
Expand All @@ -250,7 +241,8 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator:
yield popped_pool, (result_count, ), result_weight

if skip_weight is not None:
yield Pool._new_empty(), (sum(self.keep_tuple()), ), skip_weight
popped_pool = Pool._new_raw((), self._outcomes[1:], ())
yield popped_pool, (sum(self.keep_tuple()), ), skip_weight

def _generate_max(self, max_outcome) -> NextMultisetGenerator:
"""Pops the given outcome from this pool, if it is the max outcome.
Expand All @@ -264,6 +256,9 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator:
if not self.outcomes():
yield self, (0, ), 1
return
if max_outcome != self.max_outcome():
yield self, (0, ), 1
return
generators = [
iter_die_pop_max(die, die_count, max_outcome)
for die, die_count in self._dice
Expand All @@ -274,13 +269,14 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator:
result_weight = 1
next_dice_counts: MutableMapping[Any, int] = defaultdict(int)
for popped_die, misses, hits, weight in pop:
if not popped_die.is_empty():
if not popped_die.is_empty() and misses > 0:
next_dice_counts[popped_die] += misses
total_hits += hits
result_weight *= weight
popped_keep_tuple, result_count = pop_max_from_keep_tuple(
self.keep_tuple(), total_hits)
popped_pool = Pool._new_from_mapping(next_dice_counts,
self._outcomes[:-1],
popped_keep_tuple)
if not any(popped_keep_tuple):
# Dump all dice in exchange for the denominator.
Expand All @@ -291,11 +287,12 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator:
yield popped_pool, (result_count, ), result_weight

if skip_weight is not None:
yield Pool._new_empty(), (sum(self.keep_tuple()), ), skip_weight
popped_pool = Pool._new_raw((), self._outcomes[:-1], ())
yield popped_pool, (sum(self.keep_tuple()), ), skip_weight

def _set_keep_tuple(self, keep_tuple: tuple[int,
...]) -> 'KeepGenerator[T]':
return Pool._new_raw(self._dice, keep_tuple)
return Pool._new_raw(self._dice, self._outcomes, keep_tuple)

def additive_union(
*args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
Expand All @@ -304,21 +301,21 @@ def additive_union(
icepool.expression.implicit_convert_to_expression(arg)
for arg in args)
if all(isinstance(arg, Pool) for arg in args):
pools = cast(tuple[Pool, ...], args)
pools = cast(tuple[Pool[T], ...], args)
keep_tuple: tuple[int, ...] = tuple(
reduce(operator.add, (pool.keep_tuple() for pool in pools),
()))
if len(keep_tuple) == 0:
# All empty.
return Pool._new_empty()
if all(x == keep_tuple[0] for x in keep_tuple):
if len(keep_tuple) == 0 or all(x == keep_tuple[0]
for x in keep_tuple):
# All sorted positions count the same, so we can merge the
# pools.
dice: 'MutableMapping[icepool.Die, int]' = defaultdict(int)
for pool in pools:
for die, die_count in pool._dice:
dice[die] += die_count
return Pool._new_from_mapping(dice, keep_tuple)
outcomes = icepool.sorted_union(*(pool.outcomes()
for pool in pools))
return Pool._new_from_mapping(dice, outcomes, keep_tuple)
return KeepGenerator.additive_union(*args)

def __str__(self) -> str:
Expand All @@ -329,7 +326,7 @@ def __str__(self) -> str:

@cached_property
def _hash_key(self) -> tuple:
return Pool, self._dice, self._keep_tuple
return Pool, self._dice, self._outcomes, self._keep_tuple


def standard_pool(
Expand Down
Loading

0 comments on commit d759ad5

Please sign in to comment.