Skip to content

Commit

Permalink
add MultisetExpression._apply_variables() and finish re-implementat…
Browse files Browse the repository at this point in the history
…ion of `@multiset_function` #203
  • Loading branch information
HighDiceRoller committed Dec 27, 2024
1 parent 36ef0cb commit 090a52d
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/icepool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
from icepool.generator.compound_keep import CompoundKeepGenerator
from icepool.generator.mixture import MixtureGenerator

from icepool.multiset_expression import MultisetExpression, implicit_convert_to_expression, InitialMultisetGeneration, PopMultisetGeneration
from icepool.multiset_expression import MultisetExpression, implicit_convert_to_expression, InitialMultisetGeneration, PopMultisetGeneration, MultisetBindingError

from icepool.generator.multiset_generator import MultisetGenerator
from icepool.generator.alignment import Alignment
Expand Down
7 changes: 4 additions & 3 deletions src/icepool/evaluator/multiset_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,14 @@ def next_state(self, state, outcome, *counts):
else:
expressions, evaluator_state = state

evaluator_slice, bound_slice, free_slice = self._count_slices()
evaluator_slice, bound_slice, free_slice = self._count_slices
evaluator_counts = counts[evaluator_slice]
bound_counts = counts[bound_slice]
free_counts = counts[free_slice]

# ????
expression_counts = None
expressions, expression_counts = zip(
*(expression._apply_variables(outcome, bound_counts, free_counts)
for expression in expressions))
evaluator_state = self._evaluator.next_state(evaluator_state, outcome,
*evaluator_counts,
*expression_counts)
Expand Down
6 changes: 6 additions & 0 deletions src/icepool/generator/multiset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ def _unbind(
result = icepool.MultisetVariable(False, len(bound_inputs))
bound_inputs.append(self)
return result

def _apply_variables(
self, outcome: T, bound_counts: tuple[int, ...],
free_counts: tuple[int, ...]) -> 'MultisetExpression[T]':
raise icepool.MultisetBindingError(
'_unbind should have been called before _apply_variables.')
30 changes: 30 additions & 0 deletions src/icepool/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
Sequence, int]]


class MultisetBindingError(TypeError):
"""Indicates a bound multiset variable was found where a free variable was expected, or vice versa."""


def implicit_convert_to_expression(
arg: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
) -> 'MultisetExpression[T]':
Expand All @@ -46,6 +50,14 @@ def implicit_convert_to_expression(
class MultisetExpression(ABC, Generic[T]):
"""Abstract base class representing an expression that operates on multisets.
There are three types of multiset expressions:
* `MultisetGenerator`, which produce raw outcomes and counts.
* `MultisetOperator`, which takes outcomes with one or more counts and
produces a count.
* `MultisetVariable`, which is a temporary placeholder for some other
expression.
Expression methods can be applied to `MultisetGenerator`s to do simple
evaluations. For joint evaluations, try `multiset_function`.
Expand Down Expand Up @@ -201,6 +213,24 @@ def _unbind(
the position of the expression they replaced in `bound_inputs`.
"""

@abstractmethod
def _apply_variables(
self, outcome: T, bound_counts: tuple[int, ...],
free_counts: tuple[int,
...]) -> 'tuple[MultisetExpression[T], int]':
"""Advances the state of this expression given counts emitted from variables and returns a count.
Args:
outcome: The current outcome being processed.
bound_counts: The counts emitted by bound expressions.
free_counts: The counts emitted by arguments to the
`@mulitset_function`.
Returns:
An expression representing the next state and the count produced by
this expression.
"""

@property
@abstractmethod
def _local_hash_key(self) -> Hashable:
Expand Down
27 changes: 16 additions & 11 deletions src/icepool/multiset_variable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__docformat__ = 'google'

import icepool
from icepool.order import Order, OrderReason
from icepool.multiset_expression import MultisetExpression, InitialMultisetGeneration, PopMultisetGeneration

Expand All @@ -8,10 +9,6 @@
from typing import Any, Hashable, Sequence


class MultisetBindingError(TypeError):
"""Indicates a bound multiset variable was found where a free variable was expected, or vice versa."""


class MultisetVariable(MultisetExpression[Any]):
"""A variable to be filled in with a concrete sub-expression."""

Expand All @@ -22,22 +19,22 @@ def __init__(self, is_free: bool, index: int):
self._index = index

def outcomes(self) -> Sequence:
raise MultisetBindingError()
raise icepool.MultisetBindingError()

def output_arity(self) -> int:
return 1

def _is_resolvable(self) -> bool:
raise MultisetBindingError()
raise icepool.MultisetBindingError()

def _generate_initial(self) -> InitialMultisetGeneration:
raise MultisetBindingError()
raise icepool.MultisetBindingError()

def _generate_min(self, min_outcome) -> PopMultisetGeneration:
raise MultisetBindingError()
raise icepool.MultisetBindingError()

def _generate_max(self, max_outcome) -> PopMultisetGeneration:
raise MultisetBindingError()
raise icepool.MultisetBindingError()

def local_order_preference(self) -> tuple[Order, OrderReason]:
return Order.Any, OrderReason.NoPreference
Expand All @@ -46,7 +43,7 @@ def has_free_variables(self) -> bool:
return self._is_free

def denominator(self) -> int:
raise MultisetBindingError()
raise icepool.MultisetBindingError()

def _unbind(
self,
Expand All @@ -55,9 +52,17 @@ def _unbind(
if self._is_free:
return self
else:
raise MultisetBindingError(
raise icepool.MultisetBindingError(
'Attempted to unbind an expression that was already unbound.')

def _apply_variables(
self, outcome, bound_counts: tuple[int, ...],
free_counts: tuple[int, ...]) -> 'tuple[MultisetExpression, int]':
if self._is_free:
return self, free_counts[self._index]
else:
return self, bound_counts[self._index]

@property
def _local_hash_key(self) -> Hashable:
return (MultisetVariable, self._is_free, self._index)
Expand Down
15 changes: 13 additions & 2 deletions src/icepool/operator/multiset_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def _transform_next(
counts: One count per child.
Returns:
An expression representing the next state and the count produced by this expression.
An expression representing the next state and the count produced by
this expression.
Raises:
UnboundMultisetExpressionError if this is called on an expression with free variables.
UnboundMultisetExpressionError if this is called on an expression
with free variables.
"""

def outcomes(self) -> Sequence[T]:
Expand Down Expand Up @@ -98,3 +100,12 @@ def _unbind(
result = icepool.MultisetVariable(False, len(bound_inputs))
bound_inputs.append(self)
return result

def _apply_variables(
self, outcome: T, bound_counts: tuple[int, ...],
free_counts: tuple[int,
...]) -> 'tuple[MultisetExpression[T], int]':
new_children, counts = zip(
*(child._apply_variables(outcome, bound_counts, free_counts)
for child in self._children))
return self._transform_next(new_children, outcome, counts)

0 comments on commit 090a52d

Please sign in to comment.