Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memoize #394

Merged
merged 17 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions disco.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ library
polysemy-plugin >= 0.4 && < 0.5,
reflection >= 2.1.7 && < 2.2,
random >= 1.2.1.1 && < 1.3,
list-tries,
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
constraints >= 0.13.4 && < 0.15,
text >= 2.0.2 && < 2.2,
lens >= 4.14 && < 5.4,
Expand Down
10 changes: 10 additions & 0 deletions example/rsa.disco
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ import list
-- and `decrypt` functions can be used to encrypt and decrypt
-- lists of natural numbers.

fib : N -> N
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
fib 0 = 0
fib 1 = 1
fib x = fib (x .- 1) + fib (x .- 2)

foo : (N -> Bool) -> N -> Bool
foo f n = f n

tru : N -> Bool
tru _ = True

encrypt : N * N -> List(N) -> List(N)
encrypt key xs = each (encrypt1 key, xs)
Expand Down
4 changes: 2 additions & 2 deletions src/Disco/AST/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ data Core where
-- | A projection from a product type, i.e. @fst@ or @snd@.
CProj :: Side -> Core -> Core
-- | An anonymous function.
CAbs :: Bind [Name Core] Core -> Core
CAbs :: Bool -> Bind [Name Core] Core -> Core
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
-- | Function application.
CApp :: Core -> Core -> Core
-- | A "test frame" under which a test case is run. Records the
Expand Down Expand Up @@ -305,7 +305,7 @@ instance Pretty Core where
CUnit -> "unit"
CPair c1 c2 -> setPA initPA $ parens (pretty c1 <> ", " <> pretty c2)
CProj s c -> withPA funPA $ selectSide s "fst" "snd" <+> rt (pretty c)
CAbs lam -> withPA initPA $ do
CAbs _ lam -> withPA initPA $ do
lunbind lam $ \(xs, body) -> "λ" <> intercalate "," (map pretty xs) <> "." <+> lt (pretty body)
CApp c1 c2 -> withPA funPA $ lt (pretty c1) <+> rt (pretty c2)
CTest xs c -> "test" <+> prettyTestVars xs <+> pretty c
Expand Down
64 changes: 51 additions & 13 deletions src/Disco/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ compileDTerm term@(DTAbs q _ _) = do
(xs, tys, body) <- unbindDeep term
cbody <- compileDTerm body
case q of
Lam -> return $ abstract xs cbody
Lam -> if canMemo tys
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
then return $ abstractMemo xs cbody
else return $ abstract xs cbody
Ex -> return $ quantify (OExists tys) (abstract xs cbody)
All -> return $ quantify (OForall tys) (abstract xs cbody)
where
Expand All @@ -182,11 +184,47 @@ compileDTerm term@(DTAbs q _ _) = do
unbindDeep t = return ([], [], t)

abstract :: [Name DTerm] -> Core -> Core
abstract xs body = CAbs (bind (map coerce xs) body)
abstract xs body = CAbs False (bind (map coerce xs) body)

abstractMemo :: [Name DTerm] -> Core -> Core
abstractMemo xs body = CAbs True (bind (map coerce xs) body)
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved

quantify :: Op -> Core -> Core
quantify op = CApp (CConst op)

-- Intend on changing this, but functional for now
canMemo :: [Type] -> Bool
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
canMemo [] = True
canMemo (x : xs) = case x of
TyAtom a -> checkAtom a && canMemo xs
TyCon CArr tys -> arrMemo tys && canMemo tys && canMemo xs
TyCon c tys -> checkCon c && canMemo tys && canMemo xs

arrMemo :: [Type] -> Bool
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
arrMemo [] = True
arrMemo (x : xs) = case x of
TyCon CArr _ -> False
_ -> arrMemo xs

checkCon :: Con -> Bool
checkCon (CUser _) = False
checkCon CGraph = False
checkCon CMap = False
checkCon (CContainer a) = checkAtom a
checkCon _ = True

checkAtom :: Atom -> Bool
checkAtom (AVar _) = False
checkAtom (ABase b) = checkBase b

checkBase :: BaseTy -> Bool
checkBase CtrList = False
checkBase CtrBag = False
checkBase CtrSet = False
checkBase Gen = False
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
checkBase _ = True


