Skip to content

Commit

Permalink
Merge pull request #256 from turion/dev_fixture_test_benchmarks
Browse files Browse the repository at this point in the history
Test some benchmark configurations against fixtures
  • Loading branch information
turion authored Jul 17, 2023
2 parents 8ded731 + 3307f9b commit f2df3a9
Show file tree
Hide file tree
Showing 24 changed files with 230 additions and 80 deletions.
30 changes: 6 additions & 24 deletions benchmark/SSM.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Main where

import Control.Monad (forM_)
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.PMMH as PMMH (pmmh)
import Control.Monad.Bayes.Inference.RMSMC (rmsmcDynamic)
Expand All @@ -11,33 +12,14 @@ import Control.Monad.Bayes.Sampler.Strict (sampleIO, sampleIOfixed, sampleWith)
import Control.Monad.Bayes.Weighted (unweighted)
import Control.Monad.IO.Class (MonadIO (liftIO))
import NonlinearSSM (generateData, model, param)
import NonlinearSSM.Algorithms
import System.Random.Stateful (mkStdGen, newIOGenM)

main :: IO ()
main = sampleIOfixed $ do
let t = 5
dat <- generateData t
let ys = map snd dat
liftIO $ print "SMC"
smcRes <- population $ smc SMCConfig {numSteps = t, numParticles = 10, resampler = resampleMultinomial} (param >>= model ys)
liftIO $ print $ show smcRes
liftIO $ print "RM-SMC"
smcrmRes <-
population $
rmsmcDynamic
MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH}
SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic}
(param >>= model ys)
liftIO $ print $ show smcrmRes
liftIO $ print "PMMH"
pmmhRes <-
unweighted $
pmmh
MCMCConfig {numMCMCSteps = 2, numBurnIn = 0, proposal = SingleSiteMH}
SMCConfig {numSteps = t, numParticles = 3, resampler = resampleSystematic}
param
(model ys)
liftIO $ print $ show pmmhRes
liftIO $ print "SMC2"
smc2Res <- population $ smc2 t 3 2 1 param (model ys)
liftIO $ print $ show smc2Res
forM_ [SMC, RMSMCDynamic, PMMH, SMC2] $ \alg -> do
liftIO $ print alg
result <- runAlgFixed ys alg
liftIO $ putStrLn result
56 changes: 1 addition & 55 deletions benchmark/Single.hs
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImportQualifiedPost #-}

import Control.Monad.Bayes.Class (MonadMeasure)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..), Proposal (SingleSiteMH))
import Control.Monad.Bayes.Inference.RMSMC (rmsmcBasic)
import Control.Monad.Bayes.Inference.SMC
( SMCConfig (SMCConfig, numParticles, numSteps, resampler),
smc,
)
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Sampler.Strict
import Control.Monad.Bayes.Traced hiding (model)
import Control.Monad.Bayes.Weighted
import Control.Monad.ST (runST)
import Data.Time (diffUTCTime, getCurrentTime)
import HMM qualified
import LDA qualified
import LogReg qualified
import Helper
import Options.Applicative
( Applicative (liftA2),
ParserInfo,
Expand All @@ -31,47 +18,6 @@ import Options.Applicative
short,
)

data Model = LR Int | HMM Int | LDA (Int, Int)
deriving stock (Show, Read)

parseModel :: String -> Maybe Model
parseModel s =
case s of
'L' : 'R' : n -> Just $ LR (read n)
'H' : 'M' : 'M' : n -> Just $ HMM (read n)
'L' : 'D' : 'A' : n -> Just $ LDA (5, read n)
_ -> Nothing

getModel :: MonadMeasure m => Model -> (Int, m String)
getModel model = (size model, program model)
where
size (LR n) = n
size (HMM n) = n
size (LDA (d, w)) = d * w
program (LR n) = show <$> (LogReg.logisticRegression (runST $ sampleSTfixed (LogReg.syntheticData n)))
program (HMM n) = show <$> (HMM.hmm (runST $ sampleSTfixed (HMM.syntheticData n)))
program (LDA (d, w)) = show <$> (LDA.lda (runST $ sampleSTfixed (LDA.syntheticData d w)))

data Alg = SMC | MH | RMSMC
deriving stock (Read, Show)

runAlg :: Model -> Alg -> SamplerIO String
runAlg model alg =
case alg of
SMC ->
let n = 100
(k, m) = getModel model
in show <$> population (smc SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic} m)
MH ->
let t = 100
(_, m) = getModel model
in show <$> unweighted (mh t m)
RMSMC ->
let n = 10
t = 1
(k, m) = getModel model
in show <$> population (rmsmcBasic MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} (SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic}) m)

infer :: Model -> Alg -> IO ()
infer model alg = do
x <- sampleIOfixed (runAlg model alg)
Expand Down
4 changes: 4 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@
name = "monad-bayes";
root = src;
cabal2nixOptions = "--benchmark -fdev";

