Skip to content

Commit

Permalink
Test some benchmark configurations against fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel Bärenz authored and turion committed May 17, 2023
1 parent 1f6d7bc commit 97fc400
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 55 deletions.
56 changes: 1 addition & 55 deletions benchmark/Single.hs
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImportQualifiedPost #-}

import Control.Monad.Bayes.Class (MonadMeasure)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..), Proposal (SingleSiteMH))
import Control.Monad.Bayes.Inference.RMSMC (rmsmcBasic)
import Control.Monad.Bayes.Inference.SMC
( SMCConfig (SMCConfig, numParticles, numSteps, resampler),
smc,
)
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Sampler.Strict
import Control.Monad.Bayes.Traced hiding (model)
import Control.Monad.Bayes.Weighted
import Control.Monad.ST (runST)
import Data.Time (diffUTCTime, getCurrentTime)
import HMM qualified
import LDA qualified
import LogReg qualified
import Helper
import Options.Applicative
( Applicative (liftA2),
ParserInfo,
Expand All @@ -31,47 +18,6 @@ import Options.Applicative
short,
)

data Model = LR Int | HMM Int | LDA (Int, Int)
deriving stock (Show, Read)

parseModel :: String -> Maybe Model
parseModel s =
case s of
'L' : 'R' : n -> Just $ LR (read n)
'H' : 'M' : 'M' : n -> Just $ HMM (read n)
'L' : 'D' : 'A' : n -> Just $ LDA (5, read n)
_ -> Nothing

getModel :: MonadMeasure m => Model -> (Int, m String)
getModel model = (size model, program model)
where
size (LR n) = n
size (HMM n) = n
size (LDA (d, w)) = d * w
program (LR n) = show <$> (LogReg.logisticRegression (runST $ sampleSTfixed (LogReg.syntheticData n)))
program (HMM n) = show <$> (HMM.hmm (runST $ sampleSTfixed (HMM.syntheticData n)))
program (LDA (d, w)) = show <$> (LDA.lda (runST $ sampleSTfixed (LDA.syntheticData d w)))

data Alg = SMC | MH | RMSMC
deriving stock (Read, Show)

runAlg :: Model -> Alg -> SamplerIO String
runAlg model alg =
case alg of
SMC ->
let n = 100
(k, m) = getModel model
in show <$> population (smc SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic} m)
MH ->
let t = 100
(_, m) = getModel model
in show <$> unweighted (mh t m)
RMSMC ->
let n = 10
t = 1
(k, m) = getModel model
in show <$> population (rmsmcBasic MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} (SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic}) m)

infer :: Model -> Alg -> IO ()
infer model alg = do
x <- sampleIOfixed (runAlg model alg)
Expand Down
70 changes: 70 additions & 0 deletions models/Helper.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImportQualifiedPost #-}

module Helper where

import Control.Monad.Bayes.Class (MonadMeasure)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..), Proposal (SingleSiteMH))
import Control.Monad.Bayes.Inference.RMSMC (rmsmcBasic)
import Control.Monad.Bayes.Inference.SMC
( SMCConfig (SMCConfig, numParticles, numSteps, resampler),
smc,
)
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Sampler.Strict
import Control.Monad.Bayes.Traced hiding (model)
import Control.Monad.Bayes.Weighted
import Control.Monad.ST (runST)
import HMM qualified
import LDA qualified
import LogReg qualified

data Model = LR Int | HMM Int | LDA (Int, Int)
deriving stock (Show, Read)

parseModel :: String -> Maybe Model
parseModel s =
case s of
'L' : 'R' : n -> Just $ LR (read n)
'H' : 'M' : 'M' : n -> Just $ HMM (read n)
'L' : 'D' : 'A' : n -> Just $ LDA (5, read n)
_ -> Nothing

