diff --git a/src/icepool/population/base.py b/src/icepool/population/base.py index c4039043..126ba557 100644 --- a/src/icepool/population/base.py +++ b/src/icepool/population/base.py @@ -2,6 +2,7 @@ import icepool from icepool.collection.counts import CountsKeysView, CountsValuesView, CountsItemsView +from icepool.collection.vector import Vector from icepool.math import try_fraction from icepool.typing import U, Outcome, T_co, count_positional_parameters @@ -185,13 +186,14 @@ def denominator(self) -> int: """ return self._denominator - # Quantities. - def scale_quantities(self: C, scale: int) -> C: """Scales all quantities by an integer.""" if scale == 1: return self - data = {outcome: quantity * scale for outcome, quantity in self.items()} + data = { + outcome: quantity * scale + for outcome, quantity in self.items() + } return self._new_type(data) def has_zero_quantities(self) -> bool: @@ -277,8 +279,8 @@ def quantities_lt(self, outcomes: Sequence | None = None) -> Sequence[int]: outcomes: If provided, the quantities corresponding to these outcomes will be returned (or 0 if not present). """ - return tuple( - self.denominator() - x for x in self.quantities_ge(outcomes)) + return tuple(self.denominator() - x + for x in self.quantities_ge(outcomes)) def quantities_gt(self, outcomes: Sequence | None = None) -> Sequence[int]: """The quantity > each outcome in order. @@ -287,8 +289,8 @@ def quantities_gt(self, outcomes: Sequence | None = None) -> Sequence[int]: outcomes: If provided, the quantities corresponding to these outcomes will be returned (or 0 if not present). """ - return tuple( - self.denominator() - x for x in self.quantities_le(outcomes)) + return tuple(self.denominator() - x + for x in self.quantities_le(outcomes)) # Probabilities. @@ -383,8 +385,8 @@ def probabilities_le(self, @overload def probabilities_le(self, - outcomes: Sequence | - None = None) -> Sequence[Fraction]: + outcomes: Sequence | None = None + ) -> Sequence[Fraction]: ... def probabilities_le( @@ -436,8 +438,8 @@ def probabilities_ge(self, @overload def probabilities_ge(self, - outcomes: Sequence | - None = None) -> Sequence[Fraction]: + outcomes: Sequence | None = None + ) -> Sequence[Fraction]: ... def probabilities_ge( @@ -484,8 +486,8 @@ def probabilities_lt(self, @overload def probabilities_lt(self, - outcomes: Sequence | - None = None) -> Sequence[Fraction]: + outcomes: Sequence | None = None + ) -> Sequence[Fraction]: ... def probabilities_lt( @@ -528,8 +530,8 @@ def probabilities_gt(self, @overload def probabilities_gt(self, - outcomes: Sequence | - None = None) -> Sequence[Fraction]: + outcomes: Sequence | None = None + ) -> Sequence[Fraction]: ... def probabilities_gt( @@ -690,8 +692,8 @@ def entropy(self, base: float = 2.0) -> float: base: The logarithm base to use. Default is 2.0, which gives the entropy in bits. """ - return -sum( - p * math.log(p, base) for p in self.probabilities() if p > 0.0) + return -sum(p * math.log(p, base) + for p in self.probabilities() if p > 0.0) # Joint statistics. @@ -746,6 +748,29 @@ def correlation( sd_j = self.marginals[j].standard_deviation() return self.covariance(i, j) / (sd_i * sd_j) + # Transformations. + + def to_one_hot(self: C, outcomes: Sequence[T_co] | None = None) -> C: + """Converts the outcomes of this population to a one-hot representation. + + Args: + outcomes: If provided, each outcome will be mapped to a `Vector` + where the element at `outcomes.index(outcome)` is set to `True` + and the rest to `False`, or all `False` if the outcome is not + in `outcomes`. + If not provided, `self.outcomes()` is used. + """ + if outcomes is None: + outcomes = self.outcomes() + + data: MutableMapping[Vector[bool], int] = defaultdict(int) + for outcome, quantity in zip(self.outcomes(), self.quantities()): + value = [False] * len(outcomes) + if outcome in outcomes: + value[outcomes.index(outcome)] = True + data[Vector(value)] += quantity + return self._new_type(data) + def sample(self) -> T_co: """A single random sample from this population. @@ -820,4 +845,4 @@ def __format__(self, format_spec: str, /) -> str: return self.format(format_spec) def __str__(self) -> str: - return f'{self}' \ No newline at end of file + return f'{self}' diff --git a/tests/vector_test.py b/tests/vector_test.py index a1700280..f154e3ac 100644 --- a/tests/vector_test.py +++ b/tests/vector_test.py @@ -1,7 +1,7 @@ import icepool import pytest -from icepool import d6, d8, vectorize, Die, Vector +from icepool import d4, d6, d8, vectorize, Die, Vector def test_cartesian_product(): @@ -31,17 +31,17 @@ def test_vector_matmul(): def test_nested_unary_elementwise(): - result = icepool.Die([vectorize(vectorize(vectorize(1,),),)]) + result = icepool.Die([vectorize(vectorize(vectorize(1, ), ), )]) result = -result assert result.marginals[0].marginals[0].marginals[0].equals( icepool.Die([-1])) def test_nested_binary_elementwise(): - result = icepool.Die([vectorize(vectorize(vectorize(1,),),)]) + result = icepool.Die([vectorize(vectorize(vectorize(1, ), ), )]) result = result + result - assert result.marginals[0].marginals[0].marginals[0].equals(icepool.Die([2 - ])) + assert result.marginals[0].marginals[0].marginals[0].equals( + icepool.Die([2])) def test_binary_op_mismatch_outcome_len(): @@ -92,7 +92,7 @@ class OneHotEvaluator(icepool.MultisetEvaluator): def next_state(self, state, _, count): if state is None: state = () - return state + (count,) + return state + (count, ) def final_outcome(self, final_state): return icepool.Vector(final_state) @@ -129,3 +129,12 @@ def test_vector_append(): def test_vector_concatenate(): assert Vector((1, 2)).concatenate(range(2)) == Vector((1, 2, 0, 1)) + + +def test_to_one_hot(): + assert d4.to_one_hot() == Die([ + Vector([1, 0, 0, 0]), + Vector([0, 1, 0, 0]), + Vector([0, 0, 1, 0]), + Vector([0, 0, 0, 1]), + ])