Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memoize #394

Merged
merged 17 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions example/fib.disco
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
-- | Normal: should be memoized.
!!! fib 100 == 354224848179261915075
fib : N -> N
fib 0 = 0
fib 1 = 1
fib x = fib (x .- 1) + fib (x .- 2)

-- | Multiple arrows: should be memoized.
!!! fibA 100 0 == 354224848179261915075
fibA : N -> N -> N
fibA 0 _ = 0
fibA 1 _ = 1
fibA x n = fibA (x .- 1) n + fibA (x .- 2) n

-- | Container types: can be memoized but may cause
-- performance issues depending on size of container.
!!! fibList 100 [1..100] == 354224848179261915075
fibList : N -> List(N) -> N
fibList 0 _ = 0
fibList 1 _ = 1
fibList x l = fibList (x .- 1) l + fibList (x .- 2) l

!!! fibSet 100 {1..100} == 354224848179261915075
fibSet : N -> Set(N) -> N
fibSet 0 _ = 0
fibSet 1 _ = 1
fibSet x s = fibSet (x .- 1) s + fibSet (x .- 2) s

!!! fibBag 100 ⟅1..100⟆ == 354224848179261915075
fibBag : N -> Bag(N) -> N
fibBag 0 _ = 0
fibBag 1 _ = 1
fibBag x b = fibBag (x .- 1) b + fibBag (x .- 2) b

-- | Higher-order: should not be memoized.
fibH : N -> (N -> N) -> N
fibH 0 _ = 0
fibH 1 _ = 1
fibH x f = fibH (x .- 1) f + fibH (x .- 2) f

fibHH : N -> (N -> (N -> N) -> N) -> N
fibHH 0 _ = 0
fibHH 1 _ = 1
fibHH x f = fibHH (x .- 1) f + fibHH (x .- 2) f

id : a -> a
id x = x
1 change: 0 additions & 1 deletion example/rsa.disco
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import list
-- and `decrypt` functions can be used to encrypt and decrypt
-- lists of natural numbers.


encrypt : N * N -> List(N) -> List(N)
encrypt key xs = each (encrypt1 key, xs)

