Skip to content

Commit

Permalink
add range_union function, stop using align internally
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Sep 9, 2024
1 parent 15aae72 commit deedc79
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
8 changes: 8 additions & 0 deletions src/icepool/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def max_outcome(*dice: 'T | icepool.Die[T]') -> T:
converted_dice = [icepool.implicit_convert_to_die(die) for die in dice]
return max(die.outcomes()[-1] for die in converted_dice)

def range_union(*args: Iterable[int]) -> Sequence[int]:
"""Produces a sequence of consecutive ints covering the argument sets."""
start = min((x for x in itertools.chain(*args)), default=None)
if start is None:
return ()
stop = max(x for x in itertools.chain(*args))
return tuple(range(start, stop+1))

def sorted_union(*args: Iterable[T]) -> Sequence[T]:
"""Merge sets into a sorted sequence."""
if not args:
Expand Down
17 changes: 8 additions & 9 deletions src/icepool/population/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,20 +487,19 @@ def modal_quantity(self) -> int:
"""The highest quantity of any single outcome. """
return max(self.quantities())

def kolmogorov_smirnov(self, other) -> Fraction:
def kolmogorov_smirnov(self, other: 'Population') -> Fraction:
"""Kolmogorov–Smirnov statistic. The maximum absolute difference between CDFs. """
a, b = icepool.align(self, other)
outcomes = icepool.sorted_union(self, other)
return max(
abs(a - b)
for a, b in zip(a.probabilities('<='), b.probabilities('<=')))
abs(self.probability('<=', outcome) - other.probability('<=', outcome))
for outcome in outcomes)

def cramer_von_mises(self, other) -> Fraction:
def cramer_von_mises(self, other: 'Population') -> Fraction:
"""Cramér-von Mises statistic. The sum-of-squares difference between CDFs. """
a, b = icepool.align(self, other)
outcomes = icepool.sorted_union(self, other)
return sum(
((a - b)**2
for a, b in zip(a.probabilities('<='), b.probabilities('<='))),
start=Fraction(0, 1))
((self.probability('<=', outcome) - other.probability('<=', outcome)) ** 2
for outcome in outcomes), start=Fraction(0, 1))

def median(self):
"""The median, taking the mean in case of a tie.
Expand Down
14 changes: 8 additions & 6 deletions src/icepool/population/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,11 @@ def _lowest_single(*args: 'T | icepool.Die[T]') -> 'icepool.Die[T]':
dice = tuple(icepool.implicit_convert_to_die(arg) for arg in args)
max_outcome = min(die.max_outcome() for die in dice)
dice = tuple(die.clip(max_outcome=max_outcome) for die in dice)
dice = icepool.align(*dice)
outcomes = icepool.sorted_union(*dice)
quantities_ge = tuple(
math.prod(t) for t in zip(*(die.quantities('>=') for die in dice)))
return icepool.from_cumulative(dice[0].outcomes(),
math.prod(die.quantity('>=', outcome) for die in dice)
for outcome in outcomes)
return icepool.from_cumulative(outcomes,
quantities_ge,
reverse=True)

Expand All @@ -314,7 +315,8 @@ def _highest_single(*args: 'T | icepool.Die[T]') -> 'icepool.Die[T]':
dice = tuple(icepool.implicit_convert_to_die(arg) for arg in args)
min_outcome = max(die.min_outcome() for die in dice)
dice = tuple(die.clip(min_outcome=min_outcome) for die in dice)
dice = icepool.align(*dice)
outcomes = icepool.sorted_union(*dice)
quantities_le = tuple(
math.prod(t) for t in zip(*(die.quantities('<=') for die in dice)))
return icepool.from_cumulative(dice[0].outcomes(), quantities_le)
math.prod(die.quantity('<=', outcome) for die in dice)
for outcome in outcomes)
return icepool.from_cumulative(outcomes, quantities_le)
2 changes: 0 additions & 2 deletions tests/from_cumulative_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,5 @@ def test_from_rv_norm():
1000000,
loc=die.mean(),
scale=die.standard_deviation())
die, norm_die = icepool.align(die, norm_die)
print(die.kolmogorov_smirnov(norm_die))
assert die.probabilities('<=') == pytest.approx(
norm_die.probabilities('<='), abs=1e-3)

0 comments on commit deedc79

Please sign in to comment.