diff --git a/src/icepool/evaluator/multiset_evaluator.py b/src/icepool/evaluator/multiset_evaluator.py index 463776a7..404dc824 100644 --- a/src/icepool/evaluator/multiset_evaluator.py +++ b/src/icepool/evaluator/multiset_evaluator.py @@ -85,17 +85,11 @@ def next_state(self, state: Hashable, outcome: T_contra, /, *counts: outcome: The current outcome. `next_state` will see all rolled outcomes in monotonic order; either ascending or descending depending on `order()`. - If there are multiple generators, the set of outcomes is the - union of the outcomes of the invididual generators. All outcomes - with nonzero count will be seen. Outcomes with zero count - may or may not be seen. If you need to enforce that certain - outcomes are seen even if they have zero count, see - `alignment()`. + If there are multiple generators, the set of outcomes is at + least the union of the outcomes of the invididual generators. + You can use `alignment()` to add additional outcomes. *counts: One value (usually an `int`) for each generator output indicating how many of the current outcome were produced. - All outcomes with nonzero count are guaranteed to be seen. - To guarantee that outcomes are seen even if they have zero - count, override `alignment()`. Returns: A hashable object indicating the next state. @@ -144,19 +138,14 @@ def order(self) -> Order: def alignment(self, outcomes: Sequence[T_contra]) -> Collection[T_contra]: """Optional method to specify additional outcomes that should be seen by `next_state()`. - These will be seen by `next_state` even if they have zero count or do - not appear in the generator(s) at all. - - The default implementation returns `()`; this means outcomes with zero - count may or may not be seen by `next_state`. + These will be seen by `next_state` even if they do not appear in the + generator(s). The default implementation returns `()`, or no additional + outcomes. If you want `next_state` to see consecutive `int` outcomes, you can set `alignment = icepool.MultisetEvaluator.range_alignment`. See `range_alignment()` below. - If you want `next_state` to see all generator outcomes, you can return - `outcomes` as-is. - Args: outcomes: The outcomes that could be produced by the generators, in ascending order. @@ -291,7 +280,7 @@ def evaluate( # We use a separate class to guarantee all outcomes are visited. outcomes = icepool.sorted_union(*(generator.outcomes() - for generator in generators)) + for generator in generators)) alignment = Alignment(self.alignment(outcomes)) dist: MutableMapping[Any, int] = defaultdict(int) diff --git a/src/icepool/function.py b/src/icepool/function.py index 512d0d93..28392d3c 100644 --- a/src/icepool/function.py +++ b/src/icepool/function.py @@ -229,20 +229,23 @@ def max_outcome(*dice: 'T | icepool.Die[T]') -> T: converted_dice = [icepool.implicit_convert_to_die(die) for die in dice] return max(die.outcomes()[-1] for die in converted_dice) + def range_union(*args: Iterable[int]) -> Sequence[int]: """Produces a sequence of consecutive ints covering the argument sets.""" start = min((x for x in itertools.chain(*args)), default=None) if start is None: return () stop = max(x for x in itertools.chain(*args)) - return tuple(range(start, stop+1)) + return tuple(range(start, stop + 1)) + -def sorted_union(*args: Iterable[T]) -> Sequence[T]: +def sorted_union(*args: Iterable[T]) -> tuple[T, ...]: """Merge sets into a sorted sequence.""" if not args: return () return tuple(sorted(set.union(*(set(arg) for arg in args)))) + def align(*dice: 'T | icepool.Die[T]') -> tuple['icepool.Die[T]', ...]: """DEPRECATED: Pads dice with zero quantities so that all have the same set of outcomes. @@ -277,7 +280,7 @@ def align_range( def commonize_denominator( *dice: 'T | icepool.Die[T]') -> tuple['icepool.Die[T]', ...]: - """Scale the weights of the dice so that all of them have the same denominator. + """Scale the quantities of the dice so that all of them have the same denominator. Args: *dice: Any number of dice or single outcomes convertible to dice. @@ -292,8 +295,8 @@ def commonize_denominator( if die.denominator() > 0)) return tuple( die.multiply_quantities(denominator_lcm // - die.denominator() if die.denominator() > 0 else 1) - for die in converted_dice) + die.denominator() if die.denominator() > + 0 else 1) for die in converted_dice) def reduce( diff --git a/src/icepool/generator/deal.py b/src/icepool/generator/deal.py index 3929b299..9b9396c7 100644 --- a/src/icepool/generator/deal.py +++ b/src/icepool/generator/deal.py @@ -53,14 +53,6 @@ def _new_raw(cls, deck: 'icepool.Deck[T]', hand_size: int, self._keep_tuple = keep_tuple return self - @classmethod - def _new_empty(cls): - self = super(Deal, cls).__new__(cls) - self._deck = icepool.Deck(()) - self._hand_size = 0 - self._keep_tuple = () - return self - def _set_keep_tuple(self, keep_tuple: tuple[int, ...]) -> 'Deal[T]': return Deal._new_raw(self._deck, self._hand_size, keep_tuple) @@ -120,7 +112,8 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: yield popped_deal, (result_count, ), weight if skip_weight is not None: - yield Deal._new_empty(), (sum(self.keep_tuple()), ), skip_weight + popped_deal = Deal._new_raw(popped_deck, 0, ()) + yield popped_deal, (sum(self.keep_tuple()), ), skip_weight def _generate_max(self, max_outcome) -> NextMultisetGenerator: if not self.outcomes() or max_outcome != self.max_outcome(): @@ -146,7 +139,8 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator: yield popped_deal, (result_count, ), weight if skip_weight is not None: - yield Deal._new_empty(), (sum(self.keep_tuple()), ), skip_weight + popped_deal = Deal._new_raw(popped_deck, 0, ()) + yield popped_deal, (sum(self.keep_tuple()), ), skip_weight def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: lo_skip, hi_skip = icepool.generator.pop_order.lo_hi_skip( diff --git a/src/icepool/generator/pool.py b/src/icepool/generator/pool.py index 114f3226..6b4b8d32 100644 --- a/src/icepool/generator/pool.py +++ b/src/icepool/generator/pool.py @@ -36,6 +36,7 @@ class Pool(KeepGenerator[T]): """ _dice: tuple[tuple['icepool.Die[T]', int]] + _outcomes: tuple[T, ...] def __new__( cls, @@ -50,8 +51,9 @@ def __new__( It is permissible to create a `Pool` without providing dice, but not all evaluators will handle this case, especially if they depend on the - outcome type. In this case you may want to provide a die with zero - quantity. + outcome type. Dice may be in the pool zero times, in which case their + outcomes will be considered but without any count (unless another die + has that outcome). Args: dice: The dice to put in the `Pool`. This can be one of the following: @@ -87,17 +89,19 @@ def __new__( dice_counts: MutableMapping['icepool.Die[T]', int] = defaultdict(int) for die, qty in zip(converted_dice, times): - if die.is_empty() and qty == 0: - # zero empty dice is considered to have no effect + if qty == 0: continue dice_counts[die] += qty keep_tuple = (1, ) * sum(times) - return cls._new_from_mapping(dice_counts, keep_tuple) + + # Includes dice with zero qty. + outcomes = icepool.sorted_union(*converted_dice) + return cls._new_from_mapping(dice_counts, outcomes, keep_tuple) @classmethod @cache def _new_raw(cls, dice: tuple[tuple['icepool.Die[T]', int]], - keep_tuple: tuple[int, ...]) -> 'Pool[T]': + outcomes: tuple[T], keep_tuple: tuple[int, ...]) -> 'Pool[T]': """All pool creation ends up here. This method is cached. Args: @@ -106,13 +110,10 @@ def _new_raw(cls, dice: tuple[tuple['icepool.Die[T]', int]], """ self = super(Pool, cls).__new__(cls) self._dice = dice + self._outcomes = outcomes self._keep_tuple = keep_tuple return self - @classmethod - def _new_empty(cls) -> 'Pool': - return cls._new_raw((), ()) - @classmethod def clear_cache(cls): """Clears the global pool cache.""" @@ -120,6 +121,7 @@ def clear_cache(cls): @classmethod def _new_from_mapping(cls, dice_counts: Mapping['icepool.Die[T]', int], + outcomes: tuple[T, ...], keep_tuple: Sequence[int]) -> 'Pool[T]': """Creates a new pool. @@ -129,7 +131,7 @@ def _new_from_mapping(cls, dice_counts: Mapping['icepool.Die[T]', int], """ dice = tuple( sorted(dice_counts.items(), key=lambda kv: kv[0]._hash_key)) - return Pool._new_raw(dice, keep_tuple) + return Pool._new_raw(dice, outcomes, keep_tuple) @cached_property def _raw_size(self) -> int: @@ -161,13 +163,6 @@ def unique_dice(self) -> Collection['icepool.Die[T]']: """The collection of unique dice in this pool.""" return self._unique_dice - @cached_property - def _outcomes(self) -> Sequence[T]: - outcome_set = set( - itertools.chain.from_iterable(die.outcomes() - for die in self.unique_dice())) - return tuple(sorted(outcome_set)) - def outcomes(self) -> Sequence[T]: """The union of possible outcomes among all dice in this pool in ascending order.""" return self._outcomes @@ -192,21 +187,13 @@ def _preferred_pop_order(self) -> tuple[Order | None, PopOrderReason]: return Order.Any, PopOrderReason.NoPreference - @cached_property - def _min_outcome(self) -> T: - return min(die.min_outcome() for die in self.unique_dice()) - def min_outcome(self) -> T: """The min outcome among all dice in this pool.""" - return self._min_outcome - - @cached_property - def _max_outcome(self) -> T: - return max(die.max_outcome() for die in self.unique_dice()) + return self._outcomes[0] def max_outcome(self) -> T: """The max outcome among all dice in this pool.""" - return self._max_outcome + return self._outcomes[-1] def _generate_initial(self) -> InitialMultisetGenerator: yield self, 1 @@ -223,6 +210,9 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: if not self.outcomes(): yield self, (0, ), 1 return + if min_outcome != self.min_outcome(): + yield self, (0, ), 1 + return generators = [ iter_die_pop_min(die, die_count, min_outcome) for die, die_count in self._dice @@ -233,13 +223,14 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: result_weight = 1 next_dice_counts: MutableMapping[Any, int] = defaultdict(int) for popped_die, misses, hits, weight in pop: - if not popped_die.is_empty(): + if not popped_die.is_empty() and misses > 0: next_dice_counts[popped_die] += misses total_hits += hits result_weight *= weight popped_keep_tuple, result_count = pop_min_from_keep_tuple( self.keep_tuple(), total_hits) popped_pool = Pool._new_from_mapping(next_dice_counts, + self._outcomes[1:], popped_keep_tuple) if not any(popped_keep_tuple): # Dump all dice in exchange for the denominator. @@ -250,7 +241,8 @@ def _generate_min(self, min_outcome) -> NextMultisetGenerator: yield popped_pool, (result_count, ), result_weight if skip_weight is not None: - yield Pool._new_empty(), (sum(self.keep_tuple()), ), skip_weight + popped_pool = Pool._new_raw((), self._outcomes[1:], ()) + yield popped_pool, (sum(self.keep_tuple()), ), skip_weight def _generate_max(self, max_outcome) -> NextMultisetGenerator: """Pops the given outcome from this pool, if it is the max outcome. @@ -264,6 +256,9 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator: if not self.outcomes(): yield self, (0, ), 1 return + if max_outcome != self.max_outcome(): + yield self, (0, ), 1 + return generators = [ iter_die_pop_max(die, die_count, max_outcome) for die, die_count in self._dice @@ -274,13 +269,14 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator: result_weight = 1 next_dice_counts: MutableMapping[Any, int] = defaultdict(int) for popped_die, misses, hits, weight in pop: - if not popped_die.is_empty(): + if not popped_die.is_empty() and misses > 0: next_dice_counts[popped_die] += misses total_hits += hits result_weight *= weight popped_keep_tuple, result_count = pop_max_from_keep_tuple( self.keep_tuple(), total_hits) popped_pool = Pool._new_from_mapping(next_dice_counts, + self._outcomes[:-1], popped_keep_tuple) if not any(popped_keep_tuple): # Dump all dice in exchange for the denominator. @@ -291,11 +287,12 @@ def _generate_max(self, max_outcome) -> NextMultisetGenerator: yield popped_pool, (result_count, ), result_weight if skip_weight is not None: - yield Pool._new_empty(), (sum(self.keep_tuple()), ), skip_weight + popped_pool = Pool._new_raw((), self._outcomes[:-1], ()) + yield popped_pool, (sum(self.keep_tuple()), ), skip_weight def _set_keep_tuple(self, keep_tuple: tuple[int, ...]) -> 'KeepGenerator[T]': - return Pool._new_raw(self._dice, keep_tuple) + return Pool._new_raw(self._dice, self._outcomes, keep_tuple) def additive_union( *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' @@ -304,21 +301,21 @@ def additive_union( icepool.expression.implicit_convert_to_expression(arg) for arg in args) if all(isinstance(arg, Pool) for arg in args): - pools = cast(tuple[Pool, ...], args) + pools = cast(tuple[Pool[T], ...], args) keep_tuple: tuple[int, ...] = tuple( reduce(operator.add, (pool.keep_tuple() for pool in pools), ())) - if len(keep_tuple) == 0: - # All empty. - return Pool._new_empty() - if all(x == keep_tuple[0] for x in keep_tuple): + if len(keep_tuple) == 0 or all(x == keep_tuple[0] + for x in keep_tuple): # All sorted positions count the same, so we can merge the # pools. dice: 'MutableMapping[icepool.Die, int]' = defaultdict(int) for pool in pools: for die, die_count in pool._dice: dice[die] += die_count - return Pool._new_from_mapping(dice, keep_tuple) + outcomes = icepool.sorted_union(*(pool.outcomes() + for pool in pools)) + return Pool._new_from_mapping(dice, outcomes, keep_tuple) return KeepGenerator.additive_union(*args) def __str__(self) -> str: @@ -329,7 +326,7 @@ def __str__(self) -> str: @cached_property def _hash_key(self) -> tuple: - return Pool, self._dice, self._keep_tuple + return Pool, self._dice, self._outcomes, self._keep_tuple def standard_pool( diff --git a/tests/pool_alignment_test.py b/tests/pool_alignment_test.py index 0dd35377..a521ec34 100644 --- a/tests/pool_alignment_test.py +++ b/tests/pool_alignment_test.py @@ -4,7 +4,16 @@ from icepool import d4, d6, d8, d10, d12, MultisetEvaluator, Pool -class CallPathLength(MultisetEvaluator): +class OutcomeCountEvaluator(MultisetEvaluator): + + def next_state(self, state, outcome, *pools): + return (state or 0) + 1 + + +outcome_count = OutcomeCountEvaluator() + + +class OutcomeRangeEvaluator(MultisetEvaluator): def next_state(self, state, outcome, *pools): return (state or 0) + 1 @@ -12,52 +21,58 @@ def next_state(self, state, outcome, *pools): alignment = MultisetEvaluator.range_alignment -call_path_length = CallPathLength() +outcome_range = OutcomeRangeEvaluator() def test_simple_pool(): pool = d6.pool(5) - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool) assert result.outcomes() == (6, ) def test_individual_outcomes(): pool = Pool([1, 5, 6]) - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool) assert result.outcomes() == (6, ) def test_mixed_pool(): pool = Pool([d4, d6, d6, d8]) - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool) assert result.outcomes() == (8, ) def test_simple_pool_keep_tuple(): pool = d6.pool(5)[-2:] - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool) assert result.outcomes() == (6, ) def test_simple_pool_keep_tuple_low(): pool = d6.pool(5)[:2] - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool) assert result.outcomes() == (6, ) def test_mixed_pool_keep_tuple(): pool = Pool([d4, d6, d6, d8])[-2:] - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool) assert result.outcomes() == (8, ) def test_mixed_pool_keep_tuple_low(): pool = Pool([d4, d6, d6, d8])[:2] - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool) assert result.outcomes() == (8, ) +def test_zero_dice_count(): + pool = Pool({d6: 0}) + result = outcome_count.evaluate(pool) + assert result.outcomes() == (6, ) + + def test_range_alignment_non_int(): pool = Pool([0.5]) with pytest.raises(TypeError): - result = call_path_length.evaluate(pool) + result = outcome_range.evaluate(pool)