Skip to content

Commit

Permalink
A benchmark of random-fu (replacing mwc-random)
Browse files Browse the repository at this point in the history
  • Loading branch information
idontgetoutmuch committed Oct 14, 2023
1 parent 19be760 commit 60a4e3d
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 33 deletions.
32 changes: 22 additions & 10 deletions benchmark/Speed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

module Main (main) where

import Control.Monad.Bayes.Class (MonadMeasure)
import Control.Monad (replicateM)
import Control.Monad.Bayes.Class (MonadMeasure, normal)
import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (MCMCConfig, numBurnIn, numMCMCSteps, proposal), Proposal (SingleSiteMH))
import Control.Monad.Bayes.Inference.RMSMC (rmsmcDynamic)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smc)
Expand Down Expand Up @@ -136,12 +137,21 @@ samplesBenchmarks lrData hmmData ldaData = benchmarks
m <- models
return (s, m, a)

normalBenchmarks :: [Benchmark]
normalBenchmarks = [ bench "Normal single sample monad bayes" $ nfIO $ do
sampleIOfixed (do xs <- replicateM 1000 $ normal 0.0 1.0
return $ sum xs)
]

speedLengthCSV :: FilePath
speedLengthCSV = "speed-length.csv"

speedSamplesCSV :: FilePath
speedSamplesCSV = "speed-samples.csv"

normalSamplesCSV :: FilePath
normalSamplesCSV = "normal-samples.csv"

rawDAT :: FilePath
rawDAT = "raw.dat"

Expand All @@ -158,12 +168,14 @@ removeIfExists file = do

main :: IO ()
main = do
cleanupLastRun
lrData <- sampleIOfixed (LogReg.syntheticData 1000)
hmmData <- sampleIOfixed (HMM.syntheticData 1000)
ldaData <- sampleIOfixed (LDA.syntheticData 5 1000)
let configLength = defaultConfig {csvFile = Just speedLengthCSV, rawDataFile = Just rawDAT}
defaultMainWith configLength (lengthBenchmarks lrData hmmData ldaData)
let configSamples = defaultConfig {csvFile = Just speedSamplesCSV, rawDataFile = Just rawDAT}
defaultMainWith configSamples (samplesBenchmarks lrData hmmData ldaData)
void $ runProcess "python plots.py"
-- cleanupLastRun
-- lrData <- sampleIOfixed (LogReg.syntheticData 1000)
-- hmmData <- sampleIOfixed (HMM.syntheticData 1000)
-- ldaData <- sampleIOfixed (LDA.syntheticData 5 1000)
-- let configLength = defaultConfig {csvFile = Just speedLengthCSV, rawDataFile = Just rawDAT}
-- defaultMainWith configLength (lengthBenchmarks lrData hmmData ldaData)
-- let configSamples = defaultConfig {csvFile = Just speedSamplesCSV, rawDataFile = Just rawDAT}
-- defaultMainWith configSamples (samplesBenchmarks lrData hmmData ldaData)
let configNormal = defaultConfig {csvFile = Just normalSamplesCSV, rawDataFile = Just rawDAT}
defaultMainWith configNormal normalBenchmarks
-- void $ runProcess "python plots.py"
1 change: 1 addition & 0 deletions monad-bayes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ common deps
, pretty-simple ^>=4.1
, primitive >=0.7 && <0.9
, random ^>=1.2
, random-fu
, safe ^>=0.3.17
, scientific ^>=0.3
, statistics >=0.14.0 && <0.17
Expand Down
81 changes: 67 additions & 14 deletions shell.nix
Original file line number Diff line number Diff line change
@@ -1,14 +1,67 @@
(
import
(
let
lock = builtins.fromJSON (builtins.readFile ./flake.lock);
in
fetchTarball {
url = "https://github.com/edolstra/flake-compat/archive/${lock.nodes.flake-compat.locked.rev}.tar.gz";
sha256 = lock.nodes.flake-compat.locked.narHash;
}
)
{src = ./.;}
)
.shellNix
{ nixpkgs ? import <nixpkgs> {}, compiler ? "default", doBenchmark ? false }:

let

inherit (nixpkgs) pkgs;