-- Special case for Cons, which compiles to a constructor application
-- rather than a function application.
compileDTerm (DTApp _ (DTPrim _ (PrimBOp Cons)) (DTPair _ t1 t2)) =
Expand Down Expand Up @@ -234,13 +272,13 @@ compilePrim ty p@(PrimUOp _) = compilePrimErr p ty
compilePrim _ (PrimBOp Cons) = do
hd <- fresh (string2Name "hd")
tl <- fresh (string2Name "tl")
return $ CAbs $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl)))
return $ CAbs False $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl)))
compilePrim _ PrimLeft = do
a <- fresh (string2Name "a")
return $ CAbs $ bind [a] $ CInj L (CVar (localName a))
return $ CAbs False $ bind [a] $ CInj L (CVar (localName a))
compilePrim _ PrimRight = do
a <- fresh (string2Name "a")
return $ CAbs $ bind [a] $ CInj R (CVar (localName a))
return $ CAbs False $ bind [a] $ CInj R (CVar (localName a))
compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop
compilePrim ty p@(PrimBOp _) = compilePrimErr p ty
compilePrim _ PrimSqrt = return $ CConst OSqrt
Expand Down Expand Up @@ -342,13 +380,13 @@ compilePrimErr p ty = error $ "Impossible! compilePrim " ++ show p ++ " on bad t
-- of type (Unit → τ), in order to delay evaluation until explicitly
-- applying it to the unit value.
compileCase :: Member Fresh r => [DBranch] -> Sem r Core
compileCase [] = return $ CAbs (bind [string2Name "_"] (CConst OMatchErr))
compileCase [] = return $ CAbs False (bind [string2Name "_"] (CConst OMatchErr))
-- empty case ==> λ _ . error

compileCase (b : bs) = do
c1 <- compileBranch b
c2 <- compileCase bs
return $ CAbs (bind [string2Name "_"] (CApp c1 c2))
return $ CAbs False (bind [string2Name "_"] (CApp c1 c2))

-- | Compile a branch of a case expression of type τ to a core
-- language expression of type (Unit → τ) → τ. The idea is that it
Expand All @@ -362,7 +400,7 @@ compileBranch b = do
c <- compileDTerm e
k <- fresh (string2Name "k") -- Fresh name for the failure continuation
bc <- compileGuards (fromTelescope gs) k c
return $ CAbs (bind [k] bc)
return $ CAbs False (bind [k] bc)

-- | 'compileGuards' takes a list of guards, the name of the failure
-- continuation of type (Unit → τ), and a Core term of type τ to
Expand All @@ -384,25 +422,25 @@ compileGuards (DGPat (unembed -> s) p : gs) k e = do
-- calls the failure continuation in the case of failure, or the
-- rest of the guards in the case of success.
compileMatch :: Member Fresh r => DPattern -> Core -> Name Core -> Core -> Sem r Core
compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs (bind [coerce x] e)) s
compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs False (bind [coerce x] e)) s
-- Note in the below two cases that we can't just discard s since
-- that would result in a lazy semantics. With an eager/strict
-- semantics, we have to make sure s gets evaluated even if its
-- value is then discarded.
compileMatch (DPWild _) s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s
compileMatch DPUnit s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s
compileMatch (DPWild _) s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s
compileMatch DPUnit s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s
compileMatch (DPPair _ x1 x2) s _ e = do
y <- fresh (string2Name "y")

