Skip to content

Commit

Permalink
remove unneeded MultisetExpression.evaluate
Browse files Browse the repository at this point in the history
`MultisetEvaluator` takes care of determining whether the input expressions are fully
bound and either returning another `MultisetEvaluator` or a `Die`.
  • Loading branch information
HighDiceRoller committed Dec 27, 2024
1 parent 222afb5 commit 74a3314
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 72 deletions.
101 changes: 33 additions & 68 deletions src/icepool/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,7 @@ def keep(
Use the `[]` operator for the same effect as this method.
"""
if isinstance(index, int):
return self.evaluate(
evaluator=icepool.evaluator.KeepEvaluator(index))
return icepool.evaluator.KeepEvaluator(index).evaluate(self)
else:
return icepool.operator.MultisetKeep(self, index=index)

Expand Down Expand Up @@ -959,29 +958,6 @@ def maximum_match_lowest(

# Evaluations.

def evaluate(
*expressions: 'MultisetExpression[T]',
evaluator: 'icepool.MultisetEvaluator[T, U]'
) -> 'icepool.Die[U] | icepool.MultisetEvaluator[T, U]':
"""Attaches a final `MultisetEvaluator` to expressions.
All of the `MultisetExpression` methods below are evaluations,
as are the operators `<, <=, >, >=, !=, ==`. This means if the
expression is fully bound, it will be evaluated to a `Die`.
Returns:
A `Die` if the expression is are fully bound.
A `MultisetEvaluator` otherwise.
"""
if all(expression._free_arity() == 0 for expression in expressions):
return evaluator.evaluate(*expressions)
evaluator = icepool.evaluator.MultisetFunctionEvaluator(
*expressions, evaluator=evaluator)
if evaluator._free_arity == 0:
return evaluator.evaluate()
else:
return evaluator

def expand(
self,
order: Order = Order.Ascending
Expand All @@ -994,18 +970,17 @@ def expand(
order: Whether the elements are in ascending (default) or descending
order.
"""
return self.evaluate(evaluator=icepool.evaluator.ExpandEvaluator(
order=order))
return icepool.evaluator.ExpandEvaluator(order=order).evaluate(self)

def sum(
self,
map: Callable[[T], U] | Mapping[T, U] | None = None
) -> 'icepool.Die[U] | icepool.MultisetEvaluator[T, U]':
"""Evaluation: The sum of all elements."""
if map is None:
return self.evaluate(evaluator=icepool.evaluator.sum_evaluator)
return icepool.evaluator.sum_evaluator.evaluate(self)
else:
return self.evaluate(evaluator=icepool.evaluator.SumEvaluator(map))
return icepool.evaluator.SumEvaluator(map).evaluate(self)

def count(self) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]':
"""Evaluation: The total number of elements in the multiset.
Expand All @@ -1018,11 +993,11 @@ def count(self) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]':
`(generator & [4, 5, 6]).count()` will count up to one each of
4, 5, and 6.
"""
return self.evaluate(evaluator=icepool.evaluator.count_evaluator)
return icepool.evaluator.count_evaluator.evaluate(self)

def any(self) -> 'icepool.Die[bool] | icepool.MultisetEvaluator[T, bool]':
"""Evaluation: Whether the multiset has at least one positive count. """
return self.evaluate(evaluator=icepool.evaluator.any_evaluator)
return icepool.evaluator.any_evaluator.evaluate(self)

