Skip to content

Commit

Permalink
add keep_counts_lt, keep_counts_gt
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Feb 1, 2024
1 parent fa3934b commit 5e8a82a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
26 changes: 26 additions & 0 deletions src/icepool/expression/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,19 @@ def keep_counts_le(self, n: int, /) -> 'MultisetExpression[T_contra]':
"""
return icepool.expression.KeepCountsExpression(self, n, operator.le)

def keep_counts_lt(self, n: int, /) -> 'MultisetExpression[T_contra]':
"""Keeps counts that are < n, treating the rest as zero.
For example, `expression.keep_counts_lt(2)` would remove doubles,
triplets...
Example:
```
Pool([1, 2, 2, 3, 3, 3]).keep_counts_lt(2) -> [1]
```
"""
return icepool.expression.KeepCountsExpression(self, n, operator.lt)

def keep_counts_ge(self, n: int, /) -> 'MultisetExpression[T_contra]':
"""Keeps counts that are >= n, treating the rest as zero.
Expand All @@ -542,6 +555,19 @@ def keep_counts_ge(self, n: int, /) -> 'MultisetExpression[T_contra]':
"""
return icepool.expression.KeepCountsExpression(self, n, operator.ge)

def keep_counts_gt(self, n: int, /) -> 'MultisetExpression[T_contra]':
"""Keeps counts that are < n, treating the rest as zero.
For example, `expression.keep_counts_gt(2)` would remove singles and
doubles.
Example:
```
Pool([1, 2, 2, 3, 3, 3]).keep_counts_gt(2) -> [3, 3, 3]
```
"""
return icepool.expression.KeepCountsExpression(self, n, operator.gt)

def keep_counts_eq(self, n: int, /) -> 'MultisetExpression[T_contra]':
"""Keeps counts that are == n, treating the rest as zero.
Expand Down
18 changes: 15 additions & 3 deletions tests/generator_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,27 @@ def test_example_divide_counts():
assert result == expected


def test_example_keep_counts_le():
result = Pool([1, 2, 2, 3, 3, 3]).keep_counts_le(2).expand(Order.Ascending)
expected = Die([(1, 2, 2)])
assert result == expected


def test_example_keep_counts_lt():
result = Pool([1, 2, 2, 3, 3, 3]).keep_counts_lt(2).expand(Order.Ascending)
expected = Die([(1, )])
assert result == expected


def test_example_keep_counts_ge():
result = Pool([1, 2, 2, 3, 3, 3]).keep_counts_ge(2).expand(Order.Ascending)
expected = Die([(2, 2, 3, 3, 3)])
assert result == expected


def test_example_keep_counts_le():
result = Pool([1, 2, 2, 3, 3, 3]).keep_counts_le(2).expand(Order.Ascending)
expected = Die([(1, 2, 2)])
def test_example_keep_counts_gt():
result = Pool([1, 2, 2, 3, 3, 3]).keep_counts_gt(2).expand(Order.Ascending)
expected = Die([(3, 3, 3)])
assert result == expected


Expand Down

0 comments on commit 5e8a82a

Please sign in to comment.