Skip to content

Commit

Permalink
merge_order_preferences now always returns an Order or raises `Co…
Browse files Browse the repository at this point in the history
…nflictingOrderError`
  • Loading branch information
HighDiceRoller committed Dec 27, 2024
1 parent dd6551b commit 3d57dd6
Show file tree
Hide file tree
Showing 16 changed files with 97 additions and 100 deletions.
97 changes: 36 additions & 61 deletions src/icepool/evaluator/multiset_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing import Callable, Collection, TypeAlias, overload

from icepool.order import Order
from icepool.order import Order, OrderReason, merge_order_preferences
from icepool.typing import T, U_co

NestedTupleOrEvaluator: TypeAlias = MultisetEvaluator[T, U_co] | tuple[
Expand Down Expand Up @@ -129,56 +129,41 @@ def bad(a, b)
raise ValueError(
'Callable must take only a fixed number of positional arguments.'
)
tuple_or_evaluator = function(*(MV(index=i)
tuple_or_evaluator = function(*(MV(is_free=True, index=i)
for i in range(len(parameters))))
evaluator = replace_tuples_with_joint_evaluator(tuple_or_evaluator)
return update_wrapper(evaluator, function)
# This is not actually a function.
return update_wrapper(evaluator, function) # type: ignore


class MultisetFunctionEvaluator(MultisetEvaluator[T, U_co]):
"""Assigns an expression to be evaluated first to each input of an evaluator."""

def __init__(self,
*expressions:
'icepool.multiset_expression.MultisetExpression[T]',
evaluator: MultisetEvaluator[T, U_co],
truth_value: bool | None = None) -> None:
def __init__(self, *inputs: 'icepool.MultisetExpression[T]',
evaluator: MultisetEvaluator[T, U_co]) -> None:
self._evaluator = evaluator
self._bound_inputs = tuple(
itertools.chain.from_iterable(expression._bound_inputs
for expression in expressions))
self._bound_arity = len(self._bound_inputs)
self._free_arity = max(
(expression._free_arity() for expression in expressions),
default=0)

unbound_expressions: 'list[icepool.expression.MultisetExpression[T]]' = []
extra_start = 0
for expression in expressions:
unbound_expression, extra_start = expression._unbind(extra_start)
unbound_expressions.append(unbound_expression)
self._expressions = tuple(unbound_expressions)
self._truth_value = truth_value
raise NotImplementedError()
bound_inputs: 'list[icepool.MultisetExpression]' = []
self._expressions = tuple(
input._unbind(bound_inputs) for input in inputs)
self._bound_inputs = tuple(bound_inputs)

def next_state(self, state, outcome, *counts):
if state is None:
expression_states = (None, ) * len(self._expressions)
expressions = self._expressions
evaluator_state = None
else:
expression_states, evaluator_state = state
expressions, evaluator_state = state

extra_counts = counts[:len(self._evaluator.bound_inputs())]
counts = counts[len(self._evaluator.bound_inputs()):]
evaluator_slice, bound_slice, free_slice = self._count_slices()
evaluator_counts = counts[evaluator_slice]
bound_counts = counts[bound_slice]
free_counts = counts[free_slice]

expression_states, expression_counts = zip(
*(expression._next_state(expression_state, outcome, *counts)
for expression, expression_state in zip(self._expressions,
expression_states)))
# ????
expression_counts = None
evaluator_state = self._evaluator.next_state(evaluator_state, outcome,
*extra_counts,
*evaluator_counts,
*expression_counts)
return expression_states, evaluator_state
return expressions, evaluator_state

def final_outcome(
self,
Expand All @@ -190,38 +175,28 @@ def final_outcome(
return self._evaluator.final_outcome(evaluator_state)

def order(self) -> Order:
expression_order = Order.merge(*(expression.order()
for expression in self._expressions))
return Order.merge(expression_order, self._evaluator.order())
expression_order, expression_order_reason = merge_order_preferences(
*(expression.order_preference()
for expression in self._expressions),
(self._evaluator.order(), OrderReason.Mandatory))
return expression_order

def extra_outcomes(self, *generators) -> Collection[T]:
return self._evaluator.extra_outcomes(*generators)

@cached_property
def _bound_inputs(self) -> 'tuple[icepool.MultisetExpression, ...]':
return self._bound_inputs + self._evaluator.bound_inputs()

def bound_inputs(self) -> 'tuple[icepool.MultisetExpression, ...]':
return self._bound_inputs

def validate_arity(self, arity: int) -> None:
if arity < self._free_arity:
raise ValueError(
f'Expected arity of {self._free_arity}, got {arity}.')
return self._evaluator.bound_inputs() + self._bound_inputs

def __bool__(self) -> bool:
if self._truth_value is None:
raise TypeError(
'MultisetExpression only has a truth value if it is the result of the == or != operator.'
)
return self._truth_value
@cached_property
def _count_slices(self) -> 'tuple[slice, slice, slice]':
evaluator_slice = slice(None, len(self._evaluator.bound_inputs()))
bound_slice = slice(evaluator_slice.stop,
evaluator_slice.stop + len(self._bound_inputs))
free_slice = slice(bound_slice.stop, None)
return evaluator_slice, bound_slice, free_slice

def __str__(self) -> str:
input_string = f'{self._bound_arity} bound, {self._free_arity} free'
if len(self._expressions) == 1:
expression_string = f'{self._expressions[0]}'
else:
expression_string = ', '.join(
str(expression) for expression in self._expressions)
expression_string = ', '.join(
str(expression) for expression in self._expressions)
output_string = str(self._evaluator)
return f'Expression: {input_string} -> {expression_string} -> {output_string}'
return f'MultisetFunctionEvaluator: {expression_string} -> {output_string}'
2 changes: 1 addition & 1 deletion src/icepool/generator/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _generate_max(self, max_outcome) -> AlignmentGenerator:
else:
yield Alignment(self.outcomes()[:-1]), (), 1

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

def denominator(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/icepool/generator/compound_keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _generate_max(self, max_outcome) -> PopMultisetGeneration:
yield CompoundKeepGenerator(
generators, popped_keep_tuple), (result_count, ), total_weight

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return merge_order_preferences(*(inner.local_order_preference()
for inner in self._inner_generators))

Expand Down
2 changes: 1 addition & 1 deletion src/icepool/generator/deal.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _generate_max(self, max_outcome) -> PopMultisetGeneration:
popped_deal = Deal._new_raw(popped_deck, 0, ())
yield popped_deal, (sum(self.keep_tuple()), ), skip_weight

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
lo_skip, hi_skip = icepool.order.lo_hi_skip(self.keep_tuple())
if lo_skip > hi_skip:
return Order.Descending, OrderReason.KeepSkip
Expand Down
2 changes: 1 addition & 1 deletion src/icepool/generator/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _generate_max(self, max_outcome) -> PopMultisetGeneration:
'MixtureMultisetGenerator should have decayed to another generator type by this point.'
)

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return merge_order_preferences(*(inner.local_order_preference()
for inner in self._inner_generators))

Expand Down
2 changes: 1 addition & 1 deletion src/icepool/generator/multi_deal.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _generate_max(self, max_outcome) -> PopMultisetGeneration:

yield from self._generate_common(popped_deck, deck_count)

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

@cached_property
Expand Down
2 changes: 1 addition & 1 deletion src/icepool/generator/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def outcomes(self) -> Sequence[T]:
def output_arity(self) -> int:
return 1

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
can_truncate_min, can_truncate_max = icepool.order.can_truncate(
self.unique_dice())
if can_truncate_min and not can_truncate_max:
Expand Down
4 changes: 2 additions & 2 deletions src/icepool/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _generate_max(self, max_outcome: T) -> PopMultisetGeneration:
"""

@abstractmethod
def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
"""Any ordering that is preferred or required by this expression node."""

@abstractmethod
Expand Down Expand Up @@ -245,7 +245,7 @@ def _iter_nodes(self) -> 'Iterator[MultisetExpression]':
yield from child._iter_nodes()
yield self

def order_preference(self) -> tuple[Order | None, OrderReason]:
def order_preference(self) -> tuple[Order, OrderReason]:
return merge_order_preferences(*(node.local_order_preference()
for node in self._iter_nodes()))

Expand Down
2 changes: 1 addition & 1 deletion src/icepool/multiset_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _generate_min(self, min_outcome) -> PopMultisetGeneration:
def _generate_max(self, max_outcome) -> PopMultisetGeneration:
raise MultisetBindingError()

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

def has_free_variables(self) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/icepool/operator/adjust_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _transform_next(
count = self._function(outcome, *counts)
return MultisetMapCounts(*new_children, function=self._function), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

@property
Expand Down Expand Up @@ -74,7 +74,7 @@ def _transform_next(
count = self.operator(counts[0])
return type(self)(*new_children, constant=self._constant), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

@property
Expand Down Expand Up @@ -164,7 +164,7 @@ def _transform_next(
comparison=self._comparison,
constant=self._constant), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

@property
Expand Down
2 changes: 1 addition & 1 deletion src/icepool/operator/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _transform_next(
count = reduce(self.merge_counts, counts)
return type(self)(*new_children), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

@property
Expand Down
4 changes: 2 additions & 2 deletions src/icepool/operator/filter_outcomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _transform_next(
target=self._func,
invert=self._invert), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

@property
Expand Down Expand Up @@ -126,7 +126,7 @@ def _transform_next(
return MultisetFilterOutcomesBinary(*new_children,
invert=self._invert), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference

@property
Expand Down
2 changes: 1 addition & 1 deletion src/icepool/operator/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _transform_next(
keep_tuple=(),
drop=next_drop), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return self._keep_order, OrderReason.Mandatory

@property
Expand Down
4 changes: 2 additions & 2 deletions src/icepool/operator/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _transform_next(
right_first=self._right_first,
left_lead=next_left_lead), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return self._order, OrderReason.Mandatory

@property
Expand Down Expand Up @@ -134,7 +134,7 @@ def _transform_next(
keep=self._keep,
prev_matchable=next_prev_matchable), count

def local_order_preference(self) -> tuple[Order | None, OrderReason]:
def local_order_preference(self) -> tuple[Order, OrderReason]:
return self._order, OrderReason.Mandatory

@property
Expand Down
36 changes: 25 additions & 11 deletions src/icepool/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from typing import Collection


class ConflictingOrderError(Exception):
"""Indicates that two conflicting mandatory outcome orderings were encountered."""


class Order(enum.IntEnum):
"""Can be used to define what order outcomes are seen in by MultisetEvaluators."""
Ascending = 1
Expand All @@ -22,13 +26,13 @@ def merge(*orders: 'Order') -> 'Order':
`Descending` if there is at least one `Descending` in the arguments.
Raises:
`ValueError` if both `Ascending` and `Descending` are in the
arguments.
`ConflictingOrderError` if both `Ascending` and `Descending` are in
the arguments.
"""
result = Order.Any
for order in orders:
if (result > 0 and order < 0) or (result < 0 and order > 0):
raise ValueError(
raise ConflictingOrderError(
f'Conflicting orders {orders}.\n' +
'Tip: If you are using highest(keep=k), try using lowest(drop=n-k) instead, or vice versa.'
)
Expand All @@ -39,8 +43,8 @@ def merge(*orders: 'Order') -> 'Order':

class OrderReason(enum.IntEnum):
"""Greater values represent higher priorities, which strictly override lower priorities."""
Mandatory = 127
"""The object requires this pop order."""
Mandatory = 3
"""Something strictly requires this pop order."""
PoolComposition = 2
"""The composition of dice in the pool favor this pop order."""
KeepSkip = 1
Expand All @@ -50,15 +54,19 @@ class OrderReason(enum.IntEnum):


def merge_order_preferences(
*preferences: tuple[Order | None, OrderReason],
) -> tuple[Order | None, OrderReason]:
*preferences: tuple[Order, OrderReason], ) -> tuple[Order, OrderReason]:
"""Returns a pop order that fits the highest priority preferences.
Greater priorities strictly outrank lower priorities.
An order of `None` represents conflicting orders and can occur in the
argument and/or return value.
Conflicting orders of the same priority are equal to an `Order.Any` of the
next-higher priority, except for conflicitng `Mandatory` orders, which
produces an exception.
Raises:
`ConflictingOrderError` if both `Ascending` and `Descending` are in
the arguments with `Mandatory` reason.
"""
result_order: Order | None = Order.Any
result_order = Order.Any
result_reason = OrderReason.NoPreference
for order, reason in preferences:
if order == Order.Any or reason == OrderReason.NoPreference:
Expand All @@ -71,8 +79,14 @@ def merge_order_preferences(
result_order = order
elif result_order == order:
continue
elif result_reason < OrderReason.Mandatory:
result_order = Order.Any
result_reason = OrderReason(result_reason + 1)
else:
result_order = None
raise ConflictingOrderError(
f'Conflicting order preferences {preferences}.\n' +
'Tip: If you are using highest(keep=k), try using lowest(drop=n-k) instead, or vice versa.'
)
return result_order, result_reason


Expand Down
Loading

0 comments on commit 3d57dd6

Please sign in to comment.