Skip to content

Commit

Permalink
implement multiply_counts in KeepGenerator ABC
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Jun 16, 2024
1 parent a50ac91 commit 7f697f9
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 12 deletions.
4 changes: 0 additions & 4 deletions src/icepool/generator/compound_keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ def _set_keep_tuple(self, keep_tuple: tuple[int,
...]) -> 'KeepGenerator[T]':
return CompoundKeepGenerator(self._inners, keep_tuple)

def multiply_counts(self, constant: int, /) -> 'CompoundKeepGenerator[T]':
return CompoundKeepGenerator(
self._inners, tuple(x * constant for x in self.keep_tuple()))

@property
def _hash_key(self) -> Hashable:
return CompoundKeepGenerator, tuple(
Expand Down
8 changes: 4 additions & 4 deletions src/icepool/generator/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ def _set_keep_tuple(self, keep_tuple: tuple[int,
...]) -> 'KeepGenerator[T]':
"""Produces a copy with a modified keep_tuple."""

@abstractmethod
def multiply_counts(self, constant: int, /) -> 'KeepGenerator[T]':
raise NotImplementedError()

@cached_property
def _keep_size(self) -> int:
return sum(self._keep_tuple)
Expand Down Expand Up @@ -251,6 +247,10 @@ def __mul__(self, other: int) -> 'KeepGenerator[T]':
return NotImplemented
return self.multiply_counts(other)

def multiply_counts(self, constant: int, /) -> 'KeepGenerator[T]':
return self._set_keep_tuple(
tuple(n * constant for n in self._keep_tuple))

# Commutable in this case.
def __rmul__(self, other: int) -> 'KeepGenerator[T]':
if not isinstance(other, int):
Expand Down
4 changes: 0 additions & 4 deletions src/icepool/generator/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,6 @@ def _set_keep_tuple(self, keep_tuple: tuple[int,
...]) -> 'KeepGenerator[T]':
return Pool._new_raw(self._dice, keep_tuple)

def multiply_counts(self, constant: int, /) -> 'Pool[T]':
return Pool._new_raw(self._dice,
tuple(x * constant for x in self.keep_tuple()))

def additive_union(
*args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
) -> 'MultisetExpression[T]':
Expand Down

0 comments on commit 7f697f9

Please sign in to comment.