From 6d9cbd6af4f737f3317557a14c7988a3e755a47b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20B=C3=A4renz?= Date: Wed, 3 Jan 2024 10:55:58 +0100 Subject: [PATCH] Bump mtl -> 2.3, transformers -> 0.6 --- monad-bayes.cabal | 4 +-- src/Control/Applicative/List.hs | 23 +++++++++++++ src/Control/Monad/Bayes/Class.hs | 14 ++------ src/Control/Monad/Bayes/Enumerator.hs | 2 ++ src/Control/Monad/Bayes/Population.hs | 48 ++++++++++++++++++++++++--- test/TestWeighted.hs | 2 +- 6 files changed, 73 insertions(+), 20 deletions(-) create mode 100644 src/Control/Applicative/List.hs diff --git a/monad-bayes.cabal b/monad-bayes.cabal index 8dc7b22b..fb8c764e 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -53,7 +53,7 @@ common deps , matrix ^>=0.3 , monad-coroutine ^>=0.9.0 , monad-extras ^>=0.6 - , mtl ^>=2.2.2 + , mtl ^>=2.3 , mwc-random >=0.13.6 && <0.16 , pipes ^>=4.3 , pretty-simple ^>=4.1 @@ -63,6 +63,7 @@ common deps , scientific ^>=0.3 , statistics >=0.14.0 && <0.17 , text >=1.2 && <2.1 + , transformers ^>=0.6 , vector >=0.12.0 && <0.14 , vty ^>=5.38 @@ -77,7 +78,6 @@ common test-deps , process ^>=1.6 , QuickCheck ^>=2.14 , time >=1.9 && <1.13 - , transformers ^>=0.5.6 , typed-process ^>=0.2 autogen-modules: Paths_monad_bayes diff --git a/src/Control/Applicative/List.hs b/src/Control/Applicative/List.hs new file mode 100644 index 00000000..a6a0a99a --- /dev/null +++ b/src/Control/Applicative/List.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE StandaloneDeriving #-} + +module Control.Applicative.List where + +-- base +import Control.Applicative +import Data.Functor.Compose + +-- * Applicative ListT + +-- | _Applicative_ transformer adding a list/nondeterminism/choice effect. +-- It is not a valid monad transformer, but it is a valid 'Applicative'. +newtype ListT m a = ListT {getListT :: Compose m [] a} + deriving newtype (Functor, Applicative, Alternative) + +listT :: m [a] -> ListT m a +listT = ListT . Compose + +lift :: (Functor m) => m a -> ListT m a +lift = ListT . Compose . fmap pure + +runListT :: ListT m a -> m [a] +runListT = getCompose . getListT diff --git a/src/Control/Monad/Bayes/Class.hs b/src/Control/Monad/Bayes/Class.hs index 6a8c1803..90bf1491 100644 --- a/src/Control/Monad/Bayes/Class.hs +++ b/src/Control/Monad/Bayes/Class.hs @@ -77,11 +77,11 @@ where import Control.Arrow (Arrow (second)) import Control.Monad (replicateM, when) import Control.Monad.Cont (ContT) -import Control.Monad.Except (ExceptT, lift) +import Control.Monad.Except (ExceptT) import Control.Monad.Identity (IdentityT) -import Control.Monad.List (ListT) import Control.Monad.Reader (ReaderT) import Control.Monad.State (StateT) +import Control.Monad.Trans.Class (lift) import Control.Monad.Writer (WriterT) import Data.Histogram qualified as H import Data.Histogram.Fill qualified as H @@ -390,16 +390,6 @@ instance (MonadFactor m) => MonadFactor (StateT s m) where instance (MonadMeasure m) => MonadMeasure (StateT s m) -instance (MonadDistribution m) => MonadDistribution (ListT m) where - random = lift random - bernoulli = lift . bernoulli - categorical = lift . categorical - -instance (MonadFactor m) => MonadFactor (ListT m) where - score = lift . score - -instance (MonadMeasure m) => MonadMeasure (ListT m) - instance (MonadDistribution m) => MonadDistribution (ContT r m) where random = lift random diff --git a/src/Control/Monad/Bayes/Enumerator.hs b/src/Control/Monad/Bayes/Enumerator.hs index cd7e7e1a..50e091e5 100644 --- a/src/Control/Monad/Bayes/Enumerator.hs +++ b/src/Control/Monad/Bayes/Enumerator.hs @@ -32,6 +32,7 @@ where import Control.Applicative (Alternative) import Control.Arrow (second) +import Control.Monad (MonadPlus) import Control.Monad.Bayes.Class ( MonadDistribution (bernoulli, categorical, logCategorical, random), MonadFactor (..), @@ -42,6 +43,7 @@ import Data.AEq (AEq, (===), (~==)) import Data.List (sortOn) import Data.Map qualified as Map import Data.Maybe (fromMaybe) +import Data.Monoid (Product (..)) import Data.Ord (Down (Down)) import Data.Vector qualified as VV import Data.Vector.Generic qualified as V diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index 2a384f77..9dff3152 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# OPTIONS_GHC -Wno-deprecations #-} @@ -37,11 +38,12 @@ module Control.Monad.Bayes.Population ) where +import Control.Applicative (Alternative) import Control.Arrow (second) -import Control.Monad (replicateM) +import Control.Monad (forM, replicateM) import Control.Monad.Bayes.Class - ( MonadDistribution (categorical, logCategorical, random, uniform), - MonadFactor, + ( MonadDistribution (..), + MonadFactor (..), MonadMeasure, factor, ) @@ -52,7 +54,9 @@ import Control.Monad.Bayes.Weighted runWeightedT, weightedT, ) -import Control.Monad.List (ListT (..), MonadIO, MonadTrans (..)) +import Control.Monad.IO.Class +import Control.Monad.Trans.Class +import Data.Functor.Compose import Data.List (unfoldr) import Data.List qualified import Data.Maybe (catMaybes) @@ -62,6 +66,40 @@ import Numeric.Log (Log, ln, sum) import Numeric.Log qualified as Log import Prelude hiding (all, sum) +-- | The old-fashioned, broken list transformer, adding a list/nondeterminism/choice effect. +-- It is not a valid monad transformer, but it is a valid 'Applicative'. +newtype ListT m a = ListT {getListT :: Compose m [] a} + deriving newtype (Functor, Applicative, Alternative) + +listT :: m [a] -> ListT m a +listT = ListT . Compose + +runListT :: ListT m a -> m [a] +runListT = getCompose . getListT + +-- | This monad instance is _unlawful_, +-- it is only by accident and careful construction that it can be used here. +instance (Monad m) => Monad (ListT m) where + ma >>= f = ListT $ Compose $ do + as <- runListT ma + fmap concat $ forM as $ runListT . f + +instance MonadTrans ListT where + lift = ListT . Compose . fmap pure + +instance (MonadIO m) => MonadIO (ListT m) where + liftIO = lift . liftIO + +instance (MonadDistribution m) => MonadDistribution (ListT m) where + random = lift random + bernoulli = lift . bernoulli + categorical = lift . categorical + +instance (MonadFactor m) => MonadFactor (ListT m) where + score = lift . score + +instance (MonadMeasure m) => MonadMeasure (ListT m) + -- | A collection of weighted samples, or particles. newtype PopulationT m a = PopulationT {getPopulationT :: WeightedT (ListT m) a} deriving newtype (Functor, Applicative, Monad, MonadIO, MonadDistribution, MonadFactor, MonadMeasure) @@ -80,7 +118,7 @@ explicitPopulation = fmap (map (second (exp . ln))) . runPopulationT -- | Initialize 'PopulationT' with a concrete weighted sample. fromWeightedList :: (Monad m) => m [(a, Log Double)] -> PopulationT m a -fromWeightedList = PopulationT . weightedT . ListT +fromWeightedList = PopulationT . weightedT . listT -- | Increase the sample size by a given factor. -- The weights are adjusted such that their sum is preserved. diff --git a/test/TestWeighted.hs b/test/TestWeighted.hs index 3e068e0f..1b420219 100644 --- a/test/TestWeighted.hs +++ b/test/TestWeighted.hs @@ -2,6 +2,7 @@ module TestWeighted (check, passed, result, model) where +import Control.Monad (unless, when) import Control.Monad.Bayes.Class ( MonadDistribution (normal, uniformD), MonadMeasure, @@ -9,7 +10,6 @@ import Control.Monad.Bayes.Class ) import Control.Monad.Bayes.Sampler.Strict (sampleIOfixed) import Control.Monad.Bayes.Weighted (runWeightedT) -import Control.Monad.State (unless, when) import Data.AEq (AEq ((~==))) import Data.Bifunctor (second) import Numeric.Log (Log (Exp, ln))