Skip to content

Commit

Permalink
Unsuccessfull attempt at fixing SMC2
Browse files Browse the repository at this point in the history
  • Loading branch information
turion committed Jan 3, 2024
1 parent 084862f commit 4f6b297
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/Control/Monad/Bayes/Inference/RMSMC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ rmsmc ::
PopulationT m a
rmsmc (MCMCConfig {..}) (SMCConfig {..}) =
marginal
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps
. S.hoistFirst (TrStat.hoist (withParticles numParticles))
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoistModel (single . flatten) . TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps
. S.hoistFirst (TrStat.hoistModel (single . flatten) . TrStat.hoist (withParticles numParticles))

-- | Resample-move Sequential Monte Carlo with a more efficient
-- tracing representation.
Expand Down
29 changes: 26 additions & 3 deletions src/Control/Monad/Bayes/Inference/SMC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,38 @@ where

import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure)
import Control.Monad.Bayes.Population
( PopulationT,
( PopulationT (..),
flatten,
pushEvidence,
single,
withParticles,
)
import Control.Monad.Bayes.Population.Applicative qualified as Applicative
import Control.Monad.Bayes.Sequential.Coroutine as Coroutine
import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT
import Control.Monad.Bayes.Weighted (WeightedT (..), weightedT)
import Control.Monad.Coroutine
import Control.Monad.Trans.Free (FreeF (..), FreeT (..))

data SMCConfig m = SMCConfig
{ resampler :: forall x. PopulationT m x -> PopulationT m x,
numSteps :: Int,
numParticles :: Int
}

sequentialToPopulation :: (Monad m) => Coroutine.SequentialT (Applicative.PopulationT m) a -> PopulationT m a
sequentialToPopulation =
PopulationT
. weightedT
. coroutineToFree
. Coroutine.runSequentialT
where
coroutineToFree =
FreeT
. fmap (Free . fmap (\(cont, p) -> either (coroutineToFree . extract) (pure . (,p)) cont))
. Applicative.runPopulationT
. resume

-- | Sequential importance resampling.
-- Basically an SMC template that takes a custom resampler.
smc ::
Expand All @@ -42,12 +62,15 @@ smc ::
Coroutine.SequentialT (PopulationT m) a ->
PopulationT m a
smc SMCConfig {..} =
Coroutine.sequentially resampler numSteps
(single . flatten)
. Coroutine.sequentially resampler numSteps
. SequentialT.hoist (single . flatten)
. Coroutine.hoistFirst (withParticles numParticles)
. SequentialT.hoist (single . flatten)

-- | Sequential Monte Carlo with multinomial resampling at each timestep.
-- Weights are normalized at each timestep and the total weight is pushed
-- as a score into the transformed monad.
smcPush ::
(MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a
smcPush config = smc config {resampler = (pushEvidence . resampler config)}
smcPush config = smc config {resampler = (single . flatten . pushEvidence . resampler config)}
12 changes: 10 additions & 2 deletions src/Control/Monad/Bayes/Inference/SMC2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT)
import Control.Monad.Bayes.Population as Pop (PopulationT, flatten, resampleMultinomial, runPopulationT, single)
import Control.Monad.Bayes.Population qualified as PopulationT
import Control.Monad.Bayes.Sequential.Coroutine (SequentialT)
import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)
Expand Down Expand Up @@ -71,4 +73,10 @@ smc2 k n p t param m =
rmsmc
MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0}
SMCConfig {numParticles = p, numSteps = k, resampler = resampleMultinomial}
(param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . m)
(flattenSequentiallyTraced param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . flattenSMC2 . m)

flattenSequentiallyTraced :: (Monad m) => SequentialT (TracedT (PopulationT m)) a -> SequentialT (TracedT (PopulationT m)) a
flattenSequentiallyTraced = SequentialT.hoist $ hoistModel (single . flatten) . hoist (single . flatten)

flattenSMC2 :: (Monad m) => SequentialT (PopulationT (SMC2 m)) a -> SequentialT (PopulationT (SMC2 m)) a
flattenSMC2 = SequentialT.hoist $ single . flatten . PopulationT.hoist (SMC2 . flattenSequentiallyTraced . setup)
2 changes: 1 addition & 1 deletion src/Control/Monad/Bayes/Population.hs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ pushEvidence ::
(MonadFactor m) =>
PopulationT m a ->
PopulationT m a
pushEvidence = hoist applyWeight . extractEvidence
pushEvidence = single . flatten . hoist applyWeight . extractEvidence

-- | A properly weighted single sample, that is one picked at random according
-- to the weights, with the sum of all weights.
Expand Down
2 changes: 2 additions & 0 deletions src/Control/Monad/Bayes/Sequential/Coroutine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ module Control.Monad.Bayes.Sequential.Coroutine
hoist,
sequentially,
sis,
runSequentialT,
extract,
)
where

Expand Down
24 changes: 15 additions & 9 deletions src/Control/Monad/Bayes/Traced/Static.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
module Control.Monad.Bayes.Traced.Static
( TracedT (..),
hoist,
hoistModel,
marginal,
mhStep,
mh,
Expand All @@ -25,6 +26,7 @@ import Control.Monad.Bayes.Class
MonadMeasure,
)
import Control.Monad.Bayes.Density.Free (DensityT)
import Control.Monad.Bayes.Density.Free qualified as DensityT
import Control.Monad.Bayes.Traced.Common
( Trace (..),
bind,
Expand All @@ -33,6 +35,7 @@ import Control.Monad.Bayes.Traced.Common
singleton,
)
import Control.Monad.Bayes.Weighted (WeightedT)
import Control.Monad.Bayes.Weighted qualified as WeightedT
import Control.Monad.Trans (MonadTrans (..))
import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList)

Expand Down Expand Up @@ -72,6 +75,9 @@ instance (MonadMeasure m) => MonadMeasure (TracedT m)
hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a
hoist f (TracedT m d) = TracedT m (f d)

hoistModel :: (Monad m) => (forall x. m x -> m x) -> TracedT m a -> TracedT m a
hoistModel f (TracedT m d) = TracedT (WeightedT.hoist (DensityT.hoist f) m) d

-- | Discard the trace and supporting infrastructure.
marginal :: (Monad m) => TracedT m a -> m a
marginal (TracedT _ d) = fmap output d
Expand All @@ -98,15 +104,15 @@ mhStep (TracedT m d) = TracedT m d'
-- * What is the probability that it is the weekend?
--
-- >>> :{
-- let
-- bus = do x <- bernoulli (2/7)
-- let rate = if x then 3 else 10
-- factor $ poissonPdf rate 4
-- return x
-- mhRunBusSingleObs = do
-- let nSamples = 2
-- sampleIOfixed $ unweighted $ mh nSamples bus
-- in mhRunBusSingleObs
-- let
-- bus = do x <- bernoulli (2/7)
-- let rate = if x then 3 else 10
-- factor $ poissonPdf rate 4
-- return x
-- mhRunBusSingleObs = do
-- let nSamples = 2
-- sampleIOfixed $ unweighted $ mh nSamples bus
-- in mhRunBusSingleObs
-- :}
-- [True,True,True]
--
Expand Down

0 comments on commit 4f6b297

Please sign in to comment.