From 49f752f827bd8132409b68e9cf440e4bee0a7803 Mon Sep 17 00:00:00 2001 From: Albert Julius Liu Date: Wed, 25 Dec 2024 15:07:27 -0800 Subject: [PATCH] convert over adjust_counts #203 --- src/icepool/expression/adjust_counts.py | 162 ------------------- src/icepool/expression/binary_operator.py | 124 -------------- src/icepool/multiset_expression.py | 137 ++++++++++++++++ src/icepool/transform/__init__.py | 3 + src/icepool/transform/adjust_counts.py | 169 ++++++++++++++++++++ src/icepool/transform/binary_operator.py | 2 +- src/icepool/transform/multiset_transform.py | 2 +- 7 files changed, 311 insertions(+), 288 deletions(-) delete mode 100644 src/icepool/expression/adjust_counts.py delete mode 100644 src/icepool/expression/binary_operator.py create mode 100644 src/icepool/transform/adjust_counts.py diff --git a/src/icepool/expression/adjust_counts.py b/src/icepool/expression/adjust_counts.py deleted file mode 100644 index 302c6a5e..00000000 --- a/src/icepool/expression/adjust_counts.py +++ /dev/null @@ -1,162 +0,0 @@ -__docformat__ = 'google' - -import icepool - -from icepool.expression.multiset_expression import MultisetExpression - -import inspect -import operator -from abc import abstractmethod -from functools import cached_property, reduce - -from icepool.typing import Order, Outcome, T -from typing import Callable, Hashable, Literal, Sequence, cast, overload - - -class MapCountsExpression(MultisetExpression[T]): - """Expression that maps outcomes and counts to new counts.""" - - _function: Callable[..., int] - - def __init__(self, *children: MultisetExpression[T], - function: Callable[..., int]) -> None: - """Constructor. - - Args: - children: The children expression(s). - function: A function that takes `outcome, *counts` and produces a - combined count. - """ - for child in children: - self._validate_output_arity(child) - self._children = children - self._function = function - - def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression': - return MapCountsExpression(*unbound_children, function=self._function) - - def _next_state(self, state, outcome: T, *counts: - int) -> tuple[Hashable, int]: - - child_states = state or (None, ) * len(self._children) - child_states, child_counts = zip( - *(child._next_state(child_state, outcome, *counts) - for child, child_state in zip(self._children, child_states))) - - count = self._function(outcome, *child_counts) - return state, count - - def order(self) -> Order: - return Order.merge(*(child.order() for child in self._children)) - - @cached_property - def _cached_arity(self) -> int: - return max(child._free_arity() for child in self._children) - - def _free_arity(self) -> int: - return self._cached_arity - - -class AdjustCountsExpression(MultisetExpression[T]): - - def __init__(self, child: MultisetExpression[T], /, *, - constant: int) -> None: - self._validate_output_arity(child) - self._children = (child, ) - self._constant = constant - - def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression': - return type(self)(*unbound_children, constant=self._constant) - - @abstractmethod - def adjust_count(self, count: int, constant: int) -> int: - """Adjusts the count.""" - - def _next_state(self, state, outcome: T, *counts: - int) -> tuple[Hashable, int]: - state, count = self._children[0]._next_state(state, outcome, *counts) - count = self.adjust_count(count, self._constant) - return state, count - - def order(self) -> Order: - return self._children[0].order() - - def _free_arity(self) -> int: - return self._children[0]._free_arity() - - -class MultiplyCountsExpression(AdjustCountsExpression): - """Multiplies all counts by the constant.""" - - def adjust_count(self, count: int, constant: int) -> int: - return count * constant - - def __str__(self) -> str: - return f'({self._children[0]} * {self._constant})' - - -class FloorDivCountsExpression(AdjustCountsExpression): - """Divides all counts by the constant, rounding down.""" - - def adjust_count(self, count: int, constant: int) -> int: - return count // constant - - def __str__(self) -> str: - return f'({self._children[0]} // {self._constant})' - - -class ModuloCountsExpression(AdjustCountsExpression): - """Modulo all counts by the constant.""" - - def adjust_count(self, count: int, constant: int) -> int: - return count % constant - - def __str__(self) -> str: - return f'({self._children[0]} % {self._constant})' - - -class KeepCountsExpression(AdjustCountsExpression): - - def __init__(self, child: MultisetExpression[T], /, *, - comparison: Literal['==', '!=', '<=', '<', '>=', - '>'], constant: int): - super().__init__(child, constant=constant) - operators = { - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '<': operator.lt, - '>=': operator.ge, - '>': operator.gt, - } - if comparison not in operators: - raise ValueError(f'Invalid comparison {comparison}') - self._comparison = comparison - self._op = operators[comparison] - - def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression': - return KeepCountsExpression(*unbound_children, - comparison=self._comparison, - constant=self._constant) - - def adjust_count(self, count: int, constant: int) -> int: - if self._op(count, constant): - return count - else: - return 0 - - def __str__(self) -> str: - return f"{self._children[0]}.keep_counts('{self._comparison}', {self._constant})" - - -class UniqueExpression(AdjustCountsExpression): - """Limits the count produced by each outcome.""" - - def adjust_count(self, count: int, constant: int) -> int: - return min(count, constant) - - def __str__(self) -> str: - if self._constant == 1: - return f'{self._children[0]}.unique()' - else: - return f'{self._children[0]}.unique({self._constant})' diff --git a/src/icepool/expression/binary_operator.py b/src/icepool/expression/binary_operator.py deleted file mode 100644 index a6c01038..00000000 --- a/src/icepool/expression/binary_operator.py +++ /dev/null @@ -1,124 +0,0 @@ -__docformat__ = 'google' - -import icepool - -from icepool.expression.multiset_expression import MultisetExpression - -import operator -from abc import abstractmethod -from functools import cached_property, reduce - -from typing import Hashable -from icepool.typing import Order, T - - -class BinaryOperatorExpression(MultisetExpression[T]): - - def __init__(self, *children: MultisetExpression[T]) -> None: - """Constructor. - - Args: - *children: Any number of expressions to feed into the operator. - If zero expressions are provided, the result will have all zero - counts. - If more than two expressions are provided, the counts will be - `reduce`d. - """ - for child in children: - self._validate_output_arity(child) - self._children = children - - def _make_unbound(self, *unbound_children) -> 'icepool.MultisetExpression': - return type(self)(*unbound_children) - - @staticmethod - @abstractmethod - def merge_counts(left: int, right: int) -> int: - """Merge counts produced by the left and right expression.""" - - @staticmethod - @abstractmethod - def symbol() -> str: - """A symbol representing this operation.""" - - def _next_state(self, state, outcome: T, *counts: - int) -> tuple[Hashable, int]: - if len(self._children) == 0: - return (), 0 - child_states = state or (None, ) * len(self._children) - - child_states, child_counts = zip( - *(child._next_state(child_state, outcome, *counts) - for child, child_state in zip(self._children, child_states))) - - count = reduce(self.merge_counts, child_counts) - return child_states, count - - def order(self) -> Order: - return Order.merge(*(child.order() for child in self._children)) - - @cached_property - def _cached_arity(self) -> int: - return max(child._free_arity() for child in self._children) - - def _free_arity(self) -> int: - return self._cached_arity - - def __str__(self) -> str: - return '(' + (' ' + self.symbol() + ' ').join( - str(child) for child in self._children) + ')' - - -class IntersectionExpression(BinaryOperatorExpression): - - @staticmethod - def merge_counts(left: int, right: int) -> int: - return min(left, right) - - @staticmethod - def symbol() -> str: - return '&' - - -class DifferenceExpression(BinaryOperatorExpression): - - @staticmethod - def merge_counts(left: int, right: int) -> int: - return left - right - - @staticmethod - def symbol() -> str: - return '-' - - -class UnionExpression(BinaryOperatorExpression): - - @staticmethod - def merge_counts(left: int, right: int) -> int: - return max(left, right) - - @staticmethod - def symbol() -> str: - return '|' - - -class AdditiveUnionExpression(BinaryOperatorExpression): - - @staticmethod - def merge_counts(left: int, right: int) -> int: - return left + right - - @staticmethod - def symbol() -> str: - return '+' - - -class SymmetricDifferenceExpression(BinaryOperatorExpression): - - @staticmethod - def merge_counts(left: int, right: int) -> int: - return abs(left - right) - - @staticmethod - def symbol() -> str: - return '^' diff --git a/src/icepool/multiset_expression.py b/src/icepool/multiset_expression.py index 5d3ac4d2..31749328 100644 --- a/src/icepool/multiset_expression.py +++ b/src/icepool/multiset_expression.py @@ -490,6 +490,143 @@ def symmetric_difference( other = implicit_convert_to_expression(other) return icepool.transform.MultisetSymmetricDifference(self, other) + # Adjust counts. + + def map_counts(*args: + 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', + function: Callable[..., int]) -> 'MultisetExpression[T]': + """Maps the counts to new counts. + + Args: + function: A function that takes `outcome, *counts` and produces a + combined count. + """ + expressions = tuple( + implicit_convert_to_expression(arg) for arg in args) + return icepool.transform.MultisetMapCounts(*expressions, + function=function) + + def __mul__(self, n: int) -> 'MultisetExpression[T]': + if not isinstance(n, int): + return NotImplemented + return self.multiply_counts(n) + + # Commutable in this case. + def __rmul__(self, n: int) -> 'MultisetExpression[T]': + if not isinstance(n, int): + return NotImplemented + return self.multiply_counts(n) + + def multiply_counts(self, n: int, /) -> 'MultisetExpression[T]': + """Multiplies all counts by n. + + Same as `self * n`. + + Example: + ```python + Pool([1, 2, 2, 3]) * 2 -> [1, 1, 2, 2, 2, 2, 3, 3] + ``` + """ + return icepool.transform.MultisetMultiplyCounts(self, constant=n) + + @overload + def __floordiv__(self, other: int) -> 'MultisetExpression[T]': + ... + + @overload + def __floordiv__( + self, other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]' + ) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]': + """Same as divide_counts().""" + + @overload + def __floordiv__( + self, + other: 'int | MultisetExpression[T] | Mapping[T, int] | Sequence[T]' + ) -> 'MultisetExpression[T] | icepool.Die[int] | icepool.MultisetEvaluator[T, int]': + """Same as count_subset().""" + + def __floordiv__( + self, + other: 'int | MultisetExpression[T] | Mapping[T, int] | Sequence[T]' + ) -> 'MultisetExpression[T] | icepool.Die[int] | icepool.MultisetEvaluator[T, int]': + if isinstance(other, int): + return self.divide_counts(other) + else: + return self.count_subset(other) + + def divide_counts(self, n: int, /) -> 'MultisetExpression[T]': + """Divides all counts by n (rounding down). + + Same as `self // n`. + + Example: + ```python + Pool([1, 2, 2, 3]) // 2 -> [2] + ``` + """ + return icepool.transform.MultisetFloordivCounts(self, constant=n) + + def __mod__(self, n: int, /) -> 'MultisetExpression[T]': + if not isinstance(n, int): + return NotImplemented + return icepool.transform.MultisetModuloCounts(self, constant=n) + + def modulo_counts(self, n: int, /) -> 'MultisetExpression[T]': + """Moduos all counts by n. + + Same as `self % n`. + + Example: + ```python + Pool([1, 2, 2, 3]) % 2 -> [1, 3] + ``` + """ + return self % n + + def __pos__(self) -> 'MultisetExpression[T]': + """Sets all negative counts to zero.""" + return icepool.transform.MultisetKeepCounts(self, + comparison='>=', + constant=0) + + def __neg__(self) -> 'MultisetExpression[T]': + """As -1 * self.""" + return -1 * self + + def keep_counts(self, comparison: Literal['==', '!=', '<=', '<', '>=', + '>'], n: int, + /) -> 'MultisetExpression[T]': + """Keeps counts fitting the comparison, treating the rest as zero. + + For example, `expression.keep_counts('>=', 2)` would keep pairs, + triplets, etc. and drop singles. + + ```python + Pool([1, 2, 2, 3, 3, 3]).keep_counts('>=', 2) -> [2, 2, 3, 3, 3] + ``` + + Args: + comparison: The comparison to use. + n: The number to compare counts against. + """ + return icepool.transform.MultisetKeepCounts(self, + comparison=comparison, + constant=n) + + def unique(self, n: int = 1, /) -> 'MultisetExpression[T]': + """Counts each outcome at most `n` times. + + For example, `generator.unique(2)` would count each outcome at most + twice. + + Example: + ```python + Pool([1, 2, 2, 3]).unique() -> [1, 2, 3] + ``` + """ + return icepool.transform.MultisetUnique(self, constant=n) + def _compare( self, right: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]', diff --git a/src/icepool/transform/__init__.py b/src/icepool/transform/__init__.py index 4224bb64..3aa5f8b5 100644 --- a/src/icepool/transform/__init__.py +++ b/src/icepool/transform/__init__.py @@ -3,3 +3,6 @@ MultisetUnion, MultisetAdditiveUnion, MultisetSymmetricDifference) +from icepool.transform.adjust_counts import ( + MultisetMapCounts, MultisetMultiplyCounts, MultisetModuloCounts, + MultisetFloordivCounts, MultisetKeepCounts, MultisetUnique) diff --git a/src/icepool/transform/adjust_counts.py b/src/icepool/transform/adjust_counts.py new file mode 100644 index 00000000..5b45e51e --- /dev/null +++ b/src/icepool/transform/adjust_counts.py @@ -0,0 +1,169 @@ +__docformat__ = 'google' + +import icepool + +from icepool.multiset_expression import MultisetExpression +from icepool.transform.multiset_transform import MultisetTransform + +import operator +from abc import abstractmethod +from functools import cached_property, reduce + +from typing import Callable, Hashable, Iterable, Literal +from icepool.typing import Order, T + + +class MultisetMapCounts(MultisetTransform[T]): + """Maps outcomes and counts to new counts.""" + + _function: Callable[..., int] + + def __init__(self, *children: MultisetExpression[T], + function: Callable[..., int]) -> None: + """Constructor. + + Args: + children: The children expression(s). + function: A function that takes `outcome, *counts` and produces a + combined count. + """ + self._children = children + self._function = function + + def _copy( + self, copy_children: 'Iterable[MultisetExpression[T]]' + ) -> 'MultisetExpression[T]': + return MultisetMapCounts(*copy_children, function=self._function) + + def _transform_next( + self, next_children: 'Iterable[MultisetExpression[T]]', outcome: T, + counts: 'tuple[int, ...]') -> 'tuple[MultisetExpression[T], int]': + count = self._function(outcome, *counts) + return MultisetMapCounts(*next_children, + function=self._function), count + + def local_order(self) -> Order: + return Order.Any + + def _local_hash_key(self) -> Hashable: + return MultisetMapCounts, self._function + + +class MultisetCountOperator(MultisetTransform[T]): + + def __init__(self, child: MultisetExpression[T], /, *, + constant: int) -> None: + self._children = (child, ) + self._constant = constant + + @abstractmethod + def operator(self, count: int) -> int: + """Operation to apply to the counts.""" + + def _copy( + self, copy_children: 'Iterable[MultisetExpression[T]]' + ) -> 'MultisetExpression[T]': + return type(self)(*copy_children, constant=self._constant) + + def _transform_next( + self, next_children: 'Iterable[MultisetExpression[T]]', outcome: T, + counts: 'tuple[int, ...]') -> 'tuple[MultisetExpression[T], int]': + count = self.operator(counts[0]) + return type(self)(*next_children, constant=self._constant), count + + def local_order(self) -> Order: + return Order.Any + + def _local_hash_key(self) -> Hashable: + return type(self), self._constant + + +class MultisetMultiplyCounts(MultisetCountOperator): + """Multiplies all counts by the constant.""" + + def operator(self, count: int) -> int: + return count * self._constant + + def __str__(self) -> str: + return f'({self._children[0]} * {self._constant})' + + +class MultisetFloordivCounts(MultisetCountOperator): + """Divides all counts by the constant, rounding down.""" + + def operator(self, count: int) -> int: + return count // self._constant + + def __str__(self) -> str: + return f'({self._children[0]} // {self._constant})' + + +class MultisetModuloCounts(MultisetCountOperator): + """Modulo all counts by the constant.""" + + def operator(self, count: int) -> int: + return count % self._constant + + def __str__(self) -> str: + return f'({self._children[0]} % {self._constant})' + + +class MultisetUnique(MultisetCountOperator): + """Limits the count produced by each outcome.""" + + def operator(self, count: int) -> int: + return min(count, self._constant) + + def __str__(self) -> str: + if self._constant == 1: + return f'{self._children[0]}.unique()' + else: + return f'{self._children[0]}.unique({self._constant})' + + +class MultisetKeepCounts(MultisetTransform[T]): + + def __init__(self, child: MultisetExpression[T], /, *, + comparison: Literal['==', '!=', '<=', '<', '>=', + '>'], constant: int): + self._children = (child, ) + self._constant = constant + operators = { + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '<': operator.lt, + '>=': operator.ge, + '>': operator.gt, + } + if comparison not in operators: + raise ValueError(f'Invalid comparison {comparison}') + self._comparison = comparison + self._op = operators[comparison] + + def _copy( + self, copy_children: 'Iterable[MultisetExpression[T]]' + ) -> 'MultisetExpression[T]': + return MultisetKeepCounts(*copy_children, + comparison=self._comparison, + constant=self._constant) + + def _transform_next( + self, next_children: 'Iterable[MultisetExpression[T]]', outcome: T, + counts: 'tuple[int, ...]') -> 'tuple[MultisetExpression[T], int]': + if self._op(counts[0], self._constant): + count = counts[0] + else: + count = 0 + return MultisetKeepCounts(*next_children, + comparison=self._comparison, + constant=self._constant), count + + def local_order(self) -> Order: + return Order.Any + + def _local_hash_key(self) -> Hashable: + return MultisetKeepCounts, self._constant + + def __str__(self) -> str: + return f"{self._children[0]}.keep_counts('{self._comparison}', {self._constant})" diff --git a/src/icepool/transform/binary_operator.py b/src/icepool/transform/binary_operator.py index 8819578c..59c8d186 100644 --- a/src/icepool/transform/binary_operator.py +++ b/src/icepool/transform/binary_operator.py @@ -44,7 +44,7 @@ def _copy( def _transform_next( self, next_children: 'Iterable[MultisetExpression[T]]', outcome: T, - counts: 'Iterable[int]') -> 'tuple[MultisetExpression[T], int]': + counts: 'tuple[int, ...]') -> 'tuple[MultisetExpression[T], int]': count = reduce(self.merge_counts, counts) return type(self)(*next_children), count diff --git a/src/icepool/transform/multiset_transform.py b/src/icepool/transform/multiset_transform.py index d6b4aa88..77c2b03b 100644 --- a/src/icepool/transform/multiset_transform.py +++ b/src/icepool/transform/multiset_transform.py @@ -29,7 +29,7 @@ def _copy( @abstractmethod def _transform_next( self, next_children: 'Iterable[MultisetExpression[T]]', outcome: T, - counts: 'Iterable[int]') -> 'tuple[MultisetExpression[T], int]': + counts: 'tuple[int, ...]') -> 'tuple[MultisetExpression[T], int]': """Produce the next state of this expression. Args: