From 67556fb8739dc12f27daddbf553376b93d45026d Mon Sep 17 00:00:00 2001 From: Justin Grubbs Date: Sat, 6 Jul 2024 19:11:13 -0500 Subject: [PATCH] Fixes for memo logic and other cleanup --- example/fib.disco | 25 ++++++++++ example/rsa.disco | 11 ----- src/Disco/AST/Core.hs | 5 +- src/Disco/Compile.hs | 98 ++++++++++++++++++------------------- src/Disco/Interpret/CESK.hs | 4 +- 5 files changed, 79 insertions(+), 64 deletions(-) create mode 100644 example/fib.disco diff --git a/example/fib.disco b/example/fib.disco new file mode 100644 index 00000000..b808e726 --- /dev/null +++ b/example/fib.disco @@ -0,0 +1,25 @@ +-- Normal, should be memoized +fib : N -> N +fib 0 = 0 +fib 1 = 1 +fib x = fib (x .- 1) + fib (x .- 2) + +-- Multiple arrows, should be memoized +fibA : N -> N -> N -> N +fibA 0 _ _ = 0 +fibA 1 _ _ = 1 +fibA x n1 n2 = fibA (x .- 1) n1 n2 + fibA (x .- 2) n1 n2 + +-- Higher-order, should not be memoized +fibH : N -> N -> (N -> N) -> N +fibH 0 _ _ = 0 +fibH 1 _ _ = 1 +fibH x n f = fibH (x .- 1) n f + fibH (x .- 2) n f + +fibHH : N -> N -> (N -> N -> (N -> N) -> N) -> N +fibHH 0 _ _ = 0 +fibHH 1 _ _ = 1 +fibHH x n f = fibHH (x .- 1) n f + fibHH (x .- 2) n f + +help : N -> N +help _ = 30 diff --git a/example/rsa.disco b/example/rsa.disco index f84eddd9..0d2c8a19 100644 --- a/example/rsa.disco +++ b/example/rsa.disco @@ -19,17 +19,6 @@ import list -- and `decrypt` functions can be used to encrypt and decrypt -- lists of natural numbers. -fib : N -> N -fib 0 = 0 -fib 1 = 1 -fib x = fib (x .- 1) + fib (x .- 2) - -foo : (N -> Bool) -> N -> Bool -foo f n = f n - -tru : N -> Bool -tru _ = True - encrypt : N * N -> List(N) -> List(N) encrypt key xs = each (encrypt1 key, xs) diff --git a/src/Disco/AST/Core.hs b/src/Disco/AST/Core.hs index 1c240397..2c7322c5 100644 --- a/src/Disco/AST/Core.hs +++ b/src/Disco/AST/Core.hs @@ -16,6 +16,7 @@ module Disco.AST.Core ( -- * Core AST RationalDisplay (..), + ShouldMemo (..), Core (..), Op (..), opArity, @@ -62,6 +63,8 @@ instance Monoid RationalDisplay where mempty = Fraction mappend = (P.<>) +data ShouldMemo = Memo | NoMemo deriving (Show, Generic, Data, Alpha) + -- | AST for the desugared, untyped core language. data Core where -- | A variable. @@ -87,7 +90,7 @@ data Core where -- | A projection from a product type, i.e. @fst@ or @snd@. CProj :: Side -> Core -> Core -- | An anonymous function. - CAbs :: Bool -> Bind [Name Core] Core -> Core + CAbs :: ShouldMemo -> Bind [Name Core] Core -> Core -- | Function application. CApp :: Core -> Core -> Core -- | A "test frame" under which a test case is run. Records the diff --git a/src/Disco/Compile.hs b/src/Disco/Compile.hs index 70e055a2..e38a8b34 100644 --- a/src/Disco/Compile.hs +++ b/src/Disco/Compile.hs @@ -170,10 +170,10 @@ compileDTerm term@(DTAbs q _ _) = do cbody <- compileDTerm body case q of Lam -> if canMemo tys - then return $ abstractMemo xs cbody - else return $ abstract xs cbody - Ex -> return $ quantify (OExists tys) (abstract xs cbody) - All -> return $ quantify (OForall tys) (abstract xs cbody) + then return $ abstract Memo xs cbody + else return $ abstract NoMemo xs cbody + Ex -> return $ quantify (OExists tys) (abstract NoMemo xs cbody) + All -> return $ quantify (OForall tys) (abstract NoMemo xs cbody) where -- Gather nested abstractions with the same quantifier. unbindDeep :: Member Fresh r => DTerm -> Sem r ([Name DTerm], [Type], DTerm) @@ -183,46 +183,44 @@ compileDTerm term@(DTAbs q _ _) = do return (name : ns, ty : tys, body) unbindDeep t = return ([], [], t) - abstract :: [Name DTerm] -> Core -> Core - abstract xs body = CAbs False (bind (map coerce xs) body) - - abstractMemo :: [Name DTerm] -> Core -> Core - abstractMemo xs body = CAbs True (bind (map coerce xs) body) + abstract :: ShouldMemo -> [Name DTerm] -> Core -> Core + abstract m xs body = CAbs m (bind (map coerce xs) body) quantify :: Op -> Core -> Core quantify op = CApp (CConst op) - -- Intend on changing this, but functional for now - canMemo :: [Type] -> Bool - canMemo [] = True - canMemo (x : xs) = case x of - TyAtom a -> checkAtom a && canMemo xs - TyCon CArr tys -> arrMemo tys && canMemo tys && canMemo xs - TyCon c tys -> checkCon c && canMemo tys && canMemo xs - - arrMemo :: [Type] -> Bool - arrMemo [] = True - arrMemo (x : xs) = case x of - TyCon CArr _ -> False - _ -> arrMemo xs - - checkCon :: Con -> Bool - checkCon (CUser _) = False - checkCon CGraph = False - checkCon CMap = False - checkCon (CContainer a) = checkAtom a - checkCon _ = True + canMemo :: [Type] -> Bool + canMemo = all canMemoTy - checkAtom :: Atom -> Bool - checkAtom (AVar _) = False - checkAtom (ABase b) = checkBase b - - checkBase :: BaseTy -> Bool - checkBase CtrList = False - checkBase CtrBag = False - checkBase CtrSet = False - checkBase Gen = False - checkBase _ = True + canMemoTy :: Type -> Bool + canMemoTy (TyAtom a) = canMemoAtom a + -- Anti-higher-order while permitting multiple arrows + canMemoTy (TyCon CArr tys@(t:_)) = case t of + TyCon CArr _ -> False + _ -> all canMemoTy tys + canMemoTy (TyCon c tys) = canMemoCon c && all canMemoTy tys + + canMemoCon :: Con -> Bool + canMemoCon = \case + CArr -> False + CUser _ -> False + CGraph -> False + CMap -> False + CContainer a -> canMemoAtom a + _ -> True + + canMemoAtom :: Atom -> Bool + canMemoAtom (AVar _) = False + canMemoAtom (ABase b) = canMemoBase b + + canMemoBase :: BaseTy -> Bool + canMemoBase = \case + Gen -> False + CtrSet -> False + CtrBag -> False + CtrList -> False + P -> False + _ -> True -- Special case for Cons, which compiles to a constructor application @@ -272,13 +270,13 @@ compilePrim ty p@(PrimUOp _) = compilePrimErr p ty compilePrim _ (PrimBOp Cons) = do hd <- fresh (string2Name "hd") tl <- fresh (string2Name "tl") - return $ CAbs False $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) + return $ CAbs NoMemo $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) compilePrim _ PrimLeft = do a <- fresh (string2Name "a") - return $ CAbs False $ bind [a] $ CInj L (CVar (localName a)) + return $ CAbs NoMemo $ bind [a] $ CInj L (CVar (localName a)) compilePrim _ PrimRight = do a <- fresh (string2Name "a") - return $ CAbs False $ bind [a] $ CInj R (CVar (localName a)) + return $ CAbs NoMemo $ bind [a] $ CInj R (CVar (localName a)) compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop compilePrim ty p@(PrimBOp _) = compilePrimErr p ty compilePrim _ PrimSqrt = return $ CConst OSqrt @@ -380,13 +378,13 @@ compilePrimErr p ty = error $ "Impossible! compilePrim " ++ show p ++ " on bad t -- of type (Unit → τ), in order to delay evaluation until explicitly -- applying it to the unit value. compileCase :: Member Fresh r => [DBranch] -> Sem r Core -compileCase [] = return $ CAbs False (bind [string2Name "_"] (CConst OMatchErr)) +compileCase [] = return $ CAbs NoMemo (bind [string2Name "_"] (CConst OMatchErr)) -- empty case ==> λ _ . error compileCase (b : bs) = do c1 <- compileBranch b c2 <- compileCase bs - return $ CAbs False (bind [string2Name "_"] (CApp c1 c2)) + return $ CAbs NoMemo (bind [string2Name "_"] (CApp c1 c2)) -- | Compile a branch of a case expression of type τ to a core -- language expression of type (Unit → τ) → τ. The idea is that it @@ -400,7 +398,7 @@ compileBranch b = do c <- compileDTerm e k <- fresh (string2Name "k") -- Fresh name for the failure continuation bc <- compileGuards (fromTelescope gs) k c - return $ CAbs False (bind [k] bc) + return $ CAbs NoMemo (bind [k] bc) -- | 'compileGuards' takes a list of guards, the name of the failure -- continuation of type (Unit → τ), and a Core term of type τ to @@ -422,25 +420,25 @@ compileGuards (DGPat (unembed -> s) p : gs) k e = do -- calls the failure continuation in the case of failure, or the -- rest of the guards in the case of success. compileMatch :: Member Fresh r => DPattern -> Core -> Name Core -> Core -> Sem r Core -compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs False (bind [coerce x] e)) s +compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs NoMemo (bind [coerce x] e)) s -- Note in the below two cases that we can't just discard s since -- that would result in a lazy semantics. With an eager/strict -- semantics, we have to make sure s gets evaluated even if its -- value is then discarded. -compileMatch (DPWild _) s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s -compileMatch DPUnit s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s +compileMatch (DPWild _) s _ e = return $ CApp (CAbs NoMemo (bind [string2Name "_"] e)) s +compileMatch DPUnit s _ e = return $ CApp (CAbs NoMemo (bind [string2Name "_"] e)) s compileMatch (DPPair _ x1 x2) s _ e = do y <- fresh (string2Name "y") -- {? e when s is (x1,x2) ?} ==> (\y. (\x1.\x2. e) (fst y) (snd y)) s return $ CApp - ( CAbs False + ( CAbs NoMemo ( bind [y] ( CApp ( CApp - (CAbs False (bind [coerce x1, coerce x2] e)) + (CAbs NoMemo (bind [coerce x1, coerce x2] e)) (CProj L (CVar (localName y))) ) (CProj R (CVar (localName y))) diff --git a/src/Disco/Interpret/CESK.hs b/src/Disco/Interpret/CESK.hs index be2f8431..ddd31db4 100644 --- a/src/Disco/Interpret/CESK.hs +++ b/src/Disco/Interpret/CESK.hs @@ -179,11 +179,11 @@ step cesk = case cesk of (In (CAbs mem b) e k) -> do (xs, body) <- unbind b case mem of - True -> do + Memo -> do cell <- allocateValue (VTrie T.empty) -- cell <- allocateValue (VMap M.empty) return $ Out (VClo (Just (cell,[])) e xs body) k - False -> return $ Out (VClo Nothing e xs body) k + NoMemo -> return $ Out (VClo Nothing e xs body) k (In (CApp c1 c2) e k) -> return $ In c1 e (FArg e c2 : k) (In (CType ty) _ k) -> return $ Out (VType ty) k