diff --git a/src/icepool/multiset_expression.py b/src/icepool/multiset_expression.py index a13c8f8b..145230ce 100644 --- a/src/icepool/multiset_expression.py +++ b/src/icepool/multiset_expression.py @@ -704,8 +704,7 @@ def keep( Use the `[]` operator for the same effect as this method. """ if isinstance(index, int): - return self.evaluate( - evaluator=icepool.evaluator.KeepEvaluator(index)) + return icepool.evaluator.KeepEvaluator(index).evaluate(self) else: return icepool.operator.MultisetKeep(self, index=index) @@ -959,29 +958,6 @@ def maximum_match_lowest( # Evaluations. - def evaluate( - *expressions: 'MultisetExpression[T]', - evaluator: 'icepool.MultisetEvaluator[T, U]' - ) -> 'icepool.Die[U] | icepool.MultisetEvaluator[T, U]': - """Attaches a final `MultisetEvaluator` to expressions. - - All of the `MultisetExpression` methods below are evaluations, - as are the operators `<, <=, >, >=, !=, ==`. This means if the - expression is fully bound, it will be evaluated to a `Die`. - - Returns: - A `Die` if the expression is are fully bound. - A `MultisetEvaluator` otherwise. - """ - if all(expression._free_arity() == 0 for expression in expressions): - return evaluator.evaluate(*expressions) - evaluator = icepool.evaluator.MultisetFunctionEvaluator( - *expressions, evaluator=evaluator) - if evaluator._free_arity == 0: - return evaluator.evaluate() - else: - return evaluator - def expand( self, order: Order = Order.Ascending @@ -994,8 +970,7 @@ def expand( order: Whether the elements are in ascending (default) or descending order. """ - return self.evaluate(evaluator=icepool.evaluator.ExpandEvaluator( - order=order)) + return icepool.evaluator.ExpandEvaluator(order=order).evaluate(self) def sum( self, @@ -1003,9 +978,9 @@ def sum( ) -> 'icepool.Die[U] | icepool.MultisetEvaluator[T, U]': """Evaluation: The sum of all elements.""" if map is None: - return self.evaluate(evaluator=icepool.evaluator.sum_evaluator) + return icepool.evaluator.sum_evaluator.evaluate(self) else: - return self.evaluate(evaluator=icepool.evaluator.SumEvaluator(map)) + return icepool.evaluator.SumEvaluator(map).evaluate(self) def count(self) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]': """Evaluation: The total number of elements in the multiset. @@ -1018,11 +993,11 @@ def count(self) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]': `(generator & [4, 5, 6]).count()` will count up to one each of 4, 5, and 6. """ - return self.evaluate(evaluator=icepool.evaluator.count_evaluator) + return icepool.evaluator.count_evaluator.evaluate(self) def any(self) -> 'icepool.Die[bool] | icepool.MultisetEvaluator[T, bool]': """Evaluation: Whether the multiset has at least one positive count. """ - return self.evaluate(evaluator=icepool.evaluator.any_evaluator) + return icepool.evaluator.any_evaluator.evaluate(self) def highest_outcome_and_count( self @@ -1031,8 +1006,8 @@ def highest_outcome_and_count( If no outcomes have positive count, the min outcome will be returned with 0 count. """ - return self.evaluate( - evaluator=icepool.evaluator.highest_outcome_and_count_evaluator) + return icepool.evaluator.highest_outcome_and_count_evaluator.evaluate( + self) def all_counts( self, @@ -1053,21 +1028,20 @@ def all_counts( output zero counts. So we might as well use the argument to do both. """ - return self.evaluate(evaluator=icepool.evaluator.AllCountsEvaluator( - filter=filter)) + return icepool.evaluator.AllCountsEvaluator( + filter=filter).evaluate(self) def largest_count( self) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]': """Evaluation: The size of the largest matching set among the elements.""" - return self.evaluate( - evaluator=icepool.evaluator.largest_count_evaluator) + return icepool.evaluator.largest_count_evaluator.evaluate(self) def largest_count_and_outcome( self ) -> 'icepool.Die[tuple[int, T]] | icepool.MultisetEvaluator[T, tuple[int, T]]': """Evaluation: The largest matching set among the elements and the corresponding outcome.""" - return self.evaluate( - evaluator=icepool.evaluator.largest_count_and_outcome_evaluator) + return icepool.evaluator.largest_count_and_outcome_evaluator.evaluate( + self) def __rfloordiv__( self, other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' @@ -1095,9 +1069,8 @@ def count_subset( empty_divisor_outcome is not set. """ divisor = implicit_convert_to_expression(divisor) - return self.evaluate(divisor, - evaluator=icepool.evaluator.CountSubsetEvaluator( - empty_divisor=empty_divisor)) + return icepool.evaluator.CountSubsetEvaluator( + empty_divisor=empty_divisor).evaluate(self, divisor) def largest_straight( self: 'MultisetExpression[int]' @@ -1106,8 +1079,7 @@ def largest_straight( Outcomes must be `int`s. """ - return self.evaluate( - evaluator=icepool.evaluator.largest_straight_evaluator) + return icepool.evaluator.largest_straight_evaluator.evaluate(self) def largest_straight_and_outcome( self: 'MultisetExpression[int]' @@ -1116,8 +1088,8 @@ def largest_straight_and_outcome( Outcomes must be `int`s. """ - return self.evaluate( - evaluator=icepool.evaluator.largest_straight_and_outcome_evaluator) + return icepool.evaluator.largest_straight_and_outcome_evaluator.evaluate( + self) def all_straights( self: 'MultisetExpression[int]' @@ -1130,8 +1102,7 @@ def all_straights( elements can produces straights that overlap in outcomes. In this case, elements are preferentially assigned to the longer straight. """ - return self.evaluate( - evaluator=icepool.evaluator.all_straights_evaluator) + return icepool.evaluator.all_straights_evaluator.evaluate(self) def all_straights_reduce_counts( self: 'MultisetExpression[int]', @@ -1143,9 +1114,8 @@ def all_straights_reduce_counts( The result is a tuple of `(run_length, run_score)`s. """ - return self.evaluate( - evaluator=icepool.evaluator.AllStraightsReduceCountsEvaluator( - reducer=reducer)) + return icepool.evaluator.AllStraightsReduceCountsEvaluator( + reducer=reducer).evaluate(self) def argsort(self: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', @@ -1171,9 +1141,9 @@ def argsort(self: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', """ self = implicit_convert_to_expression(self) converted_args = [implicit_convert_to_expression(arg) for arg in args] - return self.evaluate(*converted_args, - evaluator=icepool.evaluator.ArgsortEvaluator( - order=order, limit=limit)) + return icepool.evaluator.ArgsortEvaluator(order=order, + limit=limit).evaluate( + self, *converted_args) # Comparators. @@ -1186,23 +1156,18 @@ def _compare( ) -> 'icepool.Die[bool] | icepool.MultisetEvaluator[T, bool]': right = icepool.implicit_convert_to_expression(right) - if self._free_arity() == 0 and right._free_arity() == 0: - if truth_value_callback is not None: + if truth_value_callback is not None: - def data_callback() -> Counts[bool]: - die = cast('icepool.Die[bool]', - operation_class().evaluate(self, right)) - if not isinstance(die, icepool.Die): - raise TypeError('Did not resolve to a die.') - return die._data + def data_callback() -> Counts[bool]: + die = cast('icepool.Die[bool]', + operation_class().evaluate(self, right)) + if not isinstance(die, icepool.Die): + raise TypeError('Did not resolve to a die.') + return die._data - return icepool.DieWithTruth(data_callback, - truth_value_callback) - else: - return operation_class().evaluate(self, right) + return icepool.DieWithTruth(data_callback, truth_value_callback) else: - return icepool.evaluator.MultisetFunctionEvaluator( - self, right, evaluator=operation_class()) + return operation_class().evaluate(self, right) def __lt__(self, other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', diff --git a/tests/deck_test.py b/tests/deck_test.py index ed9015f0..3dd3c365 100644 --- a/tests/deck_test.py +++ b/tests/deck_test.py @@ -26,7 +26,7 @@ def next_state(self, state, outcome, *counts): def test_empty_deal(): deal = icepool.Deck(range(13), times=4).deal() - result = deal.evaluate(evaluator=TrivialEvaluator()) + result = TrivialEvaluator().evaluate(deal) assert result.equals(icepool.Die([0])) @@ -42,7 +42,7 @@ def test_two_hand_sum_same_size(): result1 = deal1.sum() deal2 = deck.deal(5, 5) - result2 = deal2.evaluate(evaluator=SumEachEvaluator()) + result2 = SumEachEvaluator().evaluate(deal2) assert deal2.denominator() == result2.denominator() assert result1.equals(result2.marginals[0], simplify=True) @@ -53,7 +53,7 @@ def test_two_hand_sum_diff_size(): deck = icepool.Deck(range(4), times=4) deal = deck.deal(2, 4) - result = deal.evaluate(evaluator=SumEachEvaluator()) + result = SumEachEvaluator().evaluate(deal) assert deal.denominator() == result.denominator() assert (result.marginals[0] * 2).mean() == result.marginals[1].mean() diff --git a/tests/vector_test.py b/tests/vector_test.py index 01f013de..2de11052 100644 --- a/tests/vector_test.py +++ b/tests/vector_test.py @@ -102,7 +102,7 @@ def extra_outcomes(self, *_): return [1, 2, 3, 4, 5, 6] result = 3 @ icepool.one_hot(6) - expected = icepool.d6.pool(3).evaluate(evaluator=OneHotEvaluator()) + expected = OneHotEvaluator().evaluate(icepool.d6.pool(3)) assert result == expected