# https://github.com/tweag/monad-bayes/pull/256: Don't run tests on Mac because of machine precision issues
modifier = drv: if system == "x86_64-linux" then drv else pkgs.haskell.lib.dontCheck drv;
overrides = self: super: with pkgs.haskell.lib; { # Please check after flake.lock updates whether some of these overrides can be removed
string-qq = dontCheck super.string-qq;
hspec = super.hspec_2_11_1;
Expand All @@ -93,6 +96,7 @@
in lib.attrsets.genAttrs ghcs buildForVersion;

monad-bayes = monad-bayes-per-ghc.ghc902;

monad-bayes-all-ghcs = pkgs.linkFarm "monad-bayes-all-ghcs" monad-bayes-per-ghc;

jupyterEnvironment = mkJupyterlabFromPath ./kernels {inherit pkgs monad-bayes;};
Expand Down
70 changes: 70 additions & 0 deletions models/Helper.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImportQualifiedPost #-}

module Helper where

import Control.Monad.Bayes.Class (MonadMeasure)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..), Proposal (SingleSiteMH))
import Control.Monad.Bayes.Inference.RMSMC (rmsmcBasic)
import Control.Monad.Bayes.Inference.SMC
( SMCConfig (SMCConfig, numParticles, numSteps, resampler),
smc,
)
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Sampler.Strict
import Control.Monad.Bayes.Traced hiding (model)
import Control.Monad.Bayes.Weighted
import Control.Monad.ST (runST)
import HMM qualified
import LDA qualified
import LogReg qualified

data Model = LR Int | HMM Int | LDA (Int, Int)
deriving stock (Show, Read)

parseModel :: String -> Maybe Model
parseModel s =
case s of
'L' : 'R' : n -> Just $ LR (read n)
'H' : 'M' : 'M' : n -> Just $ HMM (read n)
'L' : 'D' : 'A' : n -> Just $ LDA (5, read n)
_ -> Nothing

serializeModel :: Model -> Maybe String
serializeModel (LR n) = Just $ "LR" ++ show n
serializeModel (HMM n) = Just $ "HMM" ++ show n
serializeModel (LDA (5, n)) = Just $ "LDA" ++ show n
serializeModel (LDA _) = Nothing

data Alg = SMC | MH | RMSMC
deriving stock (Read, Show, Eq, Ord, Enum, Bounded)

getModel :: MonadMeasure m => Model -> (Int, m String)
getModel model = (size model, program model)
where
size (LR n) = n
size (HMM n) = n
size (LDA (d, w)) = d * w
program (LR n) = show <$> (LogReg.logisticRegression (runST $ sampleSTfixed (LogReg.syntheticData n)))
program (HMM n) = show <$> (HMM.hmm (runST $ sampleSTfixed (HMM.syntheticData n)))
program (LDA (d, w)) = show <$> (LDA.lda (runST $ sampleSTfixed (LDA.syntheticData d w)))

runAlg :: Model -> Alg -> SamplerIO String
runAlg model alg =
case alg of
SMC ->
let n = 100
(k, m) = getModel model
in show <$> population (smc SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic} m)
MH ->
let t = 100
(_, m) = getModel model
in show <$> unweighted (mh t m)
RMSMC ->
let n = 10
t = 1
(k, m) = getModel model
in show <$> population (rmsmcBasic MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} (SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic}) m)

runAlgFixed :: Model -> Alg -> IO String
runAlgFixed model alg = sampleIOfixed $ runAlg model alg
56 changes: 56 additions & 0 deletions models/NonlinearSSM/Algorithms.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module NonlinearSSM.Algorithms where

import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.PMMH as PMMH (pmmh)
import Control.Monad.Bayes.Inference.RMSMC (rmsmc, rmsmcBasic, rmsmcDynamic)
import Control.Monad.Bayes.Inference.SMC
import Control.Monad.Bayes.Inference.SMC2 as SMC2 (smc2)
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Weighted (unweighted)
import NonlinearSSM

data Alg = SMC | RMSMC | RMSMCDynamic | RMSMCBasic | PMMH | SMC2
deriving (Show, Read, Eq, Ord, Enum, Bounded)

algs :: [Alg]
algs = [minBound .. maxBound]

type SSMData = [Double]

t :: Int
t = 5

