Skip to content

Commit

Permalink
Beginning work on auto-memoization
Browse files Browse the repository at this point in the history
  • Loading branch information
justingrubbs committed Jun 30, 2024
1 parent 10f21fb commit 6062daf
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/Disco/AST/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ data Core where
-- | A projection from a product type, i.e. @fst@ or @snd@.
CProj :: Side -> Core -> Core
-- | An anonymous function.
CAbs :: Bind [Name Core] Core -> Core
CAbs :: Bool -> Bind [Name Core] Core -> Core
-- | Function application.
CApp :: Core -> Core -> Core
-- | A "test frame" under which a test case is run. Records the
Expand Down Expand Up @@ -305,7 +305,7 @@ instance Pretty Core where
CUnit -> "unit"
CPair c1 c2 -> setPA initPA $ parens (pretty c1 <> ", " <> pretty c2)
CProj s c -> withPA funPA $ selectSide s "fst" "snd" <+> rt (pretty c)
CAbs lam -> withPA initPA $ do
CAbs _ lam -> withPA initPA $ do
lunbind lam $ \(xs, body) -> "λ" <> intercalate "," (map pretty xs) <> "." <+> lt (pretty body)
CApp c1 c2 -> withPA funPA $ lt (pretty c1) <+> rt (pretty c2)
CTest xs c -> "test" <+> prettyTestVars xs <+> pretty c
Expand Down
32 changes: 20 additions & 12 deletions src/Disco/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@ compileDTerm term@(DTAbs q _ _) = do
(xs, tys, body) <- unbindDeep term
cbody <- compileDTerm body
case q of

Lam -> return $ abstract xs cbody
-- Lam -> case _ of
-- _ -> return $ abstract xs cbody
-- _ -> return $ abstractMemo xs cbody

Ex -> return $ quantify (OExists tys) (abstract xs cbody)
All -> return $ quantify (OForall tys) (abstract xs cbody)
where
Expand All @@ -182,7 +187,10 @@ compileDTerm term@(DTAbs q _ _) = do
unbindDeep t = return ([], [], t)

abstract :: [Name DTerm] -> Core -> Core
abstract xs body = CAbs (bind (map coerce xs) body)
abstract xs body = CAbs False (bind (map coerce xs) body)

abstractMemo :: [Name DTerm] -> Core -> Core
abstractMemo xs body = CAbs True (bind (map coerce xs) body)

quantify :: Op -> Core -> Core
quantify op = CApp (CConst op)
Expand Down Expand Up @@ -234,13 +242,13 @@ compilePrim ty p@(PrimUOp _) = compilePrimErr p ty
compilePrim _ (PrimBOp Cons) = do
hd <- fresh (string2Name "hd")
tl <- fresh (string2Name "tl")
return $ CAbs $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl)))
return $ CAbs False $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl)))
compilePrim _ PrimLeft = do
a <- fresh (string2Name "a")
return $ CAbs $ bind [a] $ CInj L (CVar (localName a))
return $ CAbs False $ bind [a] $ CInj L (CVar (localName a))
compilePrim _ PrimRight = do
a <- fresh (string2Name "a")
return $ CAbs $ bind [a] $ CInj R (CVar (localName a))
return $ CAbs False $ bind [a] $ CInj R (CVar (localName a))
compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop
compilePrim ty p@(PrimBOp _) = compilePrimErr p ty
compilePrim _ PrimSqrt = return $ CConst OSqrt
Expand Down Expand Up @@ -342,13 +350,13 @@ compilePrimErr p ty = error $ "Impossible! compilePrim " ++ show p ++ " on bad t
-- of type (Unit → τ), in order to delay evaluation until explicitly
-- applying it to the unit value.
compileCase :: Member Fresh r => [DBranch] -> Sem r Core
compileCase [] = return $ CAbs (bind [string2Name "_"] (CConst OMatchErr))
compileCase [] = return $ CAbs False (bind [string2Name "_"] (CConst OMatchErr))
-- empty case ==> λ _ . error

compileCase (b : bs) = do
c1 <- compileBranch b
c2 <- compileCase bs
return $ CAbs (bind [string2Name "_"] (CApp c1 c2))
return $ CAbs False (bind [string2Name "_"] (CApp c1 c2))

-- | Compile a branch of a case expression of type τ to a core
-- language expression of type (Unit → τ) → τ. The idea is that it
Expand All @@ -362,7 +370,7 @@ compileBranch b = do
c <- compileDTerm e
k <- fresh (string2Name "k") -- Fresh name for the failure continuation
bc <- compileGuards (fromTelescope gs) k c
return $ CAbs (bind [k] bc)
return $ CAbs False (bind [k] bc)

-- | 'compileGuards' takes a list of guards, the name of the failure
-- continuation of type (Unit → τ), and a Core term of type τ to
Expand All @@ -384,25 +392,25 @@ compileGuards (DGPat (unembed -> s) p : gs) k e = do
-- calls the failure continuation in the case of failure, or the
-- rest of the guards in the case of success.
compileMatch :: Member Fresh r => DPattern -> Core -> Name Core -> Core -> Sem r Core
compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs (bind [coerce x] e)) s
compileMatch (DPVar _ x) s _ e = return $ CApp (CAbs False (bind [coerce x] e)) s
-- Note in the below two cases that we can't just discard s since
-- that would result in a lazy semantics. With an eager/strict
-- semantics, we have to make sure s gets evaluated even if its
-- value is then discarded.
compileMatch (DPWild _) s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s
compileMatch DPUnit s _ e = return $ CApp (CAbs (bind [string2Name "_"] e)) s
compileMatch (DPWild _) s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s
compileMatch DPUnit s _ e = return $ CApp (CAbs False (bind [string2Name "_"] e)) s
compileMatch (DPPair _ x1 x2) s _ e = do
y <- fresh (string2Name "y")