serializeModel :: Model -> Maybe String
serializeModel (LR n) = Just $ "LR" ++ show n
serializeModel (HMM n) = Just $ "HMM" ++ show n
serializeModel (LDA (5, n)) = Just $ "LDA" ++ show n
serializeModel (LDA _) = Nothing

data Alg = SMC | MH | RMSMC
deriving stock (Read, Show, Eq, Ord, Enum, Bounded)

getModel :: MonadMeasure m => Model -> (Int, m String)
getModel model = (size model, program model)
where
size (LR n) = n
size (HMM n) = n
size (LDA (d, w)) = d * w
program (LR n) = show <$> (LogReg.logisticRegression (runST $ sampleSTfixed (LogReg.syntheticData n)))
program (HMM n) = show <$> (HMM.hmm (runST $ sampleSTfixed (HMM.syntheticData n)))
program (LDA (d, w)) = show <$> (LDA.lda (runST $ sampleSTfixed (LDA.syntheticData d w)))

runAlg :: Model -> Alg -> SamplerIO String
runAlg model alg =
case alg of
SMC ->
let n = 100
(k, m) = getModel model
in show <$> population (smc SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic} m)
MH ->
let t = 100
(_, m) = getModel model
in show <$> unweighted (mh t m)
RMSMC ->
let n = 10
t = 1
(k, m) = getModel model
in show <$> population (rmsmcBasic MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} (SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic}) m)

runAlgFixed :: Model -> Alg -> IO String
runAlgFixed model alg = sampleIOfixed $ runAlg model alg
5 changes: 5 additions & 0 deletions monad-bayes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ executable example
hs-source-dirs: benchmark models
other-modules:
Dice
Helper
HMM
LDA
LogReg
Expand Down Expand Up @@ -157,9 +158,13 @@ test-suite monad-bayes-test
other-modules:
BetaBin
ConjugatePriors
Helper
HMM
LDA
LogReg
Sprinkler
TestAdvanced
TestBenchmarks
TestDistribution
TestEnumerator
TestInference
Expand Down
3 changes: 3 additions & 0 deletions test/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Test.Hspec (context, describe, hspec, it, shouldBe)
import Test.Hspec.QuickCheck (prop)
import Test.QuickCheck (ioProperty, property, (==>))
import TestAdvanced qualified
import TestBenchmarks qualified
import TestDistribution qualified
import TestEnumerator qualified
import TestInference qualified
Expand Down Expand Up @@ -164,3 +165,5 @@ main = hspec do
passed6 `shouldBe` True
passed7 <- TestAdvanced.passed7
passed7 `shouldBe` True

TestBenchmarks.test
33 changes: 33 additions & 0 deletions test/TestBenchmarks.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module TestBenchmarks where