f = { mkDerivation, abstract-par, base, brick, containers
, criterion, directory, foldl, free, histogram-fill, hspec, ieee754
, integration, lens, lib, linear, log-domain, math-functions
, matrix, monad-coroutine, monad-extras, mtl, mwc-random
, optparse-applicative, pipes, pretty-simple, primitive, process
, QuickCheck, random, random-fu, safe, scientific, statistics, text, time
, transformers, typed-process, vector, vty
}:
mkDerivation {
pname = "monad-bayes";
version = "1.2.0";
src = ./.;
isLibrary = true;
isExecutable = true;
libraryHaskellDepends = [
base brick containers foldl free histogram-fill ieee754 integration
lens linear log-domain math-functions matrix monad-coroutine
monad-extras mtl mwc-random pipes pretty-simple primitive random random-fu
safe scientific statistics text vector vty
];
executableHaskellDepends = [
abstract-par base brick containers criterion directory foldl free
histogram-fill hspec ieee754 integration lens linear log-domain
math-functions matrix monad-coroutine monad-extras mtl mwc-random
optparse-applicative pipes pretty-simple primitive process
QuickCheck random random-fu safe scientific statistics text time transformers
typed-process vector vty
];
testHaskellDepends = [
abstract-par base brick containers criterion directory foldl free
histogram-fill hspec ieee754 integration lens linear log-domain
math-functions matrix monad-coroutine monad-extras mtl mwc-random
optparse-applicative pipes pretty-simple primitive process
QuickCheck random random-fu safe scientific statistics text time transformers
typed-process vector vty
];
benchmarkHaskellDepends = [
abstract-par base brick containers criterion directory foldl free
histogram-fill hspec ieee754 integration lens linear log-domain
math-functions matrix monad-coroutine monad-extras mtl mwc-random
optparse-applicative pipes pretty-simple primitive process
QuickCheck random random-fu safe scientific statistics text time transformers
typed-process vector vty
];
homepage = "http://github.com/tweag/monad-bayes#readme";
description = "A library for probabilistic programming";
license = lib.licenses.mit;
mainProgram = "example";
};

haskellPackages = if compiler == "default"
then pkgs.haskellPackages
else pkgs.haskell.packages.${compiler};

variant = if doBenchmark then pkgs.haskell.lib.doBenchmark else pkgs.lib.id;

drv = variant (haskellPackages.callPackage f {});

in

if pkgs.lib.inNixShell then drv.env else drv
29 changes: 20 additions & 9 deletions src/Control/Monad/Bayes/Sampler/Strict.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,19 @@ import Control.Monad.Reader (MonadIO, ReaderT (..))
import Control.Monad.ST (ST)
import Numeric.Log (Log (ln))
import System.Random.MWC.Distributions qualified as MWC
import Data.Random qualified as RF
import Data.Random.Distribution qualified as RF
import Data.Random.Distribution.Normal qualified as RF
import Data.Random.Distribution.Gamma qualified as RF
import Data.Random.Distribution.Beta qualified as RF
import Data.Random.Distribution.Bernoulli qualified as RF
import Data.Random.Distribution.Uniform as RF
import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM, uniformDouble01M, uniformRM)

import Control.Monad.State
import System.Random.Stateful
import Control.Monad.Reader.Class

-- | The sampling interpretation of a probabilistic program
-- Here m is typically IO or ST
newtype SamplerT g m a = SamplerT {runSamplerT :: ReaderT g m a} deriving (Functor, Applicative, Monad, MonadIO)
Expand All @@ -58,17 +69,17 @@ type SamplerIO = SamplerT (IOGenM StdGen) IO
-- to particular pairs of monad and RNG
type SamplerST s = SamplerT (STGenM StdGen s) (ST s)

instance (StatefulGen g m) => MonadDistribution (SamplerT g m) where
random = SamplerT (ReaderT uniformDouble01M)
instance StatefulGen g m => MonadDistribution (SamplerT g m) where
random = SamplerT (ReaderT $ RF.runRVar $ RF.stdUniform)

uniform a b = SamplerT (ReaderT $ uniformRM (a, b))
normal m s = SamplerT (ReaderT (MWC.normal m s))
gamma shape scale = SamplerT (ReaderT $ MWC.gamma shape scale)
beta a b = SamplerT (ReaderT $ MWC.beta a b)
uniform a b = SamplerT (ReaderT $ RF.runRVar $ RF.doubleUniform a b)
normal m s = SamplerT (ReaderT $ RF.runRVar $ RF.normal m s)
gamma shape scale = SamplerT (ReaderT $ RF.runRVar $ RF.gamma shape scale)
beta a b = SamplerT (ReaderT $ RF.runRVar $ RF.beta a b)

bernoulli p = SamplerT (ReaderT $ MWC.bernoulli p)
categorical ps = SamplerT (ReaderT $ MWC.categorical ps)
geometric p = SamplerT (ReaderT $ MWC.geometric0 p)
bernoulli p = SamplerT (ReaderT $ RF.runRVar $ RF.bernoulli p)
-- categorical ps = error "categorical"
-- geometric p = error "geometric"

-- | Sample with a random number generator of your choice e.g. the one
-- from `System.Random`.
Expand Down

0 comments on commit 60a4e3d

Please sign in to comment.