-- {? e when s is (x1,x2) ?} ==> (\y. (\x1.\x2. e) (fst y) (snd y)) s
return $
CApp
( CAbs
( CAbs False
( bind
[y]
( CApp
( CApp
(CAbs (bind [coerce x1, coerce x2] e))
(CAbs False (bind [coerce x1, coerce x2] e))
(CProj L (CVar (localName y)))
)
(CProj R (CVar (localName y)))
Expand Down
28 changes: 22 additions & 6 deletions src/Disco/Interpret/CESK.hs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ data Frame
FUpdate Int
| -- | Record the results of a test.
FTest TestVars Env
| -- | Memoize the results of a function call. (Memory cell - Args)
FMemo Int SimpleValue
deriving (Show)

------------------------------------------------------------
Expand Down Expand Up @@ -172,9 +174,14 @@ step cesk = case cesk of
(In CUnit _ k) -> return $ Out VUnit k
(In (CPair c1 c2) e k) -> return $ In c1 e (FPairR e c2 : k)
(In (CProj s c) e k) -> return $ In c e (FProj s : k)
(In (CAbs b) e k) -> do
(In (CAbs True b) e k) -> do
(xs, body) <- unbind b
return $ Out (VClo e xs body) k
-- Init memo map and memory loc
cell <- allocateV
return $ Out (VClo (Just cell) e xs body) k
(In (CAbs False b) e k) -> do
(xs, body) <- unbind b
return $ Out (VClo Nothing e xs body) k
(In (CApp c1 c2) e k) -> return $ In c1 e (FArg e c2 : k)
(In (CType ty) _ k) -> return $ Out (VType ty) k
(In (CDelay b) e k) -> do
Expand All @@ -194,15 +201,24 @@ step cesk = case cesk of
(Out v2 (FPairL v1 : k)) -> return $ Out (VPair v1 v2) k
(Out (VPair v1 v2) (FProj s : k)) -> return $ Out (selectSide s v1 v2) k
(Out v (FArg e c2 : k)) -> return $ In c2 e (FApp v : k)
(Out v2 (FApp (VClo e [x] b) : k)) -> return $ In b (Ctx.insert (localName x) v2 e) k
(Out v2 (FApp (VClo e (x : xs) b) : k)) -> return $ Out (VClo (Ctx.insert (localName x) v2 e) xs b) k

(Out v2 (FApp (VClo mi e [x] b) : k)) -> case mi of
Just n -> undefined
Nothing -> return $ In b (Ctx.insert (localName x) v2 e) k

(Out v2 (FApp (VClo mi e (x : xs) b) : k)) -> case mi of
Just n -> return $ Out (VClo mi (Ctx.insert (localName x) v2 e) xs b) (FMemo n _ : k)
Nothing -> return $ Out (VClo mi (Ctx.insert (localName x) v2 e) xs b) k

(Out v2 (FApp (VConst op) : k)) -> appConst k op v2
(Out v2 (FApp (VFun f) : k)) -> return $ Out (f v2) k
-- Annoying to repeat this code, not sure of a better way.
-- The usual evaluation order (function then argument) doesn't work when
-- we're applying a test function to randomly generated values.
(Out (VClo e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k
(Out (VClo e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo (Ctx.insert (localName x) v e) xs b) k

(Out (VClo _ e [x] b) (FArgV v : k)) -> return $ In b (Ctx.insert (localName x) v e) k
(Out (VClo mi e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo mi (Ctx.insert (localName x) v e) xs b) k

(Out (VConst op) (FArgV v : k)) -> appConst k op v
(Out (VFun f) (FArgV v : k)) -> return $ Out (f v) k
(Out (VRef n) (FForce : k)) -> do
Expand Down
9 changes: 8 additions & 1 deletion src/Disco/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ module Disco.Value (
Mem,
emptyMem,
allocate,
allocateV,
allocateRec,
lkup,
set,
Expand Down Expand Up @@ -130,7 +131,7 @@ data Value where
VPair :: Value -> Value -> Value
-- | A closure, i.e. a function body together with its
-- environment.
VClo :: Env -> [Name Core] -> Core -> Value
VClo :: Maybe Int -> Env -> [Name Core] -> Core -> Value
-- | A disco type can be a value. For now, there are only a very
-- limited number of places this could ever show up (in
-- particular, as an argument to @enumerate@ or @count@).
Expand Down Expand Up @@ -451,6 +452,12 @@ allocate e t = do
put $ Mem (n + 1) (IM.insert n (E e t) m)
return n

allocateV :: Members '[State Mem] r => Sem r Int
allocateV = do
Mem n m <- get
put $ Mem (n + 1) (IM.insert n (Disco.Value.V (VMap M.empty)) m)
return n

-- | Allocate new memory cells for a group of mutually recursive
-- bindings, and return the indices of the allocate cells.
allocateRec :: Members '[State Mem] r => Env -> [(QName Core, Core)] -> Sem r [Int]
Expand Down

0 comments on commit 6062daf

Please sign in to comment.