From a5732955337bb092cfe2da1c82ed2378c9433ea7 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Sat, 6 Jan 2024 09:22:22 -0800 Subject: [PATCH] :sparkles: Add liftFresh --- CHANGELOG.md | 2 + src/Grisette/Core.hs | 4 ++ src/Grisette/Core/Data/Class/GenSym.hs | 71 ++++++++++++++------ test/Grisette/Core/Data/Class/GenSymTests.hs | 20 +++++- 4 files changed, 76 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd3c4b41..f63b8cdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Grisette.Data.Class.SignConversion.SignConversion` for types from `Data.Int` and `Data.Word`. ([#142](https://github.com/lsrcz/grisette/pull/142)) - Added shift functions by symbolic shift amounts. ([#151](https://github.com/lsrcz/grisette/pull/151)) - Added `apply` for uninterpreted functions. ([#155](https://github.com/lsrcz/grisette/pull/155)) +- Added `liftFresh` to lift a `Fresh` into `MonadFresh`. ([#156](https://github.com/lsrcz/grisette/pull/156)) ### Removed @@ -45,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - [Breaking] Moved `Grisette.Data.Class.Evaluate` to `Grisette.Data.Class.EvaluateSym`. ([#146](https://github.com/lsrcz/grisette/pull/146)) - [Breaking] Moved `Grisette.Data.Class.Substitute` to `Grisette.Data.Class.SubstituteSym`. ([#146](https://github.com/lsrcz/grisette/pull/146)) - [Breaking] Split the `Grisette.Data.Class.SafeArith` module to `Grisette.Data.Class.SafeDivision` and `Grisette.Data.Class.SafeLinearArith`. ([#146](https://github.com/lsrcz/grisette/pull/146)) +- [Breaking] Changed the API to `MonadFresh`. ([#156](https://github.com/lsrcz/grisette/pull/156)) ## [0.3.1.1] -- 2023-09-29 diff --git a/src/Grisette/Core.hs b/src/Grisette/Core.hs index ef5a04e6..6ccc9a05 100644 --- a/src/Grisette/Core.hs +++ b/src/Grisette/Core.hs @@ -711,6 +711,8 @@ module Grisette.Core -- ** Symbolic Generation Monad MonadFresh (..), + nextFreshIndex, + liftFresh, Fresh, FreshT (..), runFresh, @@ -1109,9 +1111,11 @@ import Grisette.Core.Data.Class.GenSym derivedSameShapeSimpleFresh, genSym, genSymSimple, + liftFresh, mrgRunFreshT, name, nameWithInfo, + nextFreshIndex, runFresh, runFreshT, ) diff --git a/src/Grisette/Core/Data/Class/GenSym.hs b/src/Grisette/Core/Data/Class/GenSym.hs index c4a5bf2e..d409ae69 100644 --- a/src/Grisette/Core/Data/Class/GenSym.hs +++ b/src/Grisette/Core/Data/Class/GenSym.hs @@ -36,6 +36,8 @@ module Grisette.Core.Data.Class.GenSym -- * Monad for fresh symbolic value generation MonadFresh (..), + nextFreshIndex, + liftFresh, FreshT (FreshT, runFreshTFromIndex), Fresh, runFreshT, @@ -83,7 +85,7 @@ import Control.Monad.RWS.Class ) import qualified Control.Monad.RWS.Lazy as RWSLazy import qualified Control.Monad.RWS.Strict as RWSStrict -import Control.Monad.Reader (ReaderT (ReaderT)) +import Control.Monad.Reader (ReaderT) import Control.Monad.Signatures (Catch) import qualified Control.Monad.State.Lazy as StateLazy import qualified Control.Monad.State.Strict as StateStrict @@ -279,12 +281,32 @@ nameWithInfo = FreshIdentWithInfo -- The monad should be a reader monad for the 'FreshIdent' and a state monad for -- the 'FreshIndex'. class (Monad m) => MonadFresh m where - -- | Increase the index by one and return the new index. - nextFreshIndex :: m FreshIndex + -- | Get the current index for fresh variable generation. + getFreshIndex :: m FreshIndex + + -- | Set the current index for fresh variable generation. + setFreshIndex :: FreshIndex -> m () -- | Get the identifier. getFreshIdent :: m FreshIdent +-- | Get the next fresh index and increase the current index. +nextFreshIndex :: (MonadFresh m) => m FreshIndex +nextFreshIndex = do + curr <- getFreshIndex + let new = curr + 1 + setFreshIndex new + return curr + +-- | Lifts an @`Fresh` a@ into any `MonadFresh`. +liftFresh :: (MonadFresh m) => Fresh a -> m a +liftFresh (FreshT f) = do + index <- nextFreshIndex + ident <- getFreshIdent + let (a, newIdx) = runIdentity $ f ident index + setFreshIndex newIdx + return a + -- | A symbolic generation monad transformer. -- It is a reader monad transformer for identifiers and -- a state monad transformer for indices. @@ -385,36 +407,44 @@ instance (MonadReader r m) => MonadReader r (FreshT m) where instance (MonadRWS r w s m) => MonadRWS r w s (FreshT m) instance (MonadFresh m) => MonadFresh (ExceptT e m) where - nextFreshIndex = ExceptT $ Right <$> nextFreshIndex - getFreshIdent = ExceptT $ Right <$> getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent instance (MonadFresh m, Monoid w) => MonadFresh (WriterLazy.WriterT w m) where - nextFreshIndex = WriterLazy.WriterT $ (,mempty) <$> nextFreshIndex - getFreshIdent = WriterLazy.WriterT $ (,mempty) <$> getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent instance (MonadFresh m, Monoid w) => MonadFresh (WriterStrict.WriterT w m) where - nextFreshIndex = WriterStrict.WriterT $ (,mempty) <$> nextFreshIndex - getFreshIdent = WriterStrict.WriterT $ (,mempty) <$> getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent instance (MonadFresh m) => MonadFresh (StateLazy.StateT s m) where - nextFreshIndex = StateLazy.StateT $ \s -> (,s) <$> nextFreshIndex - getFreshIdent = StateLazy.StateT $ \s -> (,s) <$> getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent instance (MonadFresh m) => MonadFresh (StateStrict.StateT s m) where - nextFreshIndex = StateStrict.StateT $ \s -> (,s) <$> nextFreshIndex - getFreshIdent = StateStrict.StateT $ \s -> (,s) <$> getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent instance (MonadFresh m) => MonadFresh (ReaderT r m) where - nextFreshIndex = ReaderT $ const nextFreshIndex - getFreshIdent = ReaderT $ const getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent instance (MonadFresh m, Monoid w) => MonadFresh (RWSLazy.RWST r w s m) where - nextFreshIndex = RWSLazy.RWST $ \_ s -> (,s,mempty) <$> nextFreshIndex - getFreshIdent = RWSLazy.RWST $ \_ s -> (,s,mempty) <$> getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent instance (MonadFresh m, Monoid w) => MonadFresh (RWSStrict.RWST r w s m) where - nextFreshIndex = RWSStrict.RWST $ \_ s -> (,s,mempty) <$> nextFreshIndex - getFreshIdent = RWSStrict.RWST $ \_ s -> (,s,mempty) <$> getFreshIdent + getFreshIndex = lift getFreshIndex + setFreshIndex newIdx = lift $ setFreshIndex newIdx + getFreshIdent = lift getFreshIdent -- | 'FreshT' specialized with Identity. type Fresh = FreshT Identity @@ -424,7 +454,8 @@ runFresh :: Fresh a -> FreshIdent -> a runFresh m ident = runIdentity $ runFreshT m ident instance (Monad m) => MonadFresh (FreshT m) where - nextFreshIndex = FreshT $ \_ idx -> return (idx, idx + 1) + getFreshIndex = FreshT $ \_ idx -> return (idx, idx) + setFreshIndex newIdx = FreshT $ \_ _ -> return ((), newIdx) getFreshIdent = FreshT $ curry return -- | Class of types in which symbolic values can be generated with respect to some specification. diff --git a/test/Grisette/Core/Data/Class/GenSymTests.hs b/test/Grisette/Core/Data/Class/GenSymTests.hs index 22880299..2e10aeb1 100644 --- a/test/Grisette/Core/Data/Class/GenSymTests.hs +++ b/test/Grisette/Core/Data/Class/GenSymTests.hs @@ -9,6 +9,9 @@ import Grisette.Core.Control.Monad.UnionM (UnionM) import Grisette.Core.Data.Class.GenSym ( EnumGenBound (EnumGenBound), EnumGenUpperBound (EnumGenUpperBound), + Fresh, + FreshT, + GenSymSimple (simpleFresh), ListSpec (ListSpec), SimpleListSpec (SimpleListSpec), choose, @@ -19,7 +22,9 @@ import Grisette.Core.Data.Class.GenSym chooseUnionFresh, genSym, genSymSimple, + liftFresh, runFresh, + runFreshT, ) import Grisette.Core.Data.Class.ITEOp (ITEOp (ites)) import Grisette.Core.Data.Class.SimpleMergeable @@ -1242,6 +1247,19 @@ genSymTests = (isymBool "a" 1) (mrgIf (ssymBool "x") 2 3) (mrgIf (ssymBool "x") 3 4) - ) + ), + testCase "liftFresh" $ do + let orig = simpleFresh () :: Fresh (SymBool, SymBool) + let actual = flip runFreshT "a" $ do + r1 <- liftFresh orig + r2 <- liftFresh orig + return (r1, r2) :: + FreshT UnionM ((SymBool, SymBool), (SymBool, SymBool)) + let expected = + return + ( (isymBool "a" 0, isymBool "a" 1), + (isymBool "a" 2, isymBool "a" 3) + ) + actual @?= expected ] ]