Expand Down
7 changes: 5 additions & 2 deletions src/Disco/AST/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
module Disco.AST.Core (
-- * Core AST
RationalDisplay (..),
ShouldMemo (..),
Core (..),
Op (..),
opArity,
Expand Down Expand Up @@ -62,6 +63,8 @@ instance Monoid RationalDisplay where
mempty = Fraction
mappend = (P.<>)

data ShouldMemo = Memo | NoMemo deriving (Show, Generic, Data, Alpha)

-- | AST for the desugared, untyped core language.
data Core where
-- | A variable.
Expand All @@ -87,7 +90,7 @@ data Core where
-- | A projection from a product type, i.e. @fst@ or @snd@.
CProj :: Side -> Core -> Core
-- | An anonymous function.
CAbs :: Bind [Name Core] Core -> Core
CAbs :: ShouldMemo -> Bind [Name Core] Core -> Core
-- | Function application.
CApp :: Core -> Core -> Core
-- | A "test frame" under which a test case is run. Records the
Expand Down Expand Up @@ -305,7 +308,7 @@ instance Pretty Core where
CUnit -> "unit"
CPair c1 c2 -> setPA initPA $ parens (pretty c1 <> ", " <> pretty c2)
CProj s c -> withPA funPA $ selectSide s "fst" "snd" <+> rt (pretty c)
CAbs lam -> withPA initPA $ do
CAbs _ lam -> withPA initPA $ do
lunbind lam $ \(xs, body) -> "λ" <> intercalate "," (map pretty xs) <> "." <+> lt (pretty body)
CApp c1 c2 -> withPA funPA $ lt (pretty c1) <+> rt (pretty c2)
CTest xs c -> "test" <+> prettyTestVars xs <+> pretty c
Expand Down
66 changes: 51 additions & 15 deletions src/Disco/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ compileDTerm term@(DTAbs q _ _) = do
(xs, tys, body) <- unbindDeep term
cbody <- compileDTerm body
case q of
Lam -> return $ abstract xs cbody
Ex -> return $ quantify (OExists tys) (abstract xs cbody)
All -> return $ quantify (OForall tys) (abstract xs cbody)
Lam -> return $ abstract (canMemo tys) xs cbody
Ex -> return $ quantify (OExists tys) (abstract NoMemo xs cbody)
All -> return $ quantify (OForall tys) (abstract NoMemo xs cbody)
where
-- Gather nested abstractions with the same quantifier.
unbindDeep :: Member Fresh r => DTerm -> Sem r ([Name DTerm], [Type], DTerm)
Expand All @@ -181,12 +181,47 @@ compileDTerm term@(DTAbs q _ _) = do
return (name : ns, ty : tys, body)
unbindDeep t = return ([], [], t)

abstract :: [Name DTerm] -> Core -> Core
abstract xs body = CAbs (bind (map coerce xs) body)
abstract :: ShouldMemo -> [Name DTerm] -> Core -> Core
abstract m xs body = CAbs m (bind (map coerce xs) body)

quantify :: Op -> Core -> Core
quantify op = CApp (CConst op)

-- Given a function's arguments, determine if it is memoizable.
-- A function is memoizable if its arguments can be converted into
-- a simple value (Haskell Ord instance can be derived).
canMemo :: [Type] -> ShouldMemo
canMemo tys
| all canMemoTy tys = Memo
| otherwise = NoMemo

canMemoTy :: Type -> Bool
canMemoTy (TyAtom a) = canMemoAtom a
-- Anti-higher order while allowing for curried functions.
canMemoTy (TyCon CArr tys@(t : _)) = case t of
TyCon CArr _ -> False
_ -> all canMemoTy tys
canMemoTy (TyCon c tys) = canMemoCon c && all canMemoTy tys

canMemoCon :: Con -> Bool
canMemoCon = \case
CArr -> False
CUser _ -> False
CGraph -> False
CMap -> False
CContainer a -> canMemoAtom a
_ -> True

canMemoAtom :: Atom -> Bool
canMemoAtom (AVar _) = False
canMemoAtom (ABase b) = canMemoBase b

canMemoBase :: BaseTy -> Bool
canMemoBase = \case
Gen -> False
P -> False
_ -> True

-- Special case for Cons, which compiles to a constructor application
-- rather than a function application.
compileDTerm (DTApp _ (DTPrim _ (PrimBOp Cons)) (DTPair _ t1 t2)) =
Expand Down Expand Up @@ -234,13 +269,13 @@ compilePrim ty p@(PrimUOp _) = compilePrimErr p ty
compilePrim _ (PrimBOp Cons) = do
hd <- fresh (string2Name "hd")
tl <- fresh (string2Name "tl")
return $ CAbs $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl)))
return $ CAbs NoMemo $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl)))
compilePrim _ PrimLeft = do
a <- fresh (string2Name "a")
return $ CAbs $ bind [a] $ CInj L (CVar (localName a))
return $ CAbs NoMemo $ bind [a] $ CInj L (CVar (localName a))
compilePrim _ PrimRight = do
a <- fresh (string2Name "a")
return $ CAbs $ bind [a] $ CInj R (CVar (localName a))
return $ CAbs NoMemo $ bind [a] $ CInj R (CVar (localName a))
compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop
compilePrim ty p@(PrimBOp _) = compilePrimErr p ty
compilePrim _ PrimSqrt = return $ CConst OSqrt
Expand Down Expand Up @@ -342,13 +377,13 @@ compilePrimErr p ty = error $ "Impossible! compilePrim " ++ show p ++ " on bad t
-- of type (Unit → τ), in order to delay evaluation until explicitly
-- applying it to the unit value.
compileCase :: Member Fresh r => [DBranch] -> Sem r Core
compileCase [] = return $ CAbs (bind [string2Name "_"] (CConst OMatchErr))
compileCase [] = return $ CAbs NoMemo (bind [string2Name "_"] (CConst OMatchErr))
-- empty case ==> λ _ . error

compileCase (b : bs) = do
c1 <- compileBranch b
c2 <- compileCase bs
return $ CAbs (bind [string2Name "_"] (CApp c1 c2))
return $ CAbs NoMemo (bind [string2Name "_"] (CApp c1 c2))

