From deedc792f4a733e9e2b7f8713fec3f6e0894e534 Mon Sep 17 00:00:00 2001 From: Albert Julius Liu Date: Sun, 8 Sep 2024 18:15:24 -0700 Subject: [PATCH] add `range_union` function, stop using `align` internally --- src/icepool/function.py | 8 ++++++++ src/icepool/population/base.py | 17 ++++++++--------- src/icepool/population/keep.py | 14 ++++++++------ tests/from_cumulative_test.py | 2 -- 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/icepool/function.py b/src/icepool/function.py index 522de93c..c8b60156 100644 --- a/src/icepool/function.py +++ b/src/icepool/function.py @@ -229,6 +229,14 @@ def max_outcome(*dice: 'T | icepool.Die[T]') -> T: converted_dice = [icepool.implicit_convert_to_die(die) for die in dice] return max(die.outcomes()[-1] for die in converted_dice) +def range_union(*args: Iterable[int]) -> Sequence[int]: + """Produces a sequence of consecutive ints covering the argument sets.""" + start = min((x for x in itertools.chain(*args)), default=None) + if start is None: + return () + stop = max(x for x in itertools.chain(*args)) + return tuple(range(start, stop+1)) + def sorted_union(*args: Iterable[T]) -> Sequence[T]: """Merge sets into a sorted sequence.""" if not args: diff --git a/src/icepool/population/base.py b/src/icepool/population/base.py index f14411aa..5aee9437 100644 --- a/src/icepool/population/base.py +++ b/src/icepool/population/base.py @@ -487,20 +487,19 @@ def modal_quantity(self) -> int: """The highest quantity of any single outcome. """ return max(self.quantities()) - def kolmogorov_smirnov(self, other) -> Fraction: + def kolmogorov_smirnov(self, other: 'Population') -> Fraction: """Kolmogorov–Smirnov statistic. The maximum absolute difference between CDFs. """ - a, b = icepool.align(self, other) + outcomes = icepool.sorted_union(self, other) return max( - abs(a - b) - for a, b in zip(a.probabilities('<='), b.probabilities('<='))) + abs(self.probability('<=', outcome) - other.probability('<=', outcome)) + for outcome in outcomes) - def cramer_von_mises(self, other) -> Fraction: + def cramer_von_mises(self, other: 'Population') -> Fraction: """Cramér-von Mises statistic. The sum-of-squares difference between CDFs. """ - a, b = icepool.align(self, other) + outcomes = icepool.sorted_union(self, other) return sum( - ((a - b)**2 - for a, b in zip(a.probabilities('<='), b.probabilities('<='))), - start=Fraction(0, 1)) + ((self.probability('<=', outcome) - other.probability('<=', outcome)) ** 2 + for outcome in outcomes), start=Fraction(0, 1)) def median(self): """The median, taking the mean in case of a tie. diff --git a/src/icepool/population/keep.py b/src/icepool/population/keep.py index 544fb6fb..929252fe 100644 --- a/src/icepool/population/keep.py +++ b/src/icepool/population/keep.py @@ -297,10 +297,11 @@ def _lowest_single(*args: 'T | icepool.Die[T]') -> 'icepool.Die[T]': dice = tuple(icepool.implicit_convert_to_die(arg) for arg in args) max_outcome = min(die.max_outcome() for die in dice) dice = tuple(die.clip(max_outcome=max_outcome) for die in dice) - dice = icepool.align(*dice) + outcomes = icepool.sorted_union(*dice) quantities_ge = tuple( - math.prod(t) for t in zip(*(die.quantities('>=') for die in dice))) - return icepool.from_cumulative(dice[0].outcomes(), + math.prod(die.quantity('>=', outcome) for die in dice) + for outcome in outcomes) + return icepool.from_cumulative(outcomes, quantities_ge, reverse=True) @@ -314,7 +315,8 @@ def _highest_single(*args: 'T | icepool.Die[T]') -> 'icepool.Die[T]': dice = tuple(icepool.implicit_convert_to_die(arg) for arg in args) min_outcome = max(die.min_outcome() for die in dice) dice = tuple(die.clip(min_outcome=min_outcome) for die in dice) - dice = icepool.align(*dice) + outcomes = icepool.sorted_union(*dice) quantities_le = tuple( - math.prod(t) for t in zip(*(die.quantities('<=') for die in dice))) - return icepool.from_cumulative(dice[0].outcomes(), quantities_le) + math.prod(die.quantity('<=', outcome) for die in dice) + for outcome in outcomes) + return icepool.from_cumulative(outcomes, quantities_le) diff --git a/tests/from_cumulative_test.py b/tests/from_cumulative_test.py index 785615fd..f8839262 100644 --- a/tests/from_cumulative_test.py +++ b/tests/from_cumulative_test.py @@ -27,7 +27,5 @@ def test_from_rv_norm(): 1000000, loc=die.mean(), scale=die.standard_deviation()) - die, norm_die = icepool.align(die, norm_die) - print(die.kolmogorov_smirnov(norm_die)) assert die.probabilities('<=') == pytest.approx( norm_die.probabilities('<='), abs=1e-3)