import Control.Monad (forM_)
import Data.Maybe (fromJust)
import Helper
import System.IO (readFile')
import System.IO.Error (catchIOError, isDoesNotExistError)
import Test.Hspec

fixtureToFilename :: Model -> Alg -> String
fixtureToFilename model alg = fromJust (serializeModel model) ++ "-" ++ show alg ++ ".txt"

models :: [Model]
models = [LR 10, HMM 10, LDA (5, 10)]

algs :: [Alg]
algs = [minBound .. maxBound]

test :: SpecWith ()
test = describe "Benchmarks" $ forM_ models $ \model -> forM_ algs $ testFixture model

testFixture :: Model -> Alg -> SpecWith ()
testFixture model alg = do
let filename = "test/fixtures/" ++ fixtureToFilename model alg
it ("should agree with the fixture " ++ filename) $ do
fixture <- catchIOError (readFile' filename) $ \e ->
if isDoesNotExistError e
then return ""
else ioError e
sampled <- runAlgFixed model alg
-- Reset in case of fixture update or creation
writeFile filename sampled
fixture `shouldBe` sampled
1 change: 1 addition & 0 deletions test/fixtures/HMM10-MH.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[2,2,1,1,1,1,1,2,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[1,1,2,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]"]
1 change: 1 addition & 0 deletions test/fixtures/HMM10-RMSMC.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[("[2,1,1,1,0,1,1,2,1,1]",2.4438034074800498e-8),("[1,1,1,1,1,2,1,1,1,0]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,1,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,2,1,1,2,2,0,1,1,1]",2.4438034074800498e-8)]
1 change: 1 addition & 0 deletions test/fixtures/HMM10-SMC.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[("[1,1,0,1,1,2,1,1,1,2]",2.964810681340389e-9),("[1,1,0,1,1,2,1,1,1,2]",2.964810681340389e-9),("[1,1,0,1,1,2,1,1,1,2]",2.964810681340389e-9),("[1,2,0,2,1,2,1,1,1,0]",2.964810681340389e-9),("[1,2,0,2,1,2,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,0,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,0,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,0,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,2,0,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,2,0,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,0,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,0,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,0,1,1]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,2]",2.964810681340389e-9),("[1,2,0,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,2,0,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,2,1,1,0,2,2,1,2,0]",2.964810681340389e-9),("[1,2,1,1,0,2,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,0,2,2,1,2,0]",2.964810681340389e-9),("[1,2,2,1,1,1,1,1,2,0]",2.964810681340389e-9),("[1,2,2,1,1,1,1,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,1,0]",2.964810681340389e-9),("[1,1,1,2,1,0,0,1,1,0]",2.964810681340389e-9),("[1,1,1,2,1,0,0,1,1,0]",2.964810681340389e-9),("[1,1,1,2,1,0,0,1,1,0]",2.964810681340389e-9),("[1,1,1,1,2,0,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,2,0,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,2,2]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,2,2]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,1,1]",2.964810681340389e-9),("[2,1,1,1,1,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,2,1,2,0,1,2]",2.964810681340389e-9),("[2,1,1,1,2,1,2,0,1,2]",2.964810681340389e-9),("[1,1,1,2,0,2,1,2,1,0]",2.964810681340389e-9),("[1,1,1,2,0,1,1,1,1,2]",2.964810681340389e-9),("[1,1,1,2,0,1,1,1,1,2]",2.964810681340389e-9),("[1,1,1,1,0,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,2,1,0]",2.964810681340389e-9),("[1,2,0,1,1,1,1,2,1,0]",2.964810681340389e-9),("[1,2,0,1,1,1,1,2,1,0]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,1,1,2,2]",2.964810681340389e-9),("[1,1,2,1,1,1,1,1,2,2]",2.964810681340389e-9),("[1,1,2,1,1,1,1,2,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,1,2,1,1]",2.964810681340389e-9),("[1,2,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,2,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,2,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,1,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,1,0,1,0,1,0,1,1,2]",2.964810681340389e-9),("[1,1,0,1,0,1,0,1,1,2]",2.964810681340389e-9),("[1,1,0,1,0,1,0,1,1,2]",2.964810681340389e-9),("[0,1,1,1,1,0,1,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,0,1,1,2,0]",2.964810681340389e-9),("[1,1,2,1,2,2,0,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,2,1,0,0,1,1]",2.964810681340389e-9),("[1,1,1,1,2,1,0,0,1,1]",2.964810681340389e-9),("[1,1,2,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,2,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,2,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,2,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,2,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,2,1,2,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,1,2,1,0]",2.964810681340389e-9),("[1,2,1,1,1,2,1,2,1,0]",2.964810681340389e-9),("[2,1,1,1,2,1,1,2,1,0]",2.964810681340389e-9),("[1,1,1,1,1,1,0,2,1,0]",2.964810681340389e-9),("[1,1,1,1,1,1,0,2,1,0]",2.964810681340389e-9),("[2,2,2,1,1,1,1,1,2,1]",2.964810681340389e-9)]
1 change: 1 addition & 0 deletions test/fixtures/LDA10-MH.txt

Large diffs are not rendered by default.

Loading

0 comments on commit 97fc400

Please sign in to comment.