From 6062daf58a8d7c67972ccff06cf61f15adf51ac2 Mon Sep 17 00:00:00 2001 From: JustinGrubbs Date: Sun, 30 Jun 2024 18:51:19 -0500 Subject: [PATCH] Beginning work on auto-memoization --- src/Disco/AST/Core.hs | 4 ++-- src/Disco/Compile.hs | 32 ++++++++++++++++++++------------ src/Disco/Interpret/CESK.hs | 28 ++++++++++++++++++++++------ src/Disco/Value.hs | 9 ++++++++- 4 files changed, 52 insertions(+), 21 deletions(-) diff --git a/src/Disco/AST/Core.hs b/src/Disco/AST/Core.hs index bb8b6ebf..1c240397 100644 --- a/src/Disco/AST/Core.hs +++ b/src/Disco/AST/Core.hs @@ -87,7 +87,7 @@ data Core where -- | A projection from a product type, i.e. @fst@ or @snd@. CProj :: Side -> Core -> Core -- | An anonymous function. - CAbs :: Bind [Name Core] Core -> Core + CAbs :: Bool -> Bind [Name Core] Core -> Core -- | Function application. CApp :: Core -> Core -> Core -- | A "test frame" under which a test case is run. Records the @@ -305,7 +305,7 @@ instance Pretty Core where CUnit -> "unit" CPair c1 c2 -> setPA initPA $ parens (pretty c1 <> ", " <> pretty c2) CProj s c -> withPA funPA $ selectSide s "fst" "snd" <+> rt (pretty c) - CAbs lam -> withPA initPA $ do + CAbs _ lam -> withPA initPA $ do lunbind lam $ \(xs, body) -> "λ" <> intercalate "," (map pretty xs) <> "." <+> lt (pretty body) CApp c1 c2 -> withPA funPA $ lt (pretty c1) <+> rt (pretty c2) CTest xs c -> "test" <+> prettyTestVars xs <+> pretty c diff --git a/src/Disco/Compile.hs b/src/Disco/Compile.hs index 325589ca..b345a6c3 100644 --- a/src/Disco/Compile.hs +++ b/src/Disco/Compile.hs @@ -169,7 +169,12 @@ compileDTerm term@(DTAbs q _ _) = do (xs, tys, body) <- unbindDeep term cbody <- compileDTerm body case q of + Lam -> return $ abstract xs cbody + -- Lam -> case _ of + -- _ -> return $ abstract xs cbody + -- _ -> return $ abstractMemo xs cbody + Ex -> return $ quantify (OExists tys) (abstract xs cbody) All -> return $ quantify (OForall tys) (abstract xs cbody) where @@ -182,7 +187,10 @@ compileDTerm term@(DTAbs q _ _) = do unbindDeep t = return ([], [], t) abstract :: [Name DTerm] -> Core -> Core - abstract xs body = CAbs (bind (map coerce xs) body) + abstract xs body = CAbs False (bind (map coerce xs) body) + + abstractMemo :: [Name DTerm] -> Core -> Core + abstractMemo xs body = CAbs True (bind (map coerce xs) body) quantify :: Op -> Core -> Core quantify op = CApp (CConst op) @@ -234,13 +242,13 @@ compilePrim ty p@(PrimUOp _) = compilePrimErr p ty compilePrim _ (PrimBOp Cons) = do hd <- fresh (string2Name "hd") tl <- fresh (string2Name "tl") - return $ CAbs $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) + return $ CAbs False $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) compilePrim _ PrimLeft = do a <- fresh (string2Name "a") - return $ CAbs $ bind [a] $ CInj L (CVar (localName a)) + return $ CAbs False $ bind [a] $ CInj L (CVar (localName a)) compilePrim _ PrimRight = do a <- fresh (string2Name "a") - return $ CAbs $ bind [a] $ CInj R (CVar (localName a)) + return $ CAbs False $ bind [a] $ CInj R (CVar (localName a)) compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop compilePrim ty p@(PrimBOp _) = compilePrimErr p ty compilePrim _ PrimSqrt = return $ CConst OSqrt @@ -342,13 +350,13 @@ compilePrimErr p ty = error $ "Impossible! compilePrim " ++ show p ++ " on bad t -- of type (Unit → τ), in order to delay evaluation until explicitly -- applying it to the unit value. compileCase :: Member Fresh r => [DBranch] -> Sem r Core -compileCase [] = return $ CAbs (bind [string2Name "_"] (CConst OMatchErr)) +compileCase [] = return $ CAbs False (bind [string2Name "_"] (CConst OMatchErr)) -- empty case ==> λ _ . error compileCase (b : bs) = do c1 <- compileBranch b c2 <- compileCase bs - return $ CAbs (bind [string2Name "_"] (CApp c1 c2)) + return $ CAbs False (bind [string2Name "_"] (CApp c1 c2)) -- | Compile a branch of a case expression of type τ to a core -- language expression of type (Unit → τ) → τ. The idea is that it @@ -362,7 +370,7 @@ compileBranch b = do c <- compileDTerm e k <- fresh (string2Name "k") -- Fresh name for the failure continuation bc <- compileGuards (fromTelescope gs) k c - return $ CAbs (bind [k] bc) + return $ CAbs False (bind [k] bc) -- | 'compileGuards' takes a list of guards, the name of the failure -- continuation of type (Unit → τ), and a Core term of type τ to @@ -384,25 +392,25 @@ compileGuards (DGPat (unembed -> s) p : gs) k e = do -- calls the failure continuation in the case of failure, or the -- rest of the guards in the case of success. compileMatch :: Member Fresh r => DPattern -> Core -> Name Core -> Core -> Sem r Core -compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs (bind [coerce x] e)) s +compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs False (bind [coerce x] e)) s -- Note in the below two cases that we can't just discard s since -- that would result in a lazy semantics. With an eager/strict -- semantics, we have to make sure s gets evaluated even if its -- value is then discarded. -compileMatch (DPWild _) s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s -compileMatch DPUnit s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s +compileMatch (DPWild _) s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s +compileMatch DPUnit s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s compileMatch (DPPair _ x1 x2) s _ e = do y <- fresh (string2Name "y") -- {? e when s is (x1,x2) ?} ==> (\y. (\x1.\x2. e) (fst y) (snd y)) s return $ CApp - ( CAbs + ( CAbs False ( bind [y] ( CApp ( CApp - (CAbs (bind [coerce x1, coerce x2] e)) + (CAbs False (bind [coerce x1, coerce x2] e)) (CProj L (CVar (localName y))) ) (CProj R (CVar (localName y))) diff --git a/src/Disco/Interpret/CESK.hs b/src/Disco/Interpret/CESK.hs index f28d6bf1..56866526 100644 --- a/src/Disco/Interpret/CESK.hs +++ b/src/Disco/Interpret/CESK.hs @@ -113,6 +113,8 @@ data Frame FUpdate Int | -- | Record the results of a test. FTest TestVars Env + | -- | Memoize the results of a function call. (Memory cell - Args) + FMemo Int SimpleValue deriving (Show) ------------------------------------------------------------ @@ -172,9 +174,14 @@ step cesk = case cesk of (In CUnit _ k) -> return $ Out VUnit k (In (CPair c1 c2) e k) -> return $ In c1 e (FPairR e c2 : k) (In (CProj s c) e k) -> return $ In c e (FProj s : k) - (In (CAbs b) e k) -> do + (In (CAbs True b) e k) -> do (xs, body) <- unbind b - return $ Out (VClo e xs body) k + -- Init memo map and memory loc + cell <- allocateV + return $ Out (VClo (Just cell) e xs body) k + (In (CAbs False b) e k) -> do + (xs, body) <- unbind b + return $ Out (VClo Nothing e xs body) k (In (CApp c1 c2) e k) -> return $ In c1 e (FArg e c2 : k) (In (CType ty) _ k) -> return $ Out (VType ty) k (In (CDelay b) e k) -> do @@ -194,15 +201,24 @@ step cesk = case cesk of (Out v2 (FPairL v1 : k)) -> return $ Out (VPair v1 v2) k (Out (VPair v1 v2) (FProj s : k)) -> return $ Out (selectSide s v1 v2) k (Out v (FArg e c2 : k)) -> return $ In c2 e (FApp v : k) - (Out v2 (FApp (VClo e [x] b) : k)) -> return $ In b (Ctx.insert (localName x) v2 e) k - (Out v2 (FApp (VClo e (x : xs) b) : k)) -> return $ Out (VClo (Ctx.insert (localName x) v2 e) xs b) k + + (Out v2 (FApp (VClo mi e [x] b) : k)) -> case mi of + Just n -> undefined + Nothing -> return $ In b (Ctx.insert (localName x) v2 e) k + + (Out v2 (FApp (VClo mi e (x : xs) b) : k)) -> case mi of + Just n -> return $ Out (VClo mi (Ctx.insert (localName x) v2 e) xs b) (FMemo n _ : k) + Nothing -> return $ Out (VClo mi (Ctx.insert (localName x) v2 e) xs b) k + (Out v2 (FApp (VConst op) : k)) -> appConst k op v2 (Out v2 (FApp (VFun f) : k)) -> return $ Out (f v2) k -- Annoying to repeat this code, not sure of a better way. -- The usual evaluation order (function then argument) doesn't work when -- we're applying a test function to randomly generated values. - (Out (VClo e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k - (Out (VClo e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo (Ctx.insert (localName x) v e) xs b) k + + (Out (VClo _ e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k + (Out (VClo mi e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo mi (Ctx.insert (localName x) v e) xs b) k + (Out (VConst op) (FArgV v : k)) -> appConst k op v (Out (VFun f) (FArgV v : k)) -> return $ Out (f v) k (Out (VRef n) (FForce : k)) -> do diff --git a/src/Disco/Value.hs b/src/Disco/Value.hs index 6472cdbb..78de8c15 100644 --- a/src/Disco/Value.hs +++ b/src/Disco/Value.hs @@ -68,6 +68,7 @@ module Disco.Value ( Mem, emptyMem, allocate, + allocateV, allocateRec, lkup, set, @@ -130,7 +131,7 @@ data Value where VPair :: Value -> Value -> Value -- | A closure, i.e. a function body together with its -- environment. - VClo :: Env -> [Name Core] -> Core -> Value + VClo :: Maybe Int -> Env -> [Name Core] -> Core -> Value -- | A disco type can be a value. For now, there are only a very -- limited number of places this could ever show up (in -- particular, as an argument to @enumerate@ or @count@). @@ -451,6 +452,12 @@ allocate e t = do put $ Mem (n + 1) (IM.insert n (E e t) m) return n +allocateV :: Members '[State Mem] r => Sem r Int +allocateV = do + Mem n m <- get + put $ Mem (n + 1) (IM.insert n (Disco.Value.V (VMap M.empty)) m) + return n + -- | Allocate new memory cells for a group of mutually recursive -- bindings, and return the indices of the allocate cells. allocateRec :: Members '[State Mem] r => Env -> [(QName Core, Core)] -> Sem r [Int]