-- | Compile a branch of a case expression of type τ to a core
-- language expression of type (Unit → τ) → τ. The idea is that it
Expand All @@ -362,7 +397,7 @@ compileBranch b = do
c <- compileDTerm e
k <- fresh (string2Name "k") -- Fresh name for the failure continuation
bc <- compileGuards (fromTelescope gs) k c
return $ CAbs (bind [k] bc)
return $ CAbs NoMemo (bind [k] bc)

-- | 'compileGuards' takes a list of guards, the name of the failure
-- continuation of type (Unit → τ), and a Core term of type τ to
Expand All @@ -384,25 +419,26 @@ compileGuards (DGPat (unembed -> s) p : gs) k e = do
-- calls the failure continuation in the case of failure, or the
-- rest of the guards in the case of success.
compileMatch :: Member Fresh r => DPattern -> Core -> Name Core -> Core -> Sem r Core
compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs (bind [coerce x] e)) s
compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs NoMemo (bind [coerce x] e)) s
-- Note in the below two cases that we can't just discard s since
-- that would result in a lazy semantics. With an eager/strict
-- semantics, we have to make sure s gets evaluated even if its
-- value is then discarded.
compileMatch (DPWild _) s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s
compileMatch DPUnit s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s
compileMatch (DPWild _) s _ e = return $ CApp (CAbs NoMemo (bind [string2Name "_"] e)) s
compileMatch DPUnit s _ e = return $ CApp (CAbs NoMemo (bind [string2Name "_"] e)) s
compileMatch (DPPair _ x1 x2) s _ e = do
y <- fresh (string2Name "y")

