Skip to content

Commit

Permalink
add Population.to_one_hot()
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Nov 18, 2023
1 parent 4fb1923 commit e79b553
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
61 changes: 43 additions & 18 deletions src/icepool/population/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import icepool
from icepool.collection.counts import CountsKeysView, CountsValuesView, CountsItemsView
from icepool.collection.vector import Vector
from icepool.math import try_fraction
from icepool.typing import U, Outcome, T_co, count_positional_parameters

Expand Down Expand Up @@ -185,13 +186,14 @@ def denominator(self) -> int:
"""
return self._denominator

# Quantities.

def scale_quantities(self: C, scale: int) -> C:
"""Scales all quantities by an integer."""
if scale == 1:
return self
data = {outcome: quantity * scale for outcome, quantity in self.items()}
data = {
outcome: quantity * scale
for outcome, quantity in self.items()
}
return self._new_type(data)

def has_zero_quantities(self) -> bool:
Expand Down Expand Up @@ -277,8 +279,8 @@ def quantities_lt(self, outcomes: Sequence | None = None) -> Sequence[int]:
outcomes: If provided, the quantities corresponding to these
outcomes will be returned (or 0 if not present).
"""
return tuple(
self.denominator() - x for x in self.quantities_ge(outcomes))
return tuple(self.denominator() - x
for x in self.quantities_ge(outcomes))

def quantities_gt(self, outcomes: Sequence | None = None) -> Sequence[int]:
"""The quantity > each outcome in order.
Expand All @@ -287,8 +289,8 @@ def quantities_gt(self, outcomes: Sequence | None = None) -> Sequence[int]:
outcomes: If provided, the quantities corresponding to these
outcomes will be returned (or 0 if not present).
"""
return tuple(
self.denominator() - x for x in self.quantities_le(outcomes))
return tuple(self.denominator() - x
for x in self.quantities_le(outcomes))

# Probabilities.

Expand Down Expand Up @@ -383,8 +385,8 @@ def probabilities_le(self,

@overload
def probabilities_le(self,
outcomes: Sequence |
None = None) -> Sequence[Fraction]:
outcomes: Sequence | None = None
) -> Sequence[Fraction]:
...

def probabilities_le(
Expand Down Expand Up @@ -436,8 +438,8 @@ def probabilities_ge(self,

@overload
def probabilities_ge(self,
outcomes: Sequence |
None = None) -> Sequence[Fraction]:
outcomes: Sequence | None = None
) -> Sequence[Fraction]:
...

def probabilities_ge(
Expand Down Expand Up @@ -484,8 +486,8 @@ def probabilities_lt(self,

@overload
def probabilities_lt(self,
outcomes: Sequence |
None = None) -> Sequence[Fraction]:
outcomes: Sequence | None = None
) -> Sequence[Fraction]:
...

def probabilities_lt(
Expand Down Expand Up @@ -528,8 +530,8 @@ def probabilities_gt(self,

@overload
def probabilities_gt(self,
outcomes: Sequence |
None = None) -> Sequence[Fraction]:
outcomes: Sequence | None = None
) -> Sequence[Fraction]:
...

def probabilities_gt(
Expand Down Expand Up @@ -690,8 +692,8 @@ def entropy(self, base: float = 2.0) -> float:
base: The logarithm base to use. Default is 2.0, which gives the
entropy in bits.
"""
return -sum(
p * math.log(p, base) for p in self.probabilities() if p > 0.0)
return -sum(p * math.log(p, base)
for p in self.probabilities() if p > 0.0)

# Joint statistics.

Expand Down Expand Up @@ -746,6 +748,29 @@ def correlation(
sd_j = self.marginals[j].standard_deviation()
return self.covariance(i, j) / (sd_i * sd_j)

# Transformations.

def to_one_hot(self: C, outcomes: Sequence[T_co] | None = None) -> C:
"""Converts the outcomes of this population to a one-hot representation.
Args:
outcomes: If provided, each outcome will be mapped to a `Vector`
where the element at `outcomes.index(outcome)` is set to `True`
and the rest to `False`, or all `False` if the outcome is not
in `outcomes`.
If not provided, `self.outcomes()` is used.
"""
if outcomes is None:
outcomes = self.outcomes()

data: MutableMapping[Vector[bool], int] = defaultdict(int)
for outcome, quantity in zip(self.outcomes(), self.quantities()):
value = [False] * len(outcomes)
if outcome in outcomes:
value[outcomes.index(outcome)] = True
data[Vector(value)] += quantity
return self._new_type(data)

def sample(self) -> T_co:
"""A single random sample from this population.
Expand Down Expand Up @@ -820,4 +845,4 @@ def __format__(self, format_spec: str, /) -> str:
return self.format(format_spec)

def __str__(self) -> str:
return f'{self}'
return f'{self}'
21 changes: 15 additions & 6 deletions tests/vector_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import icepool
import pytest

from icepool import d6, d8, vectorize, Die, Vector
from icepool import d4, d6, d8, vectorize, Die, Vector


def test_cartesian_product():
Expand Down Expand Up @@ -31,17 +31,17 @@ def test_vector_matmul():


def test_nested_unary_elementwise():
result = icepool.Die([vectorize(vectorize(vectorize(1,),),)])
result = icepool.Die([vectorize(vectorize(vectorize(1, ), ), )])
result = -result
assert result.marginals[0].marginals[0].marginals[0].equals(
icepool.Die([-1]))


def test_nested_binary_elementwise():
result = icepool.Die([vectorize(vectorize(vectorize(1,),),)])
result = icepool.Die([vectorize(vectorize(vectorize(1, ), ), )])
result = result + result
assert result.marginals[0].marginals[0].marginals[0].equals(icepool.Die([2
]))
assert result.marginals[0].marginals[0].marginals[0].equals(
icepool.Die([2]))


def test_binary_op_mismatch_outcome_len():
Expand Down Expand Up @@ -92,7 +92,7 @@ class OneHotEvaluator(icepool.MultisetEvaluator):
def next_state(self, state, _, count):
if state is None:
state = ()
return state + (count,)
return state + (count, )

def final_outcome(self, final_state):
return icepool.Vector(final_state)
Expand Down Expand Up @@ -129,3 +129,12 @@ def test_vector_append():

def test_vector_concatenate():
assert Vector((1, 2)).concatenate(range(2)) == Vector((1, 2, 0, 1))


def test_to_one_hot():
assert d4.to_one_hot() == Die([
Vector([1, 0, 0, 0]),
Vector([0, 1, 0, 0]),
Vector([0, 0, 1, 0]),
Vector([0, 0, 0, 1]),
])

0 comments on commit e79b553

Please sign in to comment.