-- {? e when s is (x1,x2) ?} ==> (\y. (\x1.\x2. e) (fst y) (snd y)) s
return $
CApp
( CAbs
( CAbs False
( bind
[y]
( CApp
( CApp
(CAbs (bind [coerce x1, coerce x2] e))
(CAbs False (bind [coerce x1, coerce x2] e))
(CProj L (CVar (localName y)))
)
(CProj R (CVar (localName y)))
Expand Down
40 changes: 34 additions & 6 deletions src/Disco/Interpret/CESK.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import qualified Data.List.Infinite as InfList
import qualified Data.Map as M
import Data.Maybe (isJust)
import Data.Ratio
import qualified Data.ListTrie.Map as T
import Disco.AST.Core
import Disco.AST.Generic (
Ellipsis (..),
Expand Down Expand Up @@ -113,6 +114,8 @@ data Frame
FUpdate Int
| -- | Record the results of a test.
FTest TestVars Env
| -- | Memoize the result of a function.
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved
FMemo Int SimpleValue
deriving (Show)

------------------------------------------------------------
Expand Down Expand Up @@ -172,9 +175,16 @@ step cesk = case cesk of
(In CUnit _ k) -> return $ Out VUnit k
(In (CPair c1 c2) e k) -> return $ In c1 e (FPairR e c2 : k)
(In (CProj s c) e k) -> return $ In c e (FProj s : k)
(In (CAbs b) e k) -> do

(In (CAbs mem b) e k) -> do
(xs, body) <- unbind b
return $ Out (VClo e xs body) k
case mem of
True -> do
cell <- allocateValue (VTrie T.empty)
-- cell <- allocateValue (VMap M.empty)
return $ Out (VClo (Just (cell,[])) e xs body) k
False -> return $ Out (VClo Nothing e xs body) k

(In (CApp c1 c2) e k) -> return $ In c1 e (FArg e c2 : k)
(In (CType ty) _ k) -> return $ Out (VType ty) k
(In (CDelay b) e k) -> do
Expand All @@ -194,15 +204,33 @@ step cesk = case cesk of
(Out v2 (FPairL v1 : k)) -> return $ Out (VPair v1 v2) k
(Out (VPair v1 v2) (FProj s : k)) -> return $ Out (selectSide s v1 v2) k
(Out v (FArg e c2 : k)) -> return $ In c2 e (FApp v : k)
(Out v2 (FApp (VClo e [x] b) : k)) -> return $ In b (Ctx.insert (localName x) v2 e) k
(Out v2 (FApp (VClo e (x : xs) b) : k)) -> return $ Out (VClo (Ctx.insert (localName x) v2 e) xs b) k



(Out v (FMemo n sv : k)) -> memoSet n sv v *> (return $ Out v k)

(Out v (FApp (VClo mi e [x] b) : k)) -> case mi of
Nothing -> return $ In b (Ctx.insert (localName x) v e) k
Just (n,mem) -> do
let sv = toSimpleValue $ foldr VPair VUnit (v : mem)
mv <- memoLookup n sv
case mv of
Nothing -> return $ In b (Ctx.insert (localName x) v e) (FMemo n sv : k)
Just v' -> return $ Out v' k

(Out v (FApp (VClo mi e (x : xs) b) : k)) -> case mi of
Just (n,mem) -> return $ Out (VClo (Just (n,v:mem)) (Ctx.insert (localName x) v e) xs b) k
Nothing -> return $ Out (VClo mi (Ctx.insert (localName x) v e) xs b) k
justingrubbs marked this conversation as resolved.
Show resolved Hide resolved



(Out v2 (FApp (VConst op) : k)) -> appConst k op v2
(Out v2 (FApp (VFun f) : k)) -> return $ Out (f v2) k
-- Annoying to repeat this code, not sure of a better way.
-- The usual evaluation order (function then argument) doesn't work when
-- we're applying a test function to randomly generated values.
(Out (VClo e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k
(Out (VClo e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo (Ctx.insert (localName x) v e) xs b) k
(Out (VClo _ e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k
(Out (VClo mi e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo mi (Ctx.insert (localName x) v e) xs b) k
(Out (VConst op) (FArgV v : k)) -> appConst k op v
(Out (VFun f) (FArgV v : k)) -> return $ Out (f v) k
(Out (VRef n) (FForce : k)) -> do
Expand Down
30 changes: 29 additions & 1 deletion src/Disco/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,12 @@ module Disco.Value (
Mem,
emptyMem,
allocate,
allocateValue,
allocateRec,
lkup,
memoLookup,
set,
memoSet,

-- * Pretty-printing
prettyValue',
Expand All @@ -86,6 +89,7 @@ import Data.Char (chr, ord)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.List (foldl')
import qualified Data.ListTrie.Map as T
import Data.Map (Map)
import qualified Data.Map as M
import Data.Ratio
Expand Down Expand Up @@ -130,7 +134,7 @@ data Value where
VPair :: Value -> Value -> Value
-- | A closure, i.e. a function body together with its
-- environment.
VClo :: Env -> [Name Core] -> Core -> Value
VClo :: Maybe (Int,[Value]) -> Env -> [Name Core] -> Core -> Value
-- | A disco type can be a value. For now, there are only a very
-- limited number of places this could ever show up (in
-- particular, as an argument to @enumerate@ or @count@).
Expand Down Expand Up @@ -163,6 +167,9 @@ data Value where
-- actually construct the set of entries, while functions only have this
-- property when the key type is finite.
VMap :: Map SimpleValue Value -> Value

VTrie :: T.TrieMap Map SimpleValue Value -> Value

VGen :: StdGen -> Value
deriving (Show)

Expand Down Expand Up @@ -451,6 +458,12 @@ allocate e t = do
put $ Mem (n + 1) (IM.insert n (E e t) m)
return n

allocateValue :: Members '[State Mem] r => Value -> Sem r Int
allocateValue v = do
Mem n m <- get
put $ Mem (n + 1) (IM.insert n (Disco.Value.V v) m)
return n

-- | Allocate new memory cells for a group of mutually recursive
-- bindings, and return the indices of the allocate cells.
allocateRec :: Members '[State Mem] r => Env -> [(QName Core, Core)] -> Sem r [Int]
Expand All @@ -471,6 +484,21 @@ lkup n = gets (IM.lookup n . mu)
set :: Members '[State Mem] r => Int -> Cell -> Sem r ()
set n c = modify $ \(Mem nxt m) -> Mem nxt (IM.insert n c m)

memoLookup :: Members '[State Mem] r => Int -> SimpleValue -> Sem r (Maybe Value)
memoLookup n sv = gets (mLookup . IM.lookup n . mu)
where
mLookup (Just (Disco.Value.V (VMap vmap))) = M.lookup sv vmap
mLookup (Just (Disco.Value.V (VTrie vtrie))) = T.lookup [sv] vtrie
mLookup _ = Nothing

memoSet :: Members '[State Mem] r => Int -> SimpleValue -> Value -> Sem r ()
memoSet n sv v = do
mc <- lkup n
case mc of
Just (Disco.Value.V (VMap vmap)) -> set n (Disco.Value.V (VMap (M.insert sv v vmap)))
Just (Disco.Value.V (VTrie trie)) -> set n (Disco.Value.V (VTrie (T.insert [sv] v trie)))
_ -> undefined

------------------------------------------------------------
-- Pretty-printing values
------------------------------------------------------------
Expand Down
Loading