-- {? e when s is (x1,x2) ?} ==> (\y. (\x1.\x2. e) (fst y) (snd y)) s
return $
CApp
( CAbs
NoMemo
( bind
[y]
( CApp
( CApp
(CAbs (bind [coerce x1, coerce x2] e))
(CAbs NoMemo (bind [coerce x1, coerce x2] e))
(CProj L (CVar (localName y)))
)
(CProj R (CVar (localName y)))
Expand Down
28 changes: 22 additions & 6 deletions src/Disco/Interpret/CESK.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import qualified Algebra.Graph.AdjacencyMap as AdjMap
import Control.Arrow ((***), (>>>))
import Control.Monad ((>=>))
import Data.Bifunctor (first, second)
import Data.Functor (($>))
import Data.List (find)
import qualified Data.List.Infinite as InfList
import qualified Data.Map as M
Expand Down Expand Up @@ -113,6 +114,9 @@ data Frame
FUpdate Int
| -- | Record the results of a test.
FTest TestVars Env
| -- | Given the index of a memory cell and a function's arguments,
-- memoize the results of a function.
FMemo Int SimpleValue
deriving (Show)

------------------------------------------------------------
Expand Down Expand Up @@ -172,9 +176,13 @@ step cesk = case cesk of
(In CUnit _ k) -> return $ Out VUnit k
(In (CPair c1 c2) e k) -> return $ In c1 e (FPairR e c2 : k)
(In (CProj s c) e k) -> return $ In c e (FProj s : k)
(In (CAbs b) e k) -> do
(In (CAbs mem b) e k) -> do
(xs, body) <- unbind b
return $ Out (VClo e xs body) k
case mem of
Memo -> do
cell <- allocateValue (VMap M.empty)
return $ Out (VClo (Just (cell, [])) e xs body) k
NoMemo -> return $ Out (VClo Nothing e xs body) k
(In (CApp c1 c2) e k) -> return $ In c1 e (FArg e c2 : k)
(In (CType ty) _ k) -> return $ Out (VType ty) k
(In (CDelay b) e k) -> do
Expand All @@ -194,15 +202,23 @@ step cesk = case cesk of
(Out v2 (FPairL v1 : k)) -> return $ Out (VPair v1 v2) k
(Out (VPair v1 v2) (FProj s : k)) -> return $ Out (selectSide s v1 v2) k
(Out v (FArg e c2 : k)) -> return $ In c2 e (FApp v : k)
(Out v2 (FApp (VClo e [x] b) : k)) -> return $ In b (Ctx.insert (localName x) v2 e) k
(Out v2 (FApp (VClo e (x : xs) b) : k)) -> return $ Out (VClo (Ctx.insert (localName x) v2 e) xs b) k
(Out v (FMemo n sv : k)) -> memoSet n sv v $> Out v k
(Out v (FApp (VClo mi e [x] b) : k)) -> case mi of
Nothing -> return $ In b (Ctx.insert (localName x) v e) k
Just (n, mem) -> do
let sv = toSimpleValue $ foldr VPair VUnit (v : mem)
mv <- memoLookup n sv
case mv of
Nothing -> return $ In b (Ctx.insert (localName x) v e) (FMemo n sv : k)
Just v' -> return $ Out v' k
(Out v (FApp (VClo mi e (x : xs) b) : k)) -> return $ Out (VClo (second (v :) <$> mi) (Ctx.insert (localName x) v e) xs b) k
(Out v2 (FApp (VConst op) : k)) -> appConst k op v2
(Out v2 (FApp (VFun f) : k)) -> return $ Out (f v2) k
-- Annoying to repeat this code, not sure of a better way.
-- The usual evaluation order (function then argument) doesn't work when
-- we're applying a test function to randomly generated values.
(Out (VClo e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k
(Out (VClo e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo (Ctx.insert (localName x) v e) xs b) k
(Out (VClo _ e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k
(Out (VClo mi e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo mi (Ctx.insert (localName x) v e) xs b) k
(Out (VConst op) (FArgV v : k)) -> appConst k op v
(Out (VFun f) (FArgV v : k)) -> return $ Out (f v) k
(Out (VRef n) (FForce : k)) -> do
Expand Down
24 changes: 23 additions & 1 deletion src/Disco/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,12 @@ module Disco.Value (
Mem,
emptyMem,
allocate,
allocateValue,
allocateRec,
lkup,
memoLookup,
set,
memoSet,

-- * Pretty-printing
prettyValue',
Expand Down Expand Up @@ -130,7 +133,7 @@ data Value where
VPair :: Value -> Value -> Value
-- | A closure, i.e. a function body together with its
-- environment.
VClo :: Env -> [Name Core] -> Core -> Value
VClo :: Maybe (Int, [Value]) -> Env -> [Name Core] -> Core -> Value
-- | A disco type can be a value. For now, there are only a very
-- limited number of places this could ever show up (in
-- particular, as an argument to @enumerate@ or @count@).
Expand Down Expand Up @@ -451,6 +454,12 @@ allocate e t = do
put $ Mem (n + 1) (IM.insert n (E e t) m)
return n

allocateValue :: Members '[State Mem] r => Value -> Sem r Int
allocateValue v = do
Mem n m <- get
put $ Mem (n + 1) (IM.insert n (Disco.Value.V v) m)
return n

-- | Allocate new memory cells for a group of mutually recursive
-- bindings, and return the indices of the allocate cells.
allocateRec :: Members '[State Mem] r => Env -> [(QName Core, Core)] -> Sem r [Int]
Expand All @@ -471,6 +480,19 @@ lkup n = gets (IM.lookup n . mu)
set :: Members '[State Mem] r => Int -> Cell -> Sem r ()
set n c = modify $ \(Mem nxt m) -> Mem nxt (IM.insert n c m)

memoLookup :: Members '[State Mem] r => Int -> SimpleValue -> Sem r (Maybe Value)
memoLookup n sv = gets (mLookup . IM.lookup n . mu)
where
mLookup (Just (Disco.Value.V (VMap vmap))) = M.lookup sv vmap
mLookup _ = Nothing

memoSet :: Members '[State Mem] r => Int -> SimpleValue -> Value -> Sem r ()
memoSet n sv v = do
mc <- lkup n
case mc of
Just (Disco.Value.V (VMap vmap)) -> set n (Disco.Value.V (VMap (M.insert sv v vmap)))
_ -> return ()

------------------------------------------------------------
-- Pretty-printing values
------------------------------------------------------------
Expand Down