diff --git a/benchmark/Speed.hs b/benchmark/Speed.hs index 0d72808c..ac9fa0d5 100644 --- a/benchmark/Speed.hs +++ b/benchmark/Speed.hs @@ -4,7 +4,8 @@ module Main (main) where -import Control.Monad.Bayes.Class (MonadMeasure) +import Control.Monad (replicateM) +import Control.Monad.Bayes.Class (MonadMeasure, normal) import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (MCMCConfig, numBurnIn, numMCMCSteps, proposal), Proposal (SingleSiteMH)) import Control.Monad.Bayes.Inference.RMSMC (rmsmcDynamic) import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smc) @@ -136,12 +137,21 @@ samplesBenchmarks lrData hmmData ldaData = benchmarks m <- models return (s, m, a) +normalBenchmarks :: [Benchmark] +normalBenchmarks = [ bench "Normal single sample monad bayes" $ nfIO $ do + sampleIOfixed (do xs <- replicateM 1000 $ normal 0.0 1.0 + return $ sum xs) + ] + speedLengthCSV :: FilePath speedLengthCSV = "speed-length.csv" speedSamplesCSV :: FilePath speedSamplesCSV = "speed-samples.csv" +normalSamplesCSV :: FilePath +normalSamplesCSV = "normal-samples.csv" + rawDAT :: FilePath rawDAT = "raw.dat" @@ -158,12 +168,14 @@ removeIfExists file = do main :: IO () main = do - cleanupLastRun - lrData <- sampleIOfixed (LogReg.syntheticData 1000) - hmmData <- sampleIOfixed (HMM.syntheticData 1000) - ldaData <- sampleIOfixed (LDA.syntheticData 5 1000) - let configLength = defaultConfig {csvFile = Just speedLengthCSV, rawDataFile = Just rawDAT} - defaultMainWith configLength (lengthBenchmarks lrData hmmData ldaData) - let configSamples = defaultConfig {csvFile = Just speedSamplesCSV, rawDataFile = Just rawDAT} - defaultMainWith configSamples (samplesBenchmarks lrData hmmData ldaData) - void $ runProcess "python plots.py" + -- cleanupLastRun + -- lrData <- sampleIOfixed (LogReg.syntheticData 1000) + -- hmmData <- sampleIOfixed (HMM.syntheticData 1000) + -- ldaData <- sampleIOfixed (LDA.syntheticData 5 1000) + -- let configLength = defaultConfig {csvFile = Just speedLengthCSV, rawDataFile = Just rawDAT} + -- defaultMainWith configLength (lengthBenchmarks lrData hmmData ldaData) + -- let configSamples = defaultConfig {csvFile = Just speedSamplesCSV, rawDataFile = Just rawDAT} + -- defaultMainWith configSamples (samplesBenchmarks lrData hmmData ldaData) + let configNormal = defaultConfig {csvFile = Just normalSamplesCSV, rawDataFile = Just rawDAT} + defaultMainWith configNormal normalBenchmarks + -- void $ runProcess "python plots.py" diff --git a/monad-bayes.cabal b/monad-bayes.cabal index c576c005..4fed4d5c 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -57,6 +57,7 @@ common deps , pretty-simple ^>=4.1 , primitive >=0.7 && <0.9 , random ^>=1.2 + , random-fu , safe ^>=0.3.17 , scientific ^>=0.3 , statistics >=0.14.0 && <0.17 diff --git a/shell.nix b/shell.nix index e6d91731..c19e7938 100644 --- a/shell.nix +++ b/shell.nix @@ -1,14 +1,67 @@ -( - import - ( - let - lock = builtins.fromJSON (builtins.readFile ./flake.lock); - in - fetchTarball { - url = "https://github.com/edolstra/flake-compat/archive/${lock.nodes.flake-compat.locked.rev}.tar.gz"; - sha256 = lock.nodes.flake-compat.locked.narHash; - } - ) - {src = ./.;} -) -.shellNix +{ nixpkgs ? import {}, compiler ? "default", doBenchmark ? false }: + +let + + inherit (nixpkgs) pkgs; + + f = { mkDerivation, abstract-par, base, brick, containers + , criterion, directory, foldl, free, histogram-fill, hspec, ieee754 + , integration, lens, lib, linear, log-domain, math-functions + , matrix, monad-coroutine, monad-extras, mtl, mwc-random + , optparse-applicative, pipes, pretty-simple, primitive, process + , QuickCheck, random, random-fu, safe, scientific, statistics, text, time + , transformers, typed-process, vector, vty + }: + mkDerivation { + pname = "monad-bayes"; + version = "1.2.0"; + src = ./.; + isLibrary = true; + isExecutable = true; + libraryHaskellDepends = [ + base brick containers foldl free histogram-fill ieee754 integration + lens linear log-domain math-functions matrix monad-coroutine + monad-extras mtl mwc-random pipes pretty-simple primitive random random-fu + safe scientific statistics text vector vty + ]; + executableHaskellDepends = [ + abstract-par base brick containers criterion directory foldl free + histogram-fill hspec ieee754 integration lens linear log-domain + math-functions matrix monad-coroutine monad-extras mtl mwc-random + optparse-applicative pipes pretty-simple primitive process + QuickCheck random random-fu safe scientific statistics text time transformers + typed-process vector vty + ]; + testHaskellDepends = [ + abstract-par base brick containers criterion directory foldl free + histogram-fill hspec ieee754 integration lens linear log-domain + math-functions matrix monad-coroutine monad-extras mtl mwc-random + optparse-applicative pipes pretty-simple primitive process + QuickCheck random random-fu safe scientific statistics text time transformers + typed-process vector vty + ]; + benchmarkHaskellDepends = [ + abstract-par base brick containers criterion directory foldl free + histogram-fill hspec ieee754 integration lens linear log-domain + math-functions matrix monad-coroutine monad-extras mtl mwc-random + optparse-applicative pipes pretty-simple primitive process + QuickCheck random random-fu safe scientific statistics text time transformers + typed-process vector vty + ]; + homepage = "http://github.com/tweag/monad-bayes#readme"; + description = "A library for probabilistic programming"; + license = lib.licenses.mit; + mainProgram = "example"; + }; + + haskellPackages = if compiler == "default" + then pkgs.haskellPackages + else pkgs.haskell.packages.${compiler}; + + variant = if doBenchmark then pkgs.haskell.lib.doBenchmark else pkgs.lib.id; + + drv = variant (haskellPackages.callPackage f {}); + +in + + if pkgs.lib.inNixShell then drv.env else drv diff --git a/src/Control/Monad/Bayes/Sampler/Strict.hs b/src/Control/Monad/Bayes/Sampler/Strict.hs index f5f9b645..61dd2457 100644 --- a/src/Control/Monad/Bayes/Sampler/Strict.hs +++ b/src/Control/Monad/Bayes/Sampler/Strict.hs @@ -44,8 +44,19 @@ import Control.Monad.Reader (MonadIO, ReaderT (..)) import Control.Monad.ST (ST) import Numeric.Log (Log (ln)) import System.Random.MWC.Distributions qualified as MWC +import Data.Random qualified as RF +import Data.Random.Distribution qualified as RF +import Data.Random.Distribution.Normal qualified as RF +import Data.Random.Distribution.Gamma qualified as RF +import Data.Random.Distribution.Beta qualified as RF +import Data.Random.Distribution.Bernoulli qualified as RF +import Data.Random.Distribution.Uniform as RF import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM, uniformDouble01M, uniformRM) +import Control.Monad.State +import System.Random.Stateful +import Control.Monad.Reader.Class + -- | The sampling interpretation of a probabilistic program -- Here m is typically IO or ST newtype SamplerT g m a = SamplerT {runSamplerT :: ReaderT g m a} deriving (Functor, Applicative, Monad, MonadIO) @@ -58,17 +69,17 @@ type SamplerIO = SamplerT (IOGenM StdGen) IO -- to particular pairs of monad and RNG type SamplerST s = SamplerT (STGenM StdGen s) (ST s) -instance (StatefulGen g m) => MonadDistribution (SamplerT g m) where - random = SamplerT (ReaderT uniformDouble01M) +instance StatefulGen g m => MonadDistribution (SamplerT g m) where + random = SamplerT (ReaderT $ RF.runRVar $ RF.stdUniform) - uniform a b = SamplerT (ReaderT $ uniformRM (a, b)) - normal m s = SamplerT (ReaderT (MWC.normal m s)) - gamma shape scale = SamplerT (ReaderT $ MWC.gamma shape scale) - beta a b = SamplerT (ReaderT $ MWC.beta a b) + uniform a b = SamplerT (ReaderT $ RF.runRVar $ RF.doubleUniform a b) + normal m s = SamplerT (ReaderT $ RF.runRVar $ RF.normal m s) + gamma shape scale = SamplerT (ReaderT $ RF.runRVar $ RF.gamma shape scale) + beta a b = SamplerT (ReaderT $ RF.runRVar $ RF.beta a b) - bernoulli p = SamplerT (ReaderT $ MWC.bernoulli p) - categorical ps = SamplerT (ReaderT $ MWC.categorical ps) - geometric p = SamplerT (ReaderT $ MWC.geometric0 p) + bernoulli p = SamplerT (ReaderT $ RF.runRVar $ RF.bernoulli p) + -- categorical ps = error "categorical" + -- geometric p = error "geometric" -- | Sample with a random number generator of your choice e.g. the one -- from `System.Random`.