From 034c3e8cb4cf30d25d56122e0ecb72865afa6644 Mon Sep 17 00:00:00 2001 From: Sirui Lu Date: Wed, 8 Nov 2023 01:28:12 -0800 Subject: [PATCH] :sparkles: add SafeSymRotate class --- grisette.cabal | 2 + src/Grisette/Core/Data/Class/SafeSymRotate.hs | 110 +++++++ .../Core/Data/Class/SafeSymRotateTests.hs | 301 ++++++++++++++++++ test/Main.hs | 2 + 4 files changed, 415 insertions(+) create mode 100644 src/Grisette/Core/Data/Class/SafeSymRotate.hs create mode 100644 test/Grisette/Core/Data/Class/SafeSymRotateTests.hs diff --git a/grisette.cabal b/grisette.cabal index 1e9cf443..22a85974 100644 --- a/grisette.cabal +++ b/grisette.cabal @@ -72,6 +72,7 @@ library Grisette.Core.Data.Class.ModelOps Grisette.Core.Data.Class.SafeDivision Grisette.Core.Data.Class.SafeLinearArith + Grisette.Core.Data.Class.SafeSymRotate Grisette.Core.Data.Class.SafeSymShift Grisette.Core.Data.Class.SEq Grisette.Core.Data.Class.SignConversion @@ -219,6 +220,7 @@ test-suite spec Grisette.Core.Data.Class.GenSymTests Grisette.Core.Data.Class.GPrettyTests Grisette.Core.Data.Class.MergeableTests + Grisette.Core.Data.Class.SafeSymRotateTests Grisette.Core.Data.Class.SafeSymShiftTests Grisette.Core.Data.Class.SEqTests Grisette.Core.Data.Class.SimpleMergeableTests diff --git a/src/Grisette/Core/Data/Class/SafeSymRotate.hs b/src/Grisette/Core/Data/Class/SafeSymRotate.hs new file mode 100644 index 00000000..c0bb9e8d --- /dev/null +++ b/src/Grisette/Core/Data/Class/SafeSymRotate.hs @@ -0,0 +1,110 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module Grisette.Core.Data.Class.SafeSymRotate (SafeSymRotate (..)) where + +import Control.Exception (ArithException (Overflow)) +import Control.Monad.Error.Class (MonadError) +import Data.Bits (Bits (rotateL, rotateR), FiniteBits (finiteBitSize)) +import Data.Int (Int16, Int32, Int64, Int8) +import Data.Word (Word16, Word32, Word64, Word8) +import GHC.TypeLits (KnownNat, type (<=)) +import Grisette.Core.Control.Monad.Union (MonadUnion) +import Grisette.Core.Data.BV (IntN, WordN) +import Grisette.Core.Data.Class.Mergeable (Mergeable) +import Grisette.Core.Data.Class.SOrd (SOrd ((<~))) +import Grisette.Core.Data.Class.SimpleMergeable (UnionLike, mrgIf) +import Grisette.Core.Data.Class.SymRotate (SymRotate) +import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits + ( pevalRotateLeftTerm, + pevalRotateRightTerm, + ) +import Grisette.IR.SymPrim.Data.SymPrim + ( SymIntN (SymIntN), + SymWordN (SymWordN), + ) +import Grisette.Lib.Control.Monad (mrgReturn) +import Grisette.Lib.Control.Monad.Except (mrgThrowError) + +class (SymRotate a) => SafeSymRotate e a | a -> e where + safeSymRotateL :: (MonadError e m, UnionLike m) => a -> a -> m a + safeSymRotateL = safeSymRotateL' id + safeSymRotateR :: (MonadError e m, UnionLike m) => a -> a -> m a + safeSymRotateR = safeSymRotateR' id + safeSymRotateL' :: + (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a + safeSymRotateR' :: + (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a + {-# MINIMAL safeSymRotateL', safeSymRotateR' #-} + +-- | This function handles the case when the shift amount is out the range of +-- `Int` correctly. +safeSymRotateLConcreteNum :: + (MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) => + e -> + a -> + a -> + m a +safeSymRotateLConcreteNum e _ s | s < 0 = mrgThrowError e +safeSymRotateLConcreteNum _ a s = + mrgReturn $ rotateL a (fromIntegral $ s `rem` fromIntegral (finiteBitSize s)) + +-- | This function handles the case when the shift amount is out the range of +-- `Int` correctly. +safeSymRotateRConcreteNum :: + (MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) => + e -> + a -> + a -> + m a +safeSymRotateRConcreteNum e _ s | s < 0 = mrgThrowError e +safeSymRotateRConcreteNum _ a s = + mrgReturn $ rotateR a (fromIntegral $ s `rem` fromIntegral (finiteBitSize s)) + +#define SAFE_SYM_ROTATE_CONCRETE(T) \ + instance SafeSymRotate ArithException T where \ + safeSymRotateL' f = safeSymRotateLConcreteNum (f Overflow); \ + safeSymRotateR' f = safeSymRotateRConcreteNum (f Overflow) \ + +#if 1 +SAFE_SYM_ROTATE_CONCRETE(Word8) +SAFE_SYM_ROTATE_CONCRETE(Word16) +SAFE_SYM_ROTATE_CONCRETE(Word32) +SAFE_SYM_ROTATE_CONCRETE(Word64) +SAFE_SYM_ROTATE_CONCRETE(Word) +SAFE_SYM_ROTATE_CONCRETE(Int8) +SAFE_SYM_ROTATE_CONCRETE(Int16) +SAFE_SYM_ROTATE_CONCRETE(Int32) +SAFE_SYM_ROTATE_CONCRETE(Int64) +SAFE_SYM_ROTATE_CONCRETE(Int) +#endif + +instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (WordN n) where + safeSymRotateL' f = safeSymRotateLConcreteNum (f Overflow) + safeSymRotateR' f = safeSymRotateRConcreteNum (f Overflow) + +instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (IntN n) where + safeSymRotateL' f = safeSymRotateLConcreteNum (f Overflow) + safeSymRotateR' f = safeSymRotateRConcreteNum (f Overflow) + +instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (SymWordN n) where + safeSymRotateL' _ (SymWordN ta) (SymWordN tr) = + mrgReturn $ SymWordN $ pevalRotateLeftTerm ta tr + safeSymRotateR' _ (SymWordN ta) (SymWordN tr) = + mrgReturn $ SymWordN $ pevalRotateRightTerm ta tr + +instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (SymIntN n) where + safeSymRotateL' f (SymIntN ta) r@(SymIntN tr) = + mrgIf + (r <~ 0) + (mrgThrowError $ f Overflow) + (mrgReturn $ SymIntN $ pevalRotateLeftTerm ta tr) + safeSymRotateR' f (SymIntN ta) r@(SymIntN tr) = + mrgIf + (r <~ 0) + (mrgThrowError $ f Overflow) + (mrgReturn $ SymIntN $ pevalRotateRightTerm ta tr) diff --git a/test/Grisette/Core/Data/Class/SafeSymRotateTests.hs b/test/Grisette/Core/Data/Class/SafeSymRotateTests.hs new file mode 100644 index 00000000..8d2488fd --- /dev/null +++ b/test/Grisette/Core/Data/Class/SafeSymRotateTests.hs @@ -0,0 +1,301 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Grisette.Core.Data.Class.SafeSymRotateTests + ( safeSymRotateTests, + ) +where + +import Control.Exception (ArithException (Overflow)) +import Control.Monad.Except (ExceptT) +import Data.Bits (Bits (rotateL, rotateR), FiniteBits (finiteBitSize)) +import Data.Int (Int16, Int32, Int64, Int8) +import Data.Typeable (Proxy (Proxy), Typeable) +import Data.Word (Word16, Word32, Word64, Word8) +import Grisette.Core.Control.Monad.UnionM (UnionM) +import Grisette.Core.Data.BV (IntN, WordN) +import Grisette.Core.Data.Class.Mergeable (Mergeable) +import Grisette.Core.Data.Class.SafeSymRotate + ( SafeSymRotate (safeSymRotateL, safeSymRotateR), + ) +import Grisette.Core.Data.Class.Solvable (Solvable (con)) +import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term (LinkedRep) +import Grisette.IR.SymPrim.Data.SymPrim (SymIntN, SymWordN) +import Grisette.Lib.Control.Monad (mrgReturn) +import Grisette.Lib.Control.Monad.Except (mrgThrowError) +import Test.Framework (Test, testGroup) +import Test.Framework.Providers.HUnit (testCase) +import Test.Framework.Providers.QuickCheck2 (testProperty) +import Test.HUnit ((@?=)) +import Test.QuickCheck (Arbitrary, ioProperty) +import Test.QuickCheck.Gen (chooseInt) +import Test.QuickCheck.Property (forAll) + +type EM a = ExceptT ArithException UnionM a + +overflowError :: (Mergeable a) => EM a +overflowError = mrgThrowError Overflow + +concreteTypeSafeSymRotateTests :: + forall proxy a. + ( Arbitrary a, + Show a, + Num a, + Eq a, + SafeSymRotate ArithException a, + FiniteBits a, + Bounded a, + Typeable a, + Integral a, + Mergeable a + ) => + proxy a -> + [Test] +concreteTypeSafeSymRotateTests _ = + [ testProperty "In bound" $ \(x :: a) -> do + let b = fromIntegral (maxBound :: a) :: Integer + let bs = 2 * fromIntegral (finiteBitSize x) :: Integer + let maxRotateAmount = fromIntegral (min b bs) + forAll (chooseInt (0, maxRotateAmount)) $ + \(s :: Int) -> + ioProperty $ do + let rotateAmount = fromIntegral s + let rotateLExpected = mrgReturn (rotateL x s) :: EM a + let rotateRExpected = mrgReturn (rotateR x s) :: EM a + safeSymRotateL x rotateAmount @?= rotateLExpected + safeSymRotateR x rotateAmount @?= rotateRExpected + ] + +concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests :: + forall proxy a. + ( Arbitrary a, + Show a, + Num a, + Eq a, + SafeSymRotate ArithException a, + FiniteBits a, + Bounded a, + Typeable a, + Integral a, + Mergeable a + ) => + proxy a -> + [Test] +concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests p = + testCase + "Min bound" + ( do + let x = -1 :: a + let rotateAmount = minBound :: a + safeSymRotateL x rotateAmount @?= overflowError + safeSymRotateR x rotateAmount @?= overflowError + ) + : concreteTypeSafeSymRotateTests p + +concreteUnsignedSymTypeSafeSymRotateTests :: + forall proxy c s. + ( Arbitrary c, + Show s, + Num s, + Eq s, + SafeSymRotate ArithException s, + FiniteBits c, + FiniteBits s, + Bounded c, + Typeable s, + Integral c, + LinkedRep c s, + Solvable c s, + Mergeable s + ) => + proxy s -> + [Test] +concreteUnsignedSymTypeSafeSymRotateTests _ = + [ testProperty "In bound" $ \(x :: c) -> do + let b = fromIntegral (maxBound :: c) :: Integer + let bs = 2 * fromIntegral (finiteBitSize x) :: Integer + let maxRotateAmount = fromIntegral (min b bs) + forAll (chooseInt (0, maxRotateAmount)) $ + \(s :: Int) -> + ioProperty $ do + let rotateAmount = fromIntegral s + let rotateLExpected = mrgReturn (con (rotateL x s)) :: EM s + let rotateRExpected = mrgReturn (con (rotateR x s)) :: EM s + safeSymRotateL (con x) rotateAmount @?= rotateLExpected + safeSymRotateR (con x) rotateAmount @?= rotateRExpected + ] + +concreteSignedAtLeastThreeBitsSymTypeSafeSymRotateTests :: + forall proxy c s. + ( Arbitrary c, + Show s, + Num s, + Eq s, + SafeSymRotate ArithException s, + FiniteBits c, + FiniteBits s, + Bounded c, + Typeable s, + Integral c, + LinkedRep c s, + Solvable c s, + Mergeable s + ) => + proxy s -> + [Test] +concreteSignedAtLeastThreeBitsSymTypeSafeSymRotateTests p = + testCase + "Min bound" + ( do + let x = con (-1 :: c) + let rotateAmount = con (minBound :: c) + safeSymRotateL x rotateAmount @?= (overflowError :: EM s) + safeSymRotateR x rotateAmount @?= overflowError + ) + : concreteUnsignedSymTypeSafeSymRotateTests p + +safeSymRotateTests :: Test +safeSymRotateTests = + testGroup + "SafeSymRotate" + [ testGroup "Word8" $ concreteTypeSafeSymRotateTests (Proxy @Word8), + testGroup "Word16" $ concreteTypeSafeSymRotateTests (Proxy @Word16), + testGroup "Word32" $ concreteTypeSafeSymRotateTests (Proxy @Word32), + testGroup "Word64" $ concreteTypeSafeSymRotateTests (Proxy @Word64), + testGroup "Word" $ concreteTypeSafeSymRotateTests (Proxy @Word), + testGroup "WordN 1" $ concreteTypeSafeSymRotateTests (Proxy @(WordN 1)), + testGroup "WordN 2" $ concreteTypeSafeSymRotateTests (Proxy @(WordN 2)), + testGroup "WordN 3" $ concreteTypeSafeSymRotateTests (Proxy @(WordN 3)), + testGroup "WordN 63" $ concreteTypeSafeSymRotateTests (Proxy @(WordN 63)), + testGroup "WordN 64" $ concreteTypeSafeSymRotateTests (Proxy @(WordN 64)), + testGroup "WordN 65" $ concreteTypeSafeSymRotateTests (Proxy @(WordN 65)), + testGroup "WordN 128" $ + concreteTypeSafeSymRotateTests (Proxy @(WordN 128)), + testGroup "SymWordN 1" $ + concreteUnsignedSymTypeSafeSymRotateTests (Proxy @(SymWordN 1)), + testGroup "SymWordN 2" $ + concreteUnsignedSymTypeSafeSymRotateTests (Proxy @(SymWordN 2)), + testGroup "SymWordN 3" $ + concreteUnsignedSymTypeSafeSymRotateTests (Proxy @(SymWordN 3)), + testGroup "SymWordN 63" $ + concreteUnsignedSymTypeSafeSymRotateTests (Proxy @(SymWordN 63)), + testGroup "SymWordN 64" $ + concreteUnsignedSymTypeSafeSymRotateTests (Proxy @(SymWordN 64)), + testGroup "SymWordN 65" $ + concreteUnsignedSymTypeSafeSymRotateTests (Proxy @(SymWordN 65)), + testGroup "SymWordN 128" $ + concreteUnsignedSymTypeSafeSymRotateTests (Proxy @(SymWordN 128)), + testGroup "Int8" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @Int8), + testGroup "Int16" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @Int16), + testGroup "Int32" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @Int32), + testGroup "Int64" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @Int64), + testGroup "Int" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @Int), + testGroup + "IntN 1" + [ testGroup + "SafeSymRotate" + [ testGroup + "rotate left" + [ testCase "By 0" $ do + safeSymRotateL (-1) 0 @?= (mrgReturn $ -1 :: EM (IntN 1)) + safeSymRotateR (-1) 0 @?= (mrgReturn $ -1 :: EM (IntN 1)), + testCase "By -1" $ do + safeSymRotateL (-1) (-1 :: IntN 1) @?= overflowError + safeSymRotateR (-1) (-1 :: IntN 1) @?= overflowError + ] + ] + ], + testGroup + "IntN 2" + [ testGroup + "SafeSymRotate" + [ testGroup + "rotate left" + [ testCase "By 0" $ do + safeSymRotateL (-2) 0 @?= (mrgReturn $ -2 :: EM (IntN 2)) + safeSymRotateR (-2) 0 @?= (mrgReturn $ -2 :: EM (IntN 2)), + testCase "By 1" $ do + safeSymRotateL (-2) 1 @?= (mrgReturn 1 :: EM (IntN 2)) + safeSymRotateR (-2) 1 @?= (mrgReturn 1 :: EM (IntN 2)), + testCase "By -1" $ do + safeSymRotateL (-1) (-1 :: IntN 2) @?= overflowError + safeSymRotateR (-1) (-1 :: IntN 2) @?= overflowError, + testCase "By -2" $ do + safeSymRotateL (-1) (-2 :: IntN 2) @?= overflowError + safeSymRotateR (-1) (-2 :: IntN 2) @?= overflowError + ] + ] + ], + testGroup "IntN 3" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @(IntN 3)), + testGroup "IntN 63" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @(IntN 63)), + testGroup "IntN 64" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @(IntN 64)), + testGroup "IntN 65" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests (Proxy @(IntN 65)), + testGroup "IntN 128" $ + concreteSignedAtLeastThreeBitsTypeSafeSymRotateTests + (Proxy @(IntN 128)), + testGroup "SymIntN 3" $ + concreteSignedAtLeastThreeBitsSymTypeSafeSymRotateTests + (Proxy @(SymIntN 3)), + testGroup "SymIntN 63" $ + concreteSignedAtLeastThreeBitsSymTypeSafeSymRotateTests + (Proxy @(SymIntN 63)), + testGroup "SymIntN 64" $ + concreteSignedAtLeastThreeBitsSymTypeSafeSymRotateTests + (Proxy @(SymIntN 64)), + testGroup "SymIntN 65" $ + concreteSignedAtLeastThreeBitsSymTypeSafeSymRotateTests + (Proxy @(SymIntN 65)), + testGroup "SymIntN 128" $ + concreteSignedAtLeastThreeBitsSymTypeSafeSymRotateTests + (Proxy @(SymIntN 128)), + testGroup + "SymIntN 1" + [ testGroup + "SafeSymRotate" + [ testGroup + "rotate left" + [ testCase "By 0" $ do + safeSymRotateL (-1) 0 @?= (mrgReturn $ -1 :: EM (SymIntN 1)) + safeSymRotateR (-1) 0 + @?= (mrgReturn $ -1 :: EM (SymIntN 1)), + testCase "By -1" $ do + safeSymRotateL (-1) (-1 :: SymIntN 1) @?= overflowError + safeSymRotateR (-1) (-1 :: SymIntN 1) @?= overflowError + ] + ] + ], + testGroup + "SymIntN 2" + [ testGroup + "SafeSymRotate" + [ testGroup + "rotate left" + [ testCase "By 0" $ do + safeSymRotateL (-2) 0 @?= (mrgReturn $ -2 :: EM (SymIntN 2)) + safeSymRotateR (-2) 0 + @?= (mrgReturn $ -2 :: EM (SymIntN 2)), + testCase "By 1" $ do + safeSymRotateL (-2) 1 @?= (mrgReturn 1 :: EM (IntN 2)) + safeSymRotateR (-2) 1 @?= (mrgReturn 1 :: EM (IntN 2)), + testCase "By -1" $ do + safeSymRotateL (-1) (-1 :: SymIntN 2) @?= overflowError + safeSymRotateR (-1) (-1 :: SymIntN 2) @?= overflowError, + testCase "By -2" $ do + safeSymRotateL (-1) (-2 :: SymIntN 2) @?= overflowError + safeSymRotateR (-1) (-2 :: SymIntN 2) @?= overflowError + ] + ] + ] + ] diff --git a/test/Main.hs b/test/Main.hs index 2c5b81a0..91b6b7b7 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -19,6 +19,7 @@ import Grisette.Core.Data.Class.GenSymTests (genSymTests) import Grisette.Core.Data.Class.MergeableTests (mergeableTests) import Grisette.Core.Data.Class.SEqTests (seqTests) import Grisette.Core.Data.Class.SOrdTests (sordTests) +import Grisette.Core.Data.Class.SafeSymRotateTests (safeSymRotateTests) import Grisette.Core.Data.Class.SafeSymShiftTests (safeSymShiftTests) import Grisette.Core.Data.Class.SimpleMergeableTests (simpleMergeableTests) import Grisette.Core.Data.Class.SubstituteSymTests (substituteSymTests) @@ -91,6 +92,7 @@ coreTests = gprettyTests, mergeableTests, safeSymShiftTests, + safeSymRotateTests, seqTests, sordTests, simpleMergeableTests,