Skip to content

Commit

Permalink
Merge pull request #4 from quickdudley/observing
Browse files Browse the repository at this point in the history
Add 'observing' function (Bayesian update rule)
  • Loading branch information
redelmann authored Jun 29, 2017
2 parents c1adf50 + 79b4c97 commit 7560c72
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion Data/Distribution/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ module Data.Distribution.Core
-- ** Transformation
, select
, assuming
, observing
-- ** Combination
, combineWith
-- ** Sequences
Expand Down Expand Up @@ -272,6 +273,39 @@ assuming f (Distribution xs) = Distribution $ fmap adjust filtered
total = sum $ Map.elems filtered


-- | Returns a new distribution using the Bayesian update rule.
--
-- Using this example:
-- https://en.wikipedia.org/wiki/Bayesian_inference#Probability_of_a_hypothesis
--
-- > data CookieBowl = Bowl1 | Bowl2 deriving (Eq,Ord)
-- > data CookieType = Plain | ChocolateChip deriving (Eq,Ord)
-- >
-- > assumption :: Distribution CookieBowl
-- > assumption = uniform [Bowl1,Bowl2]
-- >
-- > update :: Cookie -> Distribution CookieBowl -> Distribution CookieBowl
-- > update c = observing f where
-- > f b = case b of
-- > -- Bowl #1 contains 10 chocolate chip cookies and 30 plain cookies
-- > Bowl1 -> fromList [(c == ChocolateChip,10),(c == Plain,30)]
-- > -- Bowl #2 contains 20 of each flavour of cookie
-- > Bowl2 -> fromList [(c == ChocolateChip,20),(c == Plain,20)]
--
-- The "update" function in this example can be used to update the probability
-- distribution of which bowl you have based on observing a random cookie inside
-- the bowl.
observing :: (a -> Distribution Bool) -> Distribution a -> Distribution a
observing f (Distribution xs) = Distribution $ fmap adjust filtered
where
filtered = Map.filter (/= 0) $ Map.mapWithKey tweak xs
tweak x p = let
Distribution px = f x
pt = fromMaybe 0 $ Map.lookup True px
in pt * p
adjust x = x * (1 / total)
total = sum $ Map.elems filtered

-- Combination

combineWith :: (Ord b) => (a -> a -> b) -> Distribution a -> Distribution a -> Distribution b
Expand Down Expand Up @@ -388,4 +422,4 @@ andThen (Distribution xs) f = Distribution $
-- A distribution is valid if and only if its domain is non-empty.
-- Invalid distributions may arise from the use of 'assuming' for instance.
isValid :: Distribution a -> Bool
isValid (Distribution xs) = not $ Map.null xs
isValid (Distribution xs) = not $ Map.null xs

0 comments on commit 7560c72

Please sign in to comment.