diff --git a/src/Control/Monad/Bayes/Inference/RMSMC.hs b/src/Control/Monad/Bayes/Inference/RMSMC.hs index 626eeae0..f86a4c33 100644 --- a/src/Control/Monad/Bayes/Inference/RMSMC.hs +++ b/src/Control/Monad/Bayes/Inference/RMSMC.hs @@ -51,8 +51,8 @@ rmsmc :: PopulationT m a rmsmc (MCMCConfig {..}) (SMCConfig {..}) = marginal - . S.sequentially (composeCopies numMCMCSteps (TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps - . S.hoistFirst (TrStat.hoist (withParticles numParticles)) + . S.sequentially (composeCopies numMCMCSteps (TrStat.hoistModel (single . flatten) . TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps + . S.hoistFirst (TrStat.hoistModel (single . flatten) . TrStat.hoist (withParticles numParticles)) -- | Resample-move Sequential Monte Carlo with a more efficient -- tracing representation. diff --git a/src/Control/Monad/Bayes/Inference/SMC.hs b/src/Control/Monad/Bayes/Inference/SMC.hs index 3f3a30b2..a729dc85 100644 --- a/src/Control/Monad/Bayes/Inference/SMC.hs +++ b/src/Control/Monad/Bayes/Inference/SMC.hs @@ -22,11 +22,18 @@ where import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure) import Control.Monad.Bayes.Population - ( PopulationT, + ( PopulationT (..), + flatten, pushEvidence, + single, withParticles, ) +import Control.Monad.Bayes.Population.Applicative qualified as Applicative import Control.Monad.Bayes.Sequential.Coroutine as Coroutine +import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT +import Control.Monad.Bayes.Weighted (WeightedT (..), weightedT) +import Control.Monad.Coroutine +import Control.Monad.Trans.Free (FreeF (..), FreeT (..)) data SMCConfig m = SMCConfig { resampler :: forall x. PopulationT m x -> PopulationT m x, @@ -34,6 +41,19 @@ data SMCConfig m = SMCConfig numParticles :: Int } +sequentialToPopulation :: (Monad m) => Coroutine.SequentialT (Applicative.PopulationT m) a -> PopulationT m a +sequentialToPopulation = + PopulationT + . weightedT + . coroutineToFree + . Coroutine.runSequentialT + where + coroutineToFree = + FreeT + . fmap (Free . fmap (\(cont, p) -> either (coroutineToFree . extract) (pure . (,p)) cont)) + . Applicative.runPopulationT + . resume + -- | Sequential importance resampling. -- Basically an SMC template that takes a custom resampler. smc :: @@ -42,12 +62,15 @@ smc :: Coroutine.SequentialT (PopulationT m) a -> PopulationT m a smc SMCConfig {..} = - Coroutine.sequentially resampler numSteps + (single . flatten) + . Coroutine.sequentially resampler numSteps + . SequentialT.hoist (single . flatten) . Coroutine.hoistFirst (withParticles numParticles) + . SequentialT.hoist (single . flatten) -- | Sequential Monte Carlo with multinomial resampling at each timestep. -- Weights are normalized at each timestep and the total weight is pushed -- as a score into the transformed monad. smcPush :: (MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a -smcPush config = smc config {resampler = (pushEvidence . resampler config)} +smcPush config = smc config {resampler = (single . flatten . pushEvidence . resampler config)} diff --git a/src/Control/Monad/Bayes/Inference/SMC2.hs b/src/Control/Monad/Bayes/Inference/SMC2.hs index 5570a2ba..530d8932 100644 --- a/src/Control/Monad/Bayes/Inference/SMC2.hs +++ b/src/Control/Monad/Bayes/Inference/SMC2.hs @@ -27,8 +27,10 @@ import Control.Monad.Bayes.Class import Control.Monad.Bayes.Inference.MCMC import Control.Monad.Bayes.Inference.RMSMC (rmsmc) import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush) -import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT) +import Control.Monad.Bayes.Population as Pop (PopulationT, flatten, resampleMultinomial, runPopulationT, single) +import Control.Monad.Bayes.Population qualified as PopulationT import Control.Monad.Bayes.Sequential.Coroutine (SequentialT) +import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT import Control.Monad.Bayes.Traced import Control.Monad.Trans (MonadTrans (..)) import Numeric.Log (Log) @@ -71,4 +73,10 @@ smc2 k n p t param m = rmsmc MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0} SMCConfig {numParticles = p, numSteps = k, resampler = resampleMultinomial} - (param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . m) + (flattenSequentiallyTraced param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . flattenSMC2 . m) + +flattenSequentiallyTraced :: (Monad m) => SequentialT (TracedT (PopulationT m)) a -> SequentialT (TracedT (PopulationT m)) a +flattenSequentiallyTraced = SequentialT.hoist $ hoistModel (single . flatten) . hoist (single . flatten) + +flattenSMC2 :: (Monad m) => SequentialT (PopulationT (SMC2 m)) a -> SequentialT (PopulationT (SMC2 m)) a + flattenSMC2 = SequentialT.hoist $ single . flatten . PopulationT.hoist (SMC2 . flattenSequentiallyTraced . setup) diff --git a/src/Control/Monad/Bayes/Population.hs b/src/Control/Monad/Bayes/Population.hs index aa177de1..bad1d01e 100644 --- a/src/Control/Monad/Bayes/Population.hs +++ b/src/Control/Monad/Bayes/Population.hs @@ -238,7 +238,7 @@ pushEvidence :: (MonadFactor m) => PopulationT m a -> PopulationT m a -pushEvidence = hoist applyWeight . extractEvidence +pushEvidence = single . flatten . hoist applyWeight . extractEvidence -- | A properly weighted single sample, that is one picked at random according -- to the weights, with the sum of all weights. diff --git a/src/Control/Monad/Bayes/Sequential/Coroutine.hs b/src/Control/Monad/Bayes/Sequential/Coroutine.hs index 926c3db1..8b3b5fc1 100644 --- a/src/Control/Monad/Bayes/Sequential/Coroutine.hs +++ b/src/Control/Monad/Bayes/Sequential/Coroutine.hs @@ -22,6 +22,8 @@ module Control.Monad.Bayes.Sequential.Coroutine hoist, sequentially, sis, + runSequentialT, + extract, ) where diff --git a/src/Control/Monad/Bayes/Traced/Static.hs b/src/Control/Monad/Bayes/Traced/Static.hs index fc99b327..209b507d 100644 --- a/src/Control/Monad/Bayes/Traced/Static.hs +++ b/src/Control/Monad/Bayes/Traced/Static.hs @@ -12,6 +12,7 @@ module Control.Monad.Bayes.Traced.Static ( TracedT (..), hoist, + hoistModel, marginal, mhStep, mh, @@ -25,6 +26,7 @@ import Control.Monad.Bayes.Class MonadMeasure, ) import Control.Monad.Bayes.Density.Free (DensityT) +import Control.Monad.Bayes.Density.Free qualified as DensityT import Control.Monad.Bayes.Traced.Common ( Trace (..), bind, @@ -33,6 +35,7 @@ import Control.Monad.Bayes.Traced.Common singleton, ) import Control.Monad.Bayes.Weighted (WeightedT) +import Control.Monad.Bayes.Weighted qualified as WeightedT import Control.Monad.Trans (MonadTrans (..)) import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) @@ -72,6 +75,9 @@ instance (MonadMeasure m) => MonadMeasure (TracedT m) hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a hoist f (TracedT m d) = TracedT m (f d) +hoistModel :: (Monad m) => (forall x. m x -> m x) -> TracedT m a -> TracedT m a +hoistModel f (TracedT m d) = TracedT (WeightedT.hoist (DensityT.hoist f) m) d + -- | Discard the trace and supporting infrastructure. marginal :: (Monad m) => TracedT m a -> m a marginal (TracedT _ d) = fmap output d @@ -98,15 +104,15 @@ mhStep (TracedT m d) = TracedT m d' -- * What is the probability that it is the weekend? -- -- >>> :{ --- let --- bus = do x <- bernoulli (2/7) --- let rate = if x then 3 else 10 --- factor $ poissonPdf rate 4 --- return x --- mhRunBusSingleObs = do --- let nSamples = 2 --- sampleIOfixed $ unweighted $ mh nSamples bus --- in mhRunBusSingleObs +-- let +-- bus = do x <- bernoulli (2/7) +-- let rate = if x then 3 else 10 +-- factor $ poissonPdf rate 4 +-- return x +-- mhRunBusSingleObs = do +-- let nSamples = 2 +-- sampleIOfixed $ unweighted $ mh nSamples bus +-- in mhRunBusSingleObs -- :} -- [True,True,True] --