Skip to content

Commit

Permalink
use getattr instead of isinstance inside cartesian_product and …
Browse files Browse the repository at this point in the history
…related methods #184
  • Loading branch information
HighDiceRoller committed Jun 10, 2024
1 parent 5dc5362 commit 4efabbf
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 36 deletions.
39 changes: 33 additions & 6 deletions src/icepool/collection/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,37 @@
import itertools
import math
import operator
from typing import Callable, Hashable, Iterable, Iterator, Sequence, Type, cast, overload
from typing import Any, Callable, Hashable, Iterable, Iterator, Sequence, Type, cast, overload

from icepool.typing import Outcome, S, T, T_co, U


def iter_cartesian_product(
*args: 'Outcome | icepool.Population | icepool.MultisetExpression'
) -> Iterator[tuple[tuple, int]]:
"""Yields the independent joint distribution of the arguments.
Args:
*args: These may be dice, which will be expanded into their joint
outcomes. Non-dice are left as-is.
Yields:
Tuples containing one outcome per arg and the joint quantity.
"""

def arg_items(arg) -> Sequence[tuple[Any, int]]:
items = getattr(arg, '_items_for_cartesian_product', None)
if items is not None:
return items
else:
return [(arg, 1)]

for t in itertools.product(*(arg_items(arg) for arg in args)):
outcomes, quantities = zip(*t)
final_quantity = math.prod(quantities)
yield outcomes, final_quantity


# Typing: there is currently no way to intersect a type bound, and Protocol
# can't be used with Sequence.
def cartesian_product(
Expand All @@ -26,12 +52,13 @@ def cartesian_product(
of the same type as the input `Population`, and the outcomes are
sequences with one element per argument.
"""
population_type = None
population_type: Type | None = None
for arg in args:
if isinstance(arg, icepool.Population):
new_type = getattr(arg, '_new_type', None)
if new_type is not None:
if population_type is None:
population_type = arg._new_type
elif population_type != arg._new_type:
population_type = new_type
elif population_type != new_type:
raise TypeError(
'Arguments to vector() of type Population must all be Die or all be Deck, not a mixture of the two.'
)
Expand All @@ -40,7 +67,7 @@ def cartesian_product(
return outcome_type(args) # type: ignore
else:
data = {}
for outcomes, final_quantity in icepool.iter_cartesian_product(*args):
for outcomes, final_quantity in iter_cartesian_product(*args):
data[outcome_type(outcomes)] = final_quantity # type: ignore
return population_type(data)

Expand Down
7 changes: 7 additions & 0 deletions src/icepool/expression/multiset_expression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__docformat__ = 'google'

from functools import cached_property
from types import EllipsisType
import icepool
import icepool.evaluator
Expand Down Expand Up @@ -159,6 +160,12 @@ def _validate_output_arity(inner: 'MultisetExpression') -> None:
'Only generators with output arity of 1 may be bound to expressions.\nUse a multiset_function to select individual outputs.'
)

@cached_property
def _items_for_cartesian_product(self) -> Sequence[tuple[T_contra, int]]:
if self._free_arity() > 0:
raise ValueError('Expression must be fully bound.')
return self.expand().items() # type: ignore

# Binary operators.

def __add__(
Expand Down
31 changes: 1 addition & 30 deletions src/icepool/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import icepool
import icepool.population.markov_chain
from icepool.collection.vector import iter_cartesian_product
from icepool.typing import Outcome, T, U, guess_star

from fractions import Fraction
Expand Down Expand Up @@ -351,36 +352,6 @@ def accumulate(
yield result


def iter_cartesian_product(
*args: 'Outcome | icepool.Population | icepool.MultisetExpression'
) -> Iterator[tuple[tuple, int]]:
"""Yields the independent joint distribution of the arguments.
Args:
*args: These may be dice, which will be expanded into their joint
outcomes. Non-dice are left as-is.
Yields:
Tuples containing one outcome per arg and the joint quantity.
"""

def arg_items(arg) -> Sequence[tuple[Any, int]]:
if isinstance(arg, icepool.Population):
return arg.items()
elif isinstance(arg, icepool.MultisetExpression):
if arg._free_arity() > 0:
raise ValueError('Expression must be fully bound.')
# Expression evaluators are difficult to type.
return arg.expand().items() # type: ignore
else:
return [(arg, 1)]

for t in itertools.product(*(arg_items(arg) for arg in args)):
outcomes, quantities = zip(*t)
final_quantity = math.prod(quantities)
yield outcomes, final_quantity


def _canonicalize_transition_function(repl: 'Callable | Mapping',
arg_count: int,
star: bool | None) -> 'Callable':
Expand Down
4 changes: 4 additions & 0 deletions src/icepool/population/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def values(self) -> CountsValuesView:
def items(self) -> CountsItemsView[T_co]:
"""The (outcome, quantity)s of the population in sorted order."""

@property
def _items_for_cartesian_product(self) -> Sequence[tuple[T_co, int]]:
return self.items()

def _unary_operator(self, op: Callable, *args, **kwargs):
data: MutableMapping[Any, int] = defaultdict(int)
for outcome, quantity in self.items():
Expand Down

0 comments on commit 4efabbf

Please sign in to comment.