diff --git a/example/fib.disco b/example/fib.disco new file mode 100644 index 00000000..8ed7ecd3 --- /dev/null +++ b/example/fib.disco @@ -0,0 +1,47 @@ +-- | Normal: should be memoized. +!!! fib 100 == 354224848179261915075 +fib : N -> N +fib 0 = 0 +fib 1 = 1 +fib x = fib (x .- 1) + fib (x .- 2) + +-- | Multiple arrows: should be memoized. +!!! fibA 100 0 == 354224848179261915075 +fibA : N -> N -> N +fibA 0 _ = 0 +fibA 1 _ = 1 +fibA x n = fibA (x .- 1) n + fibA (x .- 2) n + +-- | Container types: can be memoized but may cause +-- performance issues depending on size of container. +!!! fibList 100 [1..100] == 354224848179261915075 +fibList : N -> List(N) -> N +fibList 0 _ = 0 +fibList 1 _ = 1 +fibList x l = fibList (x .- 1) l + fibList (x .- 2) l + +!!! fibSet 100 {1..100} == 354224848179261915075 +fibSet : N -> Set(N) -> N +fibSet 0 _ = 0 +fibSet 1 _ = 1 +fibSet x s = fibSet (x .- 1) s + fibSet (x .- 2) s + +!!! fibBag 100 ⟅1..100⟆ == 354224848179261915075 +fibBag : N -> Bag(N) -> N +fibBag 0 _ = 0 +fibBag 1 _ = 1 +fibBag x b = fibBag (x .- 1) b + fibBag (x .- 2) b + +-- | Higher-order: should not be memoized. +fibH : N -> (N -> N) -> N +fibH 0 _ = 0 +fibH 1 _ = 1 +fibH x f = fibH (x .- 1) f + fibH (x .- 2) f + +fibHH : N -> (N -> (N -> N) -> N) -> N +fibHH 0 _ = 0 +fibHH 1 _ = 1 +fibHH x f = fibHH (x .- 1) f + fibHH (x .- 2) f + +id : a -> a +id x = x diff --git a/example/rsa.disco b/example/rsa.disco index 9181700a..0d2c8a19 100644 --- a/example/rsa.disco +++ b/example/rsa.disco @@ -19,7 +19,6 @@ import list -- and `decrypt` functions can be used to encrypt and decrypt -- lists of natural numbers. - encrypt : N * N -> List(N) -> List(N) encrypt key xs = each (encrypt1 key, xs) diff --git a/src/Disco/AST/Core.hs b/src/Disco/AST/Core.hs index bb8b6ebf..2c7322c5 100644 --- a/src/Disco/AST/Core.hs +++ b/src/Disco/AST/Core.hs @@ -16,6 +16,7 @@ module Disco.AST.Core ( -- * Core AST RationalDisplay (..), + ShouldMemo (..), Core (..), Op (..), opArity, @@ -62,6 +63,8 @@ instance Monoid RationalDisplay where mempty = Fraction mappend = (P.<>) +data ShouldMemo = Memo | NoMemo deriving (Show, Generic, Data, Alpha) + -- | AST for the desugared, untyped core language. data Core where -- | A variable. @@ -87,7 +90,7 @@ data Core where -- | A projection from a product type, i.e. @fst@ or @snd@. CProj :: Side -> Core -> Core -- | An anonymous function. - CAbs :: Bind [Name Core] Core -> Core + CAbs :: ShouldMemo -> Bind [Name Core] Core -> Core -- | Function application. CApp :: Core -> Core -> Core -- | A "test frame" under which a test case is run. Records the @@ -305,7 +308,7 @@ instance Pretty Core where CUnit -> "unit" CPair c1 c2 -> setPA initPA $ parens (pretty c1 <> ", " <> pretty c2) CProj s c -> withPA funPA $ selectSide s "fst" "snd" <+> rt (pretty c) - CAbs lam -> withPA initPA $ do + CAbs _ lam -> withPA initPA $ do lunbind lam $ \(xs, body) -> "λ" <> intercalate "," (map pretty xs) <> "." <+> lt (pretty body) CApp c1 c2 -> withPA funPA $ lt (pretty c1) <+> rt (pretty c2) CTest xs c -> "test" <+> prettyTestVars xs <+> pretty c diff --git a/src/Disco/Compile.hs b/src/Disco/Compile.hs index 325589ca..f44a4f42 100644 --- a/src/Disco/Compile.hs +++ b/src/Disco/Compile.hs @@ -169,9 +169,9 @@ compileDTerm term@(DTAbs q _ _) = do (xs, tys, body) <- unbindDeep term cbody <- compileDTerm body case q of - Lam -> return $ abstract xs cbody - Ex -> return $ quantify (OExists tys) (abstract xs cbody) - All -> return $ quantify (OForall tys) (abstract xs cbody) + Lam -> return $ abstract (canMemo tys) xs cbody + Ex -> return $ quantify (OExists tys) (abstract NoMemo xs cbody) + All -> return $ quantify (OForall tys) (abstract NoMemo xs cbody) where -- Gather nested abstractions with the same quantifier. unbindDeep :: Member Fresh r => DTerm -> Sem r ([Name DTerm], [Type], DTerm) @@ -181,12 +181,47 @@ compileDTerm term@(DTAbs q _ _) = do return (name : ns, ty : tys, body) unbindDeep t = return ([], [], t) - abstract :: [Name DTerm] -> Core -> Core - abstract xs body = CAbs (bind (map coerce xs) body) + abstract :: ShouldMemo -> [Name DTerm] -> Core -> Core + abstract m xs body = CAbs m (bind (map coerce xs) body) quantify :: Op -> Core -> Core quantify op = CApp (CConst op) + -- Given a function's arguments, determine if it is memoizable. + -- A function is memoizable if its arguments can be converted into + -- a simple value (Haskell Ord instance can be derived). + canMemo :: [Type] -> ShouldMemo + canMemo tys + | all canMemoTy tys = Memo + | otherwise = NoMemo + + canMemoTy :: Type -> Bool + canMemoTy (TyAtom a) = canMemoAtom a + -- Anti-higher order while allowing for curried functions. + canMemoTy (TyCon CArr tys@(t : _)) = case t of + TyCon CArr _ -> False + _ -> all canMemoTy tys + canMemoTy (TyCon c tys) = canMemoCon c && all canMemoTy tys + + canMemoCon :: Con -> Bool + canMemoCon = \case + CArr -> False + CUser _ -> False + CGraph -> False + CMap -> False + CContainer a -> canMemoAtom a + _ -> True + + canMemoAtom :: Atom -> Bool + canMemoAtom (AVar _) = False + canMemoAtom (ABase b) = canMemoBase b + + canMemoBase :: BaseTy -> Bool + canMemoBase = \case + Gen -> False + P -> False + _ -> True + -- Special case for Cons, which compiles to a constructor application -- rather than a function application. compileDTerm (DTApp _ (DTPrim _ (PrimBOp Cons)) (DTPair _ t1 t2)) = @@ -234,13 +269,13 @@ compilePrim ty p@(PrimUOp _) = compilePrimErr p ty compilePrim _ (PrimBOp Cons) = do hd <- fresh (string2Name "hd") tl <- fresh (string2Name "tl") - return $ CAbs $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) + return $ CAbs NoMemo $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) compilePrim _ PrimLeft = do a <- fresh (string2Name "a") - return $ CAbs $ bind [a] $ CInj L (CVar (localName a)) + return $ CAbs NoMemo $ bind [a] $ CInj L (CVar (localName a)) compilePrim _ PrimRight = do a <- fresh (string2Name "a") - return $ CAbs $ bind [a] $ CInj R (CVar (localName a)) + return $ CAbs NoMemo $ bind [a] $ CInj R (CVar (localName a)) compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop compilePrim ty p@(PrimBOp _) = compilePrimErr p ty compilePrim _ PrimSqrt = return $ CConst OSqrt @@ -342,13 +377,13 @@ compilePrimErr p ty = error $ "Impossible! compilePrim " ++ show p ++ " on bad t -- of type (Unit → τ), in order to delay evaluation until explicitly -- applying it to the unit value. compileCase :: Member Fresh r => [DBranch] -> Sem r Core -compileCase [] = return $ CAbs (bind [string2Name "_"] (CConst OMatchErr)) +compileCase [] = return $ CAbs NoMemo (bind [string2Name "_"] (CConst OMatchErr)) -- empty case ==> λ _ . error compileCase (b : bs) = do c1 <- compileBranch b c2 <- compileCase bs - return $ CAbs (bind [string2Name "_"] (CApp c1 c2)) + return $ CAbs NoMemo (bind [string2Name "_"] (CApp c1 c2)) -- | Compile a branch of a case expression of type τ to a core -- language expression of type (Unit → τ) → τ. The idea is that it @@ -362,7 +397,7 @@ compileBranch b = do c <- compileDTerm e k <- fresh (string2Name "k") -- Fresh name for the failure continuation bc <- compileGuards (fromTelescope gs) k c - return $ CAbs (bind [k] bc) + return $ CAbs NoMemo (bind [k] bc) -- | 'compileGuards' takes a list of guards, the name of the failure -- continuation of type (Unit → τ), and a Core term of type τ to @@ -384,13 +419,13 @@ compileGuards (DGPat (unembed -> s) p : gs) k e = do -- calls the failure continuation in the case of failure, or the -- rest of the guards in the case of success. compileMatch :: Member Fresh r => DPattern -> Core -> Name Core -> Core -> Sem r Core -compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs (bind [coerce x] e)) s +compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs NoMemo (bind [coerce x] e)) s -- Note in the below two cases that we can't just discard s since -- that would result in a lazy semantics. With an eager/strict -- semantics, we have to make sure s gets evaluated even if its -- value is then discarded. -compileMatch (DPWild _) s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s -compileMatch DPUnit s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s +compileMatch (DPWild _) s _ e = return $ CApp (CAbs NoMemo (bind [string2Name "_"] e)) s +compileMatch DPUnit s _ e = return $ CApp (CAbs NoMemo (bind [string2Name "_"] e)) s compileMatch (DPPair _ x1 x2) s _ e = do y <- fresh (string2Name "y") @@ -398,11 +433,12 @@ compileMatch (DPPair _ x1 x2) s _ e = do return $ CApp ( CAbs + NoMemo ( bind [y] ( CApp ( CApp - (CAbs (bind [coerce x1, coerce x2] e)) + (CAbs NoMemo (bind [coerce x1, coerce x2] e)) (CProj L (CVar (localName y))) ) (CProj R (CVar (localName y))) diff --git a/src/Disco/Interpret/CESK.hs b/src/Disco/Interpret/CESK.hs index f28d6bf1..63e2e011 100644 --- a/src/Disco/Interpret/CESK.hs +++ b/src/Disco/Interpret/CESK.hs @@ -26,6 +26,7 @@ import qualified Algebra.Graph.AdjacencyMap as AdjMap import Control.Arrow ((***), (>>>)) import Control.Monad ((>=>)) import Data.Bifunctor (first, second) +import Data.Functor (($>)) import Data.List (find) import qualified Data.List.Infinite as InfList import qualified Data.Map as M @@ -113,6 +114,9 @@ data Frame FUpdate Int | -- | Record the results of a test. FTest TestVars Env + | -- | Given the index of a memory cell and a function's arguments, + -- memoize the results of a function. + FMemo Int SimpleValue deriving (Show) ------------------------------------------------------------ @@ -172,9 +176,13 @@ step cesk = case cesk of (In CUnit _ k) -> return $ Out VUnit k (In (CPair c1 c2) e k) -> return $ In c1 e (FPairR e c2 : k) (In (CProj s c) e k) -> return $ In c e (FProj s : k) - (In (CAbs b) e k) -> do + (In (CAbs mem b) e k) -> do (xs, body) <- unbind b - return $ Out (VClo e xs body) k + case mem of + Memo -> do + cell <- allocateValue (VMap M.empty) + return $ Out (VClo (Just (cell, [])) e xs body) k + NoMemo -> return $ Out (VClo Nothing e xs body) k (In (CApp c1 c2) e k) -> return $ In c1 e (FArg e c2 : k) (In (CType ty) _ k) -> return $ Out (VType ty) k (In (CDelay b) e k) -> do @@ -194,15 +202,23 @@ step cesk = case cesk of (Out v2 (FPairL v1 : k)) -> return $ Out (VPair v1 v2) k (Out (VPair v1 v2) (FProj s : k)) -> return $ Out (selectSide s v1 v2) k (Out v (FArg e c2 : k)) -> return $ In c2 e (FApp v : k) - (Out v2 (FApp (VClo e [x] b) : k)) -> return $ In b (Ctx.insert (localName x) v2 e) k - (Out v2 (FApp (VClo e (x : xs) b) : k)) -> return $ Out (VClo (Ctx.insert (localName x) v2 e) xs b) k + (Out v (FMemo n sv : k)) -> memoSet n sv v $> Out v k + (Out v (FApp (VClo mi e [x] b) : k)) -> case mi of + Nothing -> return $ In b (Ctx.insert (localName x) v e) k + Just (n, mem) -> do + let sv = toSimpleValue $ foldr VPair VUnit (v : mem) + mv <- memoLookup n sv + case mv of + Nothing -> return $ In b (Ctx.insert (localName x) v e) (FMemo n sv : k) + Just v' -> return $ Out v' k + (Out v (FApp (VClo mi e (x : xs) b) : k)) -> return $ Out (VClo (second (v :) <$> mi) (Ctx.insert (localName x) v e) xs b) k (Out v2 (FApp (VConst op) : k)) -> appConst k op v2 (Out v2 (FApp (VFun f) : k)) -> return $ Out (f v2) k -- Annoying to repeat this code, not sure of a better way. -- The usual evaluation order (function then argument) doesn't work when -- we're applying a test function to randomly generated values. - (Out (VClo e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k - (Out (VClo e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo (Ctx.insert (localName x) v e) xs b) k + (Out (VClo _ e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k + (Out (VClo mi e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo mi (Ctx.insert (localName x) v e) xs b) k (Out (VConst op) (FArgV v : k)) -> appConst k op v (Out (VFun f) (FArgV v : k)) -> return $ Out (f v) k (Out (VRef n) (FForce : k)) -> do diff --git a/src/Disco/Value.hs b/src/Disco/Value.hs index 6472cdbb..93f774b2 100644 --- a/src/Disco/Value.hs +++ b/src/Disco/Value.hs @@ -68,9 +68,12 @@ module Disco.Value ( Mem, emptyMem, allocate, + allocateValue, allocateRec, lkup, + memoLookup, set, + memoSet, -- * Pretty-printing prettyValue', @@ -130,7 +133,7 @@ data Value where VPair :: Value -> Value -> Value -- | A closure, i.e. a function body together with its -- environment. - VClo :: Env -> [Name Core] -> Core -> Value + VClo :: Maybe (Int, [Value]) -> Env -> [Name Core] -> Core -> Value -- | A disco type can be a value. For now, there are only a very -- limited number of places this could ever show up (in -- particular, as an argument to @enumerate@ or @count@). @@ -451,6 +454,12 @@ allocate e t = do put $ Mem (n + 1) (IM.insert n (E e t) m) return n +allocateValue :: Members '[State Mem] r => Value -> Sem r Int +allocateValue v = do + Mem n m <- get + put $ Mem (n + 1) (IM.insert n (Disco.Value.V v) m) + return n + -- | Allocate new memory cells for a group of mutually recursive -- bindings, and return the indices of the allocate cells. allocateRec :: Members '[State Mem] r => Env -> [(QName Core, Core)] -> Sem r [Int] @@ -471,6 +480,19 @@ lkup n = gets (IM.lookup n . mu) set :: Members '[State Mem] r => Int -> Cell -> Sem r () set n c = modify $ \(Mem nxt m) -> Mem nxt (IM.insert n c m) +memoLookup :: Members '[State Mem] r => Int -> SimpleValue -> Sem r (Maybe Value) +memoLookup n sv = gets (mLookup . IM.lookup n . mu) + where + mLookup (Just (Disco.Value.V (VMap vmap))) = M.lookup sv vmap + mLookup _ = Nothing + +memoSet :: Members '[State Mem] r => Int -> SimpleValue -> Value -> Sem r () +memoSet n sv v = do + mc <- lkup n + case mc of + Just (Disco.Value.V (VMap vmap)) -> set n (Disco.Value.V (VMap (M.insert sv v vmap))) + _ -> return () + ------------------------------------------------------------ -- Pretty-printing values ------------------------------------------------------------