-- FIXME refactor such that it can be reused in ssm benchmark
runAlgFixed :: MonadDistribution m => SSMData -> Alg -> m String
runAlgFixed ys SMC = fmap show $ population $ smc SMCConfig {numSteps = t, numParticles = 10, resampler = resampleMultinomial} (param >>= model ys)
runAlgFixed ys RMSMC =
fmap show $
population $
rmsmc
MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH}
SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic}
(param >>= model ys)
runAlgFixed ys RMSMCBasic =
fmap show $
population $
rmsmcBasic
MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH}
SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic}
(param >>= model ys)
runAlgFixed ys RMSMCDynamic =
fmap show $
population $
rmsmcDynamic
MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH}
SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic}
(param >>= model ys)
runAlgFixed ys PMMH =
fmap show $
unweighted $
pmmh
MCMCConfig {numMCMCSteps = 2, numBurnIn = 0, proposal = SingleSiteMH}
SMCConfig {numSteps = t, numParticles = 3, resampler = resampleSystematic}
param
(model ys)
runAlgFixed ys SMC2 = fmap show $ population $ smc2 t 3 2 1 param (model ys)
13 changes: 12 additions & 1 deletion monad-bayes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ executable example
hs-source-dirs: benchmark models
other-modules:
Dice
Helper
HMM
LDA
LogReg
Expand Down Expand Up @@ -161,9 +162,15 @@ test-suite monad-bayes-test
other-modules:
BetaBin
ConjugatePriors
Helper
HMM
LDA
LogReg
NonlinearSSM
NonlinearSSM.Algorithms
Sprinkler
TestAdvanced
TestBenchmarks
TestDistribution
TestEnumerator
TestInference
Expand All @@ -172,6 +179,7 @@ test-suite monad-bayes-test
TestPopulation
TestSampler
TestSequential
TestSSMFixtures
TestStormerVerlet
TestWeighted

Expand All @@ -198,7 +206,10 @@ benchmark ssm-bench
type: exitcode-stdio-1.0
main-is: SSM.hs
hs-source-dirs: models benchmark
other-modules: NonlinearSSM
other-modules:
NonlinearSSM
NonlinearSSM.Algorithms

default-language: Haskell2010
build-depends:
, base
Expand Down
5 changes: 5 additions & 0 deletions test/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import Test.Hspec (context, describe, hspec, it, shouldBe)
import Test.Hspec.QuickCheck (prop)
import Test.QuickCheck (ioProperty, property, (==>))
import TestAdvanced qualified
import TestBenchmarks qualified
import TestDistribution qualified
import TestEnumerator qualified
import TestInference qualified
import TestIntegrator qualified
import TestPipes (hmms)
import TestPipes qualified
import TestPopulation qualified
import TestSSMFixtures qualified
import TestSampler qualified
import TestSequential qualified
import TestStormerVerlet qualified
Expand Down Expand Up @@ -166,3 +168,6 @@ main = hspec do
passed6 `shouldBe` True
passed7 <- TestAdvanced.passed7
passed7 `shouldBe` True

TestBenchmarks.test
TestSSMFixtures.test
33 changes: 33 additions & 0 deletions test/TestBenchmarks.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module TestBenchmarks where

import Control.Monad (forM_)
import Data.Maybe (fromJust)
import Helper
import System.IO (readFile')
import System.IO.Error (catchIOError, isDoesNotExistError)
import Test.Hspec

fixtureToFilename :: Model -> Alg -> String
fixtureToFilename model alg = fromJust (serializeModel model) ++ "-" ++ show alg ++ ".txt"

models :: [Model]
models = [LR 10, HMM 10, LDA (5, 10)]

algs :: [Alg]
algs = [minBound .. maxBound]

test :: SpecWith ()
test = describe "Benchmarks" $ forM_ models $ \model -> forM_ algs $ testFixture model

testFixture :: Model -> Alg -> SpecWith ()
testFixture model alg = do
let filename = "test/fixtures/" ++ fixtureToFilename model alg
it ("should agree with the fixture " ++ filename) $ do
fixture <- catchIOError (readFile' filename) $ \e ->
if isDoesNotExistError e
then return ""
else ioError e
sampled <- runAlgFixed model alg
-- Reset in case of fixture update or creation
writeFile filename sampled
fixture `shouldBe` sampled
28 changes: 28 additions & 0 deletions test/TestSSMFixtures.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module TestSSMFixtures where

import Control.Monad.Bayes.Sampler.Strict (sampleIOfixed)
import NonlinearSSM
import NonlinearSSM.Algorithms
import System.IO (readFile')
import System.IO.Error (catchIOError, isDoesNotExistError)
import Test.Hspec

fixtureToFilename :: Alg -> FilePath
fixtureToFilename alg = "test/fixtures/SSM-" ++ show alg ++ ".txt"

testFixture :: Alg -> SpecWith ()
testFixture alg = do
let filename = fixtureToFilename alg
it ("should agree with the fixture " ++ filename) $ do
ys <- sampleIOfixed $ generateData t
fixture <- catchIOError (readFile' filename) $ \e ->
if isDoesNotExistError e
then return ""
else ioError e
sampled <- sampleIOfixed $ runAlgFixed (map fst ys) alg
-- Reset in case of fixture update or creation
writeFile filename sampled
fixture `shouldBe` sampled

test :: SpecWith ()
test = describe "TestSSMFixtures" $ mapM_ testFixture algs
1 change: 1 addition & 0 deletions test/fixtures/HMM10-MH.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[2,2,1,1,1,1,1,2,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[1,1,2,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]"]
1 change: 1 addition & 0 deletions test/fixtures/HMM10-RMSMC.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[("[2,1,1,1,0,1,1,2,1,1]",2.4438034074800498e-8),("[1,1,1,1,1,2,1,1,1,0]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,1,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,2,1,1,2,2,0,1,1,1]",2.4438034074800498e-8)]
Loading

0 comments on commit f2df3a9

Please sign in to comment.