diff --git a/benchmark/Speed.hs b/benchmark/Speed.hs index 0d72808c..b767acef 100644 --- a/benchmark/Speed.hs +++ b/benchmark/Speed.hs @@ -22,6 +22,7 @@ import Criterion.Main ) import Criterion.Types (Config (csvFile, rawDataFile)) import Data.Functor (void) +import Data.Maybe (listToMaybe) import Data.Text qualified as T import HMM qualified import LDA qualified @@ -38,7 +39,7 @@ data Model = LR [(Double, Bool)] | HMM [Double] | LDA [[T.Text]] instance Show Model where show (LR xs) = "LR" ++ show (length xs) show (HMM xs) = "HMM" ++ show (length xs) - show (LDA xs) = "LDA" ++ show (length $ head xs) + show (LDA xs) = "LDA" ++ show (maybe 0 length $ listToMaybe xs) buildModel :: (MonadMeasure m) => Model -> m String buildModel (LR dataset) = show <$> LogReg.logisticRegression dataset diff --git a/flake.nix b/flake.nix index 838884af..bc6b150a 100644 --- a/flake.nix +++ b/flake.nix @@ -85,6 +85,7 @@ "ghc927" "ghc945" "ghc964" + "ghc982" ]; buildForVersion = ghcVersion: (builtins.getAttr ghcVersion pkgs.haskell.packages).developPackage opts; in lib.attrsets.genAttrs ghcs buildForVersion; diff --git a/models/LDA.hs b/models/LDA.hs index 29a9eac3..1589e559 100644 --- a/models/LDA.hs +++ b/models/LDA.hs @@ -81,4 +81,4 @@ syntheticData d w = List.replicateM d (List.replicateM w syntheticWord) runLDA :: IO () runLDA = do s <- sampleIOfixed $ unweighted $ mh 1000 $ lda documents - pPrint (head s) + pPrint $ take 1 s diff --git a/monad-bayes.cabal b/monad-bayes.cabal index 6907cac8..48d6a179 100644 --- a/monad-bayes.cabal +++ b/monad-bayes.cabal @@ -7,7 +7,7 @@ copyright: 2015-2020 Adam Scibior maintainer: dominic.steinitz@tweag.io author: Adam Scibior stability: experimental -tested-with: GHC ==9.0.2 || ==9.2.7 || ==9.4.5 || ==9.6.4 +tested-with: GHC ==9.0.2 || ==9.2.7 || ==9.4.5 || ==9.6.4 || ==9.8.2 homepage: http://github.com/tweag/monad-bayes#readme bug-reports: https://github.com/tweag/monad-bayes/issues synopsis: A library for probabilistic programming. @@ -38,7 +38,7 @@ flag dev common deps build-depends: - , base >=4.15 && <4.19 + , base >=4.15 && <4.20 , brick ^>=2.3.1 , containers >=0.5.10 && <0.7 , foldl ^>=1.4 @@ -62,7 +62,7 @@ common deps , safe ^>=0.3.17 , scientific ^>=0.3 , statistics >=0.14.0 && <0.17 - , text >=1.2 && <2.1 + , text >=1.2 && <2.2 , transformers >=0.5.6 && <0.7 , vector >=0.12.0 && <0.14 , vty ^>=6.1 diff --git a/src/Control/Monad/Bayes/Inference/Lazy/WIS.hs b/src/Control/Monad/Bayes/Inference/Lazy/WIS.hs index 2848f61e..5b494528 100644 --- a/src/Control/Monad/Bayes/Inference/Lazy/WIS.hs +++ b/src/Control/Monad/Bayes/Inference/Lazy/WIS.hs @@ -1,7 +1,9 @@ module Control.Monad.Bayes.Inference.Lazy.WIS where +import Control.Monad (guard) import Control.Monad.Bayes.Sampler.Lazy (SamplerT, weightedSamples) import Control.Monad.Bayes.Weighted (WeightedT) +import Data.Maybe (mapMaybe) import Numeric.Log (Log (Exp)) import System.Random (Random (randoms), getStdGen, newStdGen) @@ -16,7 +18,7 @@ lwis n m = do let max' = snd $ last xws' _ <- newStdGen rs <- randoms <$> getStdGen - return $ fmap (\r -> fst $ head $ filter ((>= Exp (log r) * max') . snd) xws') rs + return $ take 1 =<< fmap (\r -> mapMaybe (\(a, p) -> guard (p >= Exp (log r) * max') >> Just a) xws') rs where accumulate :: (Num t) => [(a, t)] -> t -> [(a, t)] accumulate ((x, w) : xws) a = (x, w + a) : (x, w + a) : accumulate xws (w + a) diff --git a/src/Control/Monad/Bayes/Inference/TUI.hs b/src/Control/Monad/Bayes/Inference/TUI.hs index 97a4aa54..058dea77 100644 --- a/src/Control/Monad/Bayes/Inference/TUI.hs +++ b/src/Control/Monad/Bayes/Inference/TUI.hs @@ -21,6 +21,7 @@ import Control.Monad.Bayes.Sampler.Strict (SamplerIO, sampleIO) import Control.Monad.Bayes.Traced (TracedT) import Control.Monad.Bayes.Traced.Common hiding (burnIn) import Control.Monad.Bayes.Weighted +import Data.Maybe (listToMaybe) import Data.Scientific (FPFormat (Exponent), formatScientific, fromFloatDigits) import Data.Text qualified as T import Data.Text.Lazy qualified as TL @@ -70,7 +71,7 @@ drawUI handleSamples state = [ui] ] ) $ B.progressBar - (Just $ "Mean likelihood for last 1000 samples: " <> take 10 (show (head $ lk state <> [0]))) + (Just $ "Mean likelihood for last 1000 samples: " <> take 10 (maybe "(error)" show (listToMaybe $ lk state <> [0]))) (double2Float (Fold.fold Fold.mean $ take 1000 $ lk state) / double2Float (maximum $ 0 : lk state)) displayStep c = Just $ "Step " <> show c @@ -108,7 +109,7 @@ showEmpirical = . toEmpirical showVal :: (Show a) => [a] -> Widget n -showVal = txt . T.pack . (\case [] -> ""; a -> show $ head a) +showVal = txt . T.pack . (\case [] -> ""; a -> maybe "(error)" show $ listToMaybe a) -- | handler for events received by the TUI appEvent :: B.BrickEvent n s -> B.EventM n s ()