Skip to content

Commit

Permalink
🐛 Fix race condition in stableMemo
Browse files Browse the repository at this point in the history
  • Loading branch information
lsrcz committed Sep 7, 2024
1 parent 61e9bfb commit 1ba1356
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/Grisette/Internal/Core/Data/MemoUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ module Grisette.Internal.Core.Data.MemoUtils
where

import Control.Applicative (Const (Const, getConst))
import Control.Concurrent (MVar, newMVar, putMVar, takeMVar)
import Control.Monad.Fix (fix)
import Data.HashTable.IO (BasicHashTable)
import qualified Data.HashTable.IO as H
Expand Down Expand Up @@ -70,12 +71,15 @@ instance Ref Strong where
deRef (Strong x _) = return $ Just x
finalize (Strong _ weak) = Weak.finalize weak

finalizer :: StableName (f Any) -> Weak (MemoTable ref f g) -> IO ()
finalizer sn weakTbl = do
finalizer :: StableName (f Any) -> MVar () -> Weak (MemoTable ref f g) -> IO ()
finalizer sn lock weakTbl = do
r <- Weak.deRefWeak weakTbl
case r of
Nothing -> return ()
Just tbl -> HashTable.delete tbl sn
Just tbl -> do
takeMVar lock
HashTable.delete tbl sn
putMVar lock ()

unsafeToAny :: f a -> f Any
unsafeToAny = unsafeCoerce
Expand All @@ -89,11 +93,13 @@ memo' ::
Proxy ref ->
(forall a. f a -> g a) ->
MemoTable ref f g ->
MVar () ->
Weak (MemoTable ref f g) ->
f b ->
g b
memo' _ f tbl weakTbl !x = unsafePerformIO $ do
memo' _ f tbl lock weakTbl !x = unsafePerformIO $ do
sn <- makeStableName $ unsafeToAny x
takeMVar lock
lkp <- HashTable.lookup tbl sn
case lkp of
Nothing -> notFound sn
Expand All @@ -102,12 +108,14 @@ memo' _ f tbl weakTbl !x = unsafePerformIO $ do
case maybeVal of
Nothing -> notFound sn
Just val -> do
putMVar lock ()
return $ unsafeFromAny val
where
notFound sn = do
let y = f x
weak <- mkRef x (unsafeToAny y) $ finalizer sn weakTbl
weak <- mkRef x (unsafeToAny y) $ finalizer sn lock weakTbl
HashTable.insert tbl sn $ O weak
putMVar lock ()
return y

tableFinalizer :: (Ref ref) => MemoTable ref f g -> IO ()
Expand All @@ -121,11 +129,12 @@ memo0 ::
f b ->
g b
memo0 p f =
let (tbl, weak) = unsafePerformIO $ do
let (tbl, lock, weak) = unsafePerformIO $ do
tbl' <- HashTable.new
lock' <- newMVar ()
weak' <- Weak.mkWeakPtr tbl . Just $ tableFinalizer tbl
return (tbl', weak')
in memo' p f tbl weak
return (tbl', lock', weak')
in memo' p f tbl lock weak

-- | Memoize a unary function.
stableMemo :: (a -> b) -> (a -> b)
Expand Down

0 comments on commit 1ba1356

Please sign in to comment.