def highest_outcome_and_count(
self
Expand All @@ -1031,8 +1006,8 @@ def highest_outcome_and_count(
If no outcomes have positive count, the min outcome will be returned with 0 count.
"""
return self.evaluate(
evaluator=icepool.evaluator.highest_outcome_and_count_evaluator)
return icepool.evaluator.highest_outcome_and_count_evaluator.evaluate(
self)

def all_counts(
self,
Expand All @@ -1053,21 +1028,20 @@ def all_counts(
output zero counts. So we might as well use the argument to do
both.
"""
return self.evaluate(evaluator=icepool.evaluator.AllCountsEvaluator(
filter=filter))
return icepool.evaluator.AllCountsEvaluator(
filter=filter).evaluate(self)

def largest_count(
self) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]':
"""Evaluation: The size of the largest matching set among the elements."""
return self.evaluate(
evaluator=icepool.evaluator.largest_count_evaluator)
return icepool.evaluator.largest_count_evaluator.evaluate(self)

def largest_count_and_outcome(
self
) -> 'icepool.Die[tuple[int, T]] | icepool.MultisetEvaluator[T, tuple[int, T]]':
"""Evaluation: The largest matching set among the elements and the corresponding outcome."""
return self.evaluate(
evaluator=icepool.evaluator.largest_count_and_outcome_evaluator)
return icepool.evaluator.largest_count_and_outcome_evaluator.evaluate(
self)

def __rfloordiv__(
self, other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
Expand Down Expand Up @@ -1095,9 +1069,8 @@ def count_subset(
empty_divisor_outcome is not set.
"""
divisor = implicit_convert_to_expression(divisor)
return self.evaluate(divisor,
evaluator=icepool.evaluator.CountSubsetEvaluator(
empty_divisor=empty_divisor))
return icepool.evaluator.CountSubsetEvaluator(
empty_divisor=empty_divisor).evaluate(self, divisor)

def largest_straight(
self: 'MultisetExpression[int]'
Expand All @@ -1106,8 +1079,7 @@ def largest_straight(
Outcomes must be `int`s.
"""
return self.evaluate(
evaluator=icepool.evaluator.largest_straight_evaluator)
return icepool.evaluator.largest_straight_evaluator.evaluate(self)

def largest_straight_and_outcome(
self: 'MultisetExpression[int]'
Expand All @@ -1116,8 +1088,8 @@ def largest_straight_and_outcome(
Outcomes must be `int`s.
"""
return self.evaluate(
evaluator=icepool.evaluator.largest_straight_and_outcome_evaluator)
return icepool.evaluator.largest_straight_and_outcome_evaluator.evaluate(
self)

def all_straights(
self: 'MultisetExpression[int]'
Expand All @@ -1130,8 +1102,7 @@ def all_straights(
elements can produces straights that overlap in outcomes. In this case,
elements are preferentially assigned to the longer straight.
"""
return self.evaluate(
evaluator=icepool.evaluator.all_straights_evaluator)
return icepool.evaluator.all_straights_evaluator.evaluate(self)

def all_straights_reduce_counts(
self: 'MultisetExpression[int]',
Expand All @@ -1143,9 +1114,8 @@ def all_straights_reduce_counts(
The result is a tuple of `(run_length, run_score)`s.
"""
return self.evaluate(
evaluator=icepool.evaluator.AllStraightsReduceCountsEvaluator(
reducer=reducer))
return icepool.evaluator.AllStraightsReduceCountsEvaluator(
reducer=reducer).evaluate(self)

def argsort(self: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]',
*args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]',
Expand All @@ -1171,9 +1141,9 @@ def argsort(self: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]',
"""
self = implicit_convert_to_expression(self)
converted_args = [implicit_convert_to_expression(arg) for arg in args]
return self.evaluate(*converted_args,
evaluator=icepool.evaluator.ArgsortEvaluator(
order=order, limit=limit))
return icepool.evaluator.ArgsortEvaluator(order=order,
limit=limit).evaluate(
self, *converted_args)

# Comparators.

Expand All @@ -1186,23 +1156,18 @@ def _compare(
) -> 'icepool.Die[bool] | icepool.MultisetEvaluator[T, bool]':
right = icepool.implicit_convert_to_expression(right)

if self._free_arity() == 0 and right._free_arity() == 0:
if truth_value_callback is not None:
if truth_value_callback is not None:

def data_callback() -> Counts[bool]:
die = cast('icepool.Die[bool]',
operation_class().evaluate(self, right))
if not isinstance(die, icepool.Die):
raise TypeError('Did not resolve to a die.')
return die._data
def data_callback() -> Counts[bool]:
die = cast('icepool.Die[bool]',
operation_class().evaluate(self, right))
if not isinstance(die, icepool.Die):
raise TypeError('Did not resolve to a die.')
return die._data

return icepool.DieWithTruth(data_callback,
truth_value_callback)
else:
return operation_class().evaluate(self, right)
return icepool.DieWithTruth(data_callback, truth_value_callback)
else:
return icepool.evaluator.MultisetFunctionEvaluator(
self, right, evaluator=operation_class())
return operation_class().evaluate(self, right)

def __lt__(self,
other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]',
Expand Down
6 changes: 3 additions & 3 deletions tests/deck_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def next_state(self, state, outcome, *counts):

def test_empty_deal():
deal = icepool.Deck(range(13), times=4).deal()
result = deal.evaluate(evaluator=TrivialEvaluator())
result = TrivialEvaluator().evaluate(deal)
assert result.equals(icepool.Die([0]))


Expand All @@ -42,7 +42,7 @@ def test_two_hand_sum_same_size():
result1 = deal1.sum()

deal2 = deck.deal(5, 5)
result2 = deal2.evaluate(evaluator=SumEachEvaluator())
result2 = SumEachEvaluator().evaluate(deal2)

assert deal2.denominator() == result2.denominator()
assert result1.equals(result2.marginals[0], simplify=True)
Expand All @@ -53,7 +53,7 @@ def test_two_hand_sum_diff_size():
deck = icepool.Deck(range(4), times=4)

deal = deck.deal(2, 4)
result = deal.evaluate(evaluator=SumEachEvaluator())
result = SumEachEvaluator().evaluate(deal)

assert deal.denominator() == result.denominator()
assert (result.marginals[0] * 2).mean() == result.marginals[1].mean()
Expand Down
2 changes: 1 addition & 1 deletion tests/vector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def extra_outcomes(self, *_):
return [1, 2, 3, 4, 5, 6]

result = 3 @ icepool.one_hot(6)
expected = icepool.d6.pool(3).evaluate(evaluator=OneHotEvaluator())
expected = OneHotEvaluator().evaluate(icepool.d6.pool(3))
assert result == expected


Expand Down

0 comments on commit 74a3314

Please sign in to comment.