Skip to content

Commit

Permalink
fix various bugs in new MultisetExpression #203
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Dec 26, 2024
1 parent b810a60 commit a7ac9dd
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/icepool/evaluator/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self,
unbound_expressions.append(unbound_expression)
self._expressions = tuple(unbound_expressions)
self._truth_value = truth_value
raise NotImplementedError()

def next_state(self, state, outcome, *counts):
if state is None:
Expand Down
9 changes: 4 additions & 5 deletions src/icepool/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def _hash_key(self) -> Hashable:
Used to implement `equals()` and `__hash__()`
"""
return (self._local_hash_key, ) + tuple(child._hash_key
for child in self._children)
return (self._local_hash_key,
tuple(child._hash_key for child in self._children))

def equals(self, other) -> bool:
"""Whether this expression is logically equal to another object."""
Expand All @@ -251,6 +251,7 @@ def _hash(self) -> int:
return hash(self._hash_key)

def __hash__(self) -> int:
print(self._hash_key)
return self._hash

def _iter_nodes(self) -> 'Iterator[MultisetExpression]':
Expand Down Expand Up @@ -978,9 +979,7 @@ def evaluate(
A `Die` if the expression is are fully bound.
A `MultisetEvaluator` otherwise.
"""
if all(
isinstance(expression, icepool.MultisetGenerator)
for expression in expressions):
if all(expression._free_arity() == 0 for expression in expressions):
return evaluator.evaluate(*expressions)
evaluator = icepool.evaluator.ExpressionEvaluator(*expressions,
evaluator=evaluator)
Expand Down
1 change: 1 addition & 0 deletions src/icepool/multiset_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def bad(a, b)
output an evaluator or a nested tuple of evaluators. Tuples will
result in a `JointEvaluator`.
"""
raise NotImplementedError()
parameters = inspect.signature(function, follow_wrapped=False).parameters
for parameter in parameters.values():
if parameter.kind not in [
Expand Down
1 change: 1 addition & 0 deletions src/icepool/multiset_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def denominator(self) -> int:
def _unbind(self, next_index: int) -> 'tuple[MultisetExpression, int]':
return self, next_index

@property
def _local_hash_key(self) -> Hashable:
return (MultisetVariable, self._index)

Expand Down
3 changes: 3 additions & 0 deletions src/icepool/operator/adjust_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _transform_next(
def local_order(self) -> Order:
return Order.Any

@property
def _local_hash_key(self) -> Hashable:
return MultisetMapCounts, self._function

Expand Down Expand Up @@ -75,6 +76,7 @@ def _transform_next(
def local_order(self) -> Order:
return Order.Any

@property
def _local_hash_key(self) -> Hashable:
return type(self), self._constant

Expand Down Expand Up @@ -164,6 +166,7 @@ def _transform_next(
def local_order(self) -> Order:
return Order.Any

@property
def _local_hash_key(self) -> Hashable:
return MultisetKeepCounts, self._constant

Expand Down
1 change: 1 addition & 0 deletions src/icepool/operator/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _transform_next(
def local_order(self) -> Order:
return Order.Any

@property
def _local_hash_key(self) -> Hashable:
return type(self)

Expand Down
2 changes: 2 additions & 0 deletions src/icepool/operator/filter_outcomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _transform_next(
def local_order(self) -> Order:
return Order.Any

@property
def _local_hash_key(self) -> Hashable:
return MultisetFilterOutcomes, self._func, self._invert

Expand Down Expand Up @@ -127,6 +128,7 @@ def _transform_next(
def local_order(self) -> Order:
return Order.Any

@property
def _local_hash_key(self) -> Hashable:
return MultisetFilterOutcomesBinary, self._invert

Expand Down
1 change: 1 addition & 0 deletions src/icepool/operator/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _transform_next(
def local_order(self) -> Order:
return self._keep_order

@property
def _local_hash_key(self) -> Hashable:
return self._keep_order, self._keep_tuple, self._drop

Expand Down
2 changes: 2 additions & 0 deletions src/icepool/operator/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _transform_next(
def local_order(self) -> Order:
return self._order

@property
def _local_hash_key(self) -> Hashable:
return (MultisetSortMatch, self._order, self._tie, self._left_first,
self._right_first, self._left_lead)
Expand Down Expand Up @@ -134,6 +135,7 @@ def _transform_next(
def local_order(self) -> Order:
return self._order

@property
def _local_hash_key(self) -> Hashable:
return (MultisetMaximumMatch, self._order, self._match_equal,
self._keep, self._prev_matchable)
6 changes: 4 additions & 2 deletions src/icepool/operator/multiset_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,16 @@ def _generate_min(self, min_outcome: T) -> PopMultisetGeneration:
for t in itertools.product(*(child._generate_min(min_outcome)
for child in self._children)):
new_children, counts, weights = zip(*t)
counts = tuple(c[0] for c in counts)
next_self, count = self._transform_next(new_children, min_outcome,
counts)
yield next_self, (count, ), math.prod(weights)

def _generate_max(self, max_outcome: T) -> PopMultisetGeneration:
for t in itertools.product(*(child._generate_min(max_outcome)
for t in itertools.product(*(child._generate_max(max_outcome)
for child in self._children)):
new_children, counts, weights = zip(*t)
counts = tuple(c[0] for c in counts)
next_self, count = self._transform_next(new_children, max_outcome,
counts)
yield next_self, (count, ), math.prod(weights)
Expand All @@ -95,5 +97,5 @@ def _unbind(self, next_index: int) -> 'tuple[MultisetExpression, int]':
for child in self._children:
unbound_child, next_index = child._unbind(next_index)
unbound_children.append(unbound_child)
unbound_expression = self._copy(unbound_children)
unbound_expression = self._copy(tuple(unbound_children))
return unbound_expression, next_index

0 comments on commit a7ac9dd

Please sign in to comment.