diff --git a/.restyled.yaml b/.restyled.yaml index 017c83a6..6669b356 100644 --- a/.restyled.yaml +++ b/.restyled.yaml @@ -1,3 +1,7 @@ +restylers_version: stable restylers: - - stylish-haskell + - fourmolu: + image: 'restyled/restyler-fourmolu:v0.13.0.0' + arguments: + [] - hlint diff --git a/fourmolu.yaml b/fourmolu.yaml new file mode 100644 index 00000000..56d9ce84 --- /dev/null +++ b/fourmolu.yaml @@ -0,0 +1,13 @@ +indentation: 2 +comma-style: leading +record-brace-space: true +indent-wheres: false # 'false' means save space by only half-indenting the 'where' keyword +diff-friendly-import-export: true +let-style: inline +respectful: true +single-constraint-parens: auto +haddock-style: single-line +newlines-between-decls: 1 +reexports: + - module Text.Megaparsec exports Control.Applicative + - module Options.Applicative exports Control.Applicative diff --git a/repl/REPL.hs b/repl/REPL.hs index 0abef810..d72c129b 100644 --- a/repl/REPL.hs +++ b/repl/REPL.hs @@ -1,5 +1,5 @@ ----------------------------------------------------------------------------- --- | +-- \| -- Module : REPL -- Copyright : disco team and contributors -- Maintainer : byorgey@gmail.com @@ -7,10 +7,9 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- A text-based REPL for disco. --- ----------------------------------------------------------------------------- -import Disco.Interactive.CmdLine +import Disco.Interactive.CmdLine main :: IO () main = discoMain diff --git a/src/Disco/AST/Core.hs b/src/Disco/AST/Core.hs index ffef0768..e8178646 100644 --- a/src/Disco/AST/Core.hs +++ b/src/Disco/AST/Core.hs @@ -1,10 +1,9 @@ -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE NondecreasingIndentation #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE UndecidableInstances #-} ------------------------------------------------------------------------------ -- | -- Module : Disco.AST.Core -- Copyright : disco team and contributors @@ -14,34 +13,35 @@ -- -- Abstract syntax trees representing the desugared, untyped core -- language for Disco. ------------------------------------------------------------------------------ +module Disco.AST.Core ( + -- * Core AST + RationalDisplay (..), + Core (..), + Op (..), + opArity, + substQC, + substsQC, +) +where -module Disco.AST.Core - ( -- * Core AST - RationalDisplay(..) - , Core(..) - , Op(..), opArity, substQC, substsQC - ) - where +import Control.Lens.Plated +import Data.Data (Data) +import Data.Data.Lens (uniplate) +import qualified Data.Set as S +import GHC.Generics +import Unbound.Generics.LocallyNameless hiding (LFresh, lunbind) +import Prelude hiding ((<>)) +import qualified Prelude as P -import Control.Lens.Plated -import Data.Data (Data) -import Data.Data.Lens (uniplate) -import qualified Data.Set as S -import GHC.Generics -import Prelude hiding ((<>)) -import qualified Prelude as P -import Unbound.Generics.LocallyNameless hiding (LFresh, lunbind) +import Disco.Effects.LFresh +import Polysemy (Members, Sem) +import Polysemy.Reader -import Disco.Effects.LFresh -import Polysemy (Members, Sem) -import Polysemy.Reader - -import Data.Ratio -import Disco.AST.Generic (Side, selectSide) -import Disco.Names (QName) -import Disco.Pretty -import Disco.Types +import Data.Ratio +import Disco.AST.Generic (Side, selectSide) +import Disco.Names (QName) +import Disco.Pretty +import Disco.Types -- | A type of flags specifying whether to display a rational number -- as a fraction or a decimal. @@ -51,7 +51,7 @@ data RationalDisplay = Fraction | Decimal instance Semigroup RationalDisplay where Decimal <> _ = Decimal _ <> Decimal = Decimal - _ <> _ = Fraction + _ <> _ = Fraction -- | The 'Monoid' instance for 'RationalDisplay' corresponds to the -- idea that the result should be displayed as a decimal if any @@ -232,8 +232,7 @@ data Op OShouldEq Type | -- Other primitives OShouldLt Type - | - -- | Error for non-exhaustive pattern match + | -- | Error for non-exhaustive pattern match OMatchErr | -- | Crash with a user-supplied message OCrash @@ -252,7 +251,6 @@ data Op | -- | Not the Boolean `Impl`, but instead a propositional BOp -- | Should only be seen and used with Props. OImpl - deriving (Show, Generic, Data, Alpha, Eq, Ord) -- | Get the arity (desired number of arguments) of a function @@ -260,8 +258,8 @@ data Op -- uncurried and hence has arity 1. opArity :: Op -> Int opArity OEmptyGraph = 0 -opArity OMatchErr = 0 -opArity _ = 1 +opArity OMatchErr = 0 +opArity _ = 1 substQC :: QName Core -> Core -> Core -> Core substQC x s = transform $ \case @@ -274,64 +272,68 @@ substsQC :: [(QName Core, Core)] -> Core -> Core substsQC xs = transform $ \case CVar y -> case P.lookup y xs of Just c -> c - _ -> CVar y + _ -> CVar y t -> t instance Pretty Core where pretty = \case - CVar qn -> pretty qn + CVar qn -> pretty qn CNum _ r | denominator r == 1 -> text (show (numerator r)) - | otherwise -> text (show (numerator r)) <> "/" <> text (show (denominator r)) + | otherwise -> text (show (numerator r)) <> "/" <> text (show (denominator r)) CApp (CConst op) (CPair c1 c2) | isInfix op -> parens (pretty c1 <+> text (opToStr op) <+> pretty c2) CApp (CConst op) c | isPrefix op -> text (opToStr op) <> pretty c | isPostfix op -> pretty c <> text (opToStr op) - CConst op -> pretty op - CInj s c -> withPA funPA $ selectSide s "left" "right" <+> rt (pretty c) - CCase c l r -> do + CConst op -> pretty op + CInj s c -> withPA funPA $ selectSide s "left" "right" <+> rt (pretty c) + CCase c l r -> do lunbind l $ \(x, lc) -> do - lunbind r $ \(y, rc) -> do - "case" <+> pretty c <+> "of {" - $+$ nest 2 ( - vcat - [ withPA funPA $ "left" <+> rt (pretty x) <+> "->" <+> pretty lc - , withPA funPA $ "right" <+> rt (pretty y) <+> "->" <+> pretty rc - ]) - $+$ "}" - CUnit -> "unit" - CPair c1 c2 -> setPA initPA $ parens (pretty c1 <> ", " <> pretty c2) - CProj s c -> withPA funPA $ selectSide s "fst" "snd" <+> rt (pretty c) - CAbs lam -> withPA initPA $ do + lunbind r $ \(y, rc) -> do + "case" + <+> pretty c + <+> "of {" + $+$ nest + 2 + ( vcat + [ withPA funPA $ "left" <+> rt (pretty x) <+> "->" <+> pretty lc + , withPA funPA $ "right" <+> rt (pretty y) <+> "->" <+> pretty rc + ] + ) + $+$ "}" + CUnit -> "unit" + CPair c1 c2 -> setPA initPA $ parens (pretty c1 <> ", " <> pretty c2) + CProj s c -> withPA funPA $ selectSide s "fst" "snd" <+> rt (pretty c) + CAbs lam -> withPA initPA $ do lunbind lam $ \(xs, body) -> "λ" <> intercalate "," (map pretty xs) <> "." <+> lt (pretty body) - CApp c1 c2 -> withPA funPA $ lt (pretty c1) <+> rt (pretty c2) - CTest xs c -> "test" <+> prettyTestVars xs <+> pretty c - CType ty -> pretty ty - CDelay d -> withPA initPA $ do + CApp c1 c2 -> withPA funPA $ lt (pretty c1) <+> rt (pretty c2) + CTest xs c -> "test" <+> prettyTestVars xs <+> pretty c + CType ty -> pretty ty + CDelay d -> withPA initPA $ do lunbind d $ \(xs, bodies) -> "delay" <+> intercalate "," (map pretty xs) <> "." <+> pretty (toTuple bodies) - CForce c -> withPA funPA $ "force" <+> rt (pretty c) + CForce c -> withPA funPA $ "force" <+> rt (pretty c) toTuple :: [Core] -> Core toTuple = foldr CPair CUnit prettyTestVars :: Members '[Reader PA, LFresh] r => [(String, Type, Name Core)] -> Sem r Doc prettyTestVars = brackets . intercalate "," . map prettyTestVar - where - prettyTestVar (s, ty, n) = parens (intercalate "," [text s, pretty ty, pretty n]) + where + prettyTestVar (s, ty, n) = parens (intercalate "," [text s, pretty ty, pretty n]) isInfix, isPrefix, isPostfix :: Op -> Bool -isInfix OShouldEq{} = True -isInfix OShouldLt{} = True -isInfix op = op `S.member` S.fromList - [ OAdd, OMul, ODiv, OExp, OMod, ODivides, OMultinom, OEq, OLt, OAnd, OOr, OImpl] - +isInfix OShouldEq {} = True +isInfix OShouldLt {} = True +isInfix op = + op + `S.member` S.fromList + [OAdd, OMul, ODiv, OExp, OMod, ODivides, OMultinom, OEq, OLt, OAnd, OOr, OImpl] isPrefix ONeg = True -isPrefix _ = False - +isPrefix _ = False isPostfix OFact = True -isPostfix _ = False +isPostfix _ = False instance Pretty Op where pretty (OForall tys) = "∀" <> intercalate "," (map pretty tys) <> "." @@ -340,67 +342,67 @@ instance Pretty Op where | isInfix op = "~" <> text (opToStr op) <> "~" | isPrefix op = text (opToStr op) <> "~" | isPostfix op = "~" <> text (opToStr op) - | otherwise = text (opToStr op) + | otherwise = text (opToStr op) opToStr :: Op -> String opToStr = \case - OAdd -> "+" - ONeg -> "-" - OSqrt -> "sqrt" - OFloor -> "floor" - OCeil -> "ceil" - OAbs -> "abs" - OMul -> "*" - ODiv -> "/" - OExp -> "^" - OMod -> "mod" - ODivides -> "divides" - OMultinom -> "choose" - OFact -> "!" - OEq -> "==" - OLt -> "<" - OEnum -> "enumerate" - OCount -> "count" - OPower -> "power" - OBagElem -> "elem_bag" - OListElem -> "elem_list" - OEachBag -> "each_bag" - OEachSet -> "each_set" - OFilterBag -> "filter_bag" - OMerge -> "merge" - OBagUnions -> "unions_bag" - OSummary -> "summary" - OEmptyGraph -> "emptyGraph" - OVertex -> "vertex" - OOverlay -> "overlay" - OConnect -> "connect" - OInsert -> "insert" - OLookup -> "lookup" - OUntil -> "until" - OSetToList -> "set2list" - OBagToSet -> "bag2set" - OBagToList -> "bag2list" - OListToSet -> "list2set" - OListToBag -> "list2bag" - OBagToCounts -> "bag2counts" - OCountsToBag -> "counts2bag" + OAdd -> "+" + ONeg -> "-" + OSqrt -> "sqrt" + OFloor -> "floor" + OCeil -> "ceil" + OAbs -> "abs" + OMul -> "*" + ODiv -> "/" + OExp -> "^" + OMod -> "mod" + ODivides -> "divides" + OMultinom -> "choose" + OFact -> "!" + OEq -> "==" + OLt -> "<" + OEnum -> "enumerate" + OCount -> "count" + OPower -> "power" + OBagElem -> "elem_bag" + OListElem -> "elem_list" + OEachBag -> "each_bag" + OEachSet -> "each_set" + OFilterBag -> "filter_bag" + OMerge -> "merge" + OBagUnions -> "unions_bag" + OSummary -> "summary" + OEmptyGraph -> "emptyGraph" + OVertex -> "vertex" + OOverlay -> "overlay" + OConnect -> "connect" + OInsert -> "insert" + OLookup -> "lookup" + OUntil -> "until" + OSetToList -> "set2list" + OBagToSet -> "bag2set" + OBagToList -> "bag2list" + OListToSet -> "list2set" + OListToBag -> "list2bag" + OBagToCounts -> "bag2counts" + OCountsToBag -> "counts2bag" OUnsafeCountsToBag -> "ucounts2bag" - OMapToSet -> "map2set" - OSetToMap -> "set2map" - OIsPrime -> "isPrime" - OFactor -> "factor" - OFrac -> "frac" - OHolds -> "holds" - ONotProp -> "not" - OShouldEq _ -> "=!=" - OShouldLt _ -> "!<" - OMatchErr -> "matchErr" - OCrash -> "crash" - OId -> "id" - OLookupSeq -> "lookupSeq" - OExtendSeq -> "extendSeq" - OForall{} -> "∀" - OExists{} -> "∃" - OAnd -> "and" - OOr -> "or" - OImpl -> "implies" + OMapToSet -> "map2set" + OSetToMap -> "set2map" + OIsPrime -> "isPrime" + OFactor -> "factor" + OFrac -> "frac" + OHolds -> "holds" + ONotProp -> "not" + OShouldEq _ -> "=!=" + OShouldLt _ -> "!<" + OMatchErr -> "matchErr" + OCrash -> "crash" + OId -> "id" + OLookupSeq -> "lookupSeq" + OExtendSeq -> "extendSeq" + OForall {} -> "∀" + OExists {} -> "∃" + OAnd -> "and" + OOr -> "or" + OImpl -> "implies" diff --git a/src/Disco/AST/Desugared.hs b/src/Disco/AST/Desugared.hs index e89d142f..4662b1e9 100644 --- a/src/Disco/AST/Desugared.hs +++ b/src/Disco/AST/Desugared.hs @@ -1,6 +1,5 @@ {-# LANGUAGE PatternSynonyms #-} ------------------------------------------------------------------------------ -- | -- Module : Disco.AST.Desugared -- Copyright : disco team and contributors @@ -10,57 +9,51 @@ -- -- Typed abstract syntax trees representing the typechecked, desugared -- Disco language. --- ------------------------------------------------------------------------------ - -module Disco.AST.Desugared - ( -- * Desugared, type-annotated terms - DTerm - , pattern DTVar - , pattern DTPrim - , pattern DTUnit - , pattern DTBool - , pattern DTChar - , pattern DTNat - , pattern DTRat - , pattern DTAbs - , pattern DTApp - , pattern DTPair - , pattern DTCase - , pattern DTTyOp - , pattern DTNil - , pattern DTTest - - , Container(..) - , DBinding - , pattern DBinding - -- * Branches and guards - , DBranch - - , DGuard - , pattern DGPat - - , DPattern - , pattern DPVar - , pattern DPWild - , pattern DPUnit - , pattern DPPair - , pattern DPInj - - , DProperty - ) - where - -import GHC.Generics - -import Data.Void -import Unbound.Generics.LocallyNameless - -import Disco.AST.Generic -import Disco.Names (QName (..)) -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Types +module Disco.AST.Desugared ( + -- * Desugared, type-annotated terms + DTerm, + pattern DTVar, + pattern DTPrim, + pattern DTUnit, + pattern DTBool, + pattern DTChar, + pattern DTNat, + pattern DTRat, + pattern DTAbs, + pattern DTApp, + pattern DTPair, + pattern DTCase, + pattern DTTyOp, + pattern DTNil, + pattern DTTest, + Container (..), + DBinding, + pattern DBinding, + + -- * Branches and guards + DBranch, + DGuard, + pattern DGPat, + DPattern, + pattern DPVar, + pattern DPWild, + pattern DPUnit, + pattern DPPair, + pattern DPInj, + DProperty, +) +where + +import GHC.Generics + +import Data.Void +import Unbound.Generics.LocallyNameless + +import Disco.AST.Generic +import Disco.Names (QName (..)) +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Types data DS @@ -69,36 +62,37 @@ type DProperty = Property_ DS -- | A @DTerm@ is a term which has been typechecked and desugared, so -- it has fewer constructors and complex features than 'ATerm', but -- still retains typing information. - type DTerm = Term_ DS -type instance X_Binder DS = Name DTerm - -type instance X_TVar DS = Void -- names are qualified -type instance X_TPrim DS = Type -type instance X_TLet DS = Void -- Let gets translated to lambda -type instance X_TUnit DS = () -type instance X_TBool DS = Type -type instance X_TChar DS = () -type instance X_TString DS = Void -type instance X_TNat DS = Type -type instance X_TRat DS = () -type instance X_TAbs DS = Type -- For lambas this is the function type but - -- for forall/exists it's the argument type -type instance X_TApp DS = Type -type instance X_TCase DS = Type -type instance X_TChain DS = Void -- Chains are translated into conjunctions of - -- binary comparisons -type instance X_TTyOp DS = Type -type instance X_TContainer DS = Void -- Literal containers are desugared into - -- conversion functions applied to list literals +type instance X_Binder DS = Name DTerm + +type instance X_TVar DS = Void -- names are qualified +type instance X_TPrim DS = Type +type instance X_TLet DS = Void -- Let gets translated to lambda +type instance X_TUnit DS = () +type instance X_TBool DS = Type +type instance X_TChar DS = () +type instance X_TString DS = Void +type instance X_TNat DS = Type +type instance X_TRat DS = () +type instance X_TAbs DS = Type -- For lambas this is the function type but +-- for forall/exists it's the argument type + +type instance X_TApp DS = Type +type instance X_TCase DS = Type +type instance X_TChain DS = Void -- Chains are translated into conjunctions of +-- binary comparisons + +type instance X_TTyOp DS = Type +type instance X_TContainer DS = Void -- Literal containers are desugared into +-- conversion functions applied to list literals type instance X_TContainerComp DS = Void -- Container comprehensions are translated - -- into monadic chains +-- into monadic chains -type instance X_TAscr DS = Void -- No type ascriptions -type instance X_TTup DS = Void -- No tuples, only pairs -type instance X_TParens DS = Void -- No explicit parens +type instance X_TAscr DS = Void -- No type ascriptions +type instance X_TTup DS = Void -- No tuples, only pairs +type instance X_TParens DS = Void -- No explicit parens -- Extra constructors type instance X_Term DS = X_DTerm @@ -125,7 +119,7 @@ pattern DTUnit = TUnit_ () pattern DTBool :: Type -> Bool -> DTerm pattern DTBool ty bool = TBool_ ty bool -pattern DTNat :: Type -> Integer -> DTerm +pattern DTNat :: Type -> Integer -> DTerm pattern DTNat ty int = TNat_ ty int pattern DTRat :: Rational -> DTerm @@ -137,7 +131,7 @@ pattern DTChar c = TChar_ () c pattern DTAbs :: Quantifier -> Type -> Bind (Name DTerm) DTerm -> DTerm pattern DTAbs q ty lam = TAbs_ q ty lam -pattern DTApp :: Type -> DTerm -> DTerm -> DTerm +pattern DTApp :: Type -> DTerm -> DTerm -> DTerm pattern DTApp ty term1 term2 = TApp_ ty term1 term2 pattern DTPair :: Type -> DTerm -> DTerm -> DTerm @@ -158,9 +152,22 @@ pattern DTNil ty = XTerm_ (DTNil_ ty) pattern DTTest :: [(String, Type, Name DTerm)] -> DTerm -> DTerm pattern DTTest ns t = XTerm_ (DTTest_ ns t) -{-# COMPLETE DTVar, DTPrim, DTUnit, DTBool, DTChar, DTNat, DTRat, - DTAbs, DTApp, DTPair, DTCase, DTTyOp, - DTNil, DTTest #-} +{-# COMPLETE + DTVar + , DTPrim + , DTUnit + , DTBool + , DTChar + , DTNat + , DTRat + , DTAbs + , DTApp + , DTPair + , DTCase + , DTTyOp + , DTNil + , DTTest + #-} type instance X_TLink DS = Void @@ -175,9 +182,9 @@ type DBranch = Bind (Telescope DGuard) DTerm type DGuard = Guard_ DS -type instance X_GBool DS = Void -- Boolean guards get desugared to pattern-matching -type instance X_GPat DS = () -type instance X_GLet DS = Void -- Let gets desugared to 'when' with a variable +type instance X_GBool DS = Void -- Boolean guards get desugared to pattern-matching +type instance X_GPat DS = () +type instance X_GLet DS = Void -- Let gets desugared to 'when' with a variable pattern DGPat :: Embed DTerm -> DPattern -> DGuard pattern DGPat embedt pat = GPat_ () embedt pat @@ -186,23 +193,23 @@ pattern DGPat embedt pat = GPat_ () embedt pat type DPattern = Pattern_ DS -type instance X_PVar DS = Embed Type -type instance X_PWild DS = Embed Type -type instance X_PAscr DS = Void -type instance X_PUnit DS = () -type instance X_PBool DS = Void -type instance X_PChar DS = Void -type instance X_PString DS = Void -type instance X_PTup DS = Void -type instance X_PInj DS = Void -type instance X_PNat DS = Void -type instance X_PCons DS = Void -type instance X_PList DS = Void -type instance X_PAdd DS = Void -type instance X_PMul DS = Void -type instance X_PSub DS = Void -type instance X_PNeg DS = Void -type instance X_PFrac DS = Void +type instance X_PVar DS = Embed Type +type instance X_PWild DS = Embed Type +type instance X_PAscr DS = Void +type instance X_PUnit DS = () +type instance X_PBool DS = Void +type instance X_PChar DS = Void +type instance X_PString DS = Void +type instance X_PTup DS = Void +type instance X_PInj DS = Void +type instance X_PNat DS = Void +type instance X_PCons DS = Void +type instance X_PList DS = Void +type instance X_PAdd DS = Void +type instance X_PMul DS = Void +type instance X_PSub DS = Void +type instance X_PNeg DS = Void +type instance X_PFrac DS = Void -- In the desugared language, constructor patterns (DPPair, DPInj) can -- only contain variables, not nested patterns. This means that the @@ -210,10 +217,11 @@ type instance X_PFrac DS = Void -- exploding nested patterns into sequential guards, which makes the -- interpreter simpler. -type instance X_Pattern DS = - Either - (Embed Type, Name DTerm, Name DTerm) -- DPPair - (Embed Type, Side, Name DTerm) -- DPInj +type instance + X_Pattern DS = + Either + (Embed Type, Name DTerm, Name DTerm) -- DPPair + (Embed Type, Side, Name DTerm) -- DPInj pattern DPVar :: Type -> Name DTerm -> DPattern pattern DPVar ty name <- PVar_ (unembed -> ty) name @@ -228,19 +236,19 @@ pattern DPWild ty <- PWild_ (unembed -> ty) pattern DPUnit :: DPattern pattern DPUnit = PUnit_ () -pattern DPPair :: Type -> Name DTerm -> Name DTerm -> DPattern +pattern DPPair :: Type -> Name DTerm -> Name DTerm -> DPattern pattern DPPair ty x1 x2 <- XPattern_ (Left (unembed -> ty, x1, x2)) where DPPair ty x1 x2 = XPattern_ (Left (embed ty, x1, x2)) -pattern DPInj :: Type -> Side -> Name DTerm -> DPattern +pattern DPInj :: Type -> Side -> Name DTerm -> DPattern pattern DPInj ty s x <- XPattern_ (Right (unembed -> ty, s, x)) where DPInj ty s x = XPattern_ (Right (embed ty, s, x)) {-# COMPLETE DPVar, DPWild, DPUnit, DPPair, DPInj #-} -type instance X_QBind DS = Void +type instance X_QBind DS = Void type instance X_QGuard DS = Void ------------------------------------------------------------ @@ -248,25 +256,25 @@ type instance X_QGuard DS = Void ------------------------------------------------------------ instance HasType DTerm where - getType (DTVar ty _) = ty - getType (DTPrim ty _) = ty - getType DTUnit = TyUnit - getType (DTBool ty _) = ty - getType (DTChar _) = TyC - getType (DTNat ty _) = ty - getType (DTRat _) = TyF + getType (DTVar ty _) = ty + getType (DTPrim ty _) = ty + getType DTUnit = TyUnit + getType (DTBool ty _) = ty + getType (DTChar _) = TyC + getType (DTNat ty _) = ty + getType (DTRat _) = TyF getType (DTAbs Lam ty _) = ty - getType DTAbs{} = TyProp - getType (DTApp ty _ _) = ty - getType (DTPair ty _ _) = ty - getType (DTCase ty _) = ty - getType (DTTyOp ty _ _) = ty - getType (DTNil ty) = ty - getType (DTTest _ _) = TyProp + getType DTAbs {} = TyProp + getType (DTApp ty _ _) = ty + getType (DTPair ty _ _) = ty + getType (DTCase ty _) = ty + getType (DTTyOp ty _ _) = ty + getType (DTNil ty) = ty + getType (DTTest _ _) = TyProp instance HasType DPattern where - getType (DPVar ty _) = ty - getType (DPWild ty) = ty - getType DPUnit = TyUnit + getType (DPVar ty _) = ty + getType (DPWild ty) = ty + getType DPUnit = TyUnit getType (DPPair ty _ _) = ty - getType (DPInj ty _ _) = ty + getType (DPInj ty _ _) = ty diff --git a/src/Disco/AST/Generic.hs b/src/Disco/AST/Generic.hs index ef7cfc74..3c9827dd 100644 --- a/src/Disco/AST/Generic.hs +++ b/src/Disco/AST/Generic.hs @@ -1,15 +1,19 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveDataTypeable #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE UndecidableInstances #-} - -- Orphan Alpha Void instance {-# OPTIONS_GHC -fno-warn-orphans #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.AST.Generic -- Copyright : disco team and contributors @@ -36,128 +40,117 @@ -- underlying type. Particular instantiations of the generic -- framework here can be found in "Disco.AST.Surface", -- "Disco.AST.Typed", and "Disco.AST.Desugared". ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - -module Disco.AST.Generic - ( -- * Telescopes - - Telescope (..), telCons - , foldTelescope, mapTelescope - , traverseTelescope - , toTelescope, fromTelescope - - -- * Utility types - - , Side (..), selectSide, fromSide - , Container (..) - , Ellipsis (..) - - -- * Term - - , Term_ (..) - - , X_TVar - , X_TPrim - , X_TLet - , X_TParens - , X_TUnit - , X_TBool - , X_TNat - , X_TRat - , X_TChar - , X_TString - , X_TAbs - , X_TApp - , X_TTup - , X_TCase - , X_TChain - , X_TTyOp - , X_TContainer - , X_TContainerComp - , X_TAscr - , X_Term - - , ForallTerm - - -- * Link - - , Link_ (..) - , X_TLink - , ForallLink - - -- * Qual - - , Qual_ (..) - , X_QBind - , X_QGuard - , ForallQual - - -- * Binding - - , Binding_ (..) - - -- * Branch - , Branch_ - - -- * Guard - - , Guard_ (..) - , X_GBool - , X_GPat - , X_GLet - , ForallGuard - - -- * Pattern - - , Pattern_ (..) - , X_PVar - , X_PWild - , X_PAscr - , X_PUnit - , X_PBool - , X_PTup - , X_PInj - , X_PNat - , X_PChar - , X_PString - , X_PCons - , X_PList - , X_PAdd - , X_PMul - , X_PSub - , X_PNeg - , X_PFrac - , X_Pattern - , ForallPattern - - -- * Quantifiers - - , Quantifier(..) - , Binder_ - , X_Binder - - -- * Property - - , Property_ - ) - where - -import Control.Lens.Plated -import Data.Data (Data) -import Data.Data.Lens (uniplate) -import Data.Typeable -import GHC.Exts (Constraint) -import GHC.Generics (Generic) - -import Data.Void -import Unbound.Generics.LocallyNameless - -import Disco.Pretty -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Types +module Disco.AST.Generic ( + -- * Telescopes + Telescope (..), + telCons, + foldTelescope, + mapTelescope, + traverseTelescope, + toTelescope, + fromTelescope, + + -- * Utility types + Side (..), + selectSide, + fromSide, + Container (..), + Ellipsis (..), + + -- * Term + Term_ (..), + X_TVar, + X_TPrim, + X_TLet, + X_TParens, + X_TUnit, + X_TBool, + X_TNat, + X_TRat, + X_TChar, + X_TString, + X_TAbs, + X_TApp, + X_TTup, + X_TCase, + X_TChain, + X_TTyOp, + X_TContainer, + X_TContainerComp, + X_TAscr, + X_Term, + ForallTerm, + + -- * Link + Link_ (..), + X_TLink, + ForallLink, + + -- * Qual + Qual_ (..), + X_QBind, + X_QGuard, + ForallQual, + + -- * Binding + Binding_ (..), + + -- * Branch + Branch_, + + -- * Guard + Guard_ (..), + X_GBool, + X_GPat, + X_GLet, + ForallGuard, + + -- * Pattern + Pattern_ (..), + X_PVar, + X_PWild, + X_PAscr, + X_PUnit, + X_PBool, + X_PTup, + X_PInj, + X_PNat, + X_PChar, + X_PString, + X_PCons, + X_PList, + X_PAdd, + X_PMul, + X_PSub, + X_PNeg, + X_PFrac, + X_Pattern, + ForallPattern, + + -- * Quantifiers + Quantifier (..), + Binder_, + X_Binder, + + -- * Property + Property_, +) +where + +import Control.Lens.Plated +import Data.Data (Data) +import Data.Data.Lens (uniplate) +import Data.Typeable +import GHC.Exts (Constraint) +import GHC.Generics (Generic) + +import Data.Void +import Unbound.Generics.LocallyNameless + +import Disco.Pretty +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Types ------------------------------------------------------------ -- Telescopes @@ -166,13 +159,11 @@ import Disco.Types -- | A telescope is essentially a list, except that each item can bind -- names in the rest of the list. data Telescope b where - -- | The empty telescope. TelEmpty :: Telescope b - -- | A binder of type @b@ followed by zero or more @b@'s. This @b@ -- can bind variables in the subsequent @b@'s. - TelCons :: Rebind b (Telescope b) -> Telescope b + TelCons :: Rebind b (Telescope b) -> Telescope b deriving (Show, Generic, Alpha, Subst t, Data) -- | Add a new item to the beginning of a 'Telescope'. @@ -182,17 +173,19 @@ telCons b tb = TelCons (rebind b tb) -- | Fold a telescope given a combining function and a value to use -- for the empty telescope. Analogous to 'foldr' for lists. foldTelescope :: Alpha b => (b -> r -> r) -> r -> Telescope b -> r -foldTelescope _ z TelEmpty = z -foldTelescope f z (TelCons (unrebind -> (b,bs))) = f b (foldTelescope f z bs) +foldTelescope _ z TelEmpty = z +foldTelescope f z (TelCons (unrebind -> (b, bs))) = f b (foldTelescope f z bs) -- | Apply a function to every item in a telescope. mapTelescope :: (Alpha a, Alpha b) => (a -> b) -> Telescope a -> Telescope b mapTelescope f = toTelescope . map f . fromTelescope -- | Traverse over a telescope. -traverseTelescope - :: (Applicative f, Alpha a, Alpha b) - => (a -> f b) -> Telescope a -> f (Telescope b) +traverseTelescope :: + (Applicative f, Alpha a, Alpha b) => + (a -> f b) -> + Telescope a -> + f (Telescope b) traverseTelescope f = foldTelescope (\a ftb -> telCons <$> f a <*> ftb) (pure TelEmpty) -- | Convert a list to a telescope. @@ -230,8 +223,8 @@ fromSide s = selectSide s False True -- lists, bags, and sets. data Container where ListContainer :: Container - BagContainer :: Container - SetContainer :: Container + BagContainer :: Container + SetContainer :: Container deriving (Show, Eq, Enum, Generic, Data, Alpha, Subst t) -- | An ellipsis is an "omitted" part of a literal container (such as @@ -240,7 +233,7 @@ data Container where -- containers must be finite. data Ellipsis t where -- | 'Until' represents an ellipsis with a given endpoint, as in @[3 .. 20]@. - Until :: t -> Ellipsis t -- @.. t@ + Until :: t -> Ellipsis t -- @.. t@ deriving (Show, Generic, Functor, Foldable, Traversable, Alpha, Subst a, Data) ------------------------------------------------------------ @@ -276,43 +269,32 @@ type family X_Term e -- example, in the typed phase many constructors store an extra -- type, giving the type of the term. data Term_ e where - -- | A term variable. - TVar_ :: X_TVar e -> Name (Term_ e) -> Term_ e - + TVar_ :: X_TVar e -> Name (Term_ e) -> Term_ e -- | A primitive, /i.e./ a constant which is interpreted specially -- at runtime. See "Disco.Syntax.Prims". - TPrim_ :: X_TPrim e -> Prim -> Term_ e - + TPrim_ :: X_TPrim e -> Prim -> Term_ e -- | A (non-recursive) let expression, @let x1 = t1, x2 = t2, ... in t@. - TLet_ :: X_TLet e -> Bind (Telescope (Binding_ e)) (Term_ e) -> Term_ e - + TLet_ :: X_TLet e -> Bind (Telescope (Binding_ e)) (Term_ e) -> Term_ e -- | Explicit parentheses. We need to keep track of these in the -- surface syntax in order to syntactically distinguish -- multiplication and function application. However, note that -- these disappear after the surface syntax phase. TParens_ :: X_TParens e -> Term_ e -> Term_ e - -- | The unit value, (), of type Unit. - TUnit_ :: X_TUnit e -> Term_ e - + TUnit_ :: X_TUnit e -> Term_ e -- | A boolean value. - TBool_ :: X_TBool e -> Bool -> Term_ e - + TBool_ :: X_TBool e -> Bool -> Term_ e -- | A natural number. - TNat_ :: X_TNat e -> Integer -> Term_ e - + TNat_ :: X_TNat e -> Integer -> Term_ e -- | A nonnegative rational number, parsed as a decimal. (Note -- syntax like @3/5@ does not parse as a rational, but rather as -- the application of a division operator to two natural numbers.) - TRat_ :: X_TRat e -> Rational -> Term_ e - + TRat_ :: X_TRat e -> Rational -> Term_ e -- | A literal unicode character, /e.g./ @'d'@. - TChar_ :: X_TChar e -> Char -> Term_ e - + TChar_ :: X_TChar e -> Char -> Term_ e -- | A string literal, /e.g./ @"disco"@. TString_ :: X_TString e -> [Char] -> Term_ e - -- | A binding abstraction, of the form @Q vars. expr@ where @Q@ is -- a quantifier and @vars@ is a list of bound variables and -- optional type annotations. In particular, this could be a @@ -320,77 +302,68 @@ data Term_ e where -- (y:N). 2x + y@), a universal quantifier (@forall x, (y:N). x^2 + -- y > 0@), or an existential quantifier (@exists x, (y:N). x^2 + y -- == 0@). - TAbs_ :: Quantifier -> X_TAbs e -> Binder_ e (Term_ e) -> Term_ e - + TAbs_ :: Quantifier -> X_TAbs e -> Binder_ e (Term_ e) -> Term_ e -- | Function application, @t1 t2@. - TApp_ :: X_TApp e -> Term_ e -> Term_ e -> Term_ e - + TApp_ :: X_TApp e -> Term_ e -> Term_ e -> Term_ e -- | An n-tuple, @(t1, ..., tn)@. - TTup_ :: X_TTup e -> [Term_ e] -> Term_ e - + TTup_ :: X_TTup e -> [Term_ e] -> Term_ e -- | A case expression. - TCase_ :: X_TCase e -> [Branch_ e] -> Term_ e - + TCase_ :: X_TCase e -> [Branch_ e] -> Term_ e -- | A chained comparison, consisting of a term followed by one or -- more "links", where each link is a comparison operator and -- another term. TChain_ :: X_TChain e -> Term_ e -> [Link_ e] -> Term_ e - -- | An application of a type operator. - TTyOp_ :: X_TTyOp e -> TyOp -> Type -> Term_ e - + TTyOp_ :: X_TTyOp e -> TyOp -> Type -> Term_ e -- | A containter literal (set, bag, or list). TContainer_ :: X_TContainer e -> Container -> [(Term_ e, Maybe (Term_ e))] -> Maybe (Ellipsis (Term_ e)) -> Term_ e - -- | A container comprehension. TContainerComp_ :: X_TContainerComp e -> Container -> Bind (Telescope (Qual_ e)) (Term_ e) -> Term_ e - -- | Type ascription, @(Term_ e : type)@. - TAscr_ :: X_TAscr e -> Term_ e -> PolyType -> Term_ e - + TAscr_ :: X_TAscr e -> Term_ e -> PolyType -> Term_ e -- | A data constructor with an extension descriptor that a "concrete" -- implementation of a generic AST may use to carry extra information. - XTerm_ :: X_Term e -> Term_ e + XTerm_ :: X_Term e -> Term_ e deriving (Generic) -- A type that abstracts over constraints for generic data constructors. -- This makes it easier to derive typeclass instances for generic types. -type ForallTerm (a :: * -> Constraint) e - = ( a (X_TVar e) - , a (X_TPrim e) - , a (X_TLet e) - , a (X_TParens e) - , a (X_TUnit e) - , a (X_TBool e) - , a (X_TNat e) - , a (X_TRat e) - , a (X_TChar e) - , a (X_TString e) - , a (X_TAbs e) - , a (X_TApp e) - , a (X_TCase e) - , a (X_TTup e) - , a (X_TChain e) - , a (X_TTyOp e) - , a (X_TContainer e) - , a (X_TContainerComp e) - , a (X_TAscr e) - , a (X_Term e) - , a (Qual_ e) - , a (Guard_ e) - , a (Link_ e) - , a (Binding_ e) - , a (Pattern_ e) - , a (Binder_ e (Term_ e)) - ) +type ForallTerm (a :: * -> Constraint) e = + ( a (X_TVar e) + , a (X_TPrim e) + , a (X_TLet e) + , a (X_TParens e) + , a (X_TUnit e) + , a (X_TBool e) + , a (X_TNat e) + , a (X_TRat e) + , a (X_TChar e) + , a (X_TString e) + , a (X_TAbs e) + , a (X_TApp e) + , a (X_TCase e) + , a (X_TTup e) + , a (X_TChain e) + , a (X_TTyOp e) + , a (X_TContainer e) + , a (X_TContainerComp e) + , a (X_TAscr e) + , a (X_Term e) + , a (Qual_ e) + , a (Guard_ e) + , a (Link_ e) + , a (Binding_ e) + , a (Pattern_ e) + , a (Binder_ e (Term_ e)) + ) deriving instance ForallTerm Show e => Show (Term_ e) instance ( Typeable e , ForallTerm (Subst Type) e , ForallTerm Alpha e - ) - => Subst Type (Term_ e) + ) => + Subst Type (Term_ e) instance (Typeable e, ForallTerm Alpha e) => Alpha (Term_ e) deriving instance (Data e, Typeable e, ForallTerm Data e) => Data (Term_ e) @@ -407,19 +380,18 @@ type family X_TLink e -- followed by a sequence of links makes up a comparison chain, such -- as @2 < x < y < 10@. data Link_ e where - -- | Note that although the type of 'TLink_' says it can hold any -- 'BOp', it should really only hold comparison operators. TLink_ :: X_TLink e -> BOp -> Term_ e -> Link_ e - deriving Generic + deriving (Generic) -type ForallLink (a :: * -> Constraint) e - = ( a (X_TLink e) - , a (Term_ e) - ) +type ForallLink (a :: * -> Constraint) e = + ( a (X_TLink e) + , a (Term_ e) + ) -deriving instance ForallLink Show e => Show (Link_ e) -instance ForallLink (Subst Type) e => Subst Type (Link_ e) +deriving instance ForallLink Show e => Show (Link_ e) +instance ForallLink (Subst Type) e => Subst Type (Link_ e) instance (Typeable e, Show (Link_ e), ForallLink Alpha e) => Alpha (Link_ e) deriving instance (Typeable e, Data e, ForallLink Data e) => Data (Link_ e) @@ -434,23 +406,20 @@ type family X_QGuard e -- of qualifiers. Each qualifier either binds a variable to some -- collection or consists of a boolean guard. data Qual_ e where - -- | A binding qualifier (i.e. @x in t@). - QBind_ :: X_QBind e -> Name (Term_ e) -> Embed (Term_ e) -> Qual_ e - + QBind_ :: X_QBind e -> Name (Term_ e) -> Embed (Term_ e) -> Qual_ e -- | A boolean guard qualfier (i.e. @x + y > 4@). - QGuard_ :: X_QGuard e -> Embed (Term_ e) -> Qual_ e - - deriving Generic + QGuard_ :: X_QGuard e -> Embed (Term_ e) -> Qual_ e + deriving (Generic) -type ForallQual (a :: * -> Constraint) e - = ( a (X_QBind e) - , a (X_QGuard e) - , a (Term_ e) - ) +type ForallQual (a :: * -> Constraint) e = + ( a (X_QBind e) + , a (X_QGuard e) + , a (Term_ e) + ) -deriving instance ForallQual Show e => Show (Qual_ e) -instance ForallQual (Subst Type) e => Subst Type (Qual_ e) +deriving instance ForallQual Show e => Show (Qual_ e) +instance ForallQual (Subst Type) e => Subst Type (Qual_ e) instance (Typeable e, ForallQual Alpha e) => Alpha (Qual_ e) deriving instance (Typeable e, Data e, ForallQual Data e) => Data (Qual_ e) @@ -463,7 +432,7 @@ deriving instance (Typeable e, Data e, ForallQual Data e) => Data (Qual_ e) data Binding_ e = Binding_ (Maybe (Embed PolyType)) (Name (Term_ e)) (Embed (Term_ e)) deriving (Generic) -deriving instance ForallTerm Show e => Show (Binding_ e) +deriving instance ForallTerm Show e => Show (Binding_ e) instance Subst Type (Term_ e) => Subst Type (Binding_ e) instance (Typeable e, Show (Binding_ e), Alpha (Term_ e)) => Alpha (Binding_ e) deriving instance (Typeable e, Data e, ForallTerm Data e) => Data (Binding_ e) @@ -475,7 +444,6 @@ deriving instance (Typeable e, Data e, ForallTerm Data e) => Data (Binding_ e) -- | A branch of a case is a list of guards with an accompanying term. -- The guards scope over the term. Additionally, each guard scopes -- over subsequent guards. - type Branch_ e = Bind (Telescope (Guard_ e)) (Term_ e) ------------------------------------------------------------ @@ -488,29 +456,25 @@ type family X_GLet e -- | Guards in case expressions. data Guard_ e where - -- | Boolean guard (@if @) GBool_ :: X_GBool e -> Embed (Term_ e) -> Guard_ e - -- | Pattern guard (@when term = pat@) - GPat_ :: X_GPat e -> Embed (Term_ e) -> Pattern_ e -> Guard_ e - + GPat_ :: X_GPat e -> Embed (Term_ e) -> Pattern_ e -> Guard_ e -- | Let (@let x = term@) - GLet_ :: X_GLet e -> Binding_ e -> Guard_ e - - deriving Generic + GLet_ :: X_GLet e -> Binding_ e -> Guard_ e + deriving (Generic) -type ForallGuard (a :: * -> Constraint) e - = ( a (X_GBool e) - , a (X_GPat e) - , a (X_GLet e) - , a (Term_ e) - , a (Pattern_ e) - , a (Binding_ e) - ) +type ForallGuard (a :: * -> Constraint) e = + ( a (X_GBool e) + , a (X_GPat e) + , a (X_GLet e) + , a (Term_ e) + , a (Pattern_ e) + , a (Binding_ e) + ) -deriving instance ForallGuard Show e => Show (Guard_ e) -instance ForallGuard (Subst Type) e => Subst Type (Guard_ e) +deriving instance ForallGuard Show e => Show (Guard_ e) +instance ForallGuard (Subst Type) e => Subst Type (Guard_ e) instance (Typeable e, Show (Guard_ e), ForallGuard Alpha e) => Alpha (Guard_ e) deriving instance (Typeable e, Data e, ForallGuard Data e) => Data (Guard_ e) @@ -539,92 +503,72 @@ type family X_Pattern e -- | Patterns. data Pattern_ e where - -- | Variable pattern: matches anything and binds the variable. - PVar_ :: X_PVar e -> Name (Term_ e) -> Pattern_ e - + PVar_ :: X_PVar e -> Name (Term_ e) -> Pattern_ e -- | Wildcard pattern @_@: matches anything. PWild_ :: X_PWild e -> Pattern_ e - -- | Type ascription pattern @pat : ty@. PAscr_ :: X_PAscr e -> Pattern_ e -> Type -> Pattern_ e - -- | Unit pattern @()@: matches @()@. PUnit_ :: X_PUnit e -> Pattern_ e - -- | Literal boolean pattern. PBool_ :: X_PBool e -> Bool -> Pattern_ e - -- | Tuple pattern @(pat1, .. , patn)@. - PTup_ :: X_PTup e -> [Pattern_ e] -> Pattern_ e - + PTup_ :: X_PTup e -> [Pattern_ e] -> Pattern_ e -- | Injection pattern (@inl pat@ or @inr pat@). - PInj_ :: X_PInj e -> Side -> Pattern_ e -> Pattern_ e - + PInj_ :: X_PInj e -> Side -> Pattern_ e -> Pattern_ e -- | Literal natural number pattern. - PNat_ :: X_PNat e -> Integer -> Pattern_ e - + PNat_ :: X_PNat e -> Integer -> Pattern_ e -- | Unicode character pattern PChar_ :: X_PChar e -> Char -> Pattern_ e - -- | String pattern. PString_ :: X_PString e -> String -> Pattern_ e - -- | Cons pattern @p1 :: p2@. PCons_ :: X_PCons e -> Pattern_ e -> Pattern_ e -> Pattern_ e - -- | List pattern @[p1, .., pn]@. PList_ :: X_PList e -> [Pattern_ e] -> Pattern_ e - -- | Addition pattern, @p + t@ or @t + p@ - PAdd_ :: X_PAdd e -> Side -> Pattern_ e -> Term_ e -> Pattern_ e - + PAdd_ :: X_PAdd e -> Side -> Pattern_ e -> Term_ e -> Pattern_ e -- | Multiplication pattern, @p * t@ or @t * p@ - PMul_ :: X_PMul e -> Side -> Pattern_ e -> Term_ e -> Pattern_ e - + PMul_ :: X_PMul e -> Side -> Pattern_ e -> Term_ e -> Pattern_ e -- | Subtraction pattern, @p - t@ - PSub_ :: X_PSub e -> Pattern_ e -> Term_ e -> Pattern_ e - + PSub_ :: X_PSub e -> Pattern_ e -> Term_ e -> Pattern_ e -- | Negation pattern, @-p@ - PNeg_ :: X_PNeg e -> Pattern_ e -> Pattern_ e - + PNeg_ :: X_PNeg e -> Pattern_ e -> Pattern_ e -- | Fraction pattern, @p1/p2@ PFrac_ :: X_PFrac e -> Pattern_ e -> Pattern_ e -> Pattern_ e - -- | A special placeholder node for a nonlinear occurrence of a -- variable; we can only detect this at parse time but need to -- generate an error later. PNonlinear_ :: Embed (Pattern_ e) -> Name (Term_ e) -> Pattern_ e - -- | Expansion slot. XPattern_ :: X_Pattern e -> Pattern_ e - deriving (Generic) -type ForallPattern (a :: * -> Constraint) e - = ( a (X_PVar e) - , a (X_PWild e) - , a (X_PAscr e) - , a (X_PUnit e) - , a (X_PBool e) - , a (X_PNat e) - , a (X_PChar e) - , a (X_PString e) - , a (X_PTup e) - , a (X_PInj e) - , a (X_PCons e) - , a (X_PList e) - , a (X_PAdd e) - , a (X_PMul e) - , a (X_PSub e) - , a (X_PNeg e) - , a (X_PFrac e) - , a (X_Pattern e) - , a (Term_ e) - ) - -deriving instance ForallPattern Show e => Show (Pattern_ e) -instance ForallPattern (Subst Type) e => Subst Type (Pattern_ e) +type ForallPattern (a :: * -> Constraint) e = + ( a (X_PVar e) + , a (X_PWild e) + , a (X_PAscr e) + , a (X_PUnit e) + , a (X_PBool e) + , a (X_PNat e) + , a (X_PChar e) + , a (X_PString e) + , a (X_PTup e) + , a (X_PInj e) + , a (X_PCons e) + , a (X_PList e) + , a (X_PAdd e) + , a (X_PMul e) + , a (X_PSub e) + , a (X_PNeg e) + , a (X_PFrac e) + , a (X_Pattern e) + , a (Term_ e) + ) + +deriving instance ForallPattern Show e => Show (Pattern_ e) +instance ForallPattern (Subst Type) e => Subst Type (Pattern_ e) instance (Typeable e, Show (Pattern_ e), ForallPattern Alpha e) => Alpha (Pattern_ e) deriving instance (Typeable e, Data e, ForallPattern Data e) => Data (Pattern_ e) diff --git a/src/Disco/AST/Surface.hs b/src/Disco/AST/Surface.hs index f5d0f1a7..c15c43e8 100644 --- a/src/Disco/AST/Surface.hs +++ b/src/Disco/AST/Surface.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE UndecidableInstances #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.AST.Surface -- Copyright : disco team and contributors @@ -12,157 +15,172 @@ -- -- Abstract syntax trees representing the surface syntax of the Disco -- language. --- ------------------------------------------------------------------------------ - -module Disco.AST.Surface - ( -- * Modules - Module(..), emptyModule, TopLevel(..) - -- ** Documentation - , Docs, DocThing(..), Property - -- ** Declarations - , TypeDecl(..), TermDefn(..), TypeDefn(..) - , Decl(..), partitionDecls, prettyTyDecl - - -- * Terms - , UD - , Term - , pattern TVar - , pattern TPrim - , pattern TUn - , pattern TBin - , pattern TLet - , pattern TParens - , pattern TUnit - , pattern TBool - , pattern TChar - , pattern TString - , pattern TNat - , pattern TRat - , pattern TAbs - , pattern TApp - , pattern TTup - , pattern TCase - , pattern TChain - , pattern TTyOp - , pattern TContainerComp - , pattern TContainer - , pattern TAscr - , pattern TWild - , pattern TList - , pattern TListComp - - , Quantifier(..) - - -- ** Telescopes - , Telescope(..), foldTelescope, mapTelescope, toTelescope, fromTelescope - -- ** Expressions - , Side(..) - - , Link - , pattern TLink - - , Binding - - -- ** Lists - , Qual - , pattern QBind - , pattern QGuard - - , Container(..) - - , Ellipsis(..) - - -- ** Case expressions and patterns - , Branch - - , Guard - , pattern GBool - , pattern GPat - , pattern GLet - - , Pattern - , pattern PVar - , pattern PWild - , pattern PAscr - , pattern PUnit - , pattern PBool - , pattern PChar - , pattern PString - , pattern PTup - , pattern PInj - , pattern PNat - , pattern PCons - , pattern PList - , pattern PAdd - , pattern PMul - , pattern PSub - , pattern PNeg - , pattern PFrac - , pattern PNonlinear - - , pattern Binding - ) - where - -import Prelude hiding ((<>)) - -import Control.Lens (_1, _2, _3, (%~)) -import Data.Char (toLower) -import qualified Data.Map as M -import Data.Set (Set) -import qualified Data.Set as S -import Data.Void - -import Disco.Effects.LFresh -import Polysemy hiding (Embed, embed) -import Polysemy.Reader - -import Disco.AST.Generic -import Disco.Extensions -import Disco.Pretty -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Types -import Unbound.Generics.LocallyNameless hiding (LFresh (..), lunbind) +module Disco.AST.Surface ( + -- * Modules + Module (..), + emptyModule, + TopLevel (..), + + -- ** Documentation + Docs, + DocThing (..), + Property, + + -- ** Declarations + TypeDecl (..), + TermDefn (..), + TypeDefn (..), + Decl (..), + partitionDecls, + prettyTyDecl, + + -- * Terms + UD, + Term, + pattern TVar, + pattern TPrim, + pattern TUn, + pattern TBin, + pattern TLet, + pattern TParens, + pattern TUnit, + pattern TBool, + pattern TChar, + pattern TString, + pattern TNat, + pattern TRat, + pattern TAbs, + pattern TApp, + pattern TTup, + pattern TCase, + pattern TChain, + pattern TTyOp, + pattern TContainerComp, + pattern TContainer, + pattern TAscr, + pattern TWild, + pattern TList, + pattern TListComp, + Quantifier (..), + + -- ** Telescopes + Telescope (..), + foldTelescope, + mapTelescope, + toTelescope, + fromTelescope, + + -- ** Expressions + Side (..), + Link, + pattern TLink, + Binding, + + -- ** Lists + Qual, + pattern QBind, + pattern QGuard, + Container (..), + Ellipsis (..), + + -- ** Case expressions and patterns + Branch, + Guard, + pattern GBool, + pattern GPat, + pattern GLet, + Pattern, + pattern PVar, + pattern PWild, + pattern PAscr, + pattern PUnit, + pattern PBool, + pattern PChar, + pattern PString, + pattern PTup, + pattern PInj, + pattern PNat, + pattern PCons, + pattern PList, + pattern PAdd, + pattern PMul, + pattern PSub, + pattern PNeg, + pattern PFrac, + pattern PNonlinear, + pattern Binding, +) +where + +import Prelude hiding ((<>)) + +import Control.Lens ((%~), _1, _2, _3) +import Data.Char (toLower) +import qualified Data.Map as M +import Data.Set (Set) +import qualified Data.Set as S +import Data.Void + +import Disco.Effects.LFresh +import Polysemy hiding (Embed, embed) +import Polysemy.Reader + +import Disco.AST.Generic +import Disco.Extensions +import Disco.Pretty +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Types +import Unbound.Generics.LocallyNameless hiding (LFresh (..), lunbind) -- | The extension descriptor for Surface specific AST types. data UD -- | A module contains all the information from one disco source file. data Module = Module - { modExts :: Set Ext -- ^ Enabled extensions - , modImports :: [String] -- ^ Module imports - , modDecls :: [Decl] -- ^ Declarations - , modDocs :: [(Name Term, Docs)] -- ^ Documentation - , modTerms :: [Term] -- ^ Top-level (bare) terms + { modExts :: Set Ext + -- ^ Enabled extensions + , modImports :: [String] + -- ^ Module imports + , modDecls :: [Decl] + -- ^ Declarations + , modDocs :: [(Name Term, Docs)] + -- ^ Documentation + , modTerms :: [Term] + -- ^ Top-level (bare) terms } -deriving instance ForallTerm Show UD => Show Module + +deriving instance ForallTerm Show UD => Show Module emptyModule :: Module -emptyModule = Module - { modExts = S.empty - , modImports = [] - , modDecls = [] - , modDocs = [] - , modTerms = [] - } +emptyModule = + Module + { modExts = S.empty + , modImports = [] + , modDecls = [] + , modDocs = [] + , modTerms = [] + } -- | A @TopLevel@ is either documentation (a 'DocThing') or a -- declaration ('Decl'). data TopLevel = TLDoc DocThing | TLDecl Decl | TLExpr Term -deriving instance ForallTerm Show UD => Show TopLevel + +deriving instance ForallTerm Show UD => Show TopLevel -- | Convenient synonym for a list of 'DocThing's. type Docs = [DocThing] -- | An item of documentation. data DocThing - = DocString [String] -- ^ A documentation string, i.e. a block - -- of @||| text@ items - | DocProperty Property -- ^ An example/doctest/property of the - -- form @!!! forall (x1:ty1) ... . property@ -deriving instance ForallTerm Show UD => Show DocThing + = -- | A documentation string, i.e. a block + -- of @||| text@ items + DocString [String] + | -- | An example/doctest/property of the + -- form @!!! forall (x1:ty1) ... . property@ + DocProperty Property + +deriving instance ForallTerm Show UD => Show DocThing -- | A property is a universally quantified term of the form -- @forall v1 : T1, v2 : T2. term@. @@ -180,24 +198,24 @@ data TermDefn = TermDefn (Name Term) [Bind [Pattern] Term] -- -- @type T arg1 arg2 ... = body data TypeDefn = TypeDefn String [String] Type - deriving Show + deriving (Show) -- | A declaration is either a type declaration, a term definition, or -- a type definition. data Decl where - DType :: TypeDecl -> Decl - DDefn :: TermDefn -> Decl + DType :: TypeDecl -> Decl + DDefn :: TermDefn -> Decl DTyDef :: TypeDefn -> Decl -deriving instance ForallTerm Show UD => Show TypeDecl -deriving instance ForallTerm Show UD => Show TermDefn -deriving instance ForallTerm Show UD => Show Decl +deriving instance ForallTerm Show UD => Show TypeDecl +deriving instance ForallTerm Show UD => Show TermDefn +deriving instance ForallTerm Show UD => Show Decl partitionDecls :: [Decl] -> ([TypeDecl], [TermDefn], [TypeDefn]) -partitionDecls (DType tyDecl : ds) = (_1 %~ (tyDecl:)) (partitionDecls ds) -partitionDecls (DDefn def : ds) = (_2 %~ (def:)) (partitionDecls ds) -partitionDecls (DTyDef def : ds) = (_3 %~ (def:)) (partitionDecls ds) -partitionDecls [] = ([], [], []) +partitionDecls (DType tyDecl : ds) = (_1 %~ (tyDecl :)) (partitionDecls ds) +partitionDecls (DDefn def : ds) = (_2 %~ (def :)) (partitionDecls ds) +partitionDecls (DTyDef def : ds) = (_3 %~ (def :)) (partitionDecls ds) +partitionDecls [] = ([], [], []) ------------------------------------------------------------ -- Pretty-printing top-level declarations @@ -207,10 +225,10 @@ partitionDecls [] = ([], [], []) instance Pretty Decl where pretty = \case - DType (TypeDecl x ty) -> pretty x <+> text ":" <+> pretty ty + DType (TypeDecl x ty) -> pretty x <+> text ":" <+> pretty ty DTyDef (TypeDefn x args body) -> text "type" <+> text x <+> hsep (map text args) <+> text "=" <+> pretty body - DDefn (TermDefn x bs) -> vcat $ map (pretty . (x,)) bs + DDefn (TermDefn x bs) -> vcat $ map (pretty . (x,)) bs -- | Pretty-print a single clause in a definition. instance Pretty (Name a, Bind [Pattern] Term) where @@ -230,28 +248,28 @@ type Term = Term_ UD -- (nonempty) list of patterns. Each pattern might contain any -- number of variables, and might have type annotations on some -- of its components. -type instance X_Binder UD = [Pattern] - -type instance X_TVar UD = () -type instance X_TPrim UD = () -type instance X_TLet UD = () -type instance X_TParens UD = () -type instance X_TUnit UD = () -type instance X_TBool UD = () -type instance X_TNat UD = () -type instance X_TRat UD = () -type instance X_TChar UD = () -type instance X_TString UD = () -type instance X_TAbs UD = () -type instance X_TApp UD = () -type instance X_TTup UD = () -type instance X_TCase UD = () -type instance X_TChain UD = () -type instance X_TTyOp UD = () -type instance X_TContainer UD = () -type instance X_TContainerComp UD = () -type instance X_TAscr UD = () -type instance X_Term UD = () -- TWild +type instance X_Binder UD = [Pattern] + +type instance X_TVar UD = () +type instance X_TPrim UD = () +type instance X_TLet UD = () +type instance X_TParens UD = () +type instance X_TUnit UD = () +type instance X_TBool UD = () +type instance X_TNat UD = () +type instance X_TRat UD = () +type instance X_TChar UD = () +type instance X_TString UD = () +type instance X_TAbs UD = () +type instance X_TApp UD = () +type instance X_TTup UD = () +type instance X_TCase UD = () +type instance X_TChain UD = () +type instance X_TTyOp UD = () +type instance X_TContainer UD = () +type instance X_TContainerComp UD = () +type instance X_TAscr UD = () +type instance X_Term UD = () -- TWild pattern TVar :: Name Term -> Term pattern TVar name = TVar_ () name @@ -269,7 +287,7 @@ pattern TLet :: Bind (Telescope Binding) Term -> Term pattern TLet bind = TLet_ () bind pattern TParens :: Term -> Term -pattern TParens term = TParens_ () term +pattern TParens term = TParens_ () term pattern TUnit :: Term pattern TUnit = TUnit_ () @@ -277,7 +295,7 @@ pattern TUnit = TUnit_ () pattern TBool :: Bool -> Term pattern TBool bool = TBool_ () bool -pattern TNat :: Integer -> Term +pattern TNat :: Integer -> Term pattern TNat int = TNat_ () int pattern TRat :: Rational -> Term @@ -292,7 +310,7 @@ pattern TString s = TString_ () s pattern TAbs :: Quantifier -> Bind [Pattern] Term -> Term pattern TAbs q bind = TAbs_ q () bind -pattern TApp :: Term -> Term -> Term +pattern TApp :: Term -> Term -> Term pattern TApp term1 term2 = TApp_ () term1 term2 pattern TTup :: [Term] -> Term @@ -322,9 +340,28 @@ pattern TAscr term ty = TAscr_ () term ty pattern TWild :: Term pattern TWild = XTerm_ () -{-# COMPLETE TVar, TPrim, TLet, TParens, TUnit, TBool, TNat, TRat, TChar, - TString, TAbs, TApp, TTup, TCase, TChain, TTyOp, - TContainer, TContainerComp, TAscr, TWild #-} +{-# COMPLETE + TVar + , TPrim + , TLet + , TParens + , TUnit + , TBool + , TNat + , TRat + , TChar + , TString + , TAbs + , TApp + , TTup + , TCase + , TChain + , TTyOp + , TContainer + , TContainerComp + , TAscr + , TWild + #-} pattern TList :: [Term] -> Maybe (Ellipsis Term) -> Term pattern TList ts e <- TContainer_ () ListContainer (map fst -> ts) e @@ -368,8 +405,8 @@ type Branch = Branch_ UD type Guard = Guard_ UD type instance X_GBool UD = () -type instance X_GPat UD = () -type instance X_GLet UD = () +type instance X_GPat UD = () +type instance X_GLet UD = () pattern GBool :: Embed Term -> Guard pattern GBool embedt = GBool_ () embedt @@ -384,23 +421,23 @@ pattern GLet b = GLet_ () b type Pattern = Pattern_ UD -type instance X_PVar UD = () -type instance X_PWild UD = () -type instance X_PAscr UD = () -type instance X_PUnit UD = () -type instance X_PBool UD = () -type instance X_PTup UD = () -type instance X_PInj UD = () -type instance X_PNat UD = () -type instance X_PChar UD = () +type instance X_PVar UD = () +type instance X_PWild UD = () +type instance X_PAscr UD = () +type instance X_PUnit UD = () +type instance X_PBool UD = () +type instance X_PTup UD = () +type instance X_PInj UD = () +type instance X_PNat UD = () +type instance X_PChar UD = () type instance X_PString UD = () -type instance X_PCons UD = () -type instance X_PList UD = () -type instance X_PAdd UD = () -type instance X_PMul UD = () -type instance X_PSub UD = () -type instance X_PNeg UD = () -type instance X_PFrac UD = () +type instance X_PCons UD = () +type instance X_PList UD = () +type instance X_PAdd UD = () +type instance X_PMul UD = () +type instance X_PSub UD = () +type instance X_PNeg UD = () +type instance X_PFrac UD = () type instance X_Pattern UD = Void pattern PVar :: Name Term -> Pattern @@ -409,8 +446,8 @@ pattern PVar name = PVar_ () name pattern PWild :: Pattern pattern PWild = PWild_ () - -- (?) TAscr uses a PolyType, but without higher rank types - -- I think we can't possibly need that here. +-- (?) TAscr uses a PolyType, but without higher rank types +-- I think we can't possibly need that here. pattern PAscr :: Pattern -> Type -> Pattern pattern PAscr p ty = PAscr_ () p ty @@ -418,7 +455,7 @@ pattern PUnit :: Pattern pattern PUnit = PUnit_ () pattern PBool :: Bool -> Pattern -pattern PBool b = PBool_ () b +pattern PBool b = PBool_ () b pattern PChar :: Char -> Pattern pattern PChar c = PChar_ () c @@ -426,17 +463,17 @@ pattern PChar c = PChar_ () c pattern PString :: String -> Pattern pattern PString s = PString_ () s -pattern PTup :: [Pattern] -> Pattern +pattern PTup :: [Pattern] -> Pattern pattern PTup lp = PTup_ () lp -pattern PInj :: Side -> Pattern -> Pattern +pattern PInj :: Side -> Pattern -> Pattern pattern PInj s p = PInj_ () s p -pattern PNat :: Integer -> Pattern +pattern PNat :: Integer -> Pattern pattern PNat n = PNat_ () n pattern PCons :: Pattern -> Pattern -> Pattern -pattern PCons p1 p2 = PCons_ () p1 p2 +pattern PCons p1 p2 = PCons_ () p1 p2 pattern PList :: [Pattern] -> Pattern pattern PList lp = PList_ () lp @@ -457,11 +494,29 @@ pattern PFrac :: Pattern -> Pattern -> Pattern pattern PFrac p1 p2 = PFrac_ () p1 p2 pattern PNonlinear :: Pattern -> Name Term -> Pattern -pattern PNonlinear p x <- PNonlinear_ (unembed -> p) x where - PNonlinear p x = PNonlinear_ (embed p) x - -{-# COMPLETE PVar, PWild, PAscr, PUnit, PBool, PTup, PInj, PNat, - PChar, PString, PCons, PList, PAdd, PMul, PSub, PNeg, PFrac #-} +pattern PNonlinear p x <- PNonlinear_ (unembed -> p) x + where + PNonlinear p x = PNonlinear_ (embed p) x + +{-# COMPLETE + PVar + , PWild + , PAscr + , PUnit + , PBool + , PTup + , PInj + , PNat + , PChar + , PString + , PCons + , PList + , PAdd + , PMul + , PSub + , PNeg + , PFrac + #-} ------------------------------------------------------------ -- Pretty-printing for surface-syntax terms @@ -472,85 +527,87 @@ pattern PNonlinear p x <- PNonlinear_ (unembed -> p) x where -- | Pretty-print a term with guaranteed parentheses. prettyTermP :: Members '[LFresh, Reader PA] r => Term -> Sem r Doc -prettyTermP t@TTup{} = setPA initPA $ pretty t +prettyTermP t@TTup {} = setPA initPA $ pretty t -- prettyTermP t@TContainer{} = setPA initPA $ "" <+> prettyTerm t -prettyTermP t = withPA initPA $ pretty t +prettyTermP t = withPA initPA $ pretty t instance Pretty Term where pretty = \case - TVar x -> pretty x + TVar x -> pretty x TPrim (PrimUOp uop) -> case M.lookup uop uopMap of - Just (OpInfo (UOpF Pre _) (syn:_) _) -> text syn <> text "~" - Just (OpInfo (UOpF Post _) (syn:_) _) -> text "~" <> text syn + Just (OpInfo (UOpF Pre _) (syn : _) _) -> text syn <> text "~" + Just (OpInfo (UOpF Post _) (syn : _) _) -> text "~" <> text syn _ -> error $ "pretty @Term: " ++ show uop ++ " is not in the uopMap!" TPrim (PrimBOp bop) -> text "~" <> pretty bop <> text "~" TPrim p -> case M.lookup p primMap of - Just (PrimInfo _ nm True) -> text nm + Just (PrimInfo _ nm True) -> text nm Just (PrimInfo _ nm False) -> text "$" <> text nm Nothing -> error $ "pretty @Term: Prim " ++ show p ++ " is not in the primMap!" - TParens t -> pretty t - TUnit -> text "■" - (TBool b) -> text (map toLower $ show b) - TChar c -> text (show c) - TString cs -> doubleQuotes $ text cs - TAbs q bnd -> withPA initPA $ + TParens t -> pretty t + TUnit -> text "■" + (TBool b) -> text (map toLower $ show b) + TChar c -> text (show c) + TString cs -> doubleQuotes $ text cs + TAbs q bnd -> withPA initPA $ lunbind bnd $ \(args, body) -> - prettyQ q - <> (hsep =<< punctuate (text ",") (map pretty args)) - <> text "." - <+> lt (pretty body) - where - prettyQ Lam = text "λ" - prettyQ All = text "∀" - prettyQ Ex = text "∃" + prettyQ q + <> (hsep =<< punctuate (text ",") (map pretty args)) + <> text "." + <+> lt (pretty body) + where + prettyQ Lam = text "λ" + prettyQ All = text "∀" + prettyQ Ex = text "∃" -- special case for fully applied unary operators TApp (TPrim (PrimUOp uop)) t -> case M.lookup uop uopMap of - Just (OpInfo (UOpF Post _) _ _) -> withPA (ugetPA uop) $ - lt (pretty t) <> pretty uop - Just (OpInfo (UOpF Pre _) _ _) -> withPA (ugetPA uop) $ - pretty uop <> rt (pretty t) + Just (OpInfo (UOpF Post _) _ _) -> + withPA (ugetPA uop) $ + lt (pretty t) <> pretty uop + Just (OpInfo (UOpF Pre _) _ _) -> + withPA (ugetPA uop) $ + pretty uop <> rt (pretty t) _ -> error $ "pretty @Term: uopMap doesn't contain " ++ show uop - -- special case for fully applied binary operators - TApp (TPrim (PrimBOp bop)) (TTup [t1, t2]) -> withPA (getPA bop) $ - hsep - [ lt (pretty t1) - , pretty bop - , rt (pretty t2) - ] - + TApp (TPrim (PrimBOp bop)) (TTup [t1, t2]) -> + withPA (getPA bop) $ + hsep + [ lt (pretty t1) + , pretty bop + , rt (pretty t2) + ] -- Always pretty-print function applications with parentheses - TApp t1 t2 -> withPA funPA $ - lt (pretty t1) <> prettyTermP t2 - - TTup ts -> setPA initPA $ do + TApp t1 t2 -> + withPA funPA $ + lt (pretty t1) <> prettyTermP t2 + TTup ts -> setPA initPA $ do ds <- punctuate (text ",") (map pretty ts) parens (hsep ds) - TContainer c ts e -> setPA initPA $ do + TContainer c ts e -> setPA initPA $ do ds <- punctuate (text ",") (map prettyCount ts) let pe = case e of - Nothing -> [] - Just (Until t) -> [text "..", pretty t] + Nothing -> [] + Just (Until t) -> [text "..", pretty t] containerDelims c (hsep (ds ++ pe)) - where - prettyCount (t, Nothing) = pretty t - prettyCount (t, Just n) = lt (pretty t) <+> text "#" <+> rt (pretty n) + where + prettyCount (t, Nothing) = pretty t + prettyCount (t, Just n) = lt (pretty t) <+> text "#" <+> rt (pretty n) TContainerComp c bqst -> - lunbind bqst $ \(qs,t) -> - setPA initPA $ containerDelims c (hsep [pretty t, text "|", pretty qs]) - TNat n -> integer n - TChain t lks -> withPA (getPA Eq) . hsep $ + lunbind bqst $ \(qs, t) -> + setPA initPA $ containerDelims c (hsep [pretty t, text "|", pretty qs]) + TNat n -> integer n + TChain t lks -> + withPA (getPA Eq) . hsep $ lt (pretty t) - : concatMap prettyLink lks - where - prettyLink (TLink op t2) = - [ pretty op - , setPA (getPA op) . rt $ pretty t2 - ] + : concatMap prettyLink lks + where + prettyLink (TLink op t2) = + [ pretty op + , setPA (getPA op) . rt $ pretty t2 + ] TLet bnd -> withPA initPA $ lunbind bnd $ \(bs, t2) -> do ds <- punctuate (text ",") (map pretty (fromTelescope bs)) @@ -560,46 +617,47 @@ instance Pretty Term where , text "in" , pretty t2 ] - - TCase b -> withPA initPA $ - (text "{?" <+> prettyBranches b) $+$ text "?}" - TAscr t ty -> withPA ascrPA $ - lt (pretty t) <+> text ":" <+> rt (pretty ty) - TRat r -> text (prettyDecimal r) - TTyOp op ty -> withPA funPA $ - pretty op <+> pretty ty + TCase b -> + withPA initPA $ + (text "{?" <+> prettyBranches b) $+$ text "?}" + TAscr t ty -> + withPA ascrPA $ + lt (pretty t) <+> text ":" <+> rt (pretty ty) + TRat r -> text (prettyDecimal r) + TTyOp op ty -> + withPA funPA $ + pretty op <+> pretty ty TWild -> text "_" -- | Print appropriate delimiters for a container literal. containerDelims :: Member (Reader PA) r => Container -> (Sem r Doc -> Sem r Doc) containerDelims ListContainer = brackets -containerDelims BagContainer = bag -containerDelims SetContainer = braces +containerDelims BagContainer = bag +containerDelims SetContainer = braces prettyBranches :: Members '[Reader PA, LFresh] r => [Branch] -> Sem r Doc prettyBranches = \case [] -> error "Empty branches are disallowed." - b:bs -> + b : bs -> pretty b - $+$ - foldr (($+$) . (text "," <+>) . pretty) empty bs + $+$ foldr (($+$) . (text "," <+>) . pretty) empty bs -- | Pretty-print a single branch in a case expression. instance Pretty Branch where - pretty br = lunbind br $ \(gs,t) -> + pretty br = lunbind br $ \(gs, t) -> pretty t <+> pretty gs -- | Pretty-print the guards in a single branch of a case expression. instance Pretty (Telescope Guard) where pretty = \case TelEmpty -> text "otherwise" - gs -> foldr (\g r -> pretty g <+> r) (text "") (fromTelescope gs) + gs -> foldr (\g r -> pretty g <+> r) (text "") (fromTelescope gs) instance Pretty Guard where pretty = \case - GBool et -> text "if" <+> pretty (unembed et) + GBool et -> text "if" <+> pretty (unembed et) GPat et p -> text "when" <+> pretty (unembed et) <+> text "is" <+> pretty p - GLet b -> text "let" <+> pretty b + GLet b -> text "let" <+> pretty b -- | Pretty-print a binding, i.e. a pairing of a name (with optional -- type annotation) and term. @@ -620,12 +678,12 @@ instance Pretty (Telescope Qual) where instance Pretty Qual where pretty = \case QBind x (unembed -> t) -> hsep [pretty x, text "in", pretty t] - QGuard (unembed -> t) -> pretty t + QGuard (unembed -> t) -> pretty t -- | Pretty-print a pattern with guaranteed parentheses. prettyPatternP :: Members '[LFresh, Reader PA] r => Pattern -> Sem r Doc -prettyPatternP p@PTup{} = setPA initPA $ pretty p -prettyPatternP p = withPA initPA $ pretty p +prettyPatternP p@PTup {} = setPA initPA $ pretty p +prettyPatternP p = withPA initPA $ pretty p -- We could probably alternatively write a function to turn a pattern -- back into a term, and pretty-print that instead of the below. @@ -633,37 +691,46 @@ prettyPatternP p = withPA initPA $ pretty p instance Pretty Pattern where pretty = \case - PVar x -> pretty x - PWild -> text "_" - PAscr p ty -> withPA ascrPA $ - lt (pretty p) <+> text ":" <+> rt (pretty ty) - PUnit -> text "■" - PBool b -> text $ map toLower $ show b - PChar c -> text (show c) - PString s -> text (show s) - PTup ts -> setPA initPA $ do + PVar x -> pretty x + PWild -> text "_" + PAscr p ty -> + withPA ascrPA $ + lt (pretty p) <+> text ":" <+> rt (pretty ty) + PUnit -> text "■" + PBool b -> text $ map toLower $ show b + PChar c -> text (show c) + PString s -> text (show s) + PTup ts -> setPA initPA $ do ds <- punctuate (text ",") (map pretty ts) parens (hsep ds) - PInj s p -> withPA funPA $ - pretty s <> prettyPatternP p - PNat n -> integer n - PCons p1 p2 -> withPA (getPA Cons) $ - lt (pretty p1) <+> text "::" <+> rt (pretty p2) - PList ps -> setPA initPA $ do + PInj s p -> + withPA funPA $ + pretty s <> prettyPatternP p + PNat n -> integer n + PCons p1 p2 -> + withPA (getPA Cons) $ + lt (pretty p1) <+> text "::" <+> rt (pretty p2) + PList ps -> setPA initPA $ do ds <- punctuate (text ",") (map pretty ps) brackets (hsep ds) - PAdd L p t -> withPA (getPA Add) $ - lt (pretty p) <+> text "+" <+> rt (pretty t) - PAdd R p t -> withPA (getPA Add) $ - lt (pretty t) <+> text "+" <+> rt (pretty p) - PMul L p t -> withPA (getPA Mul) $ - lt (pretty p) <+> text "*" <+> rt (pretty t) - PMul R p t -> withPA (getPA Mul) $ - lt (pretty t) <+> text "*" <+> rt (pretty p) - PSub p t -> withPA (getPA Sub) $ - lt (pretty p) <+> text "-" <+> rt (pretty t) - PNeg p -> withPA (ugetPA Neg) $ - text "-" <> rt (pretty p) - PFrac p1 p2 -> withPA (getPA Div) $ - lt (pretty p1) <+> text "/" <+> rt (pretty p2) - + PAdd L p t -> + withPA (getPA Add) $ + lt (pretty p) <+> text "+" <+> rt (pretty t) + PAdd R p t -> + withPA (getPA Add) $ + lt (pretty t) <+> text "+" <+> rt (pretty p) + PMul L p t -> + withPA (getPA Mul) $ + lt (pretty p) <+> text "*" <+> rt (pretty t) + PMul R p t -> + withPA (getPA Mul) $ + lt (pretty t) <+> text "*" <+> rt (pretty p) + PSub p t -> + withPA (getPA Sub) $ + lt (pretty p) <+> text "-" <+> rt (pretty t) + PNeg p -> + withPA (ugetPA Neg) $ + text "-" <> rt (pretty p) + PFrac p1 p2 -> + withPA (getPA Div) $ + lt (pretty p1) <+> text "/" <+> rt (pretty p2) diff --git a/src/Disco/AST/Typed.hs b/src/Disco/AST/Typed.hs index 988e1e13..964d7768 100644 --- a/src/Disco/AST/Typed.hs +++ b/src/Disco/AST/Typed.hs @@ -1,7 +1,10 @@ {-# LANGUAGE DeriveDataTypeable #-} -{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PatternSynonyms #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.AST.Typed -- Copyright : disco team and contributors @@ -12,100 +15,91 @@ -- Typed abstract syntax trees representing the typechecked surface -- syntax of the Disco language. Each tree node is annotated with the -- type of its subtree. --- ------------------------------------------------------------------------------ - -module Disco.AST.Typed - ( -- * Type-annotated terms - ATerm - , pattern ATVar - , pattern ATPrim - , pattern ATLet - , pattern ATUnit - , pattern ATBool - , pattern ATNat - , pattern ATRat - , pattern ATChar - , pattern ATString - , pattern ATAbs - , pattern ATApp - , pattern ATTup - , pattern ATCase - , pattern ATChain - , pattern ATTyOp - , pattern ATContainer - , pattern ATContainerComp - , pattern ATList - , pattern ATListComp - , pattern ATTest - - , ALink - , pattern ATLink - - , Container(..) - , ABinding - -- * Branches and guards - , ABranch - - , AGuard - , pattern AGBool - , pattern AGPat - , pattern AGLet - - , AQual - , pattern AQBind - , pattern AQGuard - - , APattern - , pattern APVar - , pattern APWild - , pattern APUnit - , pattern APBool - , pattern APTup - , pattern APInj - , pattern APNat - , pattern APChar - , pattern APString - , pattern APCons - , pattern APList - , pattern APAdd - , pattern APMul - , pattern APSub - , pattern APNeg - , pattern APFrac - - , pattern ABinding - -- * Utilities - , varsBound - , getType - , setType - , substQT - - , AProperty - ) - where - -import Unbound.Generics.LocallyNameless -import Unbound.Generics.LocallyNameless.Unsafe - -import Control.Arrow ((***)) -import Data.Coerce (coerce) -import Data.Data (Data) -import Data.Void - -import Control.Lens.Plated (transform) -import Disco.AST.Generic -import Disco.AST.Surface -import Disco.Names -import Disco.Pretty -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Types +module Disco.AST.Typed ( + -- * Type-annotated terms + ATerm, + pattern ATVar, + pattern ATPrim, + pattern ATLet, + pattern ATUnit, + pattern ATBool, + pattern ATNat, + pattern ATRat, + pattern ATChar, + pattern ATString, + pattern ATAbs, + pattern ATApp, + pattern ATTup, + pattern ATCase, + pattern ATChain, + pattern ATTyOp, + pattern ATContainer, + pattern ATContainerComp, + pattern ATList, + pattern ATListComp, + pattern ATTest, + ALink, + pattern ATLink, + Container (..), + ABinding, + + -- * Branches and guards + ABranch, + AGuard, + pattern AGBool, + pattern AGPat, + pattern AGLet, + AQual, + pattern AQBind, + pattern AQGuard, + APattern, + pattern APVar, + pattern APWild, + pattern APUnit, + pattern APBool, + pattern APTup, + pattern APInj, + pattern APNat, + pattern APChar, + pattern APString, + pattern APCons, + pattern APList, + pattern APAdd, + pattern APMul, + pattern APSub, + pattern APNeg, + pattern APFrac, + pattern ABinding, + + -- * Utilities + varsBound, + getType, + setType, + substQT, + AProperty, +) +where + +import Unbound.Generics.LocallyNameless +import Unbound.Generics.LocallyNameless.Unsafe + +import Control.Arrow ((***)) +import Data.Coerce (coerce) +import Data.Data (Data) +import Data.Void + +import Control.Lens.Plated (transform) +import Disco.AST.Generic +import Disco.AST.Surface +import Disco.Names +import Disco.Pretty +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Types -- | The extension descriptor for Typed specific AST types. - data TY - deriving Data + deriving (Data) type AProperty = Property_ TY @@ -116,34 +110,33 @@ type AProperty = Property_ TY -- | An @ATerm@ is a typechecked term where every node in the tree has -- been annotated with the type of the subterm rooted at that node. - type ATerm = Term_ TY -type instance X_Binder TY = [APattern] - -type instance X_TVar TY = Void -- Names are now qualified -type instance X_TPrim TY = Type -type instance X_TLet TY = Type -type instance X_TUnit TY = () -type instance X_TBool TY = Type -type instance X_TNat TY = Type -type instance X_TRat TY = () -type instance X_TChar TY = () -type instance X_TString TY = () -type instance X_TAbs TY = Type -type instance X_TApp TY = Type -type instance X_TCase TY = Type -type instance X_TChain TY = Type -type instance X_TTyOp TY = Type -type instance X_TContainer TY = Type -type instance X_TContainerComp TY = Type -type instance X_TAscr TY = Void -- No more type ascriptions in typechecked terms -type instance X_TTup TY = Type -type instance X_TParens TY = Void -- No more explicit parens - - -- A test frame for reporting counterexamples in a test. These don't appear - -- in source programs, but because the deugarer manipulates partly-desugared - -- terms it helps to be able to represent these in 'ATerm'. +type instance X_Binder TY = [APattern] + +type instance X_TVar TY = Void -- Names are now qualified +type instance X_TPrim TY = Type +type instance X_TLet TY = Type +type instance X_TUnit TY = () +type instance X_TBool TY = Type +type instance X_TNat TY = Type +type instance X_TRat TY = () +type instance X_TChar TY = () +type instance X_TString TY = () +type instance X_TAbs TY = Type +type instance X_TApp TY = Type +type instance X_TCase TY = Type +type instance X_TChain TY = Type +type instance X_TTyOp TY = Type +type instance X_TContainer TY = Type +type instance X_TContainerComp TY = Type +type instance X_TAscr TY = Void -- No more type ascriptions in typechecked terms +type instance X_TTup TY = Type +type instance X_TParens TY = Void -- No more explicit parens + +-- A test frame for reporting counterexamples in a test. These don't appear +-- in source programs, but because the deugarer manipulates partly-desugared +-- terms it helps to be able to represent these in 'ATerm'. type instance X_Term TY = Either ([(String, Type, Name ATerm)], ATerm) (Type, QName ATerm) pattern ATVar :: Type -> QName ATerm -> ATerm @@ -161,7 +154,7 @@ pattern ATUnit = TUnit_ () pattern ATBool :: Type -> Bool -> ATerm pattern ATBool ty bool = TBool_ ty bool -pattern ATNat :: Type -> Integer -> ATerm +pattern ATNat :: Type -> Integer -> ATerm pattern ATNat ty int = TNat_ ty int pattern ATRat :: Rational -> ATerm @@ -176,7 +169,7 @@ pattern ATString s = TString_ () s pattern ATAbs :: Quantifier -> Type -> Bind [APattern] ATerm -> ATerm pattern ATAbs q ty bind = TAbs_ q ty bind -pattern ATApp :: Type -> ATerm -> ATerm -> ATerm +pattern ATApp :: Type -> ATerm -> ATerm -> ATerm pattern ATApp ty term1 term2 = TApp_ ty term1 term2 pattern ATTup :: Type -> [ATerm] -> ATerm @@ -200,9 +193,26 @@ pattern ATContainerComp ty c b = TContainerComp_ ty c b pattern ATTest :: [(String, Type, Name ATerm)] -> ATerm -> ATerm pattern ATTest ns t = XTerm_ (Left (ns, t)) -{-# COMPLETE ATVar, ATPrim, ATLet, ATUnit, ATBool, ATNat, ATRat, ATChar, - ATString, ATAbs, ATApp, ATTup, ATCase, ATChain, ATTyOp, - ATContainer, ATContainerComp, ATTest #-} +{-# COMPLETE + ATVar + , ATPrim + , ATLet + , ATUnit + , ATBool + , ATNat + , ATRat + , ATChar + , ATString + , ATAbs + , ATApp + , ATTup + , ATCase + , ATChain + , ATTyOp + , ATContainer + , ATContainerComp + , ATTest + #-} pattern ATList :: Type -> [ATerm] -> Maybe (Ellipsis ATerm) -> ATerm pattern ATList t xs e <- ATContainer t ListContainer (map fst -> xs) e @@ -221,13 +231,11 @@ pattern ATLink bop term = TLink_ () bop term {-# COMPLETE ATLink #-} - type AQual = Qual_ TY type instance X_QBind TY = () type instance X_QGuard TY = () - pattern AQBind :: Name ATerm -> Embed ATerm -> AQual pattern AQBind namet embedt = QBind_ () namet embedt @@ -248,8 +256,8 @@ type ABranch = Bind (Telescope AGuard) ATerm type AGuard = Guard_ TY type instance X_GBool TY = () -type instance X_GPat TY = () -type instance X_GLet TY = () -- ??? Type? +type instance X_GPat TY = () +type instance X_GLet TY = () -- ??? Type? pattern AGBool :: Embed ATerm -> AGuard pattern AGBool embedt = GBool_ () embedt @@ -267,25 +275,25 @@ type APattern = Pattern_ TY -- We have to use Embed Type because we don't want any type variables -- inside the types being treated as binders! -type instance X_PVar TY = Embed Type -type instance X_PWild TY = Embed Type -type instance X_PAscr TY = Void -- No more ascriptions in typechecked patterns. -type instance X_PUnit TY = () -type instance X_PBool TY = () -type instance X_PChar TY = () -type instance X_PString TY = () -type instance X_PTup TY = Embed Type -type instance X_PInj TY = Embed Type -type instance X_PNat TY = Embed Type -type instance X_PCons TY = Embed Type -type instance X_PList TY = Embed Type -type instance X_PAdd TY = Embed Type -type instance X_PMul TY = Embed Type -type instance X_PSub TY = Embed Type -type instance X_PNeg TY = Embed Type -type instance X_PFrac TY = Embed Type - -type instance X_Pattern TY = () +type instance X_PVar TY = Embed Type +type instance X_PWild TY = Embed Type +type instance X_PAscr TY = Void -- No more ascriptions in typechecked patterns. +type instance X_PUnit TY = () +type instance X_PBool TY = () +type instance X_PChar TY = () +type instance X_PString TY = () +type instance X_PTup TY = Embed Type +type instance X_PInj TY = Embed Type +type instance X_PNat TY = Embed Type +type instance X_PCons TY = Embed Type +type instance X_PList TY = Embed Type +type instance X_PAdd TY = Embed Type +type instance X_PMul TY = Embed Type +type instance X_PSub TY = Embed Type +type instance X_PNeg TY = Embed Type +type instance X_PFrac TY = Embed Type + +type instance X_Pattern TY = () pattern APVar :: Type -> Name ATerm -> APattern pattern APVar ty name <- PVar_ (unembed -> ty) name @@ -301,25 +309,25 @@ pattern APUnit :: APattern pattern APUnit = PUnit_ () pattern APBool :: Bool -> APattern -pattern APBool b = PBool_ () b +pattern APBool b = PBool_ () b pattern APChar :: Char -> APattern -pattern APChar c = PChar_ () c +pattern APChar c = PChar_ () c pattern APString :: String -> APattern pattern APString s = PString_ () s -pattern APTup :: Type -> [APattern] -> APattern +pattern APTup :: Type -> [APattern] -> APattern pattern APTup ty lp <- PTup_ (unembed -> ty) lp where APTup ty lp = PTup_ (embed ty) lp -pattern APInj :: Type -> Side -> APattern -> APattern +pattern APInj :: Type -> Side -> APattern -> APattern pattern APInj ty s p <- PInj_ (unembed -> ty) s p where APInj ty s p = PInj_ (embed ty) s p -pattern APNat :: Type -> Integer -> APattern +pattern APNat :: Type -> Integer -> APattern pattern APNat ty n <- PNat_ (unembed -> ty) n where APNat ty n = PNat_ (embed ty) n @@ -359,87 +367,103 @@ pattern APFrac ty p1 p2 <- PFrac_ (unembed -> ty) p1 p2 where APFrac ty p1 p2 = PFrac_ (embed ty) p1 p2 -{-# COMPLETE APVar, APWild, APUnit, APBool, APChar, APString, - APTup, APInj, APNat, APCons, APList, APAdd, APMul, APSub, APNeg, APFrac #-} +{-# COMPLETE + APVar + , APWild + , APUnit + , APBool + , APChar + , APString + , APTup + , APInj + , APNat + , APCons + , APList + , APAdd + , APMul + , APSub + , APNeg + , APFrac + #-} varsBound :: APattern -> [(Name ATerm, Type)] -varsBound (APVar ty n) = [(n, ty)] -varsBound (APWild _) = [] -varsBound APUnit = [] -varsBound (APBool _) = [] -varsBound (APChar _) = [] -varsBound (APString _) = [] -varsBound (APTup _ ps) = varsBound =<< ps -varsBound (APInj _ _ p) = varsBound p -varsBound (APNat _ _) = [] -varsBound (APCons _ p q) = varsBound p ++ varsBound q -varsBound (APList _ ps) = varsBound =<< ps +varsBound (APVar ty n) = [(n, ty)] +varsBound (APWild _) = [] +varsBound APUnit = [] +varsBound (APBool _) = [] +varsBound (APChar _) = [] +varsBound (APString _) = [] +varsBound (APTup _ ps) = varsBound =<< ps +varsBound (APInj _ _ p) = varsBound p +varsBound (APNat _ _) = [] +varsBound (APCons _ p q) = varsBound p ++ varsBound q +varsBound (APList _ ps) = varsBound =<< ps varsBound (APAdd _ _ p _) = varsBound p varsBound (APMul _ _ p _) = varsBound p -varsBound (APSub _ p _) = varsBound p -varsBound (APNeg _ p) = varsBound p -varsBound (APFrac _ p q) = varsBound p ++ varsBound q +varsBound (APSub _ p _) = varsBound p +varsBound (APNeg _ p) = varsBound p +varsBound (APFrac _ p q) = varsBound p ++ varsBound q ------------------------------------------------------------ -- getType ------------------------------------------------------------ instance HasType ATerm where - getType (ATVar ty _) = ty - getType (ATPrim ty _) = ty - getType ATUnit = TyUnit - getType (ATBool ty _) = ty - getType (ATNat ty _) = ty - getType (ATRat _) = TyF - getType (ATChar _) = TyC - getType (ATString _) = TyList TyC - getType (ATAbs _ ty _) = ty - getType (ATApp ty _ _) = ty - getType (ATTup ty _) = ty - getType (ATTyOp ty _ _) = ty - getType (ATChain ty _ _) = ty - getType (ATContainer ty _ _ _) = ty + getType (ATVar ty _) = ty + getType (ATPrim ty _) = ty + getType ATUnit = TyUnit + getType (ATBool ty _) = ty + getType (ATNat ty _) = ty + getType (ATRat _) = TyF + getType (ATChar _) = TyC + getType (ATString _) = TyList TyC + getType (ATAbs _ ty _) = ty + getType (ATApp ty _ _) = ty + getType (ATTup ty _) = ty + getType (ATTyOp ty _ _) = ty + getType (ATChain ty _ _) = ty + getType (ATContainer ty _ _ _) = ty getType (ATContainerComp ty _ _) = ty - getType (ATLet ty _) = ty - getType (ATCase ty _) = ty - getType (ATTest _ _ ) = TyProp - - setType ty (ATVar _ x ) = ATVar ty x - setType ty (ATPrim _ x ) = ATPrim ty x - setType _ ATUnit = ATUnit - setType ty (ATBool _ b) = ATBool ty b - setType ty (ATNat _ x ) = ATNat ty x - setType _ (ATRat r) = ATRat r - setType _ (ATChar c) = ATChar c - setType _ (ATString cs) = ATString cs - setType ty (ATAbs q _ x ) = ATAbs q ty x - setType ty (ATApp _ x y ) = ATApp ty x y - setType ty (ATTup _ x ) = ATTup ty x - setType ty (ATTyOp _ x y ) = ATTyOp ty x y - setType ty (ATChain _ x y ) = ATChain ty x y - setType ty (ATContainer _ x y z) = ATContainer ty x y z + getType (ATLet ty _) = ty + getType (ATCase ty _) = ty + getType (ATTest _ _) = TyProp + + setType ty (ATVar _ x) = ATVar ty x + setType ty (ATPrim _ x) = ATPrim ty x + setType _ ATUnit = ATUnit + setType ty (ATBool _ b) = ATBool ty b + setType ty (ATNat _ x) = ATNat ty x + setType _ (ATRat r) = ATRat r + setType _ (ATChar c) = ATChar c + setType _ (ATString cs) = ATString cs + setType ty (ATAbs q _ x) = ATAbs q ty x + setType ty (ATApp _ x y) = ATApp ty x y + setType ty (ATTup _ x) = ATTup ty x + setType ty (ATTyOp _ x y) = ATTyOp ty x y + setType ty (ATChain _ x y) = ATChain ty x y + setType ty (ATContainer _ x y z) = ATContainer ty x y z setType ty (ATContainerComp _ x y) = ATContainerComp ty x y - setType ty (ATLet _ x ) = ATLet ty x - setType ty (ATCase _ x ) = ATCase ty x - setType _ (ATTest vs x) = ATTest vs x + setType ty (ATLet _ x) = ATLet ty x + setType ty (ATCase _ x) = ATCase ty x + setType _ (ATTest vs x) = ATTest vs x instance HasType APattern where - getType (APVar ty _) = ty - getType (APWild ty) = ty - getType APUnit = TyUnit - getType (APBool _) = TyBool - getType (APChar _) = TyC - getType (APString _) = TyList TyC - getType (APTup ty _) = ty - getType (APInj ty _ _) = ty - getType (APNat ty _) = ty - getType (APCons ty _ _) = ty - getType (APList ty _) = ty + getType (APVar ty _) = ty + getType (APWild ty) = ty + getType APUnit = TyUnit + getType (APBool _) = TyBool + getType (APChar _) = TyC + getType (APString _) = TyList TyC + getType (APTup ty _) = ty + getType (APInj ty _ _) = ty + getType (APNat ty _) = ty + getType (APCons ty _ _) = ty + getType (APList ty _) = ty getType (APAdd ty _ _ _) = ty getType (APMul ty _ _ _) = ty - getType (APSub ty _ _) = ty - getType (APNeg ty _) = ty - getType (APFrac ty _ _) = ty + getType (APSub ty _ _) = ty + getType (APNeg ty _) = ty + getType (APFrac ty _ _) = ty instance HasType ABranch where getType = getType . snd . unsafeUnbind @@ -451,7 +475,7 @@ instance HasType ABranch where substQT :: QName ATerm -> ATerm -> ATerm -> ATerm substQT x s = transform $ \case t@(ATVar _ y) - | x == y -> s + | x == y -> s | otherwise -> t t -> t @@ -464,32 +488,34 @@ instance Pretty ATerm where explode :: ATerm -> Term explode = \case - ATVar ty x -> TAscr (TVar (coerce (qname x))) (toPolyType ty) - ATPrim ty x -> TAscr (TPrim x) (toPolyType ty) - ATLet ty tel -> TAscr (TLet (explodeTelescope explodeBinding tel)) (toPolyType ty) - ATUnit -> TUnit - ATBool _ty b -> TBool b - ATNat ty x -> TAscr (TNat x) (toPolyType ty) - ATRat r -> TRat r - ATChar c -> TChar c - ATString cs -> TString cs - ATAbs q ty a -> TAscr (TAbs q (explodeAbs a)) (toPolyType ty) - ATApp ty x y -> TAscr (TApp (explode x) (explode y)) (toPolyType ty) - ATTup ty xs -> TAscr (TTup (map explode xs)) (toPolyType ty) - ATCase ty bs -> TAscr (TCase (map explodeBranch bs)) (toPolyType ty) - ATChain ty t ls -> TAscr (TChain (explode t) (map explodeLink ls)) (toPolyType ty) - ATTyOp ty x y -> TAscr (TTyOp x y) (toPolyType ty) + ATVar ty x -> TAscr (TVar (coerce (qname x))) (toPolyType ty) + ATPrim ty x -> TAscr (TPrim x) (toPolyType ty) + ATLet ty tel -> TAscr (TLet (explodeTelescope explodeBinding tel)) (toPolyType ty) + ATUnit -> TUnit + ATBool _ty b -> TBool b + ATNat ty x -> TAscr (TNat x) (toPolyType ty) + ATRat r -> TRat r + ATChar c -> TChar c + ATString cs -> TString cs + ATAbs q ty a -> TAscr (TAbs q (explodeAbs a)) (toPolyType ty) + ATApp ty x y -> TAscr (TApp (explode x) (explode y)) (toPolyType ty) + ATTup ty xs -> TAscr (TTup (map explode xs)) (toPolyType ty) + ATCase ty bs -> TAscr (TCase (map explodeBranch bs)) (toPolyType ty) + ATChain ty t ls -> TAscr (TChain (explode t) (map explodeLink ls)) (toPolyType ty) + ATTyOp ty x y -> TAscr (TTyOp x y) (toPolyType ty) ATContainer ty c ts el -> TAscr (TContainer c (map (explode *** fmap explode) ts) (fmap (fmap explode) el)) (toPolyType ty) ATContainerComp ty c b -> TAscr (TContainerComp c (explodeTelescope explodeQual b)) (toPolyType ty) - ATTest _vs x -> TAscr (explode x) (toPolyType TyProp) + ATTest _vs x -> TAscr (explode x) (toPolyType TyProp) -explodeTelescope - :: (Alpha a, Alpha b) - => (a -> b) -> Bind (Telescope a) ATerm -> Bind (Telescope b) Term -explodeTelescope explodeBinder (unsafeUnbind -> (xs,at)) = bind (mapTelescope explodeBinder xs) (explode at) +explodeTelescope :: + (Alpha a, Alpha b) => + (a -> b) -> + Bind (Telescope a) ATerm -> + Bind (Telescope b) Term +explodeTelescope explodeBinder (unsafeUnbind -> (xs, at)) = bind (mapTelescope explodeBinder xs) (explode at) explodeBinding :: ABinding -> Binding explodeBinding (ABinding m b (unembed -> n)) = Binding m (coerce b) (embed (explode n)) @@ -499,34 +525,34 @@ explodeAbs (unsafeUnbind -> (aps, at)) = bind (map explodePattern aps) (explode explodePattern :: APattern -> Pattern explodePattern = \case - APVar ty x -> PAscr (PVar (coerce x)) ty - APWild ty -> PAscr PWild ty - APUnit -> PUnit - APBool b -> PBool b - APChar c -> PChar c - APString s -> PString s - APTup ty ps -> PAscr (PTup (map explodePattern ps)) ty - APInj ty s p -> PAscr (PInj s (explodePattern p)) ty - APNat ty n -> PAscr (PNat n) ty + APVar ty x -> PAscr (PVar (coerce x)) ty + APWild ty -> PAscr PWild ty + APUnit -> PUnit + APBool b -> PBool b + APChar c -> PChar c + APString s -> PString s + APTup ty ps -> PAscr (PTup (map explodePattern ps)) ty + APInj ty s p -> PAscr (PInj s (explodePattern p)) ty + APNat ty n -> PAscr (PNat n) ty APCons ty p1 p2 -> PAscr (PCons (explodePattern p1) (explodePattern p2)) ty - APList ty ps -> PAscr (PList (map explodePattern ps)) ty - APAdd ty s p t -> PAscr (PAdd s (explodePattern p) (explode t)) ty - APMul ty s p t -> PAscr (PMul s (explodePattern p) (explode t)) ty - APSub ty p t -> PAscr (PSub (explodePattern p) (explode t)) ty - APNeg ty p -> PAscr (PNeg (explodePattern p)) ty - APFrac ty p q -> PAscr (PFrac (explodePattern p) (explodePattern q)) ty + APList ty ps -> PAscr (PList (map explodePattern ps)) ty + APAdd ty s p t -> PAscr (PAdd s (explodePattern p) (explode t)) ty + APMul ty s p t -> PAscr (PMul s (explodePattern p) (explode t)) ty + APSub ty p t -> PAscr (PSub (explodePattern p) (explode t)) ty + APNeg ty p -> PAscr (PNeg (explodePattern p)) ty + APFrac ty p q -> PAscr (PFrac (explodePattern p) (explodePattern q)) ty explodeBranch :: ABranch -> Branch explodeBranch = explodeTelescope explodeGuard explodeGuard :: AGuard -> Guard -explodeGuard (AGBool (unembed -> at)) = GBool (embed (explode at)) +explodeGuard (AGBool (unembed -> at)) = GBool (embed (explode at)) explodeGuard (AGPat (unembed -> at) ap) = GPat (embed (explode at)) (explodePattern ap) -explodeGuard (AGLet ab) = GLet (explodeBinding ab) +explodeGuard (AGLet ab) = GLet (explodeBinding ab) explodeLink :: ALink -> Link explodeLink (ATLink bop at) = TLink bop (explode at) explodeQual :: AQual -> Qual explodeQual (AQBind x (unembed -> at)) = QBind (coerce x) (embed (explode at)) -explodeQual (AQGuard (unembed -> at)) = QGuard (embed (explode at)) +explodeQual (AQGuard (unembed -> at)) = QGuard (embed (explode at)) diff --git a/src/Disco/Compile.hs b/src/Disco/Compile.hs index fcb84765..1bca9398 100644 --- a/src/Disco/Compile.hs +++ b/src/Disco/Compile.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Compile -- Copyright : disco team and contributors @@ -8,37 +11,39 @@ -- -- Compiling the typechecked, desugared AST to the untyped core -- language. ------------------------------------------------------------------------------ - module Disco.Compile where -import Control.Monad ((<=<)) -import Data.Bool (bool) -import Data.Coerce -import qualified Data.Map as M -import Data.Ratio -import Data.Set (Set) -import qualified Data.Set as S -import Data.Set.Lens (setOf) - -import Disco.Effects.Fresh -import Polysemy (Member, Sem, run) -import Unbound.Generics.LocallyNameless (Name, bind, string2Name, - unembed) - -import Disco.AST.Core -import Disco.AST.Desugared -import Disco.AST.Generic -import Disco.AST.Typed -import Disco.Context as Ctx -import Disco.Desugar -import Disco.Module -import Disco.Names -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import qualified Disco.Typecheck.Graph as G -import Disco.Types -import Disco.Util +import Control.Monad ((<=<)) +import Data.Bool (bool) +import Data.Coerce +import qualified Data.Map as M +import Data.Ratio +import Data.Set (Set) +import qualified Data.Set as S +import Data.Set.Lens (setOf) + +import Disco.Effects.Fresh +import Polysemy (Member, Sem, run) +import Unbound.Generics.LocallyNameless ( + Name, + bind, + string2Name, + unembed, + ) + +import Disco.AST.Core +import Disco.AST.Desugared +import Disco.AST.Generic +import Disco.AST.Typed +import Disco.Context as Ctx +import Disco.Desugar +import Disco.Module +import Disco.Names +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import qualified Disco.Typecheck.Graph as G +import Disco.Types +import Disco.Util ------------------------------------------------------------ -- Convenience operations @@ -106,19 +111,18 @@ compileDefnGroup [(f, defn)] -- have -- -- fT = force (delay fL. [force fL / fT] body) - | f `S.member` setOf fvQ defn = return . (:[]) $ - (fT, CForce (CProj L (CDelay (bind [qname fL] [substQC fT (CForce (CVar fL)) cdefn])))) - + | f `S.member` setOf fvQ defn = + return . (: []) $ + (fT, CForce (CProj L (CDelay (bind [qname fL] [substQC fT (CForce (CVar fL)) cdefn])))) -- A non-recursive definition just compiles simply. | otherwise = - return [(fT, cdefn)] + return [(fT, cdefn)] + where + fT, fL :: QName Core + fT = coerce f + fL = localName (coerce (qname f)) - where - fT, fL :: QName Core - fT = coerce f - fL = localName (coerce (qname f)) - - cdefn = compileThing desugarDefn defn + cdefn = compileThing desugarDefn defn -- A group of mutually recursive definitions {f = fbody, g = gbody, ...} -- compiles to @@ -140,12 +144,12 @@ compileDefnGroup defs = do bodies' :: [Core] bodies' = map (substsQC forceVars . compileThing desugarDefn) bodies return $ - (grp, CDelay (bind (map qname varsL) bodies')) : - zip varsT (for [0 ..] $ CForce . flip proj (CVar grp)) - where - proj :: Int -> Core -> Core - proj 0 = CProj L - proj n = proj (n -1) . CProj R + (grp, CDelay (bind (map qname varsL) bodies')) + : zip varsT (for [0 ..] $ CForce . flip proj (CVar grp)) + where + proj :: Int -> Core -> Core + proj 0 = CProj L + proj n = proj (n - 1) . CProj R ------------------------------------------------------------ -- Compiling terms @@ -166,22 +170,22 @@ compileDTerm term@(DTAbs q _ _) = do cbody <- compileDTerm body case q of Lam -> return $ abstract xs cbody - Ex -> return $ quantify (OExists tys) (abstract xs cbody) + Ex -> return $ quantify (OExists tys) (abstract xs cbody) All -> return $ quantify (OForall tys) (abstract xs cbody) - where - -- Gather nested abstractions with the same quantifier. - unbindDeep :: Member Fresh r => DTerm -> Sem r ([Name DTerm], [Type], DTerm) - unbindDeep (DTAbs q' ty l) | q == q' = do - (name, inner) <- unbind l - (ns, tys, body) <- unbindDeep inner - return (name : ns, ty : tys, body) - unbindDeep t = return ([], [], t) + where + -- Gather nested abstractions with the same quantifier. + unbindDeep :: Member Fresh r => DTerm -> Sem r ([Name DTerm], [Type], DTerm) + unbindDeep (DTAbs q' ty l) | q == q' = do + (name, inner) <- unbind l + (ns, tys, body) <- unbindDeep inner + return (name : ns, ty : tys, body) + unbindDeep t = return ([], [], t) - abstract :: [Name DTerm] -> Core -> Core - abstract xs body = CAbs (bind (map coerce xs) body) + abstract :: [Name DTerm] -> Core -> Core + abstract xs body = CAbs (bind (map coerce xs) body) - quantify :: Op -> Core -> Core - quantify op = CApp (CConst op) + quantify :: Op -> Core -> Core + quantify op = CApp (CConst op) -- Special case for Cons, which compiles to a constructor application -- rather than a function application. @@ -197,12 +201,12 @@ compileDTerm (DTPair _ t1 t2) = CPair <$> compileDTerm t1 <*> compileDTerm t2 compileDTerm (DTCase _ bs) = CApp <$> compileCase bs <*> pure CUnit compileDTerm (DTTyOp _ op ty) = return $ CApp (CConst (tyOps ! op)) (CType ty) - where - tyOps = - M.fromList - [ Enumerate ==> OEnum, - Count ==> OCount - ] + where + tyOps = + M.fromList + [ Enumerate ==> OEnum + , Count ==> OCount + ] compileDTerm (DTNil _) = return $ CInj L CUnit compileDTerm (DTTest info t) = CTest (coerce info) <$> compileDTerm t @@ -231,27 +235,27 @@ compilePrim _ (PrimBOp Cons) = do hd <- fresh (string2Name "hd") tl <- fresh (string2Name "tl") return $ CAbs $ bind [hd, tl] $ CInj R (CPair (CVar (localName hd)) (CVar (localName tl))) - compilePrim _ PrimLeft = do a <- fresh (string2Name "a") return $ CAbs $ bind [a] $ CInj L (CVar (localName a)) - compilePrim _ PrimRight = do a <- fresh (string2Name "a") return $ CAbs $ bind [a] $ CInj R (CVar (localName a)) - compilePrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop) = return $ compileBOp ty1 ty2 resTy bop compilePrim ty p@(PrimBOp _) = compilePrimErr p ty compilePrim _ PrimSqrt = return $ CConst OSqrt compilePrim _ PrimFloor = return $ CConst OFloor compilePrim _ PrimCeil = return $ CConst OCeil -compilePrim (TySet _ :->: _) PrimAbs = return $ - CVar (Named Stdlib "container" .- string2Name "setSize") -compilePrim (TyBag _ :->: _) PrimAbs = return $ - CVar (Named Stdlib "container" .- string2Name "bagSize") -compilePrim (TyList _ :->: _) PrimAbs = return $ - CVar (Named Stdlib "list" .- string2Name "length") -compilePrim _ PrimAbs = return $ CConst OAbs +compilePrim (TySet _ :->: _) PrimAbs = + return $ + CVar (Named Stdlib "container" .- string2Name "setSize") +compilePrim (TyBag _ :->: _) PrimAbs = + return $ + CVar (Named Stdlib "container" .- string2Name "bagSize") +compilePrim (TyList _ :->: _) PrimAbs = + return $ + CVar (Named Stdlib "list" .- string2Name "length") +compilePrim _ PrimAbs = return $ CConst OAbs compilePrim (TySet _ :->: _) PrimPower = return $ CConst OPower compilePrim (TyBag _ :->: _) PrimPower = return $ CConst OPower compilePrim ty PrimPower = compilePrimErr PrimPower ty @@ -284,34 +288,38 @@ compilePrim ty PrimOverlay = compilePrimErr PrimOverlay ty compilePrim ty PrimConnect = compilePrimErr PrimConnect ty compilePrim _ PrimInsert = return $ CConst OInsert compilePrim _ PrimLookup = return $ CConst OLookup -compilePrim (_ :*: TyList _ :->: _) PrimEach = return $ - CVar (Named Stdlib "list" .- string2Name "eachlist") +compilePrim (_ :*: TyList _ :->: _) PrimEach = + return $ + CVar (Named Stdlib "list" .- string2Name "eachlist") compilePrim (_ :*: TyBag _ :->: TyBag _) PrimEach = return $ CConst OEachBag compilePrim (_ :*: TySet _ :->: TySet _) PrimEach = return $ CConst OEachSet compilePrim ty PrimEach = compilePrimErr PrimEach ty compilePrim (_ :*: _ :*: TyList _ :->: _) PrimReduce = return $ CVar (Named Stdlib "list" .- string2Name "foldr") -compilePrim (_ :*: _ :*: TyBag _ :->: _) PrimReduce = return $ - CVar (Named Stdlib "container" .- string2Name "reducebag") -compilePrim (_ :*: _ :*: TySet _ :->: _) PrimReduce = return $ - CVar (Named Stdlib "container" .- string2Name "reduceset") +compilePrim (_ :*: _ :*: TyBag _ :->: _) PrimReduce = + return $ + CVar (Named Stdlib "container" .- string2Name "reducebag") +compilePrim (_ :*: _ :*: TySet _ :->: _) PrimReduce = + return $ + CVar (Named Stdlib "container" .- string2Name "reduceset") compilePrim ty PrimReduce = compilePrimErr PrimReduce ty -compilePrim (_ :*: TyList _ :->: _) PrimFilter = return $ - CVar (Named Stdlib "list" .- string2Name "filterlist") +compilePrim (_ :*: TyList _ :->: _) PrimFilter = + return $ + CVar (Named Stdlib "list" .- string2Name "filterlist") compilePrim (_ :*: TyBag _ :->: _) PrimFilter = return $ CConst OFilterBag compilePrim (_ :*: TySet _ :->: _) PrimFilter = return $ CConst OFilterBag compilePrim ty PrimFilter = compilePrimErr PrimFilter ty -compilePrim (_ :->: TyList _) PrimJoin = return $ - CVar (Named Stdlib "list" .- string2Name "concat") +compilePrim (_ :->: TyList _) PrimJoin = + return $ + CVar (Named Stdlib "list" .- string2Name "concat") compilePrim (_ :->: TyBag _) PrimJoin = return $ CConst OBagUnions -compilePrim (_ :->: TySet _) PrimJoin = return $ - CVar (Named Stdlib "container" .- string2Name "unions") +compilePrim (_ :->: TySet _) PrimJoin = + return $ + CVar (Named Stdlib "container" .- string2Name "unions") compilePrim ty PrimJoin = compilePrimErr PrimJoin ty - compilePrim (_ :*: TyBag _ :*: _ :->: _) PrimMerge = return $ CConst OMerge compilePrim (_ :*: TySet _ :*: _ :->: _) PrimMerge = return $ CConst OMerge -compilePrim ty PrimMerge = compilePrimErr PrimMerge ty - +compilePrim ty PrimMerge = compilePrimErr PrimMerge ty compilePrim _ PrimIsPrime = return $ CConst OIsPrime compilePrim _ PrimFactor = return $ CConst OFactor compilePrim _ PrimFrac = return $ CConst OFrac @@ -418,14 +426,14 @@ compileUOp :: UOp -> Core compileUOp _ op = CConst (coreUOps ! op) - where - -- Just look up the corresponding core operator. - coreUOps = - M.fromList - [ Neg ==> ONeg, - Fact ==> OFact, - Not ==> ONotProp - ] + where + -- Just look up the corresponding core operator. + coreUOps = + M.fromList + [ Neg ==> ONeg + , Fact ==> OFact + , Not ==> ONotProp + ] -- | Compile a binary operator. This function needs to know the types -- of the arguments and result since some operators are overloaded @@ -463,13 +471,13 @@ compileBOp :: Type -> Type -> Type -> BOp -> Core -- addition and multiplication. compileBOp (TyGraph _) (TyGraph _) (TyGraph _) op | op `elem` [Add, Mul] = - CConst (regularOps ! op) - where - regularOps = - M.fromList - [ Add ==> OOverlay, - Mul ==> OConnect - ] + CConst (regularOps ! op) + where + regularOps = + M.fromList + [ Add ==> OOverlay + , Mul ==> OConnect + ] -- The Cartesian product operator just compiles to library function calls. compileBOp (TySet _) _ _ CartProd = @@ -478,32 +486,31 @@ compileBOp (TyBag _) _ _ CartProd = CVar (Named Stdlib "container" .- string2Name "bagCP") compileBOp (TyList _) _ _ CartProd = CVar (Named Stdlib "list" .- string2Name "listCP") - -- Some regular arithmetic operations that just translate straightforwardly. compileBOp _ _ _ op | op `M.member` regularOps = CConst (regularOps ! op) - where - regularOps = - M.fromList - [ Add ==> OAdd, - Mul ==> OMul, - Div ==> ODiv, - Exp ==> OExp, - Mod ==> OMod, - Divides ==> ODivides, - Choose ==> OMultinom, - Eq ==> OEq, - Lt ==> OLt, - And ==> OAnd, - Or ==> OOr, - Impl ==> OImpl - ] + where + regularOps = + M.fromList + [ Add ==> OAdd + , Mul ==> OMul + , Div ==> ODiv + , Exp ==> OExp + , Mod ==> OMod + , Divides ==> ODivides + , Choose ==> OMultinom + , Eq ==> OEq + , Lt ==> OLt + , And ==> OAnd + , Or ==> OOr + , Impl ==> OImpl + ] -- ShouldEq needs to know the type at which the comparison is -- occurring, so values can be correctly pretty-printed if the test -- fails. compileBOp ty _ _ ShouldEq = CConst (OShouldEq ty) -compileBOp ty _ _ ShouldLt = CConst (OShouldLt ty) +compileBOp ty _ _ ShouldLt = CConst (OShouldLt ty) compileBOp _ty (TyList _) _ Elem = CConst OListElem compileBOp _ty _ _ Elem = CConst OBagElem compileBOp ty1 ty2 resTy op = diff --git a/src/Disco/Context.hs b/src/Disco/Context.hs index 1bbfd412..16efc4f6 100644 --- a/src/Disco/Context.hs +++ b/src/Disco/Context.hs @@ -1,6 +1,11 @@ {-# LANGUAGE DeriveTraversable #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Context -- Copyright : disco team and contributors @@ -9,81 +14,80 @@ -- A *context* is a mapping from names to other things (such as types -- or values). This module defines a generic type of contexts which -- is used in many different places throughout the disco codebase. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - -module Disco.Context - ( -- * Context type - Ctx - - -- * Construction - , emptyCtx - , singleCtx - , fromList - , ctxForModule - , localCtx - - -- * Insertion - , insert - , extend - , extends - - -- * Query - , null - , lookup, lookup' - , lookupNonLocal, lookupNonLocal' - , lookupAll, lookupAll' - - -- * Conversion - , names - , elems - , assocs - , keysSet - - -- * Traversal - , coerceKeys - , restrictKeys - - -- * Combination - , joinCtx - , joinCtxs - - -- * Filter - , filter - - ) where - -import Control.Monad ((<=<)) -import Data.Bifunctor (first, second) -import Data.Coerce -import Data.Map (Map) -import qualified Data.Map as M -import Data.Map.Merge.Lazy as MM -import Data.Set (Set) -import qualified Data.Set as S -import Prelude hiding (filter, lookup, null) - -import Unbound.Generics.LocallyNameless (Name) - -import Polysemy -import Polysemy.Reader - -import Disco.Names (ModuleName, - NameProvenance (..), - QName (..)) +module Disco.Context ( + -- * Context type + Ctx, + + -- * Construction + emptyCtx, + singleCtx, + fromList, + ctxForModule, + localCtx, + + -- * Insertion + insert, + extend, + extends, + + -- * Query + null, + lookup, + lookup', + lookupNonLocal, + lookupNonLocal', + lookupAll, + lookupAll', + + -- * Conversion + names, + elems, + assocs, + keysSet, + + -- * Traversal + coerceKeys, + restrictKeys, + + -- * Combination + joinCtx, + joinCtxs, + + -- * Filter + filter, +) where + +import Control.Monad ((<=<)) +import Data.Bifunctor (first, second) +import Data.Coerce +import Data.Map (Map) +import qualified Data.Map as M +import Data.Map.Merge.Lazy as MM +import Data.Set (Set) +import qualified Data.Set as S +import Prelude hiding (filter, lookup, null) + +import Unbound.Generics.LocallyNameless (Name) + +import Polysemy +import Polysemy.Reader + +import Disco.Names ( + ModuleName, + NameProvenance (..), + QName (..), + ) -- | A context maps qualified names to things. In particular a @Ctx a -- b@ maps qualified names for @a@s to values of type @b@. -newtype Ctx a b = Ctx { getCtx :: M.Map NameProvenance (M.Map (Name a) b) } +newtype Ctx a b = Ctx {getCtx :: M.Map NameProvenance (M.Map (Name a) b)} deriving (Eq, Show, Functor, Foldable, Traversable) - -- Note that we implement a context as a nested map from - -- NameProvenance to Name to b, rather than as a Map QName b. They - -- are isomorphic, but this way it is easier to do name resolution, - -- because given an (unqualified) Name, we can look it up in each - -- inner map corresponding to modules that are in scope. +-- Note that we implement a context as a nested map from +-- NameProvenance to Name to b, rather than as a Map QName b. They +-- are isomorphic, but this way it is easier to do name resolution, +-- because given an (unqualified) Name, we can look it up in each +-- inner map corresponding to modules that are in scope. instance Semigroup (Ctx a b) where (<>) = joinCtx @@ -159,8 +163,8 @@ lookupNonLocal n = lookupNonLocal' n <$> ask -- | Look up all the non-local bindings of a name in a context. lookupNonLocal' :: Name a -> Ctx a b -> [(ModuleName, b)] lookupNonLocal' n = nonLocal . lookupAll' n - where - nonLocal bs = [(m,b) | (QName (QualifiedName m) _, b) <- bs] + where + nonLocal bs = [(m, b) | (QName (QualifiedName m) _, b) <- bs] -- | Look up all the bindings of an (unqualified) name in an ambient context. lookupAll :: Member (Reader (Ctx a b)) r => Name a -> Sem r [(QName a, b)] @@ -186,9 +190,9 @@ elems = concatMap M.elems . M.elems . getCtx -- context. assocs :: Ctx a b -> [(QName a, b)] assocs = concatMap (uncurry modAssocs) . M.assocs . getCtx - where - modAssocs :: NameProvenance -> Map (Name a) b -> [(QName a, b)] - modAssocs p = map (first (QName p)) . M.assocs + where + modAssocs :: NameProvenance -> Map (Name a) b -> [(QName a, b)] + modAssocs p = map (first (QName p)) . M.assocs -- | Return a set of all qualified names in the context. keysSet :: Ctx a b -> Set (QName a) @@ -205,9 +209,9 @@ coerceKeys = Ctx . M.map (M.mapKeys coerce) . getCtx -- | Restrict a context to only the keys in the given set. restrictKeys :: Ctx a b -> Set (QName a) -> Ctx a b restrictKeys ctx xs = Ctx . restrict m . getCtx $ ctx - where - restrict = MM.merge MM.dropMissing MM.dropMissing (MM.zipWithMatched (\_ ns m' -> M.restrictKeys m' ns)) - m = M.fromListWith S.union . map (\(QName p n) -> (p, S.singleton n)) . S.toList $ xs + where + restrict = MM.merge MM.dropMissing MM.dropMissing (MM.zipWithMatched (\_ ns m' -> M.restrictKeys m' ns)) + m = M.fromListWith S.union . map (\(QName p n) -> (p, S.singleton n)) . S.toList $ xs ------------------------------------------------------------ -- Combination @@ -217,7 +221,7 @@ restrictKeys ctx xs = Ctx . restrict m . getCtx $ ctx -- exists in both contexts, the result will use the value from the -- first context, and throw away the value from the second.). joinCtx :: Ctx a b -> Ctx a b -> Ctx a b -joinCtx a b = joinCtxs [a,b] +joinCtx a b = joinCtxs [a, b] -- | Join a list of contexts (left-biased). joinCtxs :: [Ctx a b] -> Ctx a b diff --git a/src/Disco/Data.hs b/src/Disco/Data.hs index c27e12e1..a879f7c8 100644 --- a/src/Disco/Data.hs +++ b/src/Disco/Data.hs @@ -1,25 +1,25 @@ -{-# OPTIONS_GHC -Wno-orphans #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE StandaloneDeriving #-} +{-# OPTIONS_GHC -Wno-orphans #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Data -- Copyright : disco team and contributors -- Maintainer : byorgey@gmail.com -- -- Some orphan 'Data' instances. --- ------------------------------------------------------------------------------ - module Disco.Data where -import Unbound.Generics.LocallyNameless.Bind -import Unbound.Generics.LocallyNameless.Embed -import Unbound.Generics.LocallyNameless.Name +import Unbound.Generics.LocallyNameless.Bind +import Unbound.Generics.LocallyNameless.Embed +import Unbound.Generics.LocallyNameless.Name -import Data.Data (Data) -import Unbound.Generics.LocallyNameless.Rebind +import Data.Data (Data) +import Unbound.Generics.LocallyNameless.Rebind ------------------------------------------------------------ -- Some orphan instances @@ -29,4 +29,3 @@ deriving instance (Data a, Data b) => Data (Bind a b) deriving instance Data t => Data (Embed t) deriving instance (Data a, Data b) => Data (Rebind a b) deriving instance Data a => Data (Name a) - diff --git a/src/Disco/Desugar.hs b/src/Disco/Desugar.hs index b7d174c9..5653d6b1 100644 --- a/src/Disco/Desugar.hs +++ b/src/Disco/Desugar.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Desugar -- Copyright : disco team and contributors @@ -8,43 +11,49 @@ -- -- Desugaring the typechecked surface language to a (still typed) -- simpler language. --- ------------------------------------------------------------------------------ - -module Disco.Desugar - ( -- * Running desugaring computations - runDesugar - - -- * Programs, terms, and properties - , desugarDefn, desugarTerm, desugarProperty - - -- * Case expressions and patterns - , desugarBranch, desugarGuards - ) - where - -import Control.Monad.Cont -import Data.Bool (bool) -import Data.Coerce -import Data.Maybe (fromMaybe, isJust) - -import Disco.AST.Desugared -import Disco.AST.Surface -import Disco.AST.Typed -import Disco.Module -import Disco.Names -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Typecheck (containerTy) -import Disco.Types - -import Disco.Effects.Fresh -import Polysemy (Member, Sem, run) -import Unbound.Generics.LocallyNameless (Bind, Name, bind, - embed, name2String, - string2Name, unembed, - unrebind) -import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) +module Disco.Desugar ( + -- * Running desugaring computations + runDesugar, + + -- * Programs, terms, and properties + desugarDefn, + desugarTerm, + desugarProperty, + + -- * Case expressions and patterns + desugarBranch, + desugarGuards, +) +where + +import Control.Monad.Cont +import Data.Bool (bool) +import Data.Coerce +import Data.Maybe (fromMaybe, isJust) + +import Disco.AST.Desugared +import Disco.AST.Surface +import Disco.AST.Typed +import Disco.Module +import Disco.Names +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Typecheck (containerTy) +import Disco.Types + +import Disco.Effects.Fresh +import Polysemy (Member, Sem, run) +import Unbound.Generics.LocallyNameless ( + Bind, + Name, + bind, + embed, + name2String, + string2Name, + unembed, + unrebind, + ) +import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) ------------------------------------------------------------ -- Running desugaring computations @@ -53,10 +62,11 @@ import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) -- | Run a desugaring computation. runDesugar :: Sem '[Fresh] a -> a runDesugar = run . runFresh1 - -- Using runFresh1 is a bit of a hack; that way we won't - -- ever pick a name with #0 (which is what is generated by default - -- by string2Name), hence won't conflict with any existing free - -- variables which came from the parser. + +-- Using runFresh1 is a bit of a hack; that way we won't +-- ever pick a name with #0 (which is what is generated by default +-- by string2Name), hence won't conflict with any existing free +-- variables which came from the parser. ------------------------------------------------------------ -- ATerm DSL @@ -71,20 +81,20 @@ atVar ty x = ATVar ty (QName LocalName x) tapp :: ATerm -> ATerm -> ATerm tapp t1 t2 = ATApp resTy t1 t2 - where - resTy = case getType t1 of - (_ :->: r) -> r - ty -> error $ "Impossible! Got non-function type " ++ show ty ++ " in tapp" + where + resTy = case getType t1 of + (_ :->: r) -> r + ty -> error $ "Impossible! Got non-function type " ++ show ty ++ " in tapp" mkBin :: Type -> BOp -> ATerm -> ATerm -> ATerm -mkBin resTy bop t1 t2 - = tapp (ATPrim (getType t1 :*: getType t2 :->: resTy) (PrimBOp bop)) (mkPair t1 t2) +mkBin resTy bop t1 t2 = + tapp (ATPrim (getType t1 :*: getType t2 :->: resTy) (PrimBOp bop)) (mkPair t1 t2) mkUn :: Type -> UOp -> ATerm -> ATerm mkUn resTy uop t = tapp (ATPrim (getType t :->: resTy) (PrimUOp uop)) t mkPair :: ATerm -> ATerm -> ATerm -mkPair t1 t2 = mkTup [t1,t2] +mkPair t1 t2 = mkTup [t1, t2] mkTup :: [ATerm] -> ATerm mkTup ts = ATTup (foldr1 (:*:) (map getType ts)) ts @@ -151,14 +161,14 @@ dtVar ty x = DTVar ty (QName LocalName x) dtapp :: DTerm -> DTerm -> DTerm dtapp t1 t2 = DTApp resTy t1 t2 - where - resTy = case getType t1 of - (_ :->: r) -> r - ty -> error $ "Impossible! Got non-function type " ++ show ty ++ " in dtapp" + where + resTy = case getType t1 of + (_ :->: r) -> r + ty -> error $ "Impossible! Got non-function type " ++ show ty ++ " in dtapp" dtbin :: Type -> Prim -> DTerm -> DTerm -> DTerm -dtbin resTy p dt1 dt2 - = dtapp (DTPrim (getType dt1 :*: getType dt2 :->: resTy) p) (mkDTPair dt1 dt2) +dtbin resTy p dt1 dt2 = + dtapp (DTPrim (getType dt1 :*: getType dt2 :->: resTy) p) (mkDTPair dt1 dt2) mkDTPair :: DTerm -> DTerm -> DTerm mkDTPair dt1 dt2 = DTPair (getType dt1 :*: getType dt2) dt1 dt2 @@ -181,15 +191,14 @@ desugarDefn (Defn _ patTys bodyTy def) = -- with their corresponding patterns. Definitions are abstractions -- (which happen to be named), and source-level lambdas are also -- abstractions (which happen to have only one clause). - desugarAbs :: Member Fresh r => Quantifier -> Type -> [Clause] -> Sem r DTerm -- Special case for compiling a single lambda with no pattern matching directly to a lambda desugarAbs Lam ty [cl@(unsafeUnbind -> ([APVar _ _], _))] = do (ps, at) <- unbind cl d <- desugarTerm at return $ DTAbs Lam ty (bind (getVar (head ps)) d) - where - getVar (APVar _ x) = coerce x + where + getVar (APVar _ x) = coerce x -- General case desugarAbs quant overallTy body = do clausePairs <- unbindClauses body @@ -205,36 +214,37 @@ desugarAbs quant overallTy body = do let branches = zipWith (mkBranch (zip args patTys)) bodies pats dcase <- desugarTerm $ ATCase bodyTy branches return $ mkAbs quant overallTy patTys (coerce args) dcase - - where - mkBranch :: [(Name ATerm, Type)] -> ATerm -> [APattern] -> ABranch - mkBranch xs b ps = bind (mkGuards xs ps) b - - mkGuards :: [(Name ATerm, Type)] -> [APattern] -> Telescope AGuard - mkGuards xs ps = toTelescope $ zipWith AGPat (map (\(x,ty) -> embed (atVar ty x)) xs) ps - - -- To make searches fairer, we lift up directly nested abstractions - -- with the same quantifier when there's only a single clause. That - -- way, we generate a chain of abstractions followed by a case, instead - -- of a bunch of alternating abstractions and cases. - unbindClauses :: Member Fresh r => [Clause] -> Sem r [([APattern], ATerm)] - unbindClauses [c] | quant `elem` [All, Ex] = do - (ps, t) <- liftClause c - return [(ps, addDbgInfo ps t)] - unbindClauses cs = mapM unbind cs - - liftClause :: Member Fresh r => Bind [APattern] ATerm -> Sem r ([APattern], ATerm) - liftClause c = unbind c >>= \case + where + mkBranch :: [(Name ATerm, Type)] -> ATerm -> [APattern] -> ABranch + mkBranch xs b ps = bind (mkGuards xs ps) b + + mkGuards :: [(Name ATerm, Type)] -> [APattern] -> Telescope AGuard + mkGuards xs ps = toTelescope $ zipWith AGPat (map (\(x, ty) -> embed (atVar ty x)) xs) ps + + -- To make searches fairer, we lift up directly nested abstractions + -- with the same quantifier when there's only a single clause. That + -- way, we generate a chain of abstractions followed by a case, instead + -- of a bunch of alternating abstractions and cases. + unbindClauses :: Member Fresh r => [Clause] -> Sem r [([APattern], ATerm)] + unbindClauses [c] | quant `elem` [All, Ex] = do + (ps, t) <- liftClause c + return [(ps, addDbgInfo ps t)] + unbindClauses cs = mapM unbind cs + + liftClause :: Member Fresh r => Bind [APattern] ATerm -> Sem r ([APattern], ATerm) + liftClause c = + unbind c >>= \case (ps, ATAbs q _ c') | q == quant -> do (ps', b) <- liftClause c' return (ps ++ ps', b) (ps, b) -> return (ps, b) - -- Wrap a term in a test frame to report the values of all variables - -- bound in the patterns. - addDbgInfo :: [APattern] -> ATerm -> ATerm - addDbgInfo ps t = ATTest (map withName $ concatMap varsBound ps) t - where withName (n, ty) = (name2String n, ty, n) + -- Wrap a term in a test frame to report the values of all variables + -- bound in the patterns. + addDbgInfo :: [APattern] -> ATerm -> ATerm + addDbgInfo ps t = ATTest (map withName $ concatMap varsBound ps) t + where + withName (n, ty) = (name2String n, ty, n) ------------------------------------------------------------ -- Term desugaring @@ -245,11 +255,14 @@ desugarAbs quant overallTy body = do desugarCList2B :: Member Fresh r => Prim -> Type -> Type -> Type -> Sem r DTerm desugarCList2B p ty cts b = do c <- fresh (string2Name "c") - body <- desugarTerm $ - tapp (ATPrim (TyBag cts :->: TyBag b) p) - (tapp (ATPrim (TyList cts :->: TyBag cts) PrimBag) - (atVar (TyList cts) c) - ) + body <- + desugarTerm $ + tapp + (ATPrim (TyBag cts :->: TyBag b) p) + ( tapp + (ATPrim (TyList cts :->: TyBag cts) PrimBag) + (atVar (TyList cts) c) + ) return $ mkLambda ty [c] body -- | Desugar a typechecked term. @@ -261,43 +274,33 @@ desugarTerm (ATPrim (ty1 :*: ty2 :->: resTy) (PrimBOp bop)) | bopDesugars ty1 ty2 resTy bop = desugarPrimBOp ty1 ty2 resTy bop desugarTerm (ATPrim ty@(TyList cts :->: TyBag b) PrimC2B) = desugarCList2B PrimC2B ty cts b desugarTerm (ATPrim ty@(TyList cts :->: TyBag b) PrimUC2B) = desugarCList2B PrimUC2B ty cts b - -desugarTerm (ATPrim ty x) = return $ DTPrim ty x -desugarTerm ATUnit = return DTUnit -desugarTerm (ATBool ty b) = return $ DTBool ty b -desugarTerm (ATChar c) = return $ DTChar c -desugarTerm (ATString cs) = +desugarTerm (ATPrim ty x) = return $ DTPrim ty x +desugarTerm ATUnit = return DTUnit +desugarTerm (ATBool ty b) = return $ DTBool ty b +desugarTerm (ATChar c) = return $ DTChar c +desugarTerm (ATString cs) = desugarContainer (TyList TyC) ListContainer (map (\c -> (ATChar c, Nothing)) cs) Nothing -desugarTerm (ATAbs q ty lam) = desugarAbs q ty [lam] - +desugarTerm (ATAbs q ty lam) = desugarAbs q ty [lam] -- Special cases for fully applied operators desugarTerm (ATApp resTy (ATPrim _ (PrimUOp uop)) t) | uopDesugars (getType t) resTy uop = desugarUnApp resTy uop t -desugarTerm (ATApp resTy (ATPrim _ (PrimBOp bop)) (ATTup _ [t1,t2])) +desugarTerm (ATApp resTy (ATPrim _ (PrimBOp bop)) (ATTup _ [t1, t2])) | bopDesugars (getType t1) (getType t2) resTy bop = desugarBinApp resTy bop t1 t2 - -desugarTerm (ATApp ty t1 t2) = +desugarTerm (ATApp ty t1 t2) = DTApp ty <$> desugarTerm t1 <*> desugarTerm t2 -desugarTerm (ATTup ty ts) = desugarTuples ty ts -desugarTerm (ATNat ty n) = return $ DTNat ty n -desugarTerm (ATRat r) = return $ DTRat r - -desugarTerm (ATTyOp ty op t) = return $ DTTyOp ty op t - -desugarTerm (ATChain _ t1 links) = desugarTerm $ expandChain t1 links - +desugarTerm (ATTup ty ts) = desugarTuples ty ts +desugarTerm (ATNat ty n) = return $ DTNat ty n +desugarTerm (ATRat r) = return $ DTRat r +desugarTerm (ATTyOp ty op t) = return $ DTTyOp ty op t +desugarTerm (ATChain _ t1 links) = desugarTerm $ expandChain t1 links desugarTerm (ATContainer ty c es mell) = desugarContainer ty c es mell - desugarTerm (ATContainerComp _ ctr bqt) = do (qs, t) <- unbind bqt desugarComp ctr t qs - desugarTerm (ATLet _ t) = do (bs, t2) <- unbind t desugarLet (fromTelescope bs) t2 - desugarTerm (ATCase ty bs) = DTCase ty <$> mapM desugarBranch bs - desugarTerm (ATTest info t) = DTTest (coerce info) <$> desugarTerm t -- | Desugar a property by wrapping its corresponding term in a test @@ -314,7 +317,7 @@ desugarProperty p = DTTest [] <$> desugarTerm p uopDesugars :: Type -> Type -> UOp -> Bool -- uopDesugars _ (TyFin _) Neg = True uopDesugars TyProp TyProp Not = False -uopDesugars _ _ uop = uop == Not +uopDesugars _ _ uop = uop == Not desugarPrimUOp :: Member Fresh r => Type -> Type -> UOp -> Sem r DTerm desugarPrimUOp argTy resTy op = do @@ -325,20 +328,33 @@ desugarPrimUOp argTy resTy op = do -- | Test whether a given binary operator is one that needs to be -- desugared, given the two types of the arguments and the type of the result. bopDesugars :: Type -> Type -> Type -> BOp -> Bool -bopDesugars _ TyN _ Choose = True +bopDesugars _ TyN _ Choose = True -- bopDesugars _ _ (TyFin _) bop | bop `elem` [Add, Mul] = True -- And, Or, Impl for Props don't desugar because they are primitive -- Prop constructors. On the other hand, logical operations on Bool -- can desugar in terms of more primitive conditional expressions. bopDesugars _ _ TyProp bop | bop `elem` [And, Or, Impl] = False -bopDesugars _ _ _ bop = bop `elem` - [ And, Or, Impl, Iff - , Neq, Gt, Leq, Geq, Min, Max - , IDiv - , Sub, SSub - , Inter, Diff, Union, Subset - ] +bopDesugars _ _ _ bop = + bop + `elem` [ And + , Or + , Impl + , Iff + , Neq + , Gt + , Leq + , Geq + , Min + , Max + , IDiv + , Sub + , SSub + , Inter + , Diff + , Union + , Subset + ] -- | Desugar a primitive binary operator at the given type. desugarPrimBOp :: Member Fresh r => Type -> Type -> Type -> BOp -> Sem r DTerm @@ -348,17 +364,18 @@ desugarPrimBOp ty1 ty2 resTy op = do y <- fresh (string2Name "arg2") let argsTy = ty1 :*: ty2 body <- desugarBinApp resTy op (atVar ty1 x) (atVar ty2 y) - return $ mkLambda (argsTy :->: resTy) [p] $ - DTCase resTy - [ bind - (toTelescope [DGPat (embed (dtVar argsTy (coerce p))) (DPPair argsTy (coerce x) (coerce y))]) - body - ] + return $ + mkLambda (argsTy :->: resTy) [p] $ + DTCase + resTy + [ bind + (toTelescope [DGPat (embed (dtVar argsTy (coerce p))) (DPPair argsTy (coerce x) (coerce y))]) + body + ] -- | Desugar a saturated application of a unary operator. -- The first argument is the type of the result. desugarUnApp :: Member Fresh r => Type -> UOp -> ATerm -> Sem r DTerm - -- Desugar negation on TyFin to a negation on TyZ followed by a mod. -- See the comments below re: Add and Mul on TyFin. -- desugarUnApp (TyFin n) Neg t = @@ -366,83 +383,84 @@ desugarUnApp :: Member Fresh r => Type -> UOp -> ATerm -> Sem r DTerm -- XXX This should be turned into a standard library definition. -- not t ==> {? false if t, true otherwise ?} -desugarUnApp _ Not t = desugarTerm $ - ATCase TyBool - [ fls <==. [AGBool (embed t)] - , tru <==. [] - ] - +desugarUnApp _ Not t = + desugarTerm $ + ATCase + TyBool + [ fls <==. [AGBool (embed t)] + , tru <==. [] + ] desugarUnApp ty uop t = error $ "Impossible! desugarUnApp " ++ show ty ++ " " ++ show uop ++ " " ++ show t -- | Desugar a saturated application of a binary operator. -- The first argument is the type of the result. desugarBinApp :: Member Fresh r => Type -> BOp -> ATerm -> ATerm -> Sem r DTerm - -- Implies, and, or should all be turned into a standard library -- definition. This will require first (1) adding support for -- modules/a standard library, including (2) the ability to define -- infix operators. -- t1 and t2 ==> {? t2 if t1, false otherwise ?} -desugarBinApp _ And t1 t2 = desugarTerm $ - ATCase TyBool - [ t2 <==. [tif t1] - , fls <==. [] - ] - +desugarBinApp _ And t1 t2 = + desugarTerm $ + ATCase + TyBool + [ t2 <==. [tif t1] + , fls <==. [] + ] -- (t1 implies t2) ==> (not t1 or t2) desugarBinApp _ Impl t1 t2 = desugarTerm $ tnot t1 ||. t2 - -- (t1 iff t2) ==> (t1 == t2) desugarBinApp _ Iff t1 t2 = desugarTerm $ t1 ==. t2 - -- t1 or t2 ==> {? true if t1, t2 otherwise ?}) -desugarBinApp _ Or t1 t2 = desugarTerm $ - ATCase TyBool - [ tru <==. [tif t1] - , t2 <==. [] - ] - +desugarBinApp _ Or t1 t2 = + desugarTerm $ + ATCase + TyBool + [ tru <==. [tif t1] + , t2 <==. [] + ] desugarBinApp _ Neq t1 t2 = desugarTerm $ tnot (t1 ==. t2) -desugarBinApp _ Gt t1 t2 = desugarTerm $ t2 <. t1 +desugarBinApp _ Gt t1 t2 = desugarTerm $ t2 <. t1 desugarBinApp _ Leq t1 t2 = desugarTerm $ tnot (t2 <. t1) desugarBinApp _ Geq t1 t2 = desugarTerm $ tnot (t1 <. t2) - -- XXX sharing! -desugarBinApp ty Min t1 t2 = desugarTerm $ - ATCase ty - [ t1 <==. [tif (t1 <. t2)] - , t2 <==. [] - ] - -desugarBinApp ty Max t1 t2 = desugarTerm $ - ATCase ty - [ t1 <==. [tif (t2 <. t1)] - , t2 <==. [] - ] - +desugarBinApp ty Min t1 t2 = + desugarTerm $ + ATCase + ty + [ t1 <==. [tif (t1 <. t2)] + , t2 <==. [] + ] +desugarBinApp ty Max t1 t2 = + desugarTerm $ + ATCase + ty + [ t1 <==. [tif (t2 <. t1)] + , t2 <==. [] + ] -- t1 // t2 ==> floor (t1 / t2) -desugarBinApp resTy IDiv t1 t2 = desugarTerm $ - ATApp resTy (ATPrim (getType t1 :->: resTy) PrimFloor) (mkBin (getType t1) Div t1 t2) - +desugarBinApp resTy IDiv t1 t2 = + desugarTerm $ + ATApp resTy (ATPrim (getType t1 :->: resTy) PrimFloor) (mkBin (getType t1) Div t1 t2) -- Desugar normal binomial coefficient (n choose k) to a multinomial -- coefficient with a singleton list, (n choose [k]). -- Note this will only be called when (getType t2 == TyN); see bopDesugars. -desugarBinApp _ Choose t1 t2 - = desugarTerm $ mkBin TyN Choose t1 (ctrSingleton ListContainer t2) - -desugarBinApp ty Sub t1 t2 = desugarTerm $ mkBin ty Add t1 (mkUn ty Neg t2) -desugarBinApp ty SSub t1 t2 = desugarTerm $ - -- t1 -. t2 ==> {? 0 if t1 < t2, t1 - t2 otherwise ?} - ATCase ty - [ ATNat ty 0 <==. [tif (t1 <. t2)] - , mkBin ty Sub t1 t2 <==. [] +desugarBinApp _ Choose t1 t2 = + desugarTerm $ mkBin TyN Choose t1 (ctrSingleton ListContainer t2) +desugarBinApp ty Sub t1 t2 = desugarTerm $ mkBin ty Add t1 (mkUn ty Neg t2) +desugarBinApp ty SSub t1 t2 = + desugarTerm $ + -- t1 -. t2 ==> {? 0 if t1 < t2, t1 - t2 otherwise ?} + ATCase + ty + [ ATNat ty 0 <==. [tif (t1 <. t2)] + , mkBin ty Sub t1 t2 <==. [] -- NOTE, the above is slightly bogus since the whole point of SSub is -- because we can't subtract naturals. However, this will -- immediately desugar to a DTerm. When we write a linting -- typechecker for DTerms we should allow subtraction on TyN! - ] - + ] -- Addition and multiplication on TyFin just desugar to the operation -- followed by a call to mod. -- desugarBinApp (TyFin n) op t1 t2 @@ -450,57 +468,60 @@ desugarBinApp ty SSub t1 t2 = desugarTerm $ -- mkBin (TyFin n) Mod -- (mkBin TyN op t1 t2) -- (ATNat TyN n) - -- Note the typing of this is a bit funny: t1 and t2 presumably - -- have type (TyFin n), and now we are saying that applying 'op' - -- to them results in TyN, then applying 'mod' results in a TyFin - -- n again. Using TyN as the intermediate result is necessary so - -- we don't fall into an infinite desugaring loop, and intuitively - -- makes sense because the idea is that we first do the operation - -- as a normal operation in "natural land" and then do a mod. - -- - -- We will have to think carefully about how the linting - -- typechecker for DTerms should treat TyN and TyFin. Probably - -- something like this will work: TyFin is a subtype of TyN, and - -- TyN can be turned into TyFin with mod. (We don't want such - -- typing rules in the surface disco language itself because - -- implicit coercions from TyFin -> N don't commute with - -- operations like addition and multiplication, e.g. 3+3 yields 1 - -- if we add them in Z5 and then coerce to Nat, but 6 if we first - -- coerce both and then add. +-- Note the typing of this is a bit funny: t1 and t2 presumably +-- have type (TyFin n), and now we are saying that applying 'op' +-- to them results in TyN, then applying 'mod' results in a TyFin +-- n again. Using TyN as the intermediate result is necessary so +-- we don't fall into an infinite desugaring loop, and intuitively +-- makes sense because the idea is that we first do the operation +-- as a normal operation in "natural land" and then do a mod. +-- +-- We will have to think carefully about how the linting +-- typechecker for DTerms should treat TyN and TyFin. Probably +-- something like this will work: TyFin is a subtype of TyN, and +-- TyN can be turned into TyFin with mod. (We don't want such +-- typing rules in the surface disco language itself because +-- implicit coercions from TyFin -> N don't commute with +-- operations like addition and multiplication, e.g. 3+3 yields 1 +-- if we add them in Z5 and then coerce to Nat, but 6 if we first +-- coerce both and then add. -- Intersection, difference, and union all desugar to an application -- of 'merge' with an appropriate combining operation. desugarBinApp ty op t1 t2 | op `elem` [Inter, Diff, Union] = - desugarTerm $ - tapps (ATPrim ((TyN :*: TyN :->: TyN) :*: ty :*: ty :->: ty) PrimMerge) - [ ATPrim (TyN :*: TyN :->: TyN) (mergeOp ty op) - , t1 - , t2 - ] - where - mergeOp _ Inter = PrimBOp Min - mergeOp _ Diff = PrimBOp SSub - mergeOp (TySet _) Union = PrimBOp Max - mergeOp (TyBag _) Union = PrimBOp Add - mergeOp _ _ = error $ "Impossible! mergeOp " ++ show ty ++ " " ++ show op + desugarTerm $ + tapps + (ATPrim ((TyN :*: TyN :->: TyN) :*: ty :*: ty :->: ty) PrimMerge) + [ ATPrim (TyN :*: TyN :->: TyN) (mergeOp ty op) + , t1 + , t2 + ] + where + mergeOp _ Inter = PrimBOp Min + mergeOp _ Diff = PrimBOp SSub + mergeOp (TySet _) Union = PrimBOp Max + mergeOp (TyBag _) Union = PrimBOp Add + mergeOp _ _ = error $ "Impossible! mergeOp " ++ show ty ++ " " ++ show op -- A ⊆ B <==> (A ⊔ B = B) -- where ⊔ denotes 'merge max'. -- Note it is NOT union, since this doesn't work for bags. -- e.g. bag [1] union bag [1,2] = bag [1,1,2] /= bag [1,2]. -desugarBinApp _ Subset t1 t2 = desugarTerm $ - tapps (ATPrim (ty :*: ty :->: TyBool) (PrimBOp Eq)) - [ tapps (ATPrim ((TyN :*: TyN :->: TyN) :*: ty :*: ty :->: ty) PrimMerge) - [ ATPrim (TyN :*: TyN :->: TyN) (PrimBOp Max) - , t1 - , t2 - ] - , t2 -- XXX sharing - ] - where - ty = getType t1 - +desugarBinApp _ Subset t1 t2 = + desugarTerm $ + tapps + (ATPrim (ty :*: ty :->: TyBool) (PrimBOp Eq)) + [ tapps + (ATPrim ((TyN :*: TyN :->: TyN) :*: ty :*: ty :->: ty) PrimMerge) + [ ATPrim (TyN :*: TyN :->: TyN) (PrimBOp Max) + , t1 + , t2 + ] + , t2 -- XXX sharing + ] + where + ty = getType t1 desugarBinApp ty bop t1 t2 = error $ "Impossible! desugarBinApp " ++ show ty ++ " " ++ show bop ++ " " ++ show t1 ++ " " ++ show t2 ------------------------------------------------------------ @@ -514,36 +535,37 @@ desugarComp ctr t qs = expandComp ctr t qs >>= desugarTerm -- | Expand a container comprehension into an equivalent ATerm. expandComp :: Member Fresh r => Container -> ATerm -> Telescope AQual -> Sem r ATerm - -- [ t | ] = [ t ] expandComp ctr t TelEmpty = return $ ctrSingleton ctr t - -- [ t | q, qs ] = ... -expandComp ctr t (TelCons (unrebind -> (q,qs))) - = case q of - -- [ t | x in l, qs ] = join (map (\x -> [t | qs]) l) - AQBind x (unembed -> lst) -> do - tqs <- expandComp ctr t qs - let c = containerTy ctr - tTy = getType t - xTy = case getType lst of - TyContainer _ e -> e - _ -> error "Impossible! Not a container in expandComp" - joinTy = c (c tTy) :->: c tTy - mapTy = (xTy :->: c tTy) :*: c xTy :->: c (c tTy) - return $ tapp (ATPrim joinTy PrimJoin) $ +expandComp ctr t (TelCons (unrebind -> (q, qs))) = + case q of + -- [ t | x in l, qs ] = join (map (\x -> [t | qs]) l) + AQBind x (unembed -> lst) -> do + tqs <- expandComp ctr t qs + let c = containerTy ctr + tTy = getType t + xTy = case getType lst of + TyContainer _ e -> e + _ -> error "Impossible! Not a container in expandComp" + joinTy = c (c tTy) :->: c tTy + mapTy = (xTy :->: c tTy) :*: c xTy :->: c (c tTy) + return $ + tapp (ATPrim joinTy PrimJoin) $ tapp (ATPrim mapTy PrimEach) - (mkPair - (ATAbs Lam (xTy :->: c tTy) (bind [APVar xTy x] tqs)) - lst + ( mkPair + (ATAbs Lam (xTy :->: c tTy) (bind [APVar xTy x] tqs)) + lst ) - -- [ t | g, qs ] = if g then [ t | qs ] else [] - AQGuard (unembed -> g) -> do - tqs <- expandComp ctr t qs - return $ ATCase (containerTy ctr (getType t)) - [ tqs <==. [tif g] + -- [ t | g, qs ] = if g then [ t | qs ] else [] + AQGuard (unembed -> g) -> do + tqs <- expandComp ctr t qs + return $ + ATCase + (containerTy ctr (getType t)) + [ tqs <==. [tif g] , ctrNil ctr (getType t) <==. [] ] @@ -559,8 +581,8 @@ desugarLet :: Member Fresh r => [ABinding] -> ATerm -> Sem r DTerm desugarLet [] t = desugarTerm t desugarLet ((ABinding _ x (unembed -> t1)) : ls) t = dtapp - <$> (DTAbs Lam (getType t1 :->: getType t) - <$> (bind (coerce x) <$> desugarLet ls t) + <$> ( DTAbs Lam (getType t1 :->: getType t) + <$> (bind (coerce x) <$> desugarLet ls t) ) <*> desugarTerm t1 @@ -575,26 +597,26 @@ desugarLet ((ABinding _ x (unembed -> t1)) : ls) t = -- @\x. \y. \z. q@ mkLambda :: Type -> [Name ATerm] -> DTerm -> DTerm mkLambda funty args c = go funty args - where - go _ [] = c - go ty@(_ :->: ty2) (x:xs) = DTAbs Lam ty (bind (coerce x) (go ty2 xs)) - go ty as = error $ "Impossible! mkLambda.go " ++ show ty ++ " " ++ show as + where + go _ [] = c + go ty@(_ :->: ty2) (x : xs) = DTAbs Lam ty (bind (coerce x) (go ty2 xs)) + go ty as = error $ "Impossible! mkLambda.go " ++ show ty ++ " " ++ show as mkQuant :: Quantifier -> [Type] -> [Name ATerm] -> DTerm -> DTerm mkQuant q argtys args c = foldr quantify c (zip args argtys) where - quantify (x, ty) body = DTAbs q ty (bind (coerce x) body) + quantify (x, ty) body = DTAbs q ty (bind (coerce x) body) mkAbs :: Quantifier -> Type -> [Type] -> [Name ATerm] -> DTerm -> DTerm mkAbs Lam funty _ args c = mkLambda funty args c -mkAbs q _ argtys args c = mkQuant q argtys args c +mkAbs q _ argtys args c = mkQuant q argtys args c -- | Desugar a tuple to nested pairs, /e.g./ @(a,b,c,d) ==> (a,(b,(c,d)))@.a desugarTuples :: Member Fresh r => Type -> [ATerm] -> Sem r DTerm -desugarTuples _ [t] = desugarTerm t -desugarTuples ty@(_ :*: ty2) (t:ts) = DTPair ty <$> desugarTerm t <*> desugarTuples ty2 ts -desugarTuples ty ats - = error $ "Impossible! desugarTuples " ++ show ty ++ " " ++ show ats +desugarTuples _ [t] = desugarTerm t +desugarTuples ty@(_ :*: ty2) (t : ts) = DTPair ty <$> desugarTerm t <*> desugarTuples ty2 ts +desugarTuples ty ats = + error $ "Impossible! desugarTuples " ++ show ty ++ " " ++ show ats -- | Expand a chain of comparisons into a sequence of binary -- comparisons combined with @and@. Note we only expand it into @@ -606,7 +628,9 @@ expandChain :: ATerm -> [ALink] -> ATerm expandChain _ [] = error "Can't happen! expandChain _ []" expandChain t1 [ATLink op t2] = mkBin TyBool op t1 t2 expandChain t1 (ATLink op t2 : links) = - mkBin TyBool And + mkBin + TyBool + And (mkBin TyBool op t1 t2) (expandChain t2 links) @@ -615,7 +639,7 @@ desugarBranch :: Member Fresh r => ABranch -> Sem r DBranch desugarBranch b = do (ags, at) <- unbind b dgs <- desugarGuards ags - d <- desugarTerm at + d <- desugarTerm at return $ bind dgs d -- | Desugar the list of guards in one branch of a case expression. @@ -623,215 +647,208 @@ desugarBranch b = do -- turned into pattern guards which match against @true@. desugarGuards :: Member Fresh r => Telescope AGuard -> Sem r (Telescope DGuard) desugarGuards = fmap (toTelescope . concat) . mapM desugarGuard . fromTelescope - where - desugarGuard :: Member Fresh r => AGuard -> Sem r [DGuard] - - -- A Boolean guard is desugared to a pattern-match on @true = right(unit)@. - desugarGuard (AGBool (unembed -> at)) = do - dt <- desugarTerm at - desugarMatch dt (APInj TyBool R APUnit) - - -- 'let x = t' is desugared to 'when t is x'. - desugarGuard (AGLet (ABinding _ x (unembed -> at))) = do - dt <- desugarTerm at - varMatch dt (coerce x) - - -- Desugaring 'when t is p' is the most complex case; we have to - -- break down the pattern and match it incrementally. - desugarGuard (AGPat (unembed -> at) p) = do - dt <- desugarTerm at - desugarMatch dt p - - -- Desugar a guard of the form 'when dt is p'. An entire match is - -- the right unit to desugar --- as opposed to, say, writing a - -- function to desugar a pattern --- since a match may desugar to - -- multiple matches, and on recursive calls we need to know what - -- term/variable should be bound to the pattern. - -- - -- A match may desugar to multiple matches for two reasons: - -- - -- 1. Nested patterns 'explode' into a 'telescope' matching one - -- constructor at a time, for example, 'when t is (x,y,3)' - -- becomes 'when t is (x,x0) when x0 is (y,x1) when x1 is 3'. - -- This makes the order of matching explicit and enables lazy - -- matching without requiring special support from the - -- interpreter other than WHNF reduction. - -- - -- 2. Matches against arithmetic patterns desugar to a - -- combination of matching, computation, and boolean checks. - -- For example, 'when t is (y+1)' becomes 'when t is x0 if x0 >= - -- 1 let y = x0-1'. - desugarMatch :: Member Fresh r => DTerm -> APattern -> Sem r [DGuard] - desugarMatch dt (APVar ty x) = mkMatch dt (DPVar ty (coerce x)) - desugarMatch _ (APWild _) = return [] - desugarMatch dt APUnit = mkMatch dt DPUnit - desugarMatch dt (APBool b) = desugarMatch dt (APInj TyBool (bool L R b) APUnit) - desugarMatch dt (APNat ty n) = desugarMatch (dtbin TyBool (PrimBOp Eq) dt (DTNat ty n)) (APBool True) - desugarMatch dt (APChar c) = desugarMatch (dtbin TyBool (PrimBOp Eq) dt (DTChar c)) (APBool True) - desugarMatch dt (APString s) = desugarMatch dt (APList (TyList TyC) (map APChar s)) - desugarMatch dt (APTup tupTy pat) = desugarTuplePats tupTy dt pat - where - desugarTuplePats :: Member Fresh r => Type -> DTerm -> [APattern] -> Sem r [DGuard] - desugarTuplePats _ _ [] = error "Impossible! desugarTuplePats []" - desugarTuplePats _ t [p] = desugarMatch t p - desugarTuplePats ty@(_ :*: ty2) t (p:ps) = do - (x1,gs1) <- varForPat p - (x2,gs2) <- case ps of - [APVar _ px2] -> return (coerce px2, []) - _ -> do - x <- fresh (string2Name "x") - (x,) <$> desugarTuplePats ty2 (dtVar ty2 x) ps - fmap concat . sequence $ - [ mkMatch t $ DPPair ty x1 x2 - , return gs1 - , return gs2 - ] - desugarTuplePats ty _ _ - = error $ "Impossible! desugarTuplePats with non-pair type " ++ show ty - - desugarMatch dt (APInj ty s p) = do - (x,gs) <- varForPat p - fmap concat . sequence $ - [ mkMatch dt $ DPInj ty s x - , return gs - ] - - desugarMatch dt (APCons ty p1 p2) = do - y <- fresh (string2Name "y") - (x1, gs1) <- varForPat p1 - (x2, gs2) <- varForPat p2 - - let eltTy = getType p1 - unrolledTy = eltTy :*: ty + where + desugarGuard :: Member Fresh r => AGuard -> Sem r [DGuard] + + -- A Boolean guard is desugared to a pattern-match on @true = right(unit)@. + desugarGuard (AGBool (unembed -> at)) = do + dt <- desugarTerm at + desugarMatch dt (APInj TyBool R APUnit) + + -- 'let x = t' is desugared to 'when t is x'. + desugarGuard (AGLet (ABinding _ x (unembed -> at))) = do + dt <- desugarTerm at + varMatch dt (coerce x) + + -- Desugaring 'when t is p' is the most complex case; we have to + -- break down the pattern and match it incrementally. + desugarGuard (AGPat (unembed -> at) p) = do + dt <- desugarTerm at + desugarMatch dt p + + -- Desugar a guard of the form 'when dt is p'. An entire match is + -- the right unit to desugar --- as opposed to, say, writing a + -- function to desugar a pattern --- since a match may desugar to + -- multiple matches, and on recursive calls we need to know what + -- term/variable should be bound to the pattern. + -- + -- A match may desugar to multiple matches for two reasons: + -- + -- 1. Nested patterns 'explode' into a 'telescope' matching one + -- constructor at a time, for example, 'when t is (x,y,3)' + -- becomes 'when t is (x,x0) when x0 is (y,x1) when x1 is 3'. + -- This makes the order of matching explicit and enables lazy + -- matching without requiring special support from the + -- interpreter other than WHNF reduction. + -- + -- 2. Matches against arithmetic patterns desugar to a + -- combination of matching, computation, and boolean checks. + -- For example, 'when t is (y+1)' becomes 'when t is x0 if x0 >= + -- 1 let y = x0-1'. + desugarMatch :: Member Fresh r => DTerm -> APattern -> Sem r [DGuard] + desugarMatch dt (APVar ty x) = mkMatch dt (DPVar ty (coerce x)) + desugarMatch _ (APWild _) = return [] + desugarMatch dt APUnit = mkMatch dt DPUnit + desugarMatch dt (APBool b) = desugarMatch dt (APInj TyBool (bool L R b) APUnit) + desugarMatch dt (APNat ty n) = desugarMatch (dtbin TyBool (PrimBOp Eq) dt (DTNat ty n)) (APBool True) + desugarMatch dt (APChar c) = desugarMatch (dtbin TyBool (PrimBOp Eq) dt (DTChar c)) (APBool True) + desugarMatch dt (APString s) = desugarMatch dt (APList (TyList TyC) (map APChar s)) + desugarMatch dt (APTup tupTy pat) = desugarTuplePats tupTy dt pat + where + desugarTuplePats :: Member Fresh r => Type -> DTerm -> [APattern] -> Sem r [DGuard] + desugarTuplePats _ _ [] = error "Impossible! desugarTuplePats []" + desugarTuplePats _ t [p] = desugarMatch t p + desugarTuplePats ty@(_ :*: ty2) t (p : ps) = do + (x1, gs1) <- varForPat p + (x2, gs2) <- case ps of + [APVar _ px2] -> return (coerce px2, []) + _ -> do + x <- fresh (string2Name "x") + (x,) <$> desugarTuplePats ty2 (dtVar ty2 x) ps fmap concat . sequence $ - [ mkMatch dt (DPInj ty R y) - , mkMatch (dtVar unrolledTy y) (DPPair unrolledTy x1 x2) + [ mkMatch t $ DPPair ty x1 x2 , return gs1 , return gs2 ] - - desugarMatch dt (APList ty []) = desugarMatch dt (APInj ty L APUnit) - desugarMatch dt (APList ty ps) = - desugarMatch dt $ foldr (APCons ty) (APList ty []) ps - - -- when dt is (p + t) ==> when dt is x0; let v = t; [if x0 >= v]; when x0-v is p - desugarMatch dt (APAdd ty _ p t) = arithBinMatch posRestrict (-.) dt ty p t - where - posRestrict plusty - | plusty `elem` [TyN, TyF] = Just (>=.) - | otherwise = Nothing - - -- when dt is (p * t) ==> when dt is x0; let v = t; [if v divides x0]; when x0 / v is p - desugarMatch dt (APMul ty _ p t) = arithBinMatch intRestrict (/.) dt ty p t - where - intRestrict plusty - | plusty `elem` [TyN, TyZ] = Just (flip (|.)) - | otherwise = Nothing - - -- when dt is (p - t) ==> when dt is x0; let v = t; when x0 + v is p - desugarMatch dt (APSub ty p t) = arithBinMatch (const Nothing) (+.) dt ty p t - - -- when dt is (p/q) ==> when $frac(dt) is (p, q) - desugarMatch dt (APFrac _ p q) - = desugarMatch - (dtapp (DTPrim (TyQ :->: TyZ :*: TyN) PrimFrac) dt) - (APTup (TyZ :*: TyN) [p, q]) - - -- when dt is (-p) ==> when dt is x0; if x0 < 0; when -x0 is p - desugarMatch dt (APNeg ty p) = do - - -- when dt is x0 - (x0, g1) <- varFor dt - - -- if x0 < 0 - g2 <- desugarGuard $ AGBool (embed (atVar ty (coerce x0) <. ATNat ty 0)) - - -- when -x0 is p - neg <- desugarTerm $ mkUn ty Neg (atVar ty (coerce x0)) - g3 <- desugarMatch neg p - - return (g1 ++ g2 ++ g3) - - mkMatch :: Member Fresh r => DTerm -> DPattern -> Sem r [DGuard] - mkMatch dt dp = return [DGPat (embed dt) dp] - - varMatch :: Member Fresh r => DTerm -> Name DTerm -> Sem r [DGuard] - varMatch dt x = mkMatch dt (DPVar (getType dt) x) - - varFor :: Member Fresh r => DTerm -> Sem r (Name DTerm, [DGuard]) - varFor (DTVar _ (QName _ x)) = return (x, []) -- XXX return a name + provenance?? - varFor dt = do - x <- fresh (string2Name "x") - g <- varMatch dt x - return (x, g) - - varForPat :: Member Fresh r => APattern -> Sem r (Name DTerm, [DGuard]) - varForPat (APVar _ x) = return (coerce x, []) - varForPat p = do - x <- fresh (string2Name "px") -- changing this from x fixed a bug and I don't know why =( - (x,) <$> desugarMatch (dtVar (getType p) x) p - - arithBinMatch - :: Member Fresh r - => (Type -> Maybe (ATerm -> ATerm -> ATerm)) - -> (ATerm -> ATerm -> ATerm) - -> DTerm -> Type -> APattern -> ATerm -> Sem r [DGuard] - arithBinMatch restrict inverse dt ty p t = do - (x0, g1) <- varFor dt - - -- let v = t - t' <- desugarTerm t - (v, g2) <- varFor t' - - g3 <- case restrict ty of - Nothing -> return [] - - -- if x0 `cmp` v - Just cmp -> - desugarGuard $ - AGBool (embed (atVar ty (coerce x0) `cmp` atVar (getType t) (coerce v))) - - -- when x0 `inverse` v is p - inv <- desugarTerm (atVar ty (coerce x0) `inverse` atVar (getType t) (coerce v)) - g4 <- desugarMatch inv p - - return (g1 ++ g2 ++ g3 ++ g4) + desugarTuplePats ty _ _ = + error $ "Impossible! desugarTuplePats with non-pair type " ++ show ty + desugarMatch dt (APInj ty s p) = do + (x, gs) <- varForPat p + fmap concat . sequence $ + [ mkMatch dt $ DPInj ty s x + , return gs + ] + desugarMatch dt (APCons ty p1 p2) = do + y <- fresh (string2Name "y") + (x1, gs1) <- varForPat p1 + (x2, gs2) <- varForPat p2 + + let eltTy = getType p1 + unrolledTy = eltTy :*: ty + fmap concat . sequence $ + [ mkMatch dt (DPInj ty R y) + , mkMatch (dtVar unrolledTy y) (DPPair unrolledTy x1 x2) + , return gs1 + , return gs2 + ] + desugarMatch dt (APList ty []) = desugarMatch dt (APInj ty L APUnit) + desugarMatch dt (APList ty ps) = + desugarMatch dt $ foldr (APCons ty) (APList ty []) ps + -- when dt is (p + t) ==> when dt is x0; let v = t; [if x0 >= v]; when x0-v is p + desugarMatch dt (APAdd ty _ p t) = arithBinMatch posRestrict (-.) dt ty p t + where + posRestrict plusty + | plusty `elem` [TyN, TyF] = Just (>=.) + | otherwise = Nothing + + -- when dt is (p * t) ==> when dt is x0; let v = t; [if v divides x0]; when x0 / v is p + desugarMatch dt (APMul ty _ p t) = arithBinMatch intRestrict (/.) dt ty p t + where + intRestrict plusty + | plusty `elem` [TyN, TyZ] = Just (flip (|.)) + | otherwise = Nothing + + -- when dt is (p - t) ==> when dt is x0; let v = t; when x0 + v is p + desugarMatch dt (APSub ty p t) = arithBinMatch (const Nothing) (+.) dt ty p t + -- when dt is (p/q) ==> when $frac(dt) is (p, q) + desugarMatch dt (APFrac _ p q) = + desugarMatch + (dtapp (DTPrim (TyQ :->: TyZ :*: TyN) PrimFrac) dt) + (APTup (TyZ :*: TyN) [p, q]) + -- when dt is (-p) ==> when dt is x0; if x0 < 0; when -x0 is p + desugarMatch dt (APNeg ty p) = do + -- when dt is x0 + (x0, g1) <- varFor dt + + -- if x0 < 0 + g2 <- desugarGuard $ AGBool (embed (atVar ty (coerce x0) <. ATNat ty 0)) + + -- when -x0 is p + neg <- desugarTerm $ mkUn ty Neg (atVar ty (coerce x0)) + g3 <- desugarMatch neg p + + return (g1 ++ g2 ++ g3) + + mkMatch :: Member Fresh r => DTerm -> DPattern -> Sem r [DGuard] + mkMatch dt dp = return [DGPat (embed dt) dp] + + varMatch :: Member Fresh r => DTerm -> Name DTerm -> Sem r [DGuard] + varMatch dt x = mkMatch dt (DPVar (getType dt) x) + + varFor :: Member Fresh r => DTerm -> Sem r (Name DTerm, [DGuard]) + varFor (DTVar _ (QName _ x)) = return (x, []) -- XXX return a name + provenance?? + varFor dt = do + x <- fresh (string2Name "x") + g <- varMatch dt x + return (x, g) + + varForPat :: Member Fresh r => APattern -> Sem r (Name DTerm, [DGuard]) + varForPat (APVar _ x) = return (coerce x, []) + varForPat p = do + x <- fresh (string2Name "px") -- changing this from x fixed a bug and I don't know why =( + (x,) <$> desugarMatch (dtVar (getType p) x) p + + arithBinMatch :: + Member Fresh r => + (Type -> Maybe (ATerm -> ATerm -> ATerm)) -> + (ATerm -> ATerm -> ATerm) -> + DTerm -> + Type -> + APattern -> + ATerm -> + Sem r [DGuard] + arithBinMatch restrict inverse dt ty p t = do + (x0, g1) <- varFor dt + + -- let v = t + t' <- desugarTerm t + (v, g2) <- varFor t' + + g3 <- case restrict ty of + Nothing -> return [] + -- if x0 `cmp` v + Just cmp -> + desugarGuard $ + AGBool (embed (atVar ty (coerce x0) `cmp` atVar (getType t) (coerce v))) + + -- when x0 `inverse` v is p + inv <- desugarTerm (atVar ty (coerce x0) `inverse` atVar (getType t) (coerce v)) + g4 <- desugarMatch inv p + + return (g1 ++ g2 ++ g3 ++ g4) -- | Desugar a container literal such as @[1,2,3]@ or @{1,2,3}@. desugarContainer :: Member Fresh r => Type -> Container -> [(ATerm, Maybe ATerm)] -> Maybe (Ellipsis ATerm) -> Sem r DTerm - -- Literal list containers desugar to nested applications of cons. desugarContainer ty ListContainer es Nothing = foldr (dtbin ty (PrimBOp Cons)) (DTNil ty) <$> mapM (desugarTerm . fst) es - -- A list container with an ellipsis @[x, y, z .. e]@ desugars to an -- application of the primitive 'until' function. desugarContainer ty@(TyList _) ListContainer es (Just (Until t)) = dtbin ty PrimUntil <$> desugarTerm t <*> desugarContainer ty ListContainer es Nothing - -- If desugaring a bag and there are any counts specified, desugar to -- an application of bagFromCounts to a bag of pairs (with a literal -- value of 1 filled in for missing counts as needed). desugarContainer (TyBag eltTy) BagContainer es mell | any (isJust . snd) es = - dtapp (DTPrim (TySet (eltTy :*: TyN) :->: TyBag eltTy) PrimC2B) - <$> desugarContainer (TyBag (eltTy :*: TyN)) BagContainer counts mell - - where - -- turn e.g. x # 3, y into (x, 3), (y, 1) - counts = [ (ATTup (eltTy :*: TyN) [t, fromMaybe (ATNat TyN 1) n], Nothing) - | (t, n) <- es - ] + dtapp (DTPrim (TySet (eltTy :*: TyN) :->: TyBag eltTy) PrimC2B) + <$> desugarContainer (TyBag (eltTy :*: TyN)) BagContainer counts mell + where + -- turn e.g. x # 3, y into (x, 3), (y, 1) + counts = + [ (ATTup (eltTy :*: TyN) [t, fromMaybe (ATNat TyN 1) n], Nothing) + | (t, n) <- es + ] -- Other containers desugar to an application of the appropriate -- container conversion function to the corresponding desugared list. desugarContainer ty _ es mell = dtapp (DTPrim (TyList eltTy :->: ty) conv) <$> desugarContainer (TyList eltTy) ListContainer es mell - where - (conv, eltTy) = case ty of - TyBag e -> (PrimBag, e) - TySet e -> (PrimSet, e) - _ -> error $ "Impossible! Non-container type " ++ show ty ++ " in desugarContainer" + where + (conv, eltTy) = case ty of + TyBag e -> (PrimBag, e) + TySet e -> (PrimSet, e) + _ -> error $ "Impossible! Non-container type " ++ show ty ++ " in desugarContainer" diff --git a/src/Disco/Doc.hs b/src/Disco/Doc.hs index 1d120916..5e02b245 100644 --- a/src/Disco/Doc.hs +++ b/src/Disco/Doc.hs @@ -1,152 +1,157 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Doc -- Copyright : disco team and contributors -- Maintainer : byorgey@gmail.com -- -- Built-in documentation. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause +module Disco.Doc ( + primDoc, + primReference, + otherDoc, + otherReference, +) where -module Disco.Doc - ( primDoc, primReference, otherDoc, otherReference - ) where +import Data.Map (Map) +import qualified Data.Map as M -import Data.Map (Map) -import qualified Data.Map as M - -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Util ((==>)) +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Util ((==>)) -- | A map from some primitives to a short descriptive string, -- to be shown by the :doc command. primDoc :: Map Prim String -primDoc = M.fromList - [ PrimUOp Neg ==> "Arithmetic negation." - , PrimBOp Add ==> "The sum of two numbers, types, or graphs." - , PrimBOp Sub ==> "The difference of two numbers." - , PrimBOp SSub ==> "The difference of two numbers, with a lower bound of 0." - , PrimBOp Mul ==> "The product of two numbers, types, or graphs." - , PrimBOp Div ==> "Divide two numbers." - , PrimBOp IDiv ==> "The integer quotient of two numbers, rounded down." - , PrimBOp Mod ==> "a mod b is the remainder when a is divided by b." - , PrimBOp Exp ==> "Exponentiation. a ^ b is a raised to the b power." - , PrimUOp Fact ==> "n! computes the factorial of n, that is, 1 * 2 * ... * n." - , PrimFloor ==> "floor(x) is the largest integer which is <= x." - , PrimCeil ==> "ceiling(x) is the smallest integer which is >= x." - , PrimAbs ==> "abs(x) is the absolute value of x. Also written |x|." - , PrimUOp Not ==> "Logical negation: not(true) = false and not(false) = true." - , PrimBOp And ==> "Logical conjunction (and): true /\\ true = true; otherwise x /\\ y = false." - , PrimBOp Or ==> "Logical disjunction (or): false \\/ false = false; otherwise x \\/ y = true." - , PrimBOp Impl ==> "Logical implication (implies): true -> false = false; otherwise x -> y = true." - , PrimBOp Iff ==> "Biconditional (if and only if)." - , PrimBOp Eq ==> "Equality test. x == y is true if x and y are equal." - , PrimBOp Neq ==> "Inequality test. x /= y is true if x and y are unequal." - , PrimBOp Lt ==> "Less-than test. x < y is true if x is less than (but not equal to) y." - , PrimBOp Gt ==> "Greater-than test. x > y is true if x is greater than (but not equal to) y." - , PrimBOp Leq ==> "Less-than-or-equal test. x <= y is true if x is less than or equal to y." - , PrimBOp Geq ==> "Greater-than-or-equal test. x >= y is true if x is greater than or equal to y." - - , PrimBOp CartProd ==> "Cartesian product, i.e. the collection of all pairs. Also works on bags and sets." - , PrimPower ==> "Power set, i.e. the set of all subsets. Also works on bags." - , PrimBOp Union ==> "Union of two sets (or bags)." - , PrimBOp Inter ==> "Intersection of two sets (or bags)." - , PrimBOp Diff ==> "Difference of two sets (or bags)." - ] +primDoc = + M.fromList + [ PrimUOp Neg ==> "Arithmetic negation." + , PrimBOp Add ==> "The sum of two numbers, types, or graphs." + , PrimBOp Sub ==> "The difference of two numbers." + , PrimBOp SSub ==> "The difference of two numbers, with a lower bound of 0." + , PrimBOp Mul ==> "The product of two numbers, types, or graphs." + , PrimBOp Div ==> "Divide two numbers." + , PrimBOp IDiv ==> "The integer quotient of two numbers, rounded down." + , PrimBOp Mod ==> "a mod b is the remainder when a is divided by b." + , PrimBOp Exp ==> "Exponentiation. a ^ b is a raised to the b power." + , PrimUOp Fact ==> "n! computes the factorial of n, that is, 1 * 2 * ... * n." + , PrimFloor ==> "floor(x) is the largest integer which is <= x." + , PrimCeil ==> "ceiling(x) is the smallest integer which is >= x." + , PrimAbs ==> "abs(x) is the absolute value of x. Also written |x|." + , PrimUOp Not ==> "Logical negation: not(true) = false and not(false) = true." + , PrimBOp And ==> "Logical conjunction (and): true /\\ true = true; otherwise x /\\ y = false." + , PrimBOp Or ==> "Logical disjunction (or): false \\/ false = false; otherwise x \\/ y = true." + , PrimBOp Impl ==> "Logical implication (implies): true -> false = false; otherwise x -> y = true." + , PrimBOp Iff ==> "Biconditional (if and only if)." + , PrimBOp Eq ==> "Equality test. x == y is true if x and y are equal." + , PrimBOp Neq ==> "Inequality test. x /= y is true if x and y are unequal." + , PrimBOp Lt ==> "Less-than test. x < y is true if x is less than (but not equal to) y." + , PrimBOp Gt ==> "Greater-than test. x > y is true if x is greater than (but not equal to) y." + , PrimBOp Leq ==> "Less-than-or-equal test. x <= y is true if x is less than or equal to y." + , PrimBOp Geq ==> "Greater-than-or-equal test. x >= y is true if x is greater than or equal to y." + , PrimBOp CartProd ==> "Cartesian product, i.e. the collection of all pairs. Also works on bags and sets." + , PrimPower ==> "Power set, i.e. the set of all subsets. Also works on bags." + , PrimBOp Union ==> "Union of two sets (or bags)." + , PrimBOp Inter ==> "Intersection of two sets (or bags)." + , PrimBOp Diff ==> "Difference of two sets (or bags)." + ] -- | A map from some primitives to their corresponding page in the -- Disco language reference -- (https://disco-lang.readthedocs.io/en/latest/reference/index.html). primReference :: Map Prim String -primReference = M.fromList - [ PrimBOp Add ==> "addition" - , PrimBOp Sub ==> "subtraction" - , PrimBOp SSub ==> "subtraction" - , PrimBOp Mul ==> "multiplication" - , PrimBOp Div ==> "division" - , PrimBOp IDiv ==> "integerdiv" - , PrimBOp Mod ==> "mod" - , PrimBOp Exp ==> "exponentiation" - , PrimUOp Fact ==> "factorial" - , PrimFloor ==> "round" - , PrimCeil ==> "round" - , PrimAbs ==> "abs" - , PrimUOp Not ==> "logic-ops" - , PrimBOp And ==> "logic-ops" - , PrimBOp Or ==> "logic-ops" - , PrimBOp Impl ==> "logic-ops" - , PrimBOp Iff ==> "logic-ops" - , PrimBOp CartProd ==> "cp" - , PrimPower ==> "power" - , PrimBOp Union ==> "set-ops" - , PrimBOp Inter ==> "set-ops" - , PrimBOp Diff ==> "set-ops" - , PrimBOp Eq ==> "compare" - , PrimBOp Neq ==> "compare" - , PrimBOp Lt ==> "compare" - , PrimBOp Gt ==> "compare" - , PrimBOp Leq ==> "compare" - , PrimBOp Geq ==> "compare" - - ] +primReference = + M.fromList + [ PrimBOp Add ==> "addition" + , PrimBOp Sub ==> "subtraction" + , PrimBOp SSub ==> "subtraction" + , PrimBOp Mul ==> "multiplication" + , PrimBOp Div ==> "division" + , PrimBOp IDiv ==> "integerdiv" + , PrimBOp Mod ==> "mod" + , PrimBOp Exp ==> "exponentiation" + , PrimUOp Fact ==> "factorial" + , PrimFloor ==> "round" + , PrimCeil ==> "round" + , PrimAbs ==> "abs" + , PrimUOp Not ==> "logic-ops" + , PrimBOp And ==> "logic-ops" + , PrimBOp Or ==> "logic-ops" + , PrimBOp Impl ==> "logic-ops" + , PrimBOp Iff ==> "logic-ops" + , PrimBOp CartProd ==> "cp" + , PrimPower ==> "power" + , PrimBOp Union ==> "set-ops" + , PrimBOp Inter ==> "set-ops" + , PrimBOp Diff ==> "set-ops" + , PrimBOp Eq ==> "compare" + , PrimBOp Neq ==> "compare" + , PrimBOp Lt ==> "compare" + , PrimBOp Gt ==> "compare" + , PrimBOp Leq ==> "compare" + , PrimBOp Geq ==> "compare" + ] otherDoc :: Map String String -otherDoc = M.fromList - [ "N" ==> docN - , "ℕ" ==> docN - , "Nat" ==> docN - , "Natural" ==> docN - , "Z" ==> docZ - , "ℤ" ==> docZ - , "Int" ==> docZ - , "Integer" ==> docZ - , "F" ==> docF - , "𝔽" ==> docF - , "Frac" ==> docF - , "Fractional" ==> docF - , "Q" ==> docQ - , "ℚ" ==> docQ - , "Rational" ==> docQ - , "Bool" ==> docB - , "Boolean" ==> docB - , "Unit" ==> "The unit type, i.e. a type with only a single value." - , "Prop" ==> "The type of propositions." - , "Set" ==> "The type of finite sets." - , "|~|" ==> "Absolute value, or the size of a collection." - , "{?" ==> "{? ... ?} is a case expression, for choosing a result based on conditions." - ] - where - docN = "The type of natural numbers: 0, 1, 2, ..." - docZ = "The type of integers: ..., -2, -1, 0, 1, 2, ..." - docF = "The type of fractional numbers p/q >= 0." - docQ = "The type of rational numbers p/q." - docB = "The type of Booleans (true or false)." +otherDoc = + M.fromList + [ "N" ==> docN + , "ℕ" ==> docN + , "Nat" ==> docN + , "Natural" ==> docN + , "Z" ==> docZ + , "ℤ" ==> docZ + , "Int" ==> docZ + , "Integer" ==> docZ + , "F" ==> docF + , "𝔽" ==> docF + , "Frac" ==> docF + , "Fractional" ==> docF + , "Q" ==> docQ + , "ℚ" ==> docQ + , "Rational" ==> docQ + , "Bool" ==> docB + , "Boolean" ==> docB + , "Unit" ==> "The unit type, i.e. a type with only a single value." + , "Prop" ==> "The type of propositions." + , "Set" ==> "The type of finite sets." + , "|~|" ==> "Absolute value, or the size of a collection." + , "{?" ==> "{? ... ?} is a case expression, for choosing a result based on conditions." + ] + where + docN = "The type of natural numbers: 0, 1, 2, ..." + docZ = "The type of integers: ..., -2, -1, 0, 1, 2, ..." + docF = "The type of fractional numbers p/q >= 0." + docQ = "The type of rational numbers p/q." + docB = "The type of Booleans (true or false)." otherReference :: Map String String -otherReference = M.fromList - [ "N" ==> "natural" - , "ℕ" ==> "natural" - , "Nat" ==> "natural" - , "Natural" ==> "natural" - , "Z" ==> "integer" - , "ℤ" ==> "integer" - , "Int" ==> "integer" - , "Integer" ==> "integer" - , "F" ==> "fraction" - , "𝔽" ==> "fraction" - , "Frac" ==> "fraction" - , "Fractional" ==> "fraction" - , "Q" ==> "rational" - , "ℚ" ==> "rational" - , "Rational" ==> "rational" - , "Bool" ==> "bool" - , "Boolean" ==> "bool" - , "Unit" ==> "unit" - , "Prop" ==> "prop" - , "Set" ==> "set" - , "|~|" ==> "size" - , "{?" ==> "case" - ] +otherReference = + M.fromList + [ "N" ==> "natural" + , "ℕ" ==> "natural" + , "Nat" ==> "natural" + , "Natural" ==> "natural" + , "Z" ==> "integer" + , "ℤ" ==> "integer" + , "Int" ==> "integer" + , "Integer" ==> "integer" + , "F" ==> "fraction" + , "𝔽" ==> "fraction" + , "Frac" ==> "fraction" + , "Fractional" ==> "fraction" + , "Q" ==> "rational" + , "ℚ" ==> "rational" + , "Rational" ==> "rational" + , "Bool" ==> "bool" + , "Boolean" ==> "bool" + , "Unit" ==> "unit" + , "Prop" ==> "prop" + , "Set" ==> "set" + , "|~|" ==> "size" + , "{?" ==> "case" + ] diff --git a/src/Disco/Effects/Counter.hs b/src/Disco/Effects/Counter.hs index 3e7d15ae..716efdc9 100644 --- a/src/Disco/Effects/Counter.hs +++ b/src/Disco/Effects/Counter.hs @@ -1,7 +1,10 @@ -{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE TemplateHaskell #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Effects.Counter -- Copyright : disco team and contributors @@ -10,30 +13,26 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Polysemy effect for integer counter. --- ------------------------------------------------------------------------------ - module Disco.Effects.Counter where -import Polysemy -import Polysemy.State +import Polysemy +import Polysemy.State data Counter m a where - -- | Return the next integer in sequence. - Next :: Counter m Integer + Next :: Counter m Integer makeSem ''Counter -- | Dispatch a counter effect, starting the counter from the given -- Integer. runCounter' :: Integer -> Sem (Counter ': r) a -> Sem r a -runCounter' i - = evalState i - . reinterpret \case +runCounter' i = + evalState i + . reinterpret \case Next -> do n <- get - put (n+1) + put (n + 1) return n -- | Dispatch a counter effect, starting the counter from zero. diff --git a/src/Disco/Effects/Fresh.hs b/src/Disco/Effects/Fresh.hs index b13e3b02..f33b58ef 100644 --- a/src/Disco/Effects/Fresh.hs +++ b/src/Disco/Effects/Fresh.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TemplateHaskell #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Effects.Fresh -- Copyright : disco team and contributors @@ -12,50 +15,47 @@ -- -- Polysemy effect for fresh name generation, compatible with the -- unbound-generics library. --- ------------------------------------------------------------------------------ - module Disco.Effects.Fresh where -import Disco.Effects.Counter -import Disco.Names (QName, localName) -import Polysemy -import Polysemy.ConstraintAbsorber -import qualified Unbound.Generics.LocallyNameless as U -import Unbound.Generics.LocallyNameless.Name +import Disco.Effects.Counter +import Disco.Names (QName, localName) +import Polysemy +import Polysemy.ConstraintAbsorber +import qualified Unbound.Generics.LocallyNameless as U +import Unbound.Generics.LocallyNameless.Name -- | Fresh name generation effect, supporting raw generation of fresh -- names, and opening binders with automatic freshening. Simply -- increments a global counter every time 'fresh' is called and -- makes a variable with that numeric suffix. data Fresh m a where - Fresh :: Name x -> Fresh m (Name x) + Fresh :: Name x -> Fresh m (Name x) makeSem ''Fresh -- | Dispatch the fresh name generation effect, starting at a given -- integer. runFresh' :: Integer -> Sem (Fresh ': r) a -> Sem r a -runFresh' i - = runCounter' i - . reinterpret \case +runFresh' i = + runCounter' i + . reinterpret \case Fresh x -> case x of - Fn s _ -> Fn s <$> next - nm@Bn{} -> return nm - - -- Above code copied from - -- https://hackage.haskell.org/package/unbound-generics-0.4.1/docs/src/Unbound.Generics.LocallyNameless.Fresh.html ; - -- see instance Monad m => Fresh (FreshMT m) . - - -- It turns out to make things much simpler to reimplement the - -- Fresh effect ourselves in terms of a state effect, since then - -- we can immediately dispatch it. The alternative would be to - -- implement it in terms of (Embed U.FreshM), but then we are - -- stuck with that constraint. Given the constraint-absorbing - -- machinery below, just impementing the 'fresh' effect itself - -- means we can then reuse other things from unbound-generics that - -- depend on a Fresh constraint, such as the 'unbind' function - -- below. + Fn s _ -> Fn s <$> next + nm@Bn {} -> return nm + +-- Above code copied from +-- https://hackage.haskell.org/package/unbound-generics-0.4.1/docs/src/Unbound.Generics.LocallyNameless.Fresh.html ; +-- see instance Monad m => Fresh (FreshMT m) . + +-- It turns out to make things much simpler to reimplement the +-- Fresh effect ourselves in terms of a state effect, since then +-- we can immediately dispatch it. The alternative would be to +-- implement it in terms of (Embed U.FreshM), but then we are +-- stuck with that constraint. Given the constraint-absorbing +-- machinery below, just impementing the 'fresh' effect itself +-- means we can then reuse other things from unbound-generics that +-- depend on a Fresh constraint, such as the 'unbind' function +-- below. -- | Run a computation requiring fresh name generation, beginning with -- 0 for the initial freshly generated name. @@ -92,7 +92,7 @@ absorbFresh :: Member Fresh r => (U.Fresh (Sem r) => Sem r a) -> Sem r a absorbFresh = absorbWithSem @U.Fresh @Action (FreshDict fresh) (Sub Dict) {-# INLINEABLE absorbFresh #-} -newtype FreshDict m = FreshDict { fresh_ :: forall x. Name x -> m (Name x) } +newtype FreshDict m = FreshDict {fresh_ :: forall x. Name x -> m (Name x)} -- | Wrapper for a monadic action with phantom type parameter for reflection. -- Locally defined so that the instance we are going to build with reflection @@ -100,8 +100,11 @@ newtype FreshDict m = FreshDict { fresh_ :: forall x. Name x -> m (Name x) } newtype Action m s' a = Action (m a) deriving (Functor, Applicative, Monad) -instance ( Monad m - , Reifies s' (FreshDict m) - ) => U.Fresh (Action m s') where +instance + ( Monad m + , Reifies s' (FreshDict m) + ) => + U.Fresh (Action m s') + where fresh x = Action $ fresh_ (reflect $ Proxy @s') x {-# INLINEABLE fresh #-} diff --git a/src/Disco/Effects/Input.hs b/src/Disco/Effects/Input.hs index c86d097b..8af5b417 100644 --- a/src/Disco/Effects/Input.hs +++ b/src/Disco/Effects/Input.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Effects.Input -- Copyright : disco team and contributors @@ -7,20 +10,16 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Utility functions for input effect. --- ------------------------------------------------------------------------------ - -module Disco.Effects.Input - ( module Polysemy.Input - , inputToState - ) - where +module Disco.Effects.Input ( + module Polysemy.Input, + inputToState, +) +where -import Polysemy -import Polysemy.Input -import Polysemy.State +import Polysemy +import Polysemy.Input +import Polysemy.State -- | Run an input effect in terms of an ambient state effect. inputToState :: forall s r a. Member (State s) r => Sem (Input s ': r) a -> Sem r a -inputToState = interpret (\case { Input -> get @s }) - +inputToState = interpret (\case Input -> get @s) diff --git a/src/Disco/Effects/LFresh.hs b/src/Disco/Effects/LFresh.hs index 396ab8b4..a2f92b0b 100644 --- a/src/Disco/Effects/LFresh.hs +++ b/src/Disco/Effects/LFresh.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TemplateHaskell #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Effects.LFresh -- Copyright : disco team and contributors @@ -12,24 +15,21 @@ -- -- Polysemy effect for local fresh name generation, compatible with -- the unbound-generics library. --- ------------------------------------------------------------------------------ - module Disco.Effects.LFresh where -import Data.Set (Set) -import qualified Data.Set as S -import Data.Typeable (Typeable) -import Polysemy -import Polysemy.ConstraintAbsorber -import Polysemy.Reader -import qualified Unbound.Generics.LocallyNameless as U -import Unbound.Generics.LocallyNameless.Name +import Data.Set (Set) +import qualified Data.Set as S +import Data.Typeable (Typeable) +import Polysemy +import Polysemy.ConstraintAbsorber +import Polysemy.Reader +import qualified Unbound.Generics.LocallyNameless as U +import Unbound.Generics.LocallyNameless.Name -- | Local fresh name generation effect. data LFresh m a where - Lfresh :: Typeable a => Name a -> LFresh m (Name a) - Avoid :: [AnyName] -> m a -> LFresh m a + Lfresh :: Typeable a => Name a -> LFresh m (Name a) + Avoid :: [AnyName] -> m a -> LFresh m a GetAvoids :: LFresh m (Set AnyName) makeSem ''LFresh @@ -40,44 +40,48 @@ runLFresh :: Sem (LFresh ': r) a -> Sem r a runLFresh = runReader S.empty . runLFresh' runLFresh' :: Sem (LFresh ': r) a -> Sem (Reader (Set AnyName) ': r) a -runLFresh' - = reinterpretH @_ @(Reader (Set AnyName)) \case - Lfresh nm -> do - let s = name2String nm - used <- ask - pureT $ head (filter (\x -> not (S.member (AnyName x) used)) - (map (makeName s) [0..])) - Avoid names m -> do - m' <- runT m - raise (subsume (runLFresh' (local (S.union (S.fromList names)) m'))) - GetAvoids -> ask >>= pureT - - -- Much of the above code copied from - -- https://hackage.haskell.org/package/unbound-generics-0.4.1/docs/src/Unbound.Generics.LocallyNameless.LFresh.html - -- (see instance Monad m => LFresh (LFreshMT m)) - - -- It turns out to make things much simpler to reimplement the - -- LFresh effect ourselves in terms of a reader effect, since then - -- we can immediately dispatch it as above. The alternative would - -- be to implement it in terms of (Final U.LFreshM) (see the - -- commented code at the bottom of this file), but then we are stuck - -- with that constraint. Given the constraint-absorbing machinery - -- below, just impementing the 'LFresh' effect itself means we can - -- then reuse other things from unbound-generics that depend on a - -- Fresh constraint, such as the 'lunbind' function below. - - -- NOTE: originally, there was a single function runLFresh which - -- called reinterpretH and then immediately dispatched the Reader - -- (Set AnyName) effect. However, since runLFresh is recursive, - -- this means that the recursive calls were running with a - -- completely *separate* Reader effect that started over from the - -- empty set! This meant that LFresh basically never changed any - -- names, leading to all sorts of name clashes and crashes. - -- - -- Instead, we need to organize things as above: runLFresh' is - -- recursive, and keeps the Reader effect (using 'subsume' to squash - -- the duplicated Reader effects together). Then a top-level - -- runLFresh function finally runs the Reader effect. +runLFresh' = + reinterpretH @_ @(Reader (Set AnyName)) \case + Lfresh nm -> do + let s = name2String nm + used <- ask + pureT $ + head + ( filter + (\x -> not (S.member (AnyName x) used)) + (map (makeName s) [0 ..]) + ) + Avoid names m -> do + m' <- runT m + raise (subsume (runLFresh' (local (S.union (S.fromList names)) m'))) + GetAvoids -> ask >>= pureT + +-- Much of the above code copied from +-- https://hackage.haskell.org/package/unbound-generics-0.4.1/docs/src/Unbound.Generics.LocallyNameless.LFresh.html +-- (see instance Monad m => LFresh (LFreshMT m)) + +-- It turns out to make things much simpler to reimplement the +-- LFresh effect ourselves in terms of a reader effect, since then +-- we can immediately dispatch it as above. The alternative would +-- be to implement it in terms of (Final U.LFreshM) (see the +-- commented code at the bottom of this file), but then we are stuck +-- with that constraint. Given the constraint-absorbing machinery +-- below, just impementing the 'LFresh' effect itself means we can +-- then reuse other things from unbound-generics that depend on a +-- Fresh constraint, such as the 'lunbind' function below. + +-- NOTE: originally, there was a single function runLFresh which +-- called reinterpretH and then immediately dispatched the Reader +-- (Set AnyName) effect. However, since runLFresh is recursive, +-- this means that the recursive calls were running with a +-- completely *separate* Reader effect that started over from the +-- empty set! This meant that LFresh basically never changed any +-- names, leading to all sorts of name clashes and crashes. +-- +-- Instead, we need to organize things as above: runLFresh' is +-- recursive, and keeps the Reader effect (using 'subsume' to squash +-- the duplicated Reader effects together). Then a top-level +-- runLFresh function finally runs the Reader effect. -------------------------------------------------- -- Other functions @@ -86,9 +90,11 @@ runLFresh' -- variables, and providing the opened pattern and term to the -- provided continuation. The bound variables are also added to the -- set of in-scope variables within in the continuation. -lunbind - :: (Member LFresh r, U.Alpha p, U.Alpha t) - => U.Bind p t -> ((p,t) -> Sem r c) -> Sem r c +lunbind :: + (Member LFresh r, U.Alpha p, U.Alpha t) => + U.Bind p t -> + ((p, t) -> Sem r c) -> + Sem r c lunbind b k = absorbLFresh (U.lunbind b k) ------------------------------------------------------------ @@ -101,8 +107,8 @@ absorbLFresh = absorbWithSem @U.LFresh @Action (LFreshDict lfresh avoid getAvoid {-# INLINEABLE absorbLFresh #-} data LFreshDict m = LFreshDict - { lfresh_ :: forall a. Typeable a => Name a -> m (Name a) - , avoid_ :: forall a. [AnyName] -> m a -> m a + { lfresh_ :: forall a. Typeable a => Name a -> m (Name a) + , avoid_ :: forall a. [AnyName] -> m a -> m a , getAvoids_ :: m (Set AnyName) } @@ -112,9 +118,12 @@ data LFreshDict m = LFreshDict newtype Action m s' a = Action (m a) deriving (Functor, Applicative, Monad) -instance ( Monad m - , Reifies s' (LFreshDict m) - ) => U.LFresh (Action m s') where +instance + ( Monad m + , Reifies s' (LFreshDict m) + ) => + U.LFresh (Action m s') + where lfresh x = Action $ lfresh_ (reflect $ Proxy @s') x {-# INLINEABLE lfresh #-} avoid xs (Action m) = Action $ avoid_ (reflect $ Proxy @s') xs m diff --git a/src/Disco/Effects/Random.hs b/src/Disco/Effects/Random.hs index a1713845..37762b41 100644 --- a/src/Disco/Effects/Random.hs +++ b/src/Disco/Effects/Random.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Effects.Random -- Copyright : disco team and contributors @@ -7,22 +10,19 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Utility functions for random effect. --- ------------------------------------------------------------------------------ - -module Disco.Effects.Random - ( module Polysemy.Random - , runGen - ) - where +module Disco.Effects.Random ( + module Polysemy.Random, + runGen, +) +where -import Polysemy -import Polysemy.Random +import Polysemy +import Polysemy.Random import qualified System.Random.SplitMix as SM -import qualified Test.QuickCheck.Gen as QC +import qualified Test.QuickCheck.Gen as QC import qualified Test.QuickCheck.Random as QCR -import Data.Word (Word64) +import Data.Word (Word64) -- | Run a QuickCheck generator using a 'Random' effect. runGen :: Member Random r => QC.Gen a -> Sem r a diff --git a/src/Disco/Effects/State.hs b/src/Disco/Effects/State.hs index f9f043fb..bfb43c1e 100644 --- a/src/Disco/Effects/State.hs +++ b/src/Disco/Effects/State.hs @@ -1,6 +1,9 @@ {-# LANGUAGE BlockArguments #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Effects.State -- Copyright : disco team and contributors @@ -9,25 +12,24 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Utility functions for state effect. --- ------------------------------------------------------------------------------ - -module Disco.Effects.State - ( module Polysemy.State - , zoom - , use - ,(%=),(.=)) - where +module Disco.Effects.State ( + module Polysemy.State, + zoom, + use, + (%=), + (.=), +) +where -import Control.Lens (Getter, Lens', view, (%~), (.~)) +import Control.Lens (Getter, Lens', view, (%~), (.~)) -import Polysemy -import Polysemy.State +import Polysemy +import Polysemy.State -- | Use a lens to zoom into a component of a state. zoom :: forall s a r c. Member (State s) r => Lens' s a -> Sem (State a ': r) c -> Sem r c zoom l = interpret \case - Get -> view l <$> get + Get -> view l <$> get Put a -> modify (l .~ a) use :: Member (State s) r => Getter s a -> Sem r a diff --git a/src/Disco/Effects/Store.hs b/src/Disco/Effects/Store.hs index 0578c699..8bdb91ed 100644 --- a/src/Disco/Effects/Store.hs +++ b/src/Disco/Effects/Store.hs @@ -1,7 +1,10 @@ -{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE TemplateHaskell #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Effects.Store -- Copyright : disco team and contributors @@ -10,44 +13,40 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Polysemy effect for a memory store with integer keys. --- ------------------------------------------------------------------------------ - module Disco.Effects.Store where -import qualified Data.IntMap.Lazy as IntMap -import Data.IntSet (IntSet) -import qualified Data.IntSet as IntSet +import qualified Data.IntMap.Lazy as IntMap +import Data.IntSet (IntSet) +import qualified Data.IntSet as IntSet -import Disco.Effects.Counter -import Polysemy -import Polysemy.State +import Disco.Effects.Counter +import Polysemy +import Polysemy.State data Store v m a where - - ClearStore :: Store v m () - New :: v -> Store v m Int + ClearStore :: Store v m () + New :: v -> Store v m Int LookupStore :: Int -> Store v m (Maybe v) InsertStore :: Int -> v -> Store v m () - MapStore :: (v -> v) -> Store v m () + MapStore :: (v -> v) -> Store v m () AssocsStore :: Store v m [(Int, v)] - KeepKeys :: IntSet -> Store v m () + KeepKeys :: IntSet -> Store v m () makeSem ''Store -- | Dispatch a store effect. runStore :: forall v r a. Sem (Store v ': r) a -> Sem r a -runStore - = runCounter - . evalState @(IntMap.IntMap v) IntMap.empty - . reinterpret2 \case - ClearStore -> put IntMap.empty - New v -> do +runStore = + runCounter + . evalState @(IntMap.IntMap v) IntMap.empty + . reinterpret2 \case + ClearStore -> put IntMap.empty + New v -> do loc <- fromIntegral <$> next modify $ IntMap.insert loc v return loc - LookupStore k -> gets (IntMap.lookup k) + LookupStore k -> gets (IntMap.lookup k) InsertStore k v -> modify (IntMap.insert k v) - MapStore f -> modify (IntMap.map f) - AssocsStore -> gets IntMap.assocs - KeepKeys ks -> modify (\m -> IntMap.withoutKeys m (IntMap.keysSet m `IntSet.difference` ks)) + MapStore f -> modify (IntMap.map f) + AssocsStore -> gets IntMap.assocs + KeepKeys ks -> modify (\m -> IntMap.withoutKeys m (IntMap.keysSet m `IntSet.difference` ks)) diff --git a/src/Disco/Enumerate.hs b/src/Disco/Enumerate.hs index cdde761e..ed096b58 100644 --- a/src/Disco/Enumerate.hs +++ b/src/Disco/Enumerate.hs @@ -1,6 +1,9 @@ {-# LANGUAGE NondecreasingIndentation #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Enumerate -- Copyright : disco team and contributors @@ -9,41 +12,38 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Enumerate values inhabiting Disco types. --- ------------------------------------------------------------------------------ - -module Disco.Enumerate - ( - ValueEnumeration - -- * Base types - , enumVoid - , enumUnit - , enumBool - , enumN - , enumZ - , enumF - , enumQ - , enumC - - -- * Containers - , enumSet - -- , enumBag - , enumList - - -- * Any type - , enumType - , enumTypes - - -- * Lifted functions that return lists - , enumerateType - , enumerateTypes - ) - where +module Disco.Enumerate ( + ValueEnumeration, + + -- * Base types + enumVoid, + enumUnit, + enumBool, + enumN, + enumZ, + enumF, + enumQ, + enumC, + + -- * Containers + enumSet, + -- , enumBag + enumList, + + -- * Any type + enumType, + enumTypes, + + -- * Lifted functions that return lists + enumerateType, + enumerateTypes, +) +where import qualified Data.Enumeration.Invertible as E -import Disco.AST.Generic (Side (..)) -import Disco.Types -import Disco.Value +import Disco.AST.Generic (Side (..)) +import Disco.Types +import Disco.Value type ValueEnumeration = E.IEnumeration Value @@ -58,16 +58,16 @@ enumUnit = E.singleton VUnit -- | Enumerate the values of type @Bool@ as @[false, true]@. enumBool :: ValueEnumeration enumBool = E.mapE toV fromV $ E.finiteList [L, R] - where - toV i = VInj i VUnit - fromV (VInj i VUnit) = i - fromV _ = error "enumBool.fromV: value isn't a bool" + where + toV i = VInj i VUnit + fromV (VInj i VUnit) = i + fromV _ = error "enumBool.fromV: value isn't a bool" -- | Unsafely extract the numeric value of a @Value@ -- (assumed to be a VNum). valToRat :: Value -> Rational valToRat (VNum _ r) = r -valToRat _ = error "valToRat: value isn't a number" +valToRat _ = error "valToRat: value isn't a number" ratToVal :: Rational -> Value ratToVal = VNum mempty @@ -93,9 +93,9 @@ enumQ = E.mapE ratToVal valToRat E.rat -- | Enumerate all Unicode characters. enumC :: ValueEnumeration enumC = E.mapE toV fromV (E.boundedEnum @Char) - where - toV = ratToVal . fromIntegral . fromEnum - fromV = toEnum . floor . valToRat + where + toV = ratToVal . fromIntegral . fromEnum + fromV = toEnum . floor . valToRat -- | Enumerate all *finite* sets over a certain element type, given an -- enumeration of the elements. If we think of each finite set as a @@ -103,21 +103,21 @@ enumC = E.mapE toV fromV (E.boundedEnum @Char) -- members, the sets are enumerated in order of the binary strings. enumSet :: ValueEnumeration -> ValueEnumeration enumSet e = E.mapE toV fromV (E.finiteSubsetOf e) - where - toV = VBag . map (,1) - fromV (VBag vs) = map fst vs - fromV _ = error "enumSet.fromV: value isn't a set" + where + toV = VBag . map (,1) + fromV (VBag vs) = map fst vs + fromV _ = error "enumSet.fromV: value isn't a set" -- | Enumerate all *finite* lists over a certain element type, given -- an enumeration of the elements. It is very difficult to describe -- the order in which the lists are generated. enumList :: ValueEnumeration -> ValueEnumeration enumList e = E.mapE toV fromV (E.listOf e) - where - toV = foldr VCons VNil - fromV (VCons h t) = h : fromV t - fromV VNil = [] - fromV _ = error "enumList.fromV: value isn't a list" + where + toV = foldr VCons VNil + fromV (VCons h t) = h : fromV t + fromV VNil = [] + fromV _ = error "enumList.fromV: value isn't a list" -- | Enumerate all functions from a finite domain, given enumerations -- for the domain and codomain. @@ -127,60 +127,60 @@ enumFunction xs ys = (E.Finite 0, _) -> E.singleton (VFun $ \_ -> error "enumFunction: void function called") (_, E.Finite 0) -> E.void (_, E.Finite 1) -> E.singleton (VFun $ \_ -> E.select ys 0) - _ -> E.mapE toV fromV (E.functionOf xs ys) + _ -> E.mapE toV fromV (E.functionOf xs ys) + where + -- XXX TODO: better error message on functions with an infinite domain - -- XXX TODO: better error message on functions with an infinite domain - where - toV = VFun - fromV (VFun f) = f - fromV _ = error "enumFunction.fromV: value isn't a VFun" + toV = VFun + fromV (VFun f) = f + fromV _ = error "enumFunction.fromV: value isn't a VFun" -- | Enumerate all values of a product type, given enumerations of the -- two component types. Uses a fair interleaving for infinite -- component types. enumProd :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration enumProd xs ys = E.mapE toV fromV $ (E.><) xs ys - where - toV (x, y) = VPair x y - fromV (VPair x y) = (x, y) - fromV _ = error "enumProd.fromV: value isn't a pair" + where + toV (x, y) = VPair x y + fromV (VPair x y) = (x, y) + fromV _ = error "enumProd.fromV: value isn't a pair" -- | Enumerate all values of a sum type, given enumerations of the two -- component types. enumSum :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration enumSum xs ys = E.mapE toV fromV $ (E.<+>) xs ys - where - toV (Left x) = VInj L x - toV (Right y) = VInj R y - fromV (VInj L x) = Left x - fromV (VInj R y) = Right y - fromV _ = error "enumSum.fromV: value isn't a sum" + where + toV (Left x) = VInj L x + toV (Right y) = VInj R y + fromV (VInj L x) = Left x + fromV (VInj R y) = Right y + fromV _ = error "enumSum.fromV: value isn't a sum" -- | Enumerate the values of a given type. enumType :: Type -> ValueEnumeration -enumType TyVoid = enumVoid -enumType TyUnit = enumUnit -enumType TyBool = enumBool -enumType TyN = enumN -enumType TyZ = enumZ -enumType TyF = enumF -enumType TyQ = enumQ -enumType TyC = enumC -enumType (TySet t) = enumSet (enumType t) +enumType TyVoid = enumVoid +enumType TyUnit = enumUnit +enumType TyBool = enumBool +enumType TyN = enumN +enumType TyZ = enumZ +enumType TyF = enumF +enumType TyQ = enumQ +enumType TyC = enumC +enumType (TySet t) = enumSet (enumType t) enumType (TyList t) = enumList (enumType t) -enumType (a :*: b) = enumProd (enumType a) (enumType b) -enumType (a :+: b) = enumSum (enumType a) (enumType b) +enumType (a :*: b) = enumProd (enumType a) (enumType b) +enumType (a :+: b) = enumSum (enumType a) (enumType b) enumType (a :->: b) = enumFunction (enumType a) (enumType b) -enumType ty = error $ "enumType: can't enumerate " ++ show ty +enumType ty = error $ "enumType: can't enumerate " ++ show ty -- | Enumerate a finite product of types. enumTypes :: [Type] -> E.IEnumeration [Value] -enumTypes [] = E.singleton [] -enumTypes (t:ts) = E.mapE toL fromL $ (E.><) (enumType t) (enumTypes ts) - where - toL (x, xs) = x:xs - fromL (x:xs) = (x, xs) - fromL [] = error "enumTypes.fromL: empty list not in enumeration range" +enumTypes [] = E.singleton [] +enumTypes (t : ts) = E.mapE toL fromL $ (E.><) (enumType t) (enumTypes ts) + where + toL (x, xs) = x : xs + fromL (x : xs) = (x, xs) + fromL [] = error "enumTypes.fromL: empty list not in enumeration range" -- | Produce an actual list of the values of a type. enumerateType :: Type -> [Value] diff --git a/src/Disco/Error.hs b/src/Disco/Error.hs index bb353e4c..6282d0eb 100644 --- a/src/Disco/Error.hs +++ b/src/Disco/Error.hs @@ -1,7 +1,10 @@ -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE StandaloneDeriving #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Error -- Copyright : disco team and contributors @@ -11,81 +14,68 @@ -- -- Type for collecting all potential Disco errors at the top level, -- and a type for runtime errors. --- ------------------------------------------------------------------------------ - -module Disco.Error (DiscoError(..), EvalError(..), panic, outputDiscoErrors) where - -import Prelude hiding ((<>)) - -import Text.Megaparsec (ParseErrorBundle, - errorBundlePretty) -import Unbound.Generics.LocallyNameless (Name) - -import Disco.Effects.LFresh -import Polysemy -import Polysemy.Error -import Polysemy.Output -import Polysemy.Reader - -import Disco.Messages -import Disco.Names (ModuleName, QName) -import Disco.Parser (DiscoParseError) -import Disco.Pretty -import Disco.Typecheck.Solve -import Disco.Typecheck.Util (LocTCError (..), - TCError (..)) -import Disco.Types -import Disco.Types.Qualifiers +module Disco.Error (DiscoError (..), EvalError (..), panic, outputDiscoErrors) where + +import Prelude hiding ((<>)) + +import Text.Megaparsec ( + ParseErrorBundle, + errorBundlePretty, + ) +import Unbound.Generics.LocallyNameless (Name) + +import Disco.Effects.LFresh +import Polysemy +import Polysemy.Error +import Polysemy.Output +import Polysemy.Reader + +import Disco.Messages +import Disco.Names (ModuleName, QName) +import Disco.Parser (DiscoParseError) +import Disco.Pretty +import Disco.Typecheck.Solve +import Disco.Typecheck.Util ( + LocTCError (..), + TCError (..), + ) +import Disco.Types +import Disco.Types.Qualifiers -- | Top-level error type for Disco. data DiscoError where - -- | Module not found. ModuleNotFound :: String -> DiscoError - -- | Cyclic import encountered. CyclicImport :: [ModuleName] -> DiscoError - -- | Error encountered during typechecking. TypeCheckErr :: LocTCError -> DiscoError - -- | Error encountered during parsing. ParseErr :: ParseErrorBundle String DiscoParseError -> DiscoError - -- | Error encountered at runtime. EvalErr :: EvalError -> DiscoError - -- | Something that shouldn't happen; indicates the presence of a -- bug. - Panic :: String -> DiscoError - - deriving Show + Panic :: String -> DiscoError + deriving (Show) -- | Errors that can be generated at runtime. data EvalError where - -- | An unbound name was encountered. - UnboundError :: QName core -> EvalError - + UnboundError :: QName core -> EvalError -- | An unbound name that really shouldn't happen, coming from some -- kind of internal name generation scheme. - UnboundPanic :: Name core -> EvalError - + UnboundPanic :: Name core -> EvalError -- | Division by zero. - DivByZero :: EvalError - + DivByZero :: EvalError -- | Overflow, e.g. (2^66)! - Overflow :: EvalError - + Overflow :: EvalError -- | Non-exhaustive case analysis. - NonExhaustive :: EvalError - + NonExhaustive :: EvalError -- | Infinite loop detected via black hole. - InfiniteLoop :: EvalError - + InfiniteLoop :: EvalError -- | User-generated crash. - Crash :: String -> EvalError + Crash :: String -> EvalError deriving instance Show EvalError @@ -99,17 +89,17 @@ outputDiscoErrors m = do instance Pretty DiscoError where pretty = \case - ModuleNotFound m -> "Error: couldn't find a module named '" <> text m <> "'." - CyclicImport ms -> cyclicImportError ms + ModuleNotFound m -> "Error: couldn't find a module named '" <> text m <> "'." + CyclicImport ms -> cyclicImportError ms TypeCheckErr (LocTCError Nothing te) -> prettyTCError te TypeCheckErr (LocTCError (Just n) te) -> vcat [ "While checking " <> pretty' n <> ":" , nest 2 $ prettyTCError te ] - ParseErr pe -> text (errorBundlePretty pe) - EvalErr ee -> prettyEvalError ee - Panic s -> + ParseErr pe -> text (errorBundlePretty pe) + EvalErr ee -> prettyEvalError ee + Panic s -> vcat [ "Bug! " <> text s , "Please report this as a bug at https://github.com/disco-lang/disco/issues/ ." @@ -121,9 +111,10 @@ rtd page = "https://disco-lang.readthedocs.io/en/latest/reference/" <> text page issue :: Int -> Sem r Doc issue n = "See https://github.com/disco-lang/disco/issues/" <> text (show n) -cyclicImportError - :: Members '[Reader PA, LFresh] r - => [ModuleName] -> Sem r Doc +cyclicImportError :: + Members '[Reader PA, LFresh] r => + [ModuleName] -> + Sem r Doc cyclicImportError ms = vcat [ "Error: module imports form a cycle:" @@ -132,16 +123,15 @@ cyclicImportError ms = prettyEvalError :: Members '[Reader PA, LFresh] r => EvalError -> Sem r Doc prettyEvalError = \case - UnboundPanic x -> - ("Bug! No variable found named" <+> pretty' x <> ".") - $+$ - "Please report this as a bug at https://github.com/disco-lang/disco/issues/ ." - UnboundError x -> "Error: encountered undefined name" <+> pretty' x <> ". Maybe you haven't defined it yet?" - DivByZero -> "Error: division by zero." - Overflow -> "Error: that number would not even fit in the universe!" - NonExhaustive -> "Error: value did not match any of the branches in a case expression." - InfiniteLoop -> "Error: infinite loop detected!" - Crash s -> "User crash:" <+> text s + UnboundPanic x -> + ("Bug! No variable found named" <+> pretty' x <> ".") + $+$ "Please report this as a bug at https://github.com/disco-lang/disco/issues/ ." + UnboundError x -> "Error: encountered undefined name" <+> pretty' x <> ". Maybe you haven't defined it yet?" + DivByZero -> "Error: division by zero." + Overflow -> "Error: that number would not even fit in the universe!" + NonExhaustive -> "Error: value did not match any of the branches in a case expression." + InfiniteLoop -> "Error: infinite loop detected!" + Crash s -> "User crash:" <+> text s -- [X] Step 1: nice error messages, make sure all are tested -- [ ] Step 2: link to wiki/website with more info on errors! @@ -150,175 +140,176 @@ prettyEvalError = \case -- [ ] Step 5: save parse locations, display with errors prettyTCError :: Members '[Reader PA, LFresh] r => TCError -> Sem r Doc prettyTCError = \case - -- XXX include some potential misspellings along with Unbound -- see https://github.com/disco-lang/disco/issues/180 - Unbound x -> vcat - [ "Error: there is nothing named" <+> pretty' x <> "." - , rtd "unbound" - ] - - Ambiguous x ms -> vcat - [ "Error: the name" <+> pretty' x <+> "is ambiguous. It could refer to:" - , nest 2 (vcat . map (\m -> pretty' m <> "." <> pretty' x) $ ms) - , rtd "ambiguous" - ] - - NoType x -> vcat - [ "Error: the definition of" <+> pretty' x <+> "must have an accompanying type signature." - , "Try writing something like '" <> pretty' x <+> ": Int' (or whatever the type of" - <+> pretty' x <+> "should be) first." - , rtd "missingtype" - ] - - NotCon c t ty -> vcat - [ "Error: the expression" - , nest 2 $ pretty' t - , "must have both a" <+> conWord c <+> "type and also the incompatible type" - , nest 2 $ pretty' ty <> "." - , rtd "notcon" - ] - - EmptyCase -> vcat - [ "Error: empty case expressions {? ?} are not allowed." - , rtd "empty-case" - ] - - PatternType c pat ty -> vcat - [ "Error: the pattern" - , nest 2 $ pretty' pat - , "is supposed to have type" - , nest 2 $ pretty' ty <> "," - , "but instead it has a" <+> conWord c <+> "type." - , rtd "pattern-type" - ] - - DuplicateDecls x -> vcat - [ "Error: duplicate type signature for" <+> pretty' x <> "." - , rtd "dup-sig" - ] - - DuplicateDefns x -> vcat - [ "Error: duplicate definition for" <+> pretty' x <> "." - , rtd "dup-def" - ] - - DuplicateTyDefns s -> vcat - [ "Error: duplicate definition for type" <+> text s <> "." - , rtd "dup-tydef" - ] - + Unbound x -> + vcat + [ "Error: there is nothing named" <+> pretty' x <> "." + , rtd "unbound" + ] + Ambiguous x ms -> + vcat + [ "Error: the name" <+> pretty' x <+> "is ambiguous. It could refer to:" + , nest 2 (vcat . map (\m -> pretty' m <> "." <> pretty' x) $ ms) + , rtd "ambiguous" + ] + NoType x -> + vcat + [ "Error: the definition of" <+> pretty' x <+> "must have an accompanying type signature." + , "Try writing something like '" + <> pretty' x + <+> ": Int' (or whatever the type of" + <+> pretty' x + <+> "should be) first." + , rtd "missingtype" + ] + NotCon c t ty -> + vcat + [ "Error: the expression" + , nest 2 $ pretty' t + , "must have both a" <+> conWord c <+> "type and also the incompatible type" + , nest 2 $ pretty' ty <> "." + , rtd "notcon" + ] + EmptyCase -> + vcat + [ "Error: empty case expressions {? ?} are not allowed." + , rtd "empty-case" + ] + PatternType c pat ty -> + vcat + [ "Error: the pattern" + , nest 2 $ pretty' pat + , "is supposed to have type" + , nest 2 $ pretty' ty <> "," + , "but instead it has a" <+> conWord c <+> "type." + , rtd "pattern-type" + ] + DuplicateDecls x -> + vcat + [ "Error: duplicate type signature for" <+> pretty' x <> "." + , rtd "dup-sig" + ] + DuplicateDefns x -> + vcat + [ "Error: duplicate definition for" <+> pretty' x <> "." + , rtd "dup-def" + ] + DuplicateTyDefns s -> + vcat + [ "Error: duplicate definition for type" <+> text s <> "." + , rtd "dup-tydef" + ] -- XXX include all types involved in the cycle. - CyclicTyDef s -> vcat - [ "Error: cyclic type definition for" <+> text s <> "." - , rtd "cyc-ty" - ] - + CyclicTyDef s -> + vcat + [ "Error: cyclic type definition for" <+> text s <> "." + , rtd "cyc-ty" + ] -- XXX lots more info! & Split into several different errors. - NumPatterns -> vcat - [ "Error: number of arguments does not match." - , rtd "num-args" - ] - - NonlinearPattern p x -> vcat - [ "Error: pattern" <+> pretty' p <+> "contains duplicate variable" <+> pretty' x <> "." - , rtd "nonlinear" - ] - + NumPatterns -> + vcat + [ "Error: number of arguments does not match." + , rtd "num-args" + ] + NonlinearPattern p x -> + vcat + [ "Error: pattern" <+> pretty' p <+> "contains duplicate variable" <+> pretty' x <> "." + , rtd "nonlinear" + ] NoSearch ty -> vcat - [ "Error: the type" - , nest 2 $ pretty' ty - , "is not searchable (i.e. it cannot be used in a forall)." - , rtd "no-search" - ] - + [ "Error: the type" + , nest 2 $ pretty' ty + , "is not searchable (i.e. it cannot be used in a forall)." + , rtd "no-search" + ] Unsolvable solveErr -> prettySolveError solveErr - -- XXX maybe include close edit-distance alternatives? - NotTyDef s -> vcat - [ "Error: there is no built-in or user-defined type named '" <> text s <> "'." - , rtd "no-tydef" - ] - - NoTWild -> vcat - [ "Error: wildcards (_) are not allowed in expressions." - , rtd "wildcard-expr" - ] - + NotTyDef s -> + vcat + [ "Error: there is no built-in or user-defined type named '" <> text s <> "'." + , rtd "no-tydef" + ] + NoTWild -> + vcat + [ "Error: wildcards (_) are not allowed in expressions." + , rtd "wildcard-expr" + ] -- XXX say how many are expected, how many there were, what the actual arguments were? -- XXX distinguish between built-in and user-supplied type constructors in the error -- message? - NotEnoughArgs con -> vcat - [ "Error: not enough arguments for the type '" <> pretty' con <> "'." - , rtd "num-args-type" - ] - - TooManyArgs con -> vcat - [ "Error: too many arguments for the type '" <> pretty' con <> "'." - , rtd "num-args-type" - ] - + NotEnoughArgs con -> + vcat + [ "Error: not enough arguments for the type '" <> pretty' con <> "'." + , rtd "num-args-type" + ] + TooManyArgs con -> + vcat + [ "Error: too many arguments for the type '" <> pretty' con <> "'." + , rtd "num-args-type" + ] -- XXX Mention the definition in which it was found, suggest adding the variable -- as a parameter - UnboundTyVar v -> vcat - [ "Error: Unknown type variable '" <> pretty' v <> "'." - , rtd "unbound-tyvar" - ] - - NoPolyRec s ss tys -> vcat - [ "Error: in the definition of " <> text s <> parens (intercalate "," (map text ss)) <> ": recursive occurrences of" <+> text s <+> "may only have type variables as arguments." - , nest 2 ( - text s <> parens (intercalate "," (map pretty' tys)) <+> "does not follow this rule." - ) - , rtd "no-poly-rec" - ] - + UnboundTyVar v -> + vcat + [ "Error: Unknown type variable '" <> pretty' v <> "'." + , rtd "unbound-tyvar" + ] + NoPolyRec s ss tys -> + vcat + [ "Error: in the definition of " <> text s <> parens (intercalate "," (map text ss)) <> ": recursive occurrences of" <+> text s <+> "may only have type variables as arguments." + , nest + 2 + ( text s <> parens (intercalate "," (map pretty' tys)) <+> "does not follow this rule." + ) + , rtd "no-poly-rec" + ] NoError -> empty conWord :: Con -> Sem r Doc conWord = \case - CArr -> "function" - CProd -> "pair" - CSum -> "sum" - CSet -> "set" - CBag -> "bag" - CList -> "list" + CArr -> "function" + CProd -> "pair" + CSum -> "sum" + CSet -> "set" + CBag -> "bag" + CList -> "list" CContainer _ -> "container" - CMap -> "map" - CGraph -> "graph" - CUser s -> text s + CMap -> "map" + CGraph -> "graph" + CUser s -> text s prettySolveError :: Members '[Reader PA, LFresh] r => SolveError -> Sem r Doc prettySolveError = \case - -- XXX say which types! - NoWeakUnifier -> vcat - [ "Error: the shape of two types does not match." - , rtd "shape-mismatch" - ] - + NoWeakUnifier -> + vcat + [ "Error: the shape of two types does not match." + , rtd "shape-mismatch" + ] -- XXX say more! XXX HIGHEST PRIORITY! - NoUnify -> vcat - [ "Error: typechecking failed." - , rtd "typecheck-fail" - ] - - UnqualBase q b -> vcat - [ "Error: values of type" <+> pretty' b <+> qualPhrase False q <> "." - , rtd "not-qual" - ] - - Unqual q ty -> vcat - [ "Error: values of type" <+> pretty' ty <+> qualPhrase False q <> "." - , rtd "not-qual" - ] - - QualSkolem q a -> vcat - [ "Error: type variable" <+> pretty' a <+> "represents any type, so we cannot assume values of that type" - , nest 2 (qualPhrase True q) <> "." - , rtd "qual-skolem" - ] + NoUnify -> + vcat + [ "Error: typechecking failed." + , rtd "typecheck-fail" + ] + UnqualBase q b -> + vcat + [ "Error: values of type" <+> pretty' b <+> qualPhrase False q <> "." + , rtd "not-qual" + ] + Unqual q ty -> + vcat + [ "Error: values of type" <+> pretty' ty <+> qualPhrase False q <> "." + , rtd "not-qual" + ] + QualSkolem q a -> + vcat + [ "Error: type variable" <+> pretty' a <+> "represents any type, so we cannot assume values of that type" + , nest 2 (qualPhrase True q) <> "." + , rtd "qual-skolem" + ] qualPhrase :: Bool -> Qualifier -> Sem r Doc qualPhrase b q @@ -327,12 +318,11 @@ qualPhrase b q qualAction :: Qualifier -> Sem r Doc qualAction = \case - QNum -> "added and multiplied" - QSub -> "subtracted" - QDiv -> "divided" - QCmp -> "compared" - QEnum -> "enumerated" - QBool -> "boolean" - QBasic -> "basic" + QNum -> "added and multiplied" + QSub -> "subtracted" + QDiv -> "divided" + QCmp -> "compared" + QEnum -> "enumerated" + QBool -> "boolean" + QBasic -> "basic" QSimple -> "simple" - diff --git a/src/Disco/Eval.hs b/src/Disco/Eval.hs index d1da9006..b246dbd0 100644 --- a/src/Disco/Eval.hs +++ b/src/Disco/Eval.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeSynonymInstances #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Eval -- Copyright : disco team and contributors @@ -11,86 +14,91 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Top-level evaluation utilities. ------------------------------------------------------------------------------ - -module Disco.Eval - ( - -- * Effects - - EvalEffects - , DiscoEffects - - -- * Top-level info record and associated lenses - - , DiscoConfig, initDiscoConfig, debugMode - , TopInfo - , replModInfo, topEnv, topModMap, lastFile, discoConfig - - -- * Running things - - , runDisco - , runTCM - , inputTopEnv - , parseDiscoModule - , typecheckTop - - -- * Loading modules - - , loadDiscoModule - , loadParsedDiscoModule - , loadFile - , addToREPLModule - , setREPLModule - , loadDefsFrom - , loadDef - - ) - where - -import Control.Arrow ((&&&)) -import Control.Exception (SomeException, handle) -import Control.Lens (makeLenses, toListOf, view, (%~), - (.~), (<>~), (^.)) -import Control.Monad (unless, void, when) -import Control.Monad.IO.Class (liftIO) -import Data.Bifunctor -import Data.Map (Map) -import qualified Data.Map as M -import qualified Data.Set as S -import Prelude -import System.FilePath ((-<.>)) +module Disco.Eval ( + -- * Effects + EvalEffects, + DiscoEffects, + + -- * Top-level info record and associated lenses + DiscoConfig, + initDiscoConfig, + debugMode, + TopInfo, + replModInfo, + topEnv, + topModMap, + lastFile, + discoConfig, + + -- * Running things + runDisco, + runTCM, + inputTopEnv, + parseDiscoModule, + typecheckTop, + + -- * Loading modules + loadDiscoModule, + loadParsedDiscoModule, + loadFile, + addToREPLModule, + setREPLModule, + loadDefsFrom, + loadDef, +) +where + +import Control.Arrow ((&&&)) +import Control.Exception (SomeException, handle) +import Control.Lens ( + makeLenses, + toListOf, + view, + (%~), + (.~), + (<>~), + (^.), + ) +import Control.Monad (unless, void, when) +import Control.Monad.IO.Class (liftIO) +import Data.Bifunctor +import Data.Map (Map) +import qualified Data.Map as M +import qualified Data.Set as S +import System.FilePath ((-<.>)) +import Prelude import qualified System.Console.Haskeline as H -import Disco.Effects.Fresh -import Disco.Effects.Input -import Disco.Effects.LFresh -import Disco.Effects.State -import Polysemy -import Polysemy.Embed -import Polysemy.Error -import Polysemy.Fail -import Polysemy.Output -import Polysemy.Random -import Polysemy.Reader - -import Disco.AST.Core -import Disco.AST.Surface -import Disco.Compile (compileDefns) -import Disco.Context as Ctx -import Disco.Error -import Disco.Extensions -import Disco.Interpret.CESK -import Disco.Messages -import Disco.Module -import Disco.Names -import Disco.Parser -import Disco.Pretty hiding ((<>)) -import qualified Disco.Pretty as Pretty -import Disco.Typecheck (checkModule) -import Disco.Typecheck.Util -import Disco.Types -import Disco.Value +import Disco.Effects.Fresh +import Disco.Effects.Input +import Disco.Effects.LFresh +import Disco.Effects.State +import Polysemy +import Polysemy.Embed +import Polysemy.Error +import Polysemy.Fail +import Polysemy.Output +import Polysemy.Random +import Polysemy.Reader + +import Disco.AST.Core +import Disco.AST.Surface +import Disco.Compile (compileDefns) +import Disco.Context as Ctx +import Disco.Error +import Disco.Extensions +import Disco.Interpret.CESK +import Disco.Messages +import Disco.Module +import Disco.Names +import Disco.Parser +import Disco.Pretty hiding ((<>)) +import qualified Disco.Pretty as Pretty +import Disco.Typecheck (checkModule) +import Disco.Typecheck.Util +import Disco.Types +import Disco.Value ------------------------------------------------------------ -- Configuation options @@ -103,9 +111,10 @@ data DiscoConfig = DiscoConfig makeLenses ''DiscoConfig initDiscoConfig :: DiscoConfig -initDiscoConfig = DiscoConfig - { _debugMode = False - } +initDiscoConfig = + DiscoConfig + { _debugMode = False + } ------------------------------------------------------------ -- Top level info record @@ -114,32 +123,29 @@ initDiscoConfig = DiscoConfig -- | A record of information about the current top-level environment. data TopInfo = TopInfo { _replModInfo :: ModuleInfo - -- ^ Info about the top-level module collecting stuff entered at - -- the REPL. - - , _topEnv :: Env - -- ^ Top-level environment mapping names to values. Set by - -- 'loadDefs'. - - , _topModMap :: Map ModuleName ModuleInfo - -- ^ Mapping from loaded module names to their 'ModuleInfo' - -- records. - - , _lastFile :: Maybe FilePath - -- ^ The most recent file which was :loaded by the user. - + -- ^ Info about the top-level module collecting stuff entered at + -- the REPL. + , _topEnv :: Env + -- ^ Top-level environment mapping names to values. Set by + -- 'loadDefs'. + , _topModMap :: Map ModuleName ModuleInfo + -- ^ Mapping from loaded module names to their 'ModuleInfo' + -- records. + , _lastFile :: Maybe FilePath + -- ^ The most recent file which was :loaded by the user. , _discoConfig :: DiscoConfig } -- | The initial (empty) record of top-level info. initTopInfo :: DiscoConfig -> TopInfo -initTopInfo cfg = TopInfo - { _replModInfo = emptyModuleInfo - , _topEnv = emptyCtx - , _topModMap = M.empty - , _lastFile = Nothing - , _discoConfig = cfg - } +initTopInfo cfg = + TopInfo + { _replModInfo = emptyModuleInfo + , _topEnv = emptyCtx + , _topModMap = M.empty + , _lastFile = Nothing + , _discoConfig = cfg + } makeLenses ''TopInfo @@ -162,8 +168,9 @@ type TopEffects = '[Error DiscoError, State TopInfo, Output Message, Embed IO, F -- | Effects needed for evaluation. type EvalEffects = [Error EvalError, Random, LFresh, Output Message, State Mem] - -- XXX write about order. - -- memory, counter etc. should not be reset by errors. + +-- XXX write about order. +-- memory, counter etc. should not be reset by errors. -- | All effects needed for the top level + evaluation. type DiscoEffects = AppendEffects EvalEffects TopEffects @@ -188,21 +195,21 @@ runDisco cfg m = . runFinal @(H.InputT IO) . embedToFinal . runEmbedded @_ @(H.InputT IO) liftIO - . runOutputSem (handleMsg msgFilter) -- Handle Output Message via printing to console - . stateToIO (initTopInfo cfg) -- Run State TopInfo via an IORef - . inputToState -- Dispatch Input TopInfo effect via State effect - . runState emptyMem -- Start with empty memory - . outputDiscoErrors -- Output any top-level errors - . runLFresh -- Generate locally fresh names - . runRandomIO -- Generate randomness via IO - . mapError EvalErr -- Embed runtime errors into top-level error type - . failToError Panic -- Turn pattern-match failures into a Panic error - . runReader emptyCtx -- Keep track of current Env + . runOutputSem (handleMsg msgFilter) -- Handle Output Message via printing to console + . stateToIO (initTopInfo cfg) -- Run State TopInfo via an IORef + . inputToState -- Dispatch Input TopInfo effect via State effect + . runState emptyMem -- Start with empty memory + . outputDiscoErrors -- Output any top-level errors + . runLFresh -- Generate locally fresh names + . runRandomIO -- Generate randomness via IO + . mapError EvalErr -- Embed runtime errors into top-level error type + . failToError Panic -- Turn pattern-match failures into a Panic error + . runReader emptyCtx -- Keep track of current Env $ m - where - msgFilter - | cfg ^. debugMode = const True - | otherwise = (/= Debug) . view messageType + where + msgFilter + | cfg ^. debugMode = const True + | otherwise = (/= Debug) . view messageType ------------------------------------------------------------ -- Environment utilities @@ -252,7 +259,7 @@ runTCM tyCtx tyDefCtx = -- | A variant of 'runTCM' that requires only a 'TCError' instead -- of a 'LocTCError'. -runTCM' :: +runTCM' :: Member (Error DiscoError) r => TyCtx -> TyDefCtx -> @@ -266,12 +273,12 @@ runTCM' tyCtx tyDefCtx = -- | Run a typechecking computation in the context of the top-level -- REPL module, re-throwing a wrapped error if it fails. -typecheckTop - :: Members '[Input TopInfo, Error DiscoError] r - => Sem (Reader TyCtx ': Reader TyDefCtx ': Fresh ': Error TCError ': r) a - -> Sem r a +typecheckTop :: + Members '[Input TopInfo, Error DiscoError] r => + Sem (Reader TyCtx ': Reader TyDefCtx ': Fresh ': Error TCError ': r) a -> + Sem r a typecheckTop tcm = do - tyctx <- inputs (view (replModInfo . miTys)) + tyctx <- inputs (view (replModInfo . miTys)) imptyctx <- inputs (toListOf (replModInfo . miImports . traverse . miTys)) tydefs <- inputs (view (replModInfo . miTydefs)) imptydefs <- inputs (toListOf (replModInfo . miImports . traverse . miTydefs)) @@ -288,9 +295,12 @@ typecheckTop tcm = do -- -- The 'Resolver' argument specifies where to look for imported -- modules. -loadDiscoModule - :: Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r - => Bool -> Resolver -> FilePath -> Sem r ModuleInfo +loadDiscoModule :: + Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r => + Bool -> + Resolver -> + FilePath -> + Sem r ModuleInfo loadDiscoModule quiet resolver = loadDiscoModule' quiet resolver [] @@ -299,9 +309,13 @@ loadDiscoModule quiet resolver = -- a context that includes the current top-level context (unlike a -- module loaded from disk). Used for e.g. blocks/modules entered -- at the REPL prompt. -loadParsedDiscoModule - :: Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r - => Bool -> Resolver -> ModuleName -> Module -> Sem r ModuleInfo +loadParsedDiscoModule :: + Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r => + Bool -> + Resolver -> + ModuleName -> + Module -> + Sem r ModuleInfo loadParsedDiscoModule quiet resolver = loadParsedDiscoModule' quiet REPL resolver [] @@ -309,15 +323,19 @@ loadParsedDiscoModule quiet resolver = -- Map from module names to 'ModuleInfo' records, to avoid loading -- any imported module more than once. Resolve the module, load and -- parse it, then call 'loadParsedDiscoModule''. -loadDiscoModule' - :: Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r - => Bool -> Resolver -> [ModuleName] -> FilePath - -> Sem r ModuleInfo -loadDiscoModule' quiet resolver inProcess modPath = do - (resolvedPath, prov) <- resolveModule resolver modPath - >>= maybe (throw $ ModuleNotFound modPath) return +loadDiscoModule' :: + Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r => + Bool -> + Resolver -> + [ModuleName] -> + FilePath -> + Sem r ModuleInfo +loadDiscoModule' quiet resolver inProcess modPath = do + (resolvedPath, prov) <- + resolveModule resolver modPath + >>= maybe (throw $ ModuleNotFound modPath) return let name = Named prov modPath - when (name `elem` inProcess) (throw $ CyclicImport (name:inProcess)) + when (name `elem` inProcess) (throw $ CyclicImport (name : inProcess)) modMap <- use @TopInfo topModMap case M.lookup name modMap of Just mi -> return mi @@ -337,34 +355,39 @@ stdLib = ["list", "container"] -- it in the context of the top-level type context iff the -- 'LoadingMode' parameter is 'REPL'. Recursively load all its -- imports, then typecheck it. -loadParsedDiscoModule' - :: Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r - => Bool -> LoadingMode -> Resolver -> [ModuleName] -> ModuleName -> Module -> Sem r ModuleInfo +loadParsedDiscoModule' :: + Members '[State TopInfo, Output Message, Random, State Mem, Error DiscoError, Embed IO] r => + Bool -> + LoadingMode -> + Resolver -> + [ModuleName] -> + ModuleName -> + Module -> + Sem r ModuleInfo loadParsedDiscoModule' quiet mode resolver inProcess name cm@(Module _ mns _ _ _) = do - -- Recursively load any modules imported by this one, plus standard -- library modules (unless NoStdLib is enabled), and build a map with the results. mis <- mapM (loadDiscoModule' quiet (withStdlib resolver) inProcess) mns stdmis <- case NoStdLib `S.member` modExts cm of - True -> return [] + True -> return [] False -> mapM (loadDiscoModule' True FromStdlib inProcess) stdLib let modImps = M.fromList (map (view miName &&& id) (mis ++ stdmis)) -- Get context and type definitions from the REPL, in case we are in REPL mode. topImports <- use (replModInfo . miImports) - topTyCtx <- use (replModInfo . miTys) + topTyCtx <- use (replModInfo . miTys) topTyDefns <- use (replModInfo . miTydefs) -- Choose the contexts to use based on mode: if we are loading a -- standalone module, we should start it in an empty context. If we -- are loading something entered at the REPL, we need to include any -- existing top-level REPL context. - let importMap = case mode of { Standalone -> modImps; REPL -> topImports <> modImps } - tyctx = case mode of { Standalone -> emptyCtx ; REPL -> topTyCtx } - tydefns = case mode of { Standalone -> M.empty ; REPL -> topTyDefns } + let importMap = case mode of Standalone -> modImps; REPL -> topImports <> modImps + tyctx = case mode of Standalone -> emptyCtx; REPL -> topTyCtx + tydefns = case mode of Standalone -> M.empty; REPL -> topTyDefns -- Typecheck (and resolve names in) the module. - m <- runTCM tyctx tydefns $ checkModule name importMap cm + m <- runTCM tyctx tydefns $ checkModule name importMap cm -- Evaluate all the module definitions and add them to the topEnv. mapError EvalErr $ loadDefsFrom m @@ -379,23 +402,25 @@ loadFile :: Members '[Output Message, Embed IO] r => FilePath -> Sem r (Maybe St loadFile file = do res <- liftIO $ handle @SomeException (return . Left) (Right <$> readFile file) case res of - Left _ -> info ("File not found:" <+> text file) >> return Nothing + Left _ -> info ("File not found:" <+> text file) >> return Nothing Right s -> return (Just s) -- | Add things from the given module to the set of currently loaded -- things. -addToREPLModule - :: Members '[Error DiscoError, State TopInfo, Random, State Mem, Output Message] r - => ModuleInfo -> Sem r () +addToREPLModule :: + Members '[Error DiscoError, State TopInfo, Random, State Mem, Output Message] r => + ModuleInfo -> + Sem r () addToREPLModule mi = modify @TopInfo (replModInfo <>~ mi) -- | Set the given 'ModuleInfo' record as the currently loaded -- module. This also includes updating the top-level state with new -- term definitions, documentation, types, and type definitions. -- Replaces any previously loaded module. -setREPLModule - :: Members '[State TopInfo, Random, Error EvalError, State Mem, Output Message] r - => ModuleInfo -> Sem r () +setREPLModule :: + Members '[State TopInfo, Random, Error EvalError, State Mem, Output Message] r => + ModuleInfo -> + Sem r () setREPLModule mi = do modify @TopInfo $ replModInfo .~ mi @@ -408,7 +433,6 @@ loadDefsFrom :: ModuleInfo -> Sem r () loadDefsFrom mi = do - -- Note that the compiled definitions we get back from compileDefns -- are topologically sorted by mutually recursive group. Each -- definition needs to be evaluated in an environment containing the @@ -418,7 +442,9 @@ loadDefsFrom mi = do loadDef :: Members '[State TopInfo, Random, Error EvalError, State Mem] r => - QName Core -> Core -> Sem r () + QName Core -> + Core -> + Sem r () loadDef x body = do v <- inputToState @TopInfo . inputTopEnv $ eval body modify @TopInfo $ topEnv %~ Ctx.insert x v diff --git a/src/Disco/Extensions.hs b/src/Disco/Extensions.hs index 1b450e74..80c33424 100644 --- a/src/Disco/Extensions.hs +++ b/src/Disco/Extensions.hs @@ -6,17 +6,17 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Optional extensions to the disco language. -module Disco.Extensions - ( Ext (..), - ExtSet, - defaultExts, - allExts, - allExtsList, - addExtension, - ) +module Disco.Extensions ( + Ext (..), + ExtSet, + defaultExts, + allExts, + allExtsList, + addExtension, +) where -import Data.Set (Set) +import Data.Set (Set) import qualified Data.Set as S type ExtSet = Set Ext diff --git a/src/Disco/Interactive/CmdLine.hs b/src/Disco/Interactive/CmdLine.hs index 666880cc..18842482 100644 --- a/src/Disco/Interactive/CmdLine.hs +++ b/src/Disco/Interactive/CmdLine.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Interactive.CmdLine -- Copyright : disco team and contributors @@ -7,53 +10,52 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Definition of the command-line REPL interface for Disco. --- ------------------------------------------------------------------------------ - -module Disco.Interactive.CmdLine - ( -- * Command-line options record - - DiscoOpts(..) - - -- * optparse-applicative command line parsers - , discoOpts, discoInfo - - -- * main - - , discoMain - - ) where - -import Data.Version (showVersion) -import Paths_disco (version) - -import Control.Lens hiding (use) -import Control.Monad (unless, when) -import qualified Control.Monad.Catch as CMC -import Control.Monad.IO.Class (MonadIO (..)) -import Data.Foldable (forM_) -import Data.List (isPrefixOf) -import Data.Maybe (isJust) -import System.Exit (exitFailure, - exitSuccess) - -import qualified Options.Applicative as O -import System.Console.Haskeline as H - -import Disco.AST.Surface (emptyModule) -import Disco.Error -import Disco.Eval -import Disco.Interactive.Commands -import Disco.Messages -import Disco.Module (Resolver (FromStdlib), - miExts) -import Disco.Names (ModuleName (REPLModule)) -import Disco.Pretty - -import Disco.Effects.State -import Polysemy -import Polysemy.ConstraintAbsorber.MonadCatch -import Polysemy.Error +module Disco.Interactive.CmdLine ( + -- * Command-line options record + DiscoOpts (..), + + -- * optparse-applicative command line parsers + discoOpts, + discoInfo, + + -- * main + discoMain, +) where + +import Data.Version (showVersion) +import Paths_disco (version) + +import Control.Lens hiding (use) +import Control.Monad (unless, when) +import qualified Control.Monad.Catch as CMC +import Control.Monad.IO.Class (MonadIO (..)) +import Data.Foldable (forM_) +import Data.List (isPrefixOf) +import Data.Maybe (isJust) +import System.Exit ( + exitFailure, + exitSuccess, + ) + +import qualified Options.Applicative as O +import System.Console.Haskeline as H + +import Disco.AST.Surface (emptyModule) +import Disco.Error +import Disco.Eval +import Disco.Interactive.Commands +import Disco.Messages +import Disco.Module ( + Resolver (FromStdlib), + miExts, + ) +import Disco.Names (ModuleName (REPLModule)) +import Disco.Pretty + +import Disco.Effects.State +import Polysemy +import Polysemy.ConstraintAbsorber.MonadCatch +import Polysemy.Error ------------------------------------------------------------ -- Command-line options parser @@ -61,63 +63,75 @@ import Polysemy.Error -- | Command-line options for disco. data DiscoOpts = DiscoOpts - { onlyVersion :: Bool -- ^ Should we just print the version? - , evaluate :: Maybe String -- ^ A single expression to evaluate - , cmdFile :: Maybe String -- ^ Execute the commands in a given file - , checkFile :: Maybe String -- ^ Check a file and then exit - , debugFlag :: Bool + { onlyVersion :: Bool + -- ^ Should we just print the version? + , evaluate :: Maybe String + -- ^ A single expression to evaluate + , cmdFile :: Maybe String + -- ^ Execute the commands in a given file + , checkFile :: Maybe String + -- ^ Check a file and then exit + , debugFlag :: Bool } discoOpts :: O.Parser DiscoOpts -discoOpts = DiscoOpts - <$> O.switch ( - mconcat - [ O.long "version" - , O.short 'v' - , O.help "show current version" - ] - ) - - <*> O.optional ( - O.strOption (mconcat - [ O.long "evaluate" - , O.short 'e' - , O.help "evaluate an expression" - , O.metavar "TERM" - ]) +discoOpts = + DiscoOpts + <$> O.switch + ( mconcat + [ O.long "version" + , O.short 'v' + , O.help "show current version" + ] ) - <*> O.optional ( - O.strOption (mconcat - [ O.long "file" - , O.short 'f' - , O.help "execute the commands in a file" - , O.metavar "FILE" - ]) + <*> O.optional + ( O.strOption + ( mconcat + [ O.long "evaluate" + , O.short 'e' + , O.help "evaluate an expression" + , O.metavar "TERM" + ] + ) ) - <*> O.optional ( - O.strOption (mconcat - [ O.long "check" - , O.help "check a file without starting the interactive REPL" - , O.metavar "FILE" - ]) + <*> O.optional + ( O.strOption + ( mconcat + [ O.long "file" + , O.short 'f' + , O.help "execute the commands in a file" + , O.metavar "FILE" + ] + ) + ) + <*> O.optional + ( O.strOption + ( mconcat + [ O.long "check" + , O.help "check a file without starting the interactive REPL" + , O.metavar "FILE" + ] + ) + ) + <*> O.switch + ( mconcat + [ O.long "debug" + , O.help "print debugging information" + , O.short 'd' + ] ) - <*> O.switch ( - mconcat - [ O.long "debug" - , O.help "print debugging information" - , O.short 'd' - ] - ) discoVersion :: String discoVersion = showVersion version discoInfo :: O.ParserInfo DiscoOpts -discoInfo = O.info (O.helper <*> discoOpts) $ mconcat - [ O.fullDesc - , O.progDesc "Command-line interface for Disco, a programming language for discrete mathematics." - , O.header $ "disco " ++ discoVersion - ] +discoInfo = + O.info (O.helper <*> discoOpts) $ + mconcat + [ O.fullDesc + , O.progDesc "Command-line interface for Disco, a programming language for discrete mathematics." + , O.header $ "disco " ++ discoVersion + ] optsToCfg :: DiscoOpts -> DiscoConfig optsToCfg opts = initDiscoConfig & debugMode .~ debugFlag opts @@ -140,7 +154,6 @@ discoMain = do let batch = any isJust [evaluate opts, cmdFile opts, checkFile opts] unless batch $ putStr banner runDisco (optsToCfg opts) $ do - -- Load an empty module just to force standard libraries to be loaded first _ <- loadParsedDiscoModule True FromStdlib REPLModule emptyModule @@ -148,65 +161,63 @@ discoMain = do Just file -> do res <- handleLoad file liftIO $ if res then exitSuccess else exitFailure - Nothing -> return () + Nothing -> return () case cmdFile opts of Just file -> do mcmds <- loadFile file case mcmds of - Nothing -> return () + Nothing -> return () Just cmds -> mapM_ handleCMD (lines cmds) - Nothing -> return () + Nothing -> return () forM_ (evaluate opts) handleCMD unless batch $ do loop - - where - - -- These types used to involve InputT Disco, but we now use Final - -- (InputT IO) in the list of effects. see - -- https://github.com/polysemy-research/polysemy/issues/395 for - -- inspiration. - - ctrlC :: MonadIO m => m a -> SomeException -> m a - ctrlC act e = do - liftIO $ print e - act - - withCtrlC :: (MonadIO m, CMC.MonadCatch m) => m a -> m a -> m a - withCtrlC resume act = CMC.catch act (ctrlC resume) - - loop :: Members DiscoEffects r => Sem r () - loop = do - minput <- embedFinal $ withCtrlC (return $ Just "") (getInputLine "Disco> ") - case minput of - Nothing -> return () - Just input - | ":q" `isPrefixOf` input && input `isPrefixOf` ":quit" -> do - liftIO $ putStrLn "Goodbye!" - return () - | ":{" `isPrefixOf` input -> do - multiLineLoop [] - loop - | otherwise -> do - mapError @_ @DiscoError (Panic . show) $ - absorbMonadCatch $ + where + -- These types used to involve InputT Disco, but we now use Final + -- (InputT IO) in the list of effects. see + -- https://github.com/polysemy-research/polysemy/issues/395 for + -- inspiration. + + ctrlC :: MonadIO m => m a -> SomeException -> m a + ctrlC act e = do + liftIO $ print e + act + + withCtrlC :: (MonadIO m, CMC.MonadCatch m) => m a -> m a -> m a + withCtrlC resume act = CMC.catch act (ctrlC resume) + + loop :: Members DiscoEffects r => Sem r () + loop = do + minput <- embedFinal $ withCtrlC (return $ Just "") (getInputLine "Disco> ") + case minput of + Nothing -> return () + Just input + | ":q" `isPrefixOf` input && input `isPrefixOf` ":quit" -> do + liftIO $ putStrLn "Goodbye!" + return () + | ":{" `isPrefixOf` input -> do + multiLineLoop [] + loop + | otherwise -> do + mapError @_ @DiscoError (Panic . show) $ + absorbMonadCatch $ withCtrlC (return ()) $ - handleCMD input - loop - - multiLineLoop :: Members DiscoEffects r => [String] -> Sem r () - multiLineLoop ls = do - minput <- embedFinal $ withCtrlC (return Nothing) (getInputLine "Disco| ") - case minput of - Nothing -> return () - Just input - | ":}" `isPrefixOf` input -> do - mapError @_ @DiscoError (Panic . show) $ - absorbMonadCatch $ + handleCMD input + loop + + multiLineLoop :: Members DiscoEffects r => [String] -> Sem r () + multiLineLoop ls = do + minput <- embedFinal $ withCtrlC (return Nothing) (getInputLine "Disco| ") + case minput of + Nothing -> return () + Just input + | ":}" `isPrefixOf` input -> do + mapError @_ @DiscoError (Panic . show) $ + absorbMonadCatch $ withCtrlC (return ()) $ - handleCMD (unlines (reverse ls)) - | otherwise -> do - multiLineLoop (input:ls) + handleCMD (unlines (reverse ls)) + | otherwise -> do + multiLineLoop (input : ls) -- | Parse and run the command corresponding to some REPL input. handleCMD :: Members DiscoEffects r => String -> Sem r () @@ -214,7 +225,8 @@ handleCMD "" = return () handleCMD s = do exts <- use @TopInfo (replModInfo . miExts) case parseLine discoCommands exts s of - Left m -> info (text m) + Left m -> info (text m) Right l -> catch @DiscoError (dispatch discoCommands l) (info . pretty') - -- The above has to be catch, not outputErrors, because - -- the latter won't resume afterwards. + +-- The above has to be catch, not outputErrors, because +-- the latter won't resume afterwards. diff --git a/src/Disco/Interactive/Commands.hs b/src/Disco/Interactive/Commands.hs index 576fcaf9..acf10c38 100644 --- a/src/Disco/Interactive/Commands.hs +++ b/src/Disco/Interactive/Commands.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE StandaloneDeriving #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Interactive.Commands -- Copyright : disco team and contributors @@ -12,69 +15,86 @@ -- -- Defining and dispatching all commands/functionality available at -- the REPL prompt. ------------------------------------------------------------------------------ - -module Disco.Interactive.Commands - ( dispatch, - discoCommands, - handleLoad, - loadFile, - parseLine - ) where - -import Control.Arrow ((&&&)) -import Control.Lens (to, view, (%~), (.~), (?~), - (^.)) -import Control.Monad.Except -import Data.Char (isSpace) -import Data.Coerce -import Data.List (find, isPrefixOf, sortBy) -import Data.Map ((!)) -import qualified Data.Map as M -import Data.Typeable -import Prelude as P -import System.FilePath (splitFileName) - -import Text.Megaparsec hiding (State, runParser) -import qualified Text.Megaparsec.Char as C -import Unbound.Generics.LocallyNameless (Name, name2String, - string2Name) - -import Disco.Effects.Input -import Disco.Effects.LFresh -import Disco.Effects.State -import Polysemy -import Polysemy.Error hiding (try) -import Polysemy.Output -import Polysemy.Reader - -import Data.Maybe (mapMaybe, maybeToList) -import Disco.AST.Surface -import Disco.AST.Typed -import Disco.Compile -import Disco.Context as Ctx -import Disco.Desugar -import Disco.Doc -import Disco.Error -import Disco.Eval -import Disco.Extensions -import Disco.Interpret.CESK -import Disco.Messages -import Disco.Module -import Disco.Names -import Disco.Parser (Parser, ident, reservedOp, - runParser, sc, symbol, term, - wholeModule, withExts) -import Disco.Pretty hiding (empty, (<>)) -import qualified Disco.Pretty as Pretty -import Disco.Property (prettyTestResult) -import Disco.Syntax.Operators -import Disco.Syntax.Prims (Prim (PrimBOp, PrimUOp), - toPrim) -import Disco.Typecheck -import Disco.Typecheck.Erase -import Disco.Types (pattern TyString, toPolyType) -import Disco.Value +module Disco.Interactive.Commands ( + dispatch, + discoCommands, + handleLoad, + loadFile, + parseLine, +) where + +import Control.Arrow ((&&&)) +import Control.Lens ( + to, + view, + (%~), + (.~), + (?~), + (^.), + ) +import Control.Monad.Except +import Data.Char (isSpace) +import Data.Coerce +import Data.List (find, isPrefixOf, sortBy) +import Data.Map ((!)) +import qualified Data.Map as M +import Data.Typeable +import System.FilePath (splitFileName) +import Prelude as P + +import Text.Megaparsec hiding (State, runParser) +import qualified Text.Megaparsec.Char as C +import Unbound.Generics.LocallyNameless ( + Name, + name2String, + string2Name, + ) + +import Disco.Effects.Input +import Disco.Effects.LFresh +import Disco.Effects.State +import Polysemy +import Polysemy.Error hiding (try) +import Polysemy.Output +import Polysemy.Reader + +import Data.Maybe (mapMaybe, maybeToList) +import Disco.AST.Surface +import Disco.AST.Typed +import Disco.Compile +import Disco.Context as Ctx +import Disco.Desugar +import Disco.Doc +import Disco.Error +import Disco.Eval +import Disco.Extensions +import Disco.Interpret.CESK +import Disco.Messages +import Disco.Module +import Disco.Names +import Disco.Parser ( + Parser, + ident, + reservedOp, + runParser, + sc, + symbol, + term, + wholeModule, + withExts, + ) +import Disco.Pretty hiding (empty, (<>)) +import qualified Disco.Pretty as Pretty +import Disco.Property (prettyTestResult) +import Disco.Syntax.Operators +import Disco.Syntax.Prims ( + Prim (PrimBOp, PrimUOp), + toPrim, + ) +import Disco.Typecheck +import Disco.Typecheck.Erase +import Disco.Types (toPolyType, pattern TyString) +import Disco.Value ------------------------------------------------------------ -- REPL expression type @@ -83,24 +103,24 @@ import Disco.Value -- | Data type to represent things typed at the Disco REPL. Each -- constructor has a singleton type to facilitate dispatch. data REPLExpr :: CmdTag -> * where - TypeCheck :: Term -> REPLExpr 'CTypeCheck -- Typecheck a term - Eval :: Module -> REPLExpr 'CEval -- Evaluate a block - TestProp :: Term -> REPLExpr 'CTestProp -- Run a property test - ShowDefn :: Name Term -> REPLExpr 'CShowDefn -- Show a variable's definition - Parse :: Term -> REPLExpr 'CParse -- Show the parsed AST - Pretty :: Term -> REPLExpr 'CPretty -- Pretty-print a term - Print :: Term -> REPLExpr 'CPrint -- Print a string - Ann :: Term -> REPLExpr 'CAnn -- Show type-annotated term - Desugar :: Term -> REPLExpr 'CDesugar -- Show a desugared term - Compile :: Term -> REPLExpr 'CCompile -- Show a compiled term - Load :: FilePath -> REPLExpr 'CLoad -- Load a file. - Reload :: REPLExpr 'CReload -- Reloads the most recently - -- loaded file. - Doc :: DocInput -> REPLExpr 'CDoc -- Show documentation. - Nop :: REPLExpr 'CNop -- No-op, e.g. if the user - -- just enters a comment - Help :: REPLExpr 'CHelp -- Show help - Names :: REPLExpr 'CNames -- Show bound names + TypeCheck :: Term -> REPLExpr 'CTypeCheck -- Typecheck a term + Eval :: Module -> REPLExpr 'CEval -- Evaluate a block + TestProp :: Term -> REPLExpr 'CTestProp -- Run a property test + ShowDefn :: Name Term -> REPLExpr 'CShowDefn -- Show a variable's definition + Parse :: Term -> REPLExpr 'CParse -- Show the parsed AST + Pretty :: Term -> REPLExpr 'CPretty -- Pretty-print a term + Print :: Term -> REPLExpr 'CPrint -- Print a string + Ann :: Term -> REPLExpr 'CAnn -- Show type-annotated term + Desugar :: Term -> REPLExpr 'CDesugar -- Show a desugared term + Compile :: Term -> REPLExpr 'CCompile -- Show a compiled term + Load :: FilePath -> REPLExpr 'CLoad -- Load a file. + Reload :: REPLExpr 'CReload -- Reloads the most recently + -- loaded file. + Doc :: DocInput -> REPLExpr 'CDoc -- Show documentation. + Nop :: REPLExpr 'CNop -- No-op, e.g. if the user + -- just enters a comment + Help :: REPLExpr 'CHelp -- Show help + Names :: REPLExpr 'CNames -- Show bound names deriving instance Show (REPLExpr c) @@ -153,23 +173,23 @@ data CmdTag -- | Data type to represent all the information about a single REPL -- command. data REPLCommand (c :: CmdTag) = REPLCommand - { -- | Name of the command - name :: String, - -- | Help text showing how to use the command, e.g. ":ann " - helpcmd :: String, - -- | Short free-form text explaining the command. - -- We could also consider adding long help text as well. - shortHelp :: String, - -- | Is the command for users or devs? - category :: REPLCommandCategory, - -- | Is it a built-in command or colon command? - cmdtype :: REPLCommandType, - -- | The action to execute, - -- given the input to the - -- command. - action :: REPLExpr c -> (forall r. Members DiscoEffects r => Sem r ()), - -- | Parser for the command argument(s). - parser :: Parser (REPLExpr c) + { name :: String + -- ^ Name of the command + , helpcmd :: String + -- ^ Help text showing how to use the command, e.g. ":ann " + , shortHelp :: String + -- ^ Short free-form text explaining the command. + -- We could also consider adding long help text as well. + , category :: REPLCommandCategory + -- ^ Is the command for users or devs? + , cmdtype :: REPLCommandType + -- ^ Is it a built-in command or colon command? + , action :: REPLExpr c -> (forall r. Members DiscoEffects r => Sem r ()) + -- ^ The action to execute, + -- given the input to the + -- command. + , parser :: Parser (REPLExpr c) + -- ^ Parser for the command argument(s). } -- | An existential wrapper around any REPL command info record. @@ -201,22 +221,22 @@ dispatch (SomeCmd c : cs) r@(SomeREPL e) = case gcast e of -- to the first matching command. discoCommands :: REPLCommands discoCommands = - [ SomeCmd annCmd, - SomeCmd compileCmd, - SomeCmd desugarCmd, - SomeCmd docCmd, - SomeCmd evalCmd, - SomeCmd helpCmd, - SomeCmd loadCmd, - SomeCmd namesCmd, - SomeCmd nopCmd, - SomeCmd parseCmd, - SomeCmd prettyCmd, - SomeCmd printCmd, - SomeCmd reloadCmd, - SomeCmd showDefnCmd, - SomeCmd typeCheckCmd, - SomeCmd testPropCmd + [ SomeCmd annCmd + , SomeCmd compileCmd + , SomeCmd desugarCmd + , SomeCmd docCmd + , SomeCmd evalCmd + , SomeCmd helpCmd + , SomeCmd loadCmd + , SomeCmd namesCmd + , SomeCmd nopCmd + , SomeCmd parseCmd + , SomeCmd prettyCmd + , SomeCmd printCmd + , SomeCmd reloadCmd + , SomeCmd showDefnCmd + , SomeCmd typeCheckCmd + , SomeCmd testPropCmd ] ------------------------------------------------------------ @@ -237,12 +257,12 @@ commandParser allCommands = -- colon, return a parser for its arguments. parseCommandArgs :: REPLCommands -> String -> Parser SomeREPLExpr parseCommandArgs allCommands cmd = maybe badCmd snd $ find ((cmd `isPrefixOf`) . fst) parsers - where - badCmd = fail $ "Command \":" ++ cmd ++ "\" is unrecognized." + where + badCmd = fail $ "Command \":" ++ cmd ++ "\" is unrecognized." - parsers = - map (\(SomeCmd rc) -> (name rc, SomeREPL <$> parser rc)) $ - byCmdType ColonCmd allCommands + parsers = + map (\(SomeCmd rc) -> (name rc, SomeREPL <$> parser rc)) $ + byCmdType ColonCmd allCommands -- | Parse a file name. fileParser :: Parser FilePath @@ -260,7 +280,7 @@ lineParser allCommands = parseLine :: REPLCommands -> ExtSet -> String -> Either String SomeREPLExpr parseLine allCommands exts s = case runParser (withExts exts (lineParser allCommands)) "" s of - Left e -> Left $ errorBundlePretty e + Left e -> Left $ errorBundlePretty e Right l -> Right l -------------------------------------------------------------------------------- @@ -273,13 +293,13 @@ parseLine allCommands exts s = annCmd :: REPLCommand 'CAnn annCmd = REPLCommand - { name = "ann", - helpcmd = ":ann", - shortHelp = "Show type-annotated typechecked term", - category = Dev, - cmdtype = ColonCmd, - action = inputToState @TopInfo . handleAnn, - parser = Ann <$> term + { name = "ann" + , helpcmd = ":ann" + , shortHelp = "Show type-annotated typechecked term" + , category = Dev + , cmdtype = ColonCmd + , action = inputToState @TopInfo . handleAnn + , parser = Ann <$> term } handleAnn :: @@ -296,13 +316,13 @@ handleAnn (Ann t) = do compileCmd :: REPLCommand 'CCompile compileCmd = REPLCommand - { name = "compile", - helpcmd = ":compile", - shortHelp = "Show a compiled term", - category = Dev, - cmdtype = ColonCmd, - action = inputToState @TopInfo . handleCompile, - parser = Compile <$> term + { name = "compile" + , helpcmd = ":compile" + , shortHelp = "Show a compiled term" + , category = Dev + , cmdtype = ColonCmd + , action = inputToState @TopInfo . handleCompile + , parser = Compile <$> term } handleCompile :: @@ -319,13 +339,13 @@ handleCompile (Compile t) = do desugarCmd :: REPLCommand 'CDesugar desugarCmd = REPLCommand - { name = "desugar", - helpcmd = ":desugar", - shortHelp = "Show a desugared term", - category = Dev, - cmdtype = ColonCmd, - action = inputToState @TopInfo . handleDesugar, - parser = Desugar <$> term + { name = "desugar" + , helpcmd = ":desugar" + , shortHelp = "Show a desugared term" + , category = Dev + , cmdtype = ColonCmd + , action = inputToState @TopInfo . handleDesugar + , parser = Desugar <$> term } handleDesugar :: @@ -342,13 +362,13 @@ handleDesugar (Desugar t) = do docCmd :: REPLCommand 'CDoc docCmd = REPLCommand - { name = "doc", - helpcmd = ":doc ", - shortHelp = "Show documentation", - category = User, - cmdtype = ColonCmd, - action = inputToState @TopInfo . handleDoc, - parser = Doc <$> parseDoc + { name = "doc" + , helpcmd = ":doc " + , shortHelp = "Show documentation" + , category = User + , cmdtype = ColonCmd + , action = inputToState @TopInfo . handleDoc + , parser = Doc <$> parseDoc } -- An input to the :doc command can be either a term, a primitive @@ -358,51 +378,48 @@ data DocInput = DocTerm Term | DocPrim Prim | DocOther String parseDoc :: Parser DocInput parseDoc = - (DocTerm <$> try term) - <|> (DocPrim <$> try (parseNakedOpPrim "operator")) - <|> (DocOther <$> (sc *> many (anySingleBut ' '))) + (DocTerm <$> try term) + <|> (DocPrim <$> try (parseNakedOpPrim "operator")) + <|> (DocOther <$> (sc *> many (anySingleBut ' '))) handleDoc :: Members '[Error DiscoError, Input TopInfo, LFresh, Output Message] r => REPLExpr 'CDoc -> Sem r () handleDoc (Doc (DocTerm (TBool _))) = handleDocBool -handleDoc (Doc (DocTerm TUnit)) = handleDocUnit -handleDoc (Doc (DocTerm TWild)) = handleDocWild +handleDoc (Doc (DocTerm TUnit)) = handleDocUnit +handleDoc (Doc (DocTerm TWild)) = handleDocWild handleDoc (Doc (DocTerm (TPrim p))) = handleDocPrim p -handleDoc (Doc (DocTerm (TVar x))) = handleDocVar x -handleDoc (Doc (DocTerm _)) = +handleDoc (Doc (DocTerm (TVar x))) = handleDocVar x +handleDoc (Doc (DocTerm _)) = err "Can't display documentation for an expression. Try asking about a function, operator, or type name." -handleDoc (Doc (DocPrim p)) = handleDocPrim p -handleDoc (Doc (DocOther s)) = handleDocOther s +handleDoc (Doc (DocPrim p)) = handleDocPrim p +handleDoc (Doc (DocOther s)) = handleDocOther s handleDocBool :: Members '[Output Message] r => Sem r () handleDocBool = info $ "true and false (also written True and False) are the two possible values of type Boolean." - $+$ - mkReference "bool" + $+$ mkReference "bool" handleDocUnit :: Members '[Output Message] r => Sem r () handleDocUnit = info $ "The unit value, i.e. the single value of type Unit." - $+$ - mkReference "unit" + $+$ mkReference "unit" handleDocWild :: Members '[Output Message] r => Sem r () handleDocWild = info $ "A wildcard pattern." - $+$ - mkReference "wild-pattern" + $+$ mkReference "wild-pattern" handleDocVar :: Members '[Error DiscoError, Input TopInfo, LFresh, Output Message] r => Name Term -> Sem r () handleDocVar x = do - replCtx <- inputs @TopInfo (view (replModInfo . miTys)) + replCtx <- inputs @TopInfo (view (replModInfo . miTys)) replTydefs <- inputs @TopInfo (view (replModInfo . miTydefs)) replDocs <- inputs @TopInfo (view (replModInfo . miDocs)) @@ -420,24 +437,23 @@ handleDocVar x = do ([], Nothing) -> -- Maybe the variable name entered by the user is actually a prim. case toPrim (name2String x) of - (prim:_) -> handleDocPrim prim - _ -> err $ "No documentation found for '" <> pretty' x <> "'." + (prim : _) -> handleDocPrim prim + _ -> err $ "No documentation found for '" <> pretty' x <> "'." (binds, def) -> mapM_ (showDoc docs) (map Left binds ++ map Right (maybeToList def)) - - where - showDoc docMap (Left (qn, ty)) = info $ + where + showDoc docMap (Left (qn, ty)) = + info $ hsep [pretty' x, ":", pretty' ty] - $+$ - case Ctx.lookup' qn docMap of - Just (DocString ss : _) -> vcat (text "" : map text ss ++ [text ""]) - _ -> Pretty.empty - showDoc docMap (Right tdBody) = info $ + $+$ case Ctx.lookup' qn docMap of + Just (DocString ss : _) -> vcat (text "" : map text ss ++ [text ""]) + _ -> Pretty.empty + showDoc docMap (Right tdBody) = + info $ pretty' (name2String x, tdBody) - $+$ - case Ctx.lookupAll' x docMap of - ((_, DocString ss : _) : _) -> vcat (text "" : map text ss ++ [text ""]) - _ -> Pretty.empty + $+$ case Ctx.lookupAll' x docMap of + ((_, DocString ss : _) : _) -> vcat (text "" : map text ss ++ [text ""]) + _ -> Pretty.empty handleDocPrim :: Members '[Error DiscoError, Input TopInfo, LFresh, Output Message] r => @@ -445,37 +461,39 @@ handleDocPrim :: Sem r () handleDocPrim prim = do handleTypeCheck (TypeCheck (TPrim prim)) - info $ vcat - [ case prim of - PrimUOp u -> describeAlts (f == Post) (f == Pre) syns - where + info $ + vcat + [ case prim of + PrimUOp u -> describeAlts (f == Post) (f == Pre) syns + where OpInfo (UOpF f _) syns _ = uopMap ! u - PrimBOp b -> describeAlts True True (opSyns $ bopMap ! b) - _ -> Pretty.empty - , case prim of - PrimUOp u -> describePrec (uPrec u) - PrimBOp b -> describePrec (bPrec b) <> describeFixity (assoc b) - _ -> Pretty.empty - ] + PrimBOp b -> describeAlts True True (opSyns $ bopMap ! b) + _ -> Pretty.empty + , case prim of + PrimUOp u -> describePrec (uPrec u) + PrimBOp b -> describePrec (bPrec b) <> describeFixity (assoc b) + _ -> Pretty.empty + ] case (M.lookup prim primDoc, M.lookup prim primReference) of (Nothing, Nothing) -> return () - (Nothing, Just p) -> info $ mkReference p - (Just d, mp) -> + (Nothing, Just p) -> info $ mkReference p + (Just d, mp) -> info $ "" $+$ text d $+$ "" $+$ maybe Pretty.empty (\p -> mkReference p $+$ "") mp - where - describePrec p = "precedence level" <+> text (show p) - describeFixity In = Pretty.empty - describeFixity InL = ", left associative" - describeFixity InR = ", right associative" - describeAlts _ _ [] = Pretty.empty - describeAlts _ _ [_] = Pretty.empty - describeAlts pre post (_:alts) = "Alternative syntax:" <+> intercalate "," (map showOp alts) - where - showOp op = hcat - [ if pre then "~" else Pretty.empty - , text op - , if post then "~" else Pretty.empty] - + where + describePrec p = "precedence level" <+> text (show p) + describeFixity In = Pretty.empty + describeFixity InL = ", left associative" + describeFixity InR = ", right associative" + describeAlts _ _ [] = Pretty.empty + describeAlts _ _ [_] = Pretty.empty + describeAlts pre post (_ : alts) = "Alternative syntax:" <+> intercalate "," (map showOp alts) + where + showOp op = + hcat + [ if pre then "~" else Pretty.empty + , text op + , if post then "~" else Pretty.empty + ] mkReference :: String -> Sem r Doc mkReference p = @@ -488,32 +506,35 @@ handleDocOther :: handleDocOther s = case (M.lookup s otherDoc, M.lookup s otherReference) of (Nothing, Nothing) -> info $ "No documentation found for '" <> text s <> "'." - (Nothing, Just p) -> info $ mkReference p - (Just d, mp) -> + (Nothing, Just p) -> info $ mkReference p + (Just d, mp) -> info $ text d $+$ "" $+$ maybe Pretty.empty (\p -> mkReference p $+$ "") mp ------------------------------------------------------------ -- eval evalCmd :: REPLCommand 'CEval -evalCmd = REPLCommand - { name = "eval" - , helpcmd = "" - , shortHelp = "Evaluate a block of code" - , category = User - , cmdtype = BuiltIn - , action = handleEval - , parser = Eval <$> wholeModule REPL - } +evalCmd = + REPLCommand + { name = "eval" + , helpcmd = "" + , shortHelp = "Evaluate a block of code" + , category = User + , cmdtype = BuiltIn + , action = handleEval + , parser = Eval <$> wholeModule REPL + } -handleEval - :: Members (Error DiscoError ': State TopInfo ': Output Message ': Embed IO ': EvalEffects) r - => REPLExpr 'CEval -> Sem r () +handleEval :: + Members (Error DiscoError ': State TopInfo ': Output Message ': Embed IO ': EvalEffects) r => + REPLExpr 'CEval -> + Sem r () handleEval (Eval m) = do mi <- inputToState @TopInfo $ loadParsedDiscoModule False FromCwdOrStdlib REPLModule m addToREPLModule mi forM_ (mi ^. miTerms) (mapError EvalErr . evalTerm True . fst) - -- garbageCollect? + +-- garbageCollect? -- First argument = should the value be printed? evalTerm :: Members (Error EvalError ': State TopInfo ': Output Message ': EvalEffects) r => Bool -> ATerm -> Sem r Value @@ -525,11 +546,11 @@ evalTerm pr at = do when pr $ info $ runInputConst tydefs $ prettyValue' ty v modify @TopInfo $ - (replModInfo . miTys %~ Ctx.insert (QName (QualifiedName REPLModule) (string2Name "it")) (toPolyType ty)) . - (topEnv %~ Ctx.insert (QName (QualifiedName REPLModule) (string2Name "it")) v) + (replModInfo . miTys %~ Ctx.insert (QName (QualifiedName REPLModule) (string2Name "it")) (toPolyType ty)) + . (topEnv %~ Ctx.insert (QName (QualifiedName REPLModule) (string2Name "it")) v) return v - where - ty = getType at + where + ty = getType at ------------------------------------------------------------ -- :help @@ -537,33 +558,33 @@ evalTerm pr at = do helpCmd :: REPLCommand 'CHelp helpCmd = REPLCommand - { name = "help", - helpcmd = ":help", - shortHelp = "Show help", - category = User, - cmdtype = ColonCmd, - action = handleHelp, - parser = return Help + { name = "help" + , helpcmd = ":help" + , shortHelp = "Show help" + , category = User + , cmdtype = ColonCmd + , action = handleHelp + , parser = return Help } handleHelp :: Member (Output Message) r => REPLExpr 'CHelp -> Sem r () handleHelp Help = info $ vcat - [ "Commands available from the prompt:" - , text "" - , vcat (map (\(SomeCmd c) -> showCmd c) $ sortedList discoCommands) - , text "" - ] - where - maxlen = longestCmd discoCommands - sortedList cmds = - sortBy (\(SomeCmd x) (SomeCmd y) -> compare (name x) (name y)) $ filteredCommands cmds - showCmd c = text (padRight (helpcmd c) maxlen ++ " " ++ shortHelp c) - longestCmd cmds = maximum $ map (\(SomeCmd c) -> length $ helpcmd c) cmds - padRight s maxsize = take maxsize (s ++ repeat ' ') - -- don't show dev-only commands by default - filteredCommands = P.filter (\(SomeCmd c) -> category c == User) + [ "Commands available from the prompt:" + , text "" + , vcat (map (\(SomeCmd c) -> showCmd c) $ sortedList discoCommands) + , text "" + ] + where + maxlen = longestCmd discoCommands + sortedList cmds = + sortBy (\(SomeCmd x) (SomeCmd y) -> compare (name x) (name y)) $ filteredCommands cmds + showCmd c = text (padRight (helpcmd c) maxlen ++ " " ++ shortHelp c) + longestCmd cmds = maximum $ map (\(SomeCmd c) -> length $ helpcmd c) cmds + padRight s maxsize = take maxsize (s ++ repeat ' ') + -- don't show dev-only commands by default + filteredCommands = P.filter (\(SomeCmd c) -> category c == User) ------------------------------------------------------------ -- :load @@ -571,13 +592,13 @@ handleHelp Help = loadCmd :: REPLCommand 'CLoad loadCmd = REPLCommand - { name = "load", - helpcmd = ":load ", - shortHelp = "Load a file", - category = User, - cmdtype = ColonCmd, - action = handleLoadWrapper, - parser = Load <$> fileParser + { name = "load" + , helpcmd = ":load " + , shortHelp = "Load a file" + , category = User + , cmdtype = ColonCmd + , action = handleLoadWrapper + , parser = Load <$> fileParser } -- | Parses, typechecks, and loads a module by first recursively loading any imported @@ -622,30 +643,29 @@ handleLoad fp = do runAllTests :: Members (Output Message ': Input TopInfo ': EvalEffects) r => [QName Term] -> Ctx ATerm [AProperty] -> Sem r Bool -- (Ctx ATerm [TestResult]) runAllTests declNames aprops | Ctx.null aprops = return True - | otherwise = do + | otherwise = do info "Running tests..." -- Use the order the names were defined in the module and <$> mapM (uncurry runTests) (mapMaybe (\n -> (n,) <$> Ctx.lookup' (coerce n) aprops) declNames) - - where - numSamples :: Int - numSamples = 50 -- XXX make this configurable somehow - - runTests :: Members (Output Message ': Input TopInfo ': EvalEffects) r => QName Term -> [AProperty] -> Sem r Bool - runTests (QName _ n) props = do - results <- inputTopEnv $ traverse (sequenceA . (id &&& runTest numSamples)) props - let failures = P.filter (not . testIsOk . snd) results - hdr = pretty' n <> ":" - - case P.null failures of - True -> info $ nest 2 $ hdr <+> "OK" - False -> do - tydefs <- inputs @TopInfo (view (replModInfo . to allTydefs)) - let prettyFailures = - runInputConst tydefs . runReader initPA . runLFresh $ - bulletList "-" $ map (uncurry prettyTestResult) failures - info $ nest 2 $ hdr $+$ prettyFailures - return (P.null failures) + where + numSamples :: Int + numSamples = 50 -- XXX make this configurable somehow + runTests :: Members (Output Message ': Input TopInfo ': EvalEffects) r => QName Term -> [AProperty] -> Sem r Bool + runTests (QName _ n) props = do + results <- inputTopEnv $ traverse (sequenceA . (id &&& runTest numSamples)) props + let failures = P.filter (not . testIsOk . snd) results + hdr = pretty' n <> ":" + + case P.null failures of + True -> info $ nest 2 $ hdr <+> "OK" + False -> do + tydefs <- inputs @TopInfo (view (replModInfo . to allTydefs)) + let prettyFailures = + runInputConst tydefs . runReader initPA . runLFresh $ + bulletList "-" $ + map (uncurry prettyTestResult) failures + info $ nest 2 $ hdr $+$ prettyFailures + return (P.null failures) ------------------------------------------------------------ -- :names @@ -653,13 +673,13 @@ runAllTests declNames aprops namesCmd :: REPLCommand 'CNames namesCmd = REPLCommand - { name = "names", - helpcmd = ":names", - shortHelp = "Show all names in current scope", - category = User, - cmdtype = ColonCmd, - action = inputToState . handleNames, - parser = return Names + { name = "names" + , helpcmd = ":names" + , shortHelp = "Show all names in current scope" + , category = User + , cmdtype = ColonCmd + , action = inputToState . handleNames + , parser = return Names } -- | Show names and types for each item in the top-level context. @@ -672,10 +692,9 @@ handleNames Names = do ctx <- inputs @TopInfo (view (replModInfo . miTys)) info $ vcat (map pretty' (M.assocs tyDef)) - $+$ - vcat (map showFn (Ctx.assocs ctx)) - where - showFn (QName _ x, ty) = hsep [pretty' x, text ":", pretty' ty] + $+$ vcat (map showFn (Ctx.assocs ctx)) + where + showFn (QName _ x, ty) = hsep [pretty' x, text ":", pretty' ty] ------------------------------------------------------------ -- nop @@ -683,13 +702,13 @@ handleNames Names = do nopCmd :: REPLCommand 'CNop nopCmd = REPLCommand - { name = "nop", - helpcmd = "", - shortHelp = "No-op, e.g. if the user just enters a comment", - category = Dev, - cmdtype = BuiltIn, - action = handleNop, - parser = Nop <$ (sc <* eof) + { name = "nop" + , helpcmd = "" + , shortHelp = "No-op, e.g. if the user just enters a comment" + , category = Dev + , cmdtype = BuiltIn + , action = handleNop + , parser = Nop <$ (sc <* eof) } handleNop :: REPLExpr 'CNop -> Sem r () @@ -701,13 +720,13 @@ handleNop Nop = pure () parseCmd :: REPLCommand 'CParse parseCmd = REPLCommand - { name = "parse", - helpcmd = ":parse ", - shortHelp = "Show the parsed AST", - category = Dev, - cmdtype = ColonCmd, - action = handleParse, - parser = Parse <$> term + { name = "parse" + , helpcmd = ":parse " + , shortHelp = "Show the parsed AST" + , category = Dev + , cmdtype = ColonCmd + , action = handleParse + , parser = Parse <$> term } handleParse :: Member (Output Message) r => REPLExpr 'CParse -> Sem r () @@ -719,13 +738,13 @@ handleParse (Parse t) = info (text (show t)) prettyCmd :: REPLCommand 'CPretty prettyCmd = REPLCommand - { name = "pretty", - helpcmd = ":pretty ", - shortHelp = "Pretty-print a term", - category = Dev, - cmdtype = ColonCmd, - action = handlePretty, - parser = Pretty <$> term + { name = "pretty" + , helpcmd = ":pretty " + , shortHelp = "Pretty-print a term" + , category = Dev + , cmdtype = ColonCmd + , action = handlePretty + , parser = Pretty <$> term } handlePretty :: Members '[LFresh, Output Message] r => REPLExpr 'CPretty -> Sem r () @@ -737,13 +756,13 @@ handlePretty (Pretty t) = info $ pretty' t printCmd :: REPLCommand 'CPrint printCmd = REPLCommand - { name = "print", - helpcmd = ":print ", - shortHelp = "Print a string without the double quotes, interpreting special characters", - category = User, - cmdtype = ColonCmd, - action = handlePrint, - parser = Print <$> term + { name = "print" + , helpcmd = ":print " + , shortHelp = "Print a string without the double quotes, interpreting special characters" + , category = User + , cmdtype = ColonCmd + , action = handlePrint + , parser = Print <$> term } handlePrint :: Members (Error DiscoError ': State TopInfo ': Output Message ': EvalEffects) r => REPLExpr 'CPrint -> Sem r () @@ -758,13 +777,13 @@ handlePrint (Print t) = do reloadCmd :: REPLCommand 'CReload reloadCmd = REPLCommand - { name = "reload", - helpcmd = ":reload", - shortHelp = "Reloads the most recently loaded file", - category = User, - cmdtype = ColonCmd, - action = handleReload, - parser = return Reload + { name = "reload" + , helpcmd = ":reload" + , shortHelp = "Reloads the most recently loaded file" + , category = User + , cmdtype = ColonCmd + , action = handleReload + , parser = return Reload } handleReload :: @@ -775,7 +794,7 @@ handleReload Reload = do file <- use lastFile case file of Nothing -> info "No file to reload." - Just f -> void (handleLoad f) + Just f -> void (handleLoad f) ------------------------------------------------------------ -- :defn @@ -783,13 +802,13 @@ handleReload Reload = do showDefnCmd :: REPLCommand 'CShowDefn showDefnCmd = REPLCommand - { name = "defn", - helpcmd = ":defn ", - shortHelp = "Show a variable's definition", - category = User, - cmdtype = ColonCmd, - action = inputToState @TopInfo . handleShowDefn, - parser = ShowDefn <$> (sc *> ident) + { name = "defn" + , helpcmd = ":defn " + , shortHelp = "Show a variable's definition" + , category = User + , cmdtype = ColonCmd + , action = inputToState @TopInfo . handleShowDefn + , parser = ShowDefn <$> (sc *> ident) } handleShowDefn :: @@ -798,7 +817,7 @@ handleShowDefn :: Sem r () handleShowDefn (ShowDefn x) = do let name2s = name2String x - defns <- inputs @TopInfo (view (replModInfo . miTermdefs)) + defns <- inputs @TopInfo (view (replModInfo . miTermdefs)) tyDefns <- inputs @TopInfo (view (replModInfo . miTydefs)) let xdefs = Ctx.lookupAll' (coerce x) defns @@ -808,7 +827,7 @@ handleShowDefn (ShowDefn x) = do let ds = map (pretty' . snd) xdefs ++ maybe [] (pure . pretty' . (name2s,)) mtydef case ds of [] -> text "No definition for" <+> pretty' x - _ -> vcat ds + _ -> vcat ds ------------------------------------------------------------ -- :test @@ -816,13 +835,13 @@ handleShowDefn (ShowDefn x) = do testPropCmd :: REPLCommand 'CTestProp testPropCmd = REPLCommand - { name = "test", - helpcmd = ":test ", - shortHelp = "Test a property using random examples", - category = User, - cmdtype = ColonCmd, - action = handleTest, - parser = TestProp <$> term + { name = "test" + , helpcmd = ":test " + , shortHelp = "Test a property using random examples" + , category = User + , cmdtype = ColonCmd + , action = handleTest + , parser = TestProp <$> term } handleTest :: @@ -842,13 +861,13 @@ handleTest (TestProp t) = do typeCheckCmd :: REPLCommand 'CTypeCheck typeCheckCmd = REPLCommand - { name = "type", - helpcmd = ":type ", - shortHelp = "Typecheck a term", - category = Dev, - cmdtype = ColonCmd, - action = inputToState @TopInfo . handleTypeCheck, - parser = parseTypeCheck + { name = "type" + , helpcmd = ":type " + , shortHelp = "Typecheck a term" + , category = Dev + , cmdtype = ColonCmd + , action = inputToState @TopInfo . handleTypeCheck + , parser = parseTypeCheck } handleTypeCheck :: @@ -875,7 +894,7 @@ parseNakedOp = TPrim <$> parseNakedOpPrim parseNakedOpPrim :: Parser Prim parseNakedOpPrim = sc *> choice (map mkOpParser (concat opTable)) - where - mkOpParser :: OpInfo -> Parser Prim - mkOpParser (OpInfo (UOpF _ op) syns _) = choice (map ((PrimUOp op <$) . reservedOp) syns) - mkOpParser (OpInfo (BOpF _ op) syns _) = choice (map ((PrimBOp op <$) . reservedOp) syns) + where + mkOpParser :: OpInfo -> Parser Prim + mkOpParser (OpInfo (UOpF _ op) syns _) = choice (map ((PrimUOp op <$) . reservedOp) syns) + mkOpParser (OpInfo (BOpF _ op) syns _) = choice (map ((PrimBOp op <$) . reservedOp) syns) diff --git a/src/Disco/Interpret/CESK.hs b/src/Disco/Interpret/CESK.hs index a9efb480..3b08ed14 100644 --- a/src/Disco/Interpret/CESK.hs +++ b/src/Disco/Interpret/CESK.hs @@ -2,6 +2,9 @@ {-# OPTIONS_GHC -fmax-pmcheck-models=200 #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Interpret.CESK -- Copyright : disco team and contributors @@ -10,54 +13,57 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- CESK machine interpreter for Disco. ------------------------------------------------------------------------------ - -module Disco.Interpret.CESK - ( CESK, - runCESK, - step, - eval, - runTest, - ) +module Disco.Interpret.CESK ( + CESK, + runCESK, + step, + eval, + runTest, +) where -import Unbound.Generics.LocallyNameless (Bind, Name) - -import Algebra.Graph -import qualified Algebra.Graph.AdjacencyMap as AdjMap -import Control.Arrow ((***), (>>>)) -import Control.Monad ((>=>)) -import Data.Bifunctor (first, second) -import Data.List (find) -import qualified Data.Map as M -import Data.Maybe (isJust) -import Data.Ratio -import Disco.AST.Core -import Disco.AST.Generic (Ellipsis (..), Side (..), - selectSide) -import Disco.AST.Typed (AProperty) -import Disco.Compile -import Disco.Context as Ctx -import Disco.Enumerate -import Disco.Error -import Disco.Names -import Disco.Property -import Disco.Types hiding (V) -import Disco.Value -import Math.Combinatorics.Exact.Binomial (choose) -import Math.Combinatorics.Exact.Factorial (factorial) -import Math.NumberTheory.Primes (factorise, unPrime) -import Math.NumberTheory.Primes.Testing (isPrime) -import Math.OEIS (catalogNums, - extendSequence, - lookupSequence) - -import Disco.Effects.Fresh -import Disco.Effects.Input -import Disco.Effects.Random -import Polysemy -import Polysemy.Error -import Polysemy.State +import Unbound.Generics.LocallyNameless (Bind, Name) + +import Algebra.Graph +import qualified Algebra.Graph.AdjacencyMap as AdjMap +import Control.Arrow ((***), (>>>)) +import Control.Monad ((>=>)) +import Data.Bifunctor (first, second) +import Data.List (find) +import qualified Data.Map as M +import Data.Maybe (isJust) +import Data.Ratio +import Disco.AST.Core +import Disco.AST.Generic ( + Ellipsis (..), + Side (..), + selectSide, + ) +import Disco.AST.Typed (AProperty) +import Disco.Compile +import Disco.Context as Ctx +import Disco.Enumerate +import Disco.Error +import Disco.Names +import Disco.Property +import Disco.Types hiding (V) +import Disco.Value +import Math.Combinatorics.Exact.Binomial (choose) +import Math.Combinatorics.Exact.Factorial (factorial) +import Math.NumberTheory.Primes (factorise, unPrime) +import Math.NumberTheory.Primes.Testing (isPrime) +import Math.OEIS ( + catalogNums, + extendSequence, + lookupSequence, + ) + +import Disco.Effects.Fresh +import Disco.Effects.Input +import Disco.Effects.Random +import Polysemy +import Polysemy.Error +import Polysemy.State ------------------------------------------------------------ -- Utilities @@ -141,22 +147,22 @@ data CESK -- | Is the CESK machine in a final state? isFinal :: CESK -> Maybe (Either EvalError Value) -isFinal (Up e []) = Just (Left e) +isFinal (Up e []) = Just (Left e) isFinal (Out v []) = Just (Right v) -isFinal _ = Nothing +isFinal _ = Nothing -- | Run a CESK machine to completion. runCESK :: Members '[Fresh, Random, State Mem] r => CESK -> Sem r (Either EvalError Value) runCESK cesk = case isFinal cesk of Just res -> return res - Nothing -> step cesk >>= runCESK + Nothing -> step cesk >>= runCESK -- | Advance the CESK machine by one step. step :: Members '[Fresh, Random, State Mem] r => CESK -> Sem r CESK step cesk = case cesk of (In (CVar x) e k) -> case Ctx.lookup' x e of Nothing -> return $ Up (UnboundError x) k - Just v -> return $ Out v k + Just v -> return $ Out v k (In (CNum d r) _ k) -> return $ Out (VNum d r) k (In (CConst OMatchErr) _ k) -> return $ Up NonExhaustive k (In (CConst OEmptyGraph) _ k) -> return $ Out (VGraph empty) k @@ -177,7 +183,6 @@ step cesk = case cesk of return $ Out (foldr (VPair . VRef) VUnit locs) k (In (CForce c) e k) -> return $ In c e (FForce : k) (In (CTest vars c) e k) -> return $ In c e (FTest (TestVars vars) e : k) - (Out v (FInj s : k)) -> return $ Out (VInj s v) k (Out (VInj L v) (FCase e b1 _ : k)) -> do (x, c1) <- unbind b1 @@ -200,7 +205,6 @@ step cesk = case cesk of (Out (VClo e (x : xs) b) (FArgV v : k)) -> return $ Out (VClo (Ctx.insert (localName x) v e) xs b) k (Out (VConst op) (FArgV v : k)) -> appConst k op v (Out (VFun f) (FArgV v : k)) -> return $ Out (f v) k - (Out (VRef n) (FForce : k)) -> do cell <- lkup n case cell of @@ -213,18 +217,15 @@ step cesk = case cesk of (Out v (FUpdate n : k)) -> do set n (V v) return $ Out v k - - (Up err (f@FTest{} : k)) -> + (Up err (f@FTest {} : k)) -> return $ Out (VProp (VPDone (TestResult False (TestRuntimeError err) emptyTestEnv))) (f : k) (Up err (_ : ks)) -> return $ Up err ks - (Out v (FTest vs e : k)) -> do let result = ensureProp v - res = getTestEnv vs e + res = getTestEnv vs e case res of Left err -> return $ Up err k Right e' -> return $ Out (VProp $ extendPropEnv e' result) k - _ -> error "Impossible! Bad CESK machine state" ------------------------------------------------------------ @@ -233,15 +234,18 @@ step cesk = case cesk of arity2 :: (Value -> Value -> a) -> Value -> a arity2 f (VPair x y) = f x y -arity2 _f _v = error "arity2 on a non-pair!" +arity2 _f _v = error "arity2 on a non-pair!" arity3 :: (Value -> Value -> Value -> a) -> Value -> a arity3 f (VPair x (VPair y z)) = f x y z -arity3 _f _v = error "arity3 on a non-triple!" +arity3 _f _v = error "arity3 on a non-triple!" -appConst - :: Members '[Random, State Mem] r - => Cont -> Op -> Value -> Sem r CESK +appConst :: + Members '[Random, State Mem] r => + Cont -> + Op -> + Value -> + Sem r CESK appConst k = \case -------------------------------------------------- -- Basics @@ -259,73 +263,73 @@ appConst k = \case OAbs -> numOp1 abs >=> out OMul -> numOp2 (*) >=> out ODiv -> numOp2' divOp >>> outWithErr - where - divOp :: Member (Error EvalError) r => Rational -> Rational -> Sem r Value - divOp _ 0 = throw DivByZero - divOp m n = return $ ratv (m / n) + where + divOp :: Member (Error EvalError) r => Rational -> Rational -> Sem r Value + divOp _ 0 = throw DivByZero + divOp m n = return $ ratv (m / n) OExp -> numOp2 (\m n -> m ^^ numerator n) >=> out OMod -> numOp2' modOp >>> outWithErr - where - modOp :: Member (Error EvalError) r => Rational -> Rational -> Sem r Value - modOp m n - | n == 0 = throw DivByZero - | otherwise = return $ intv (numerator m `mod` numerator n) + where + modOp :: Member (Error EvalError) r => Rational -> Rational -> Sem r Value + modOp m n + | n == 0 = throw DivByZero + | otherwise = return $ intv (numerator m `mod` numerator n) ODivides -> numOp2' (\m n -> return (enumv $ divides m n)) >=> out - where - divides 0 0 = True - divides 0 _ = False - divides x y = denominator (y / x) == 1 + where + divides 0 0 = True + divides 0 _ = False + divides x y = denominator (y / x) == 1 -------------------------------------------------- -- Number theory OIsPrime -> intOp1 (enumv . isPrime) >=> out OFactor -> intOp1' primFactor >>> outWithErr - where - -- Semantics of the @$factor@ prim: turn a natural number into its - -- bag of prime factors. Crash if called on 0, which does not have - -- a prime factorization. - primFactor :: Member (Error EvalError) r => Integer -> Sem r Value - primFactor 0 = throw (Crash "0 has no prime factorization!") - primFactor n = return . VBag $ map ((intv . unPrime) *** fromIntegral) (factorise n) + where + -- Semantics of the @$factor@ prim: turn a natural number into its + -- bag of prime factors. Crash if called on 0, which does not have + -- a prime factorization. + primFactor :: Member (Error EvalError) r => Integer -> Sem r Value + primFactor 0 = throw (Crash "0 has no prime factorization!") + primFactor n = return . VBag $ map ((intv . unPrime) *** fromIntegral) (factorise n) OFrac -> numOp1' (return . primFrac) >=> out - where - -- Semantics of the @$frac@ prim: turn a rational number into a pair - -- of its numerator and denominator. - primFrac :: Rational -> Value - primFrac r = VPair (intv (numerator r)) (intv (denominator r)) + where + -- Semantics of the @$frac@ prim: turn a rational number into a pair + -- of its numerator and denominator. + primFrac :: Rational -> Value + primFrac r = VPair (intv (numerator r)) (intv (denominator r)) -------------------------------------------------- -- Combinatorics OMultinom -> arity2 multinomOp >=> out - where - multinomOp :: Value -> Value -> Sem r Value - multinomOp (vint -> n0) (vlist vint -> ks0) = return . intv $ multinomial n0 ks0 - where - multinomial :: Integer -> [Integer] -> Integer - multinomial _ [] = 1 - multinomial n (k' : ks) - | k' > n = 0 - | otherwise = choose n k' * multinomial (n - k') ks + where + multinomOp :: Value -> Value -> Sem r Value + multinomOp (vint -> n0) (vlist vint -> ks0) = return . intv $ multinomial n0 ks0 + where + multinomial :: Integer -> [Integer] -> Integer + multinomial _ [] = 1 + multinomial n (k' : ks) + | k' > n = 0 + | otherwise = choose n k' * multinomial (n - k') ks OFact -> numOp1' factOp >>> outWithErr - where - factOp :: Member (Error EvalError) r => Rational -> Sem r Value - factOp (numerator -> n) - | n > fromIntegral (maxBound :: Int) = throw Overflow - | otherwise = return . intv $ factorial (fromIntegral n) + where + factOp :: Member (Error EvalError) r => Rational -> Sem r Value + factOp (numerator -> n) + | n > fromIntegral (maxBound :: Int) = throw Overflow + | otherwise = return . intv $ factorial (fromIntegral n) OEnum -> out . enumOp - where - enumOp :: Value -> Value - enumOp (VType ty) = listv id (enumerateType ty) - enumOp v = error $ "Impossible! enumOp on non-type " ++ show v + where + enumOp :: Value -> Value + enumOp (VType ty) = listv id (enumerateType ty) + enumOp v = error $ "Impossible! enumOp on non-type " ++ show v OCount -> out . countOp - where - countOp :: Value -> Value - countOp (VType ty) = case countType ty of - Just num -> VInj R (intv num) - Nothing -> VNil - countOp v = error $ "Impossible! countOp on non-type " ++ show v + where + countOp :: Value -> Value + countOp (VType ty) = case countType ty of + Just num -> VInj R (intv num) + Nothing -> VNil + countOp v = error $ "Impossible! countOp on non-type " ++ show v -------------------------------------------------- -- Sequences @@ -342,40 +346,37 @@ appConst k = \case -- Container operations OPower -> withBag OPower $ out . VBag . sortNCount . map (first VBag) . choices - where - choices :: [(Value, Integer)] -> [([(Value, Integer)], Integer)] - choices [] = [([], 1)] - choices ((x, n) : xs) = xs' ++ concatMap (\k' -> map (cons n (x, k')) xs') [1 .. n] - where - xs' = choices xs - cons n (x, k') (zs, m) = ((x, k') : zs, choose n k' * m) - OBagElem -> arity2 $ \x -> withBag OBagElem $ - out . enumv . isJust . find (valEq x) . map fst + where + choices :: [(Value, Integer)] -> [([(Value, Integer)], Integer)] + choices [] = [([], 1)] + choices ((x, n) : xs) = xs' ++ concatMap (\k' -> map (cons n (x, k')) xs') [1 .. n] + where + xs' = choices xs + cons n (x, k') (zs, m) = ((x, k') : zs, choose n k' * m) + OBagElem -> arity2 $ \x -> + withBag OBagElem $ + out . enumv . isJust . find (valEq x) . map fst OListElem -> arity2 $ \x -> out . enumv . isJust . find (valEq x) . vlist id - - OEachSet -> arity2 $ \f -> withBag OEachSet $ - outWithErr . fmap (VBag . countValues) . mapM (evalApp f . (:[]) . fst) - - OEachBag -> arity2 $ \f -> withBag OEachBag $ - outWithErr . fmap (VBag . sortNCount) . mapM (\(x,n) -> (,n) <$> evalApp f [x]) - + OEachSet -> arity2 $ \f -> + withBag OEachSet $ + outWithErr . fmap (VBag . countValues) . mapM (evalApp f . (: []) . fst) + OEachBag -> arity2 $ \f -> + withBag OEachBag $ + outWithErr . fmap (VBag . sortNCount) . mapM (\(x, n) -> (,n) <$> evalApp f [x]) OFilterBag -> arity2 $ \f -> withBag OFilterBag $ \xs -> outWithErr $ do - bs <- mapM (evalApp f . (:[]) . fst) xs + bs <- mapM (evalApp f . (: []) . fst) xs return . VBag . map snd . Prelude.filter (isTrue . fst) $ zip bs xs - where - isTrue (VInj R VUnit) = True - isTrue _ = False - + where + isTrue (VInj R VUnit) = True + isTrue _ = False OMerge -> arity3 $ \f bxs bys -> case (bxs, bys) of (VBag xs, VBag ys) -> outWithErr (VBag <$> mergeM f xs ys) (VBag _, _) -> error $ "Impossible! OMerge on non-VBag " ++ show bys - _ -> error $ "Impossible! OMerge on non-VBag " ++ show bxs - + _ -> error $ "Impossible! OMerge on non-VBag " ++ show bxs OBagUnions -> withBag OBagUnions $ \cts -> - out . VBag $ sortNCount [(x, m*n) | (VBag xs, n) <- cts, (x,m) <- xs] - + out . VBag $ sortNCount [(x, m * n) | (VBag xs, n) <- cts, (x, m) <- xs] -------------------------------------------------- -- Container conversions @@ -392,51 +393,52 @@ appConst k = \case -- Disco> :desugar let x = 3 in ⟅ 'a' # (2 + x), 'b', 'b' ⟆ -- (λx. bagFromCounts(bag(('a', 2 + x) :: ('b', 1) :: ('b', 1) :: [])))(3) - OCountsToBag -> withBag OCountsToBag $ - out . VBag . sortNCount . map (second (uncurry (*)) . assoc . first (vpair id vint)) - where - assoc ((a, b), c) = (a, (b, c)) - - OUnsafeCountsToBag -> withBag OUnsafeCountsToBag $ - out . VBag . map (second (uncurry (*)) . assoc . first (vpair id vint)) - where - assoc ((a, b), c) = (a, (b, c)) + OCountsToBag -> + withBag OCountsToBag $ + out . VBag . sortNCount . map (second (uncurry (*)) . assoc . first (vpair id vint)) + where + assoc ((a, b), c) = (a, (b, c)) + OUnsafeCountsToBag -> + withBag OUnsafeCountsToBag $ + out . VBag . map (second (uncurry (*)) . assoc . first (vpair id vint)) + where + assoc ((a, b), c) = (a, (b, c)) -------------------------------------------------- -- Maps - OMapToSet -> withMap OMapToSet $ - out . VBag . map (\(k',v) -> (VPair (fromSimpleValue k') v, 1)) . M.assocs - - OSetToMap -> withBag OSetToMap $ - out . VMap . M.fromList . map (convertAssoc . fst) - where - convertAssoc (VPair k' v) = (toSimpleValue k', v) - convertAssoc v = error $ "Impossible! convertAssoc on non-VPair " ++ show v - - OInsert -> arity3 $ \k' v -> withMap OInsert $ - out . VMap . M.insert (toSimpleValue k') v - - OLookup -> arity2 $ \k' -> withMap OLookup $ - out . toMaybe . M.lookup (toSimpleValue k') - where - toMaybe = maybe (VInj L VUnit) (VInj R) + OMapToSet -> + withMap OMapToSet $ + out . VBag . map (\(k', v) -> (VPair (fromSimpleValue k') v, 1)) . M.assocs + OSetToMap -> + withBag OSetToMap $ + out . VMap . M.fromList . map (convertAssoc . fst) + where + convertAssoc (VPair k' v) = (toSimpleValue k', v) + convertAssoc v = error $ "Impossible! convertAssoc on non-VPair " ++ show v + OInsert -> arity3 $ \k' v -> + withMap OInsert $ + out . VMap . M.insert (toSimpleValue k') v + OLookup -> arity2 $ \k' -> + withMap OLookup $ + out . toMaybe . M.lookup (toSimpleValue k') + where + toMaybe = maybe (VInj L VUnit) (VInj R) -------------------------------------------------- -- Graph operations - OVertex -> out . VGraph . Vertex . toSimpleValue + OVertex -> out . VGraph . Vertex . toSimpleValue OOverlay -> arity2 $ withGraph2 OOverlay $ \g1 g2 -> out $ VGraph (Overlay g1 g2) OConnect -> arity2 $ withGraph2 OConnect $ \g1 g2 -> out $ VGraph (Connect g1 g2) OSummary -> withGraph OSummary $ out . graphSummary - -------------------------------------------------- -- Propositions - OForall tys -> out . (\v -> VProp (VPSearch SMForall tys v emptyTestEnv )) - OExists tys -> out . (\v -> VProp (VPSearch SMExists tys v emptyTestEnv )) + OForall tys -> out . (\v -> VProp (VPSearch SMForall tys v emptyTestEnv)) + OExists tys -> out . (\v -> VProp (VPSearch SMExists tys v emptyTestEnv)) OHolds -> testProperty Exhaustive >=> resultToBool >>> outWithErr ONotProp -> out . VProp . notProp . ensureProp OShouldEq ty -> arity2 $ \v1 v2 -> @@ -449,34 +451,33 @@ appConst k = \case out $ VProp (VPBin LOr (ensureProp p1) (ensureProp p2)) OImpl -> arity2 $ \p1 p2 -> out $ VProp (VPBin LImpl (ensureProp p1) (ensureProp p2)) - c -> error $ "Unimplemented: appConst " ++ show c - where - outWithErr :: Sem (Error EvalError ': r) Value -> Sem r CESK - outWithErr m = either (`Up` k) (`Out` k) <$> runError m - out v = return $ Out v k - up e = return $ Up e k - - withBag :: Op -> ([(Value,Integer)] -> Sem r a) -> Value -> Sem r a - withBag op f = \case - VBag xs -> f xs - v -> error $ "Impossible! " ++ show op ++ " on non-VBag " ++ show v - - withMap :: Op -> (M.Map SimpleValue Value -> Sem r a) -> Value -> Sem r a - withMap op f = \case - VMap m -> f m - v -> error $ "Impossible! " ++ show op ++ " on non-VMap " ++ show v - - withGraph :: Op -> (Graph SimpleValue -> Sem r a) -> Value -> Sem r a - withGraph op f = \case - VGraph g -> f g - v -> error $ "Impossible! " ++ show op ++ " on non-VGraph " ++ show v - - withGraph2 :: Op -> (Graph SimpleValue -> Graph SimpleValue -> Sem r a) -> Value -> Value -> Sem r a - withGraph2 op f v1 v2 = case (v1, v2) of - (VGraph g1, VGraph g2) -> f g1 g2 - (_, VGraph _) -> error $ "Impossible! " ++ show op ++ " on non-VGraph " ++ show v1 - _ -> error $ "Impossible! " ++ show op ++ " on non-VGraph " ++ show v2 + where + outWithErr :: Sem (Error EvalError ': r) Value -> Sem r CESK + outWithErr m = either (`Up` k) (`Out` k) <$> runError m + out v = return $ Out v k + up e = return $ Up e k + + withBag :: Op -> ([(Value, Integer)] -> Sem r a) -> Value -> Sem r a + withBag op f = \case + VBag xs -> f xs + v -> error $ "Impossible! " ++ show op ++ " on non-VBag " ++ show v + + withMap :: Op -> (M.Map SimpleValue Value -> Sem r a) -> Value -> Sem r a + withMap op f = \case + VMap m -> f m + v -> error $ "Impossible! " ++ show op ++ " on non-VMap " ++ show v + + withGraph :: Op -> (Graph SimpleValue -> Sem r a) -> Value -> Sem r a + withGraph op f = \case + VGraph g -> f g + v -> error $ "Impossible! " ++ show op ++ " on non-VGraph " ++ show v + + withGraph2 :: Op -> (Graph SimpleValue -> Graph SimpleValue -> Sem r a) -> Value -> Value -> Sem r a + withGraph2 op f v1 v2 = case (v1, v2) of + (VGraph g1, VGraph g2) -> f g1 g2 + (_, VGraph _) -> error $ "Impossible! " ++ show op ++ " on non-VGraph " ++ show v1 + _ -> error $ "Impossible! " ++ show op ++ " on non-VGraph " ++ show v2 -------------------------------------------------- -- Arithmetic @@ -492,7 +493,7 @@ numOp1 f = numOp1' $ return . ratv . f numOp1' :: (Rational -> Sem r Value) -> Value -> Sem r Value numOp1' f (VNum _ m) = f m -numOp1' _ v = error $ "Impossible! numOp1' on non-VNum " ++ show v +numOp1' _ v = error $ "Impossible! numOp1' on non-VNum " ++ show v numOp2 :: (Rational -> Rational -> Rational) -> Value -> Sem r Value numOp2 (#) = numOp2' $ \m n -> return (ratv (m # n)) @@ -504,9 +505,9 @@ numOp2' (#) = res <- n1 # n2 case res of VNum _ r -> return $ VNum (d1 <> d2) r - _ -> return res - (VNum{}, _) -> error $ "Impossible! numOp2' on non-VNum " ++ show v2 - _ -> error $ "Impossible! numOp2' on non-VNum " ++ show v1 + _ -> return res + (VNum {}, _) -> error $ "Impossible! numOp2' on non-VNum " ++ show v2 + _ -> error $ "Impossible! numOp2' on non-VNum " ++ show v1 -- | Perform a square root operation. If the program typechecks, -- then the argument and output will really be Natural. @@ -559,15 +560,15 @@ compareBags :: [(Value, Integer)] -> [(Value, Integer)] -> Ordering compareBags [] [] = EQ compareBags [] _ = LT compareBags _ [] = GT -compareBags ((x, xn) : xs) ((y, yn) : ys) - = valCmp x y <> compare xn yn <> compareBags xs ys +compareBags ((x, xn) : xs) ((y, yn) : ys) = + valCmp x y <> compare xn yn <> compareBags xs ys compareMaps :: [(SimpleValue, Value)] -> [(SimpleValue, Value)] -> Ordering compareMaps [] [] = EQ compareMaps [] _ = LT compareMaps _ [] = GT -compareMaps ((k1, v1) : as1) ((k2, v2) : as2) - = valCmp (fromSimpleValue k1) (fromSimpleValue k2) <> valCmp v1 v2 <> compareMaps as1 as2 +compareMaps ((k1, v1) : as1) ((k2, v2) : as2) = + valCmp (fromSimpleValue k1) (fromSimpleValue k2) <> valCmp v1 v2 <> compareMaps as1 as2 ------------------------------------------------------------ -- Polynomial sequences [a,b,c,d .. e] @@ -585,16 +586,16 @@ enumEllipsis xs (Until y) | d > 0 = takeWhile (<= y) nums | d < 0 = takeWhile (>= y) nums | otherwise = nums - where - d = constdiff xs - nums = babbage xs + where + d = constdiff xs + nums = babbage xs -- | Extend a sequence infinitely by interpolating it as a polynomial -- sequence, via forward differences. Essentially the same -- algorithm used by Babbage's famous Difference Engine. babbage :: Num a => [a] -> [a] -babbage [] = [] -babbage [x] = repeat x +babbage [] = [] +babbage [x] = repeat x babbage (x : xs) = scanl (+) x (babbage (diff (x : xs))) -- | Compute the forward difference of the given sequence, that is, @@ -620,12 +621,12 @@ constdiff (x : xs) -- otherwise 'right "https://oeis.org/"' oeisLookup :: Value -> Value oeisLookup (vlist vint -> ns) = maybe VNil parseResult (lookupSequence ns) - where - parseResult r = VInj R (listv charv ("https://oeis.org/" ++ seqNum r)) - seqNum = getCatalogNum . catalogNums + where + parseResult r = VInj R (listv charv ("https://oeis.org/" ++ seqNum r)) + seqNum = getCatalogNum . catalogNums - getCatalogNum [] = error "No catalog info" - getCatalogNum (n : _) = n + getCatalogNum [] = error "No catalog info" + getCatalogNum (n : _) = n -- | Extends a Disco integer list with data from a known OEIS -- sequence. Returns a list of integers upon success, otherwise the @@ -653,8 +654,8 @@ sortNCount :: [(Value, Integer)] -> [(Value, Integer)] sortNCount [] = [] sortNCount [x] = [x] sortNCount xs = merge (+) (sortNCount firstHalf) (sortNCount secondHalf) - where - (firstHalf, secondHalf) = splitAt (length xs `div` 2) xs + where + (firstHalf, secondHalf) = splitAt (length xs `div` 2) xs -- | Generic function for merging two sorted, count-annotated lists of -- type @[(a,Integer)]@ a la merge sort, using the given comparison @@ -668,18 +669,18 @@ merge :: [(Value, Integer)] -> [(Value, Integer)] merge g = go - where - go [] [] = [] - go [] ((y, n) : ys) = mergeCons y 0 n (go [] ys) - go ((x, n) : xs) [] = mergeCons x n 0 (go xs []) - go ((x, n1) : xs) ((y, n2) : ys) = case valCmp x y of - LT -> mergeCons x n1 0 (go xs ((y, n2) : ys)) - EQ -> mergeCons x n1 n2 (go xs ys) - GT -> mergeCons y 0 n2 (go ((x, n1) : xs) ys) - - mergeCons a m1 m2 zs = case g m1 m2 of - 0 -> zs - n -> (a, n) : zs + where + go [] [] = [] + go [] ((y, n) : ys) = mergeCons y 0 n (go [] ys) + go ((x, n) : xs) [] = mergeCons x n 0 (go xs []) + go ((x, n1) : xs) ((y, n2) : ys) = case valCmp x y of + LT -> mergeCons x n1 0 (go xs ((y, n2) : ys)) + EQ -> mergeCons x n1 n2 (go xs ys) + GT -> mergeCons y 0 n2 (go ((x, n1) : xs) ys) + + mergeCons a m1 m2 zs = case g m1 m2 of + 0 -> zs + n -> (a, n) : zs mergeM :: Members '[Random, Error EvalError, State Mem] r => @@ -688,21 +689,21 @@ mergeM :: [(Value, Integer)] -> Sem r [(Value, Integer)] mergeM g = go - where - go [] [] = return [] - go [] ((y, n) : ys) = mergeCons y 0 n =<< go [] ys - go ((x, n) : xs) [] = mergeCons x n 0 =<< go xs [] - go ((x, n1) : xs) ((y, n2) : ys) = case valCmp x y of - LT -> mergeCons x n1 0 =<< go xs ((y, n2) : ys) - EQ -> mergeCons x n1 n2 =<< go xs ys - GT -> mergeCons y 0 n2 =<< go ((x, n1) : xs) ys - - mergeCons a m1 m2 zs = do - nm <- evalApp g [VPair (intv m1) (intv m2)] - return $ case nm of - VNum _ 0 -> zs - VNum _ n -> (a, numerator n) : zs - v -> error $ "Impossible! merge function in mergeM returned non-VNum " ++ show v + where + go [] [] = return [] + go [] ((y, n) : ys) = mergeCons y 0 n =<< go [] ys + go ((x, n) : xs) [] = mergeCons x n 0 =<< go xs [] + go ((x, n1) : xs) ((y, n2) : ys) = case valCmp x y of + LT -> mergeCons x n1 0 =<< go xs ((y, n2) : ys) + EQ -> mergeCons x n1 n2 =<< go xs ys + GT -> mergeCons y 0 n2 =<< go ((x, n1) : xs) ys + + mergeCons a m1 m2 zs = do + nm <- evalApp g [VPair (intv m1) (intv m2)] + return $ case nm of + VNum _ 0 -> zs + VNum _ n -> (a, numerator n) : zs + v -> error $ "Impossible! merge function in mergeM returned non-VNum " ++ show v ------------------------------------------------------------ -- Graphs @@ -710,14 +711,14 @@ mergeM g = go graphSummary :: Graph SimpleValue -> Value graphSummary = toDiscoAdjMap . reifyGraph - where - reifyGraph :: Graph SimpleValue -> [(SimpleValue, [SimpleValue])] - reifyGraph = - AdjMap.adjacencyList . foldg AdjMap.empty AdjMap.vertex AdjMap.overlay AdjMap.connect + where + reifyGraph :: Graph SimpleValue -> [(SimpleValue, [SimpleValue])] + reifyGraph = + AdjMap.adjacencyList . foldg AdjMap.empty AdjMap.vertex AdjMap.overlay AdjMap.connect - toDiscoAdjMap :: [(SimpleValue, [SimpleValue])] -> Value - toDiscoAdjMap = - VMap . M.fromList . map (second (VBag . countValues . map fromSimpleValue)) + toDiscoAdjMap :: [(SimpleValue, [SimpleValue])] -> Value + toDiscoAdjMap = + VMap . M.fromList . map (second (VBag . countValues . map fromSimpleValue)) ------------------------------------------------------------ -- Propositions / tests @@ -725,76 +726,87 @@ graphSummary = toDiscoAdjMap . reifyGraph resultToBool :: Member (Error EvalError) r => TestResult -> Sem r Value resultToBool (TestResult _ (TestRuntimeError e) _) = throw e -resultToBool (TestResult b _ _) = return $ enumv b +resultToBool (TestResult b _ _) = return $ enumv b notProp :: ValProp -> ValProp -notProp (VPDone r) = VPDone (invertPropResult r) +notProp (VPDone r) = VPDone (invertPropResult r) notProp (VPSearch sm tys p e) = VPSearch (invertMotive sm) tys p e -notProp (VPBin LAnd vp1 vp2) = VPBin LOr (notProp vp1) (notProp vp2) -notProp (VPBin LOr vp1 vp2) = VPBin LAnd (notProp vp1) (notProp vp2) +notProp (VPBin LAnd vp1 vp2) = VPBin LOr (notProp vp1) (notProp vp2) +notProp (VPBin LOr vp1 vp2) = VPBin LAnd (notProp vp1) (notProp vp2) notProp (VPBin LImpl vp1 vp2) = VPBin LAnd vp1 (notProp vp2) -- | Convert a @Value@ to a @ValProp@, embedding booleans if necessary. ensureProp :: Value -> ValProp -ensureProp (VProp p) = p +ensureProp (VProp p) = p ensureProp (VInj L _) = VPDone (TestResult False TestBool emptyTestEnv) ensureProp (VInj R _) = VPDone (TestResult True TestBool emptyTestEnv) -ensureProp _ = error "ensureProp: non-prop value" +ensureProp _ = error "ensureProp: non-prop value" combineTestResultBool :: LOp -> TestResult -> TestResult -> Bool combineTestResultBool op (TestResult b1 _ _) (TestResult b2 _ _) = interpLOp op b1 b2 -testProperty - :: Members '[Random, State Mem] r - => SearchType -> Value -> Sem r TestResult +testProperty :: + Members '[Random, State Mem] r => + SearchType -> + Value -> + Sem r TestResult testProperty initialSt = checkProp . ensureProp - where - checkProp - :: Members '[Random, State Mem] r - => ValProp -> Sem r TestResult - checkProp (VPDone r) = return r - checkProp (VPBin op vp1 vp2) = do - tr1 <- checkProp vp1 - tr2 <- checkProp vp2 - return $ TestResult (combineTestResultBool op tr1 tr2) (TestBin op tr1 tr2) emptyTestEnv - checkProp (VPSearch sm tys f e) = - extendResultEnv e <$> (generateSamples initialSt vals >>= go) - where - vals = enumTypes tys - (SearchMotive (whenFound, wantsSuccess)) = sm - - go - :: Members '[Random, State Mem] r - => ([[Value]], SearchType) -> Sem r TestResult - go ([], st) = return $ TestResult (not whenFound) (TestNotFound st) emptyTestEnv - go (x:xs, st) = do - mprop <- runError (ensureProp <$> evalApp f x) - case mprop of - Left err -> return $ TestResult False (TestRuntimeError err) emptyTestEnv - Right (VPDone r) -> continue st xs r - Right prop -> checkProp prop >>= continue st xs - - continue - :: Members '[Random, State Mem] r - => SearchType -> [[Value]] -> TestResult -> Sem r TestResult - continue st xs r@(TestResult _ _ e') - | testIsError r = return r - | testIsOk r == wantsSuccess = - return $ TestResult whenFound (TestFound r) e' - | otherwise = go (xs, st) - -evalApp - :: Members '[Random, Error EvalError, State Mem] r - => Value -> [Value] -> Sem r Value + where + checkProp :: + Members '[Random, State Mem] r => + ValProp -> + Sem r TestResult + checkProp (VPDone r) = return r + checkProp (VPBin op vp1 vp2) = do + tr1 <- checkProp vp1 + tr2 <- checkProp vp2 + return $ TestResult (combineTestResultBool op tr1 tr2) (TestBin op tr1 tr2) emptyTestEnv + checkProp (VPSearch sm tys f e) = + extendResultEnv e <$> (generateSamples initialSt vals >>= go) + where + vals = enumTypes tys + (SearchMotive (whenFound, wantsSuccess)) = sm + + go :: + Members '[Random, State Mem] r => + ([[Value]], SearchType) -> + Sem r TestResult + go ([], st) = return $ TestResult (not whenFound) (TestNotFound st) emptyTestEnv + go (x : xs, st) = do + mprop <- runError (ensureProp <$> evalApp f x) + case mprop of + Left err -> return $ TestResult False (TestRuntimeError err) emptyTestEnv + Right (VPDone r) -> continue st xs r + Right prop -> checkProp prop >>= continue st xs + + continue :: + Members '[Random, State Mem] r => + SearchType -> + [[Value]] -> + TestResult -> + Sem r TestResult + continue st xs r@(TestResult _ _ e') + | testIsError r = return r + | testIsOk r == wantsSuccess = + return $ TestResult whenFound (TestFound r) e' + | otherwise = go (xs, st) + +evalApp :: + Members '[Random, Error EvalError, State Mem] r => + Value -> + [Value] -> + Sem r Value evalApp f xs = runFresh (runCESK (Out f (map FArgV xs))) >>= either throw return -runTest - :: Members '[Random, Error EvalError, Input Env, State Mem] r - => Int -> AProperty -> Sem r TestResult +runTest :: + Members '[Random, Error EvalError, Input Env, State Mem] r => + Int -> + AProperty -> + Sem r TestResult runTest n p = testProperty (Randomized n' n') =<< eval (compileProperty p) - where - n' = fromIntegral (n `div` 2) + where + n' = fromIntegral (n `div` 2) ------------------------------------------------------------ -- Top-level evaluation diff --git a/src/Disco/Messages.hs b/src/Disco/Messages.hs index 6c6fd950..4e5ea677 100644 --- a/src/Disco/Messages.hs +++ b/src/Disco/Messages.hs @@ -1,7 +1,10 @@ -{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE TemplateHaskell #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Messages -- Copyright : disco team and contributors @@ -11,27 +14,24 @@ -- -- Message logging framework (e.g. for errors, warnings, etc.) for -- disco. --- ------------------------------------------------------------------------------ - module Disco.Messages where -import Control.Lens -import Control.Monad (when) -import Polysemy -import Polysemy.Output +import Control.Lens +import Control.Monad (when) +import Polysemy +import Polysemy.Output -import Disco.Pretty (Doc, Pretty, pretty', renderDoc') +import Disco.Pretty (Doc, Pretty, pretty', renderDoc') data MessageType - = Info - | Warning - | ErrMsg - | Debug - deriving (Show, Read, Eq, Ord, Enum, Bounded) + = Info + | Warning + | ErrMsg + | Debug + deriving (Show, Read, Eq, Ord, Enum, Bounded) data Message = Message {_messageType :: MessageType, _message :: Doc} - deriving (Show) + deriving (Show) makeLenses ''Message diff --git a/src/Disco/Module.hs b/src/Disco/Module.hs index ef8aaad3..8e6c8ced 100644 --- a/src/Disco/Module.hs +++ b/src/Disco/Module.hs @@ -1,10 +1,13 @@ -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveDataTypeable #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UndecidableInstances #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Module -- Copyright : (c) 2019 disco team (see LICENSE) @@ -14,43 +17,52 @@ -- -- The 'ModuleInfo' record representing a disco module, and functions -- to resolve the location of a module on disk. ------------------------------------------------------------------------------ - module Disco.Module where -import Data.Data (Data) -import GHC.Generics (Generic) - -import Control.Lens (Getting, foldOf, - makeLenses, view) -import Control.Monad (filterM) -import Control.Monad.IO.Class (MonadIO (..)) -import Data.Bifunctor (first) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Maybe (listToMaybe) -import qualified Data.Set as S -import System.Directory (doesFileExist) -import System.FilePath (replaceExtension, - ()) - -import Unbound.Generics.LocallyNameless (Alpha, Bind, Name, - Subst, bind) -import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) - -import Polysemy - -import Disco.AST.Surface -import Disco.AST.Typed -import Disco.Context -import Disco.Extensions -import Disco.Names -import Disco.Pretty hiding ((<>)) -import Disco.Typecheck.Erase (erase, erasePattern) -import Disco.Typecheck.Util (TyCtx) -import Disco.Types - -import Paths_disco +import Data.Data (Data) +import GHC.Generics (Generic) + +import Control.Lens ( + Getting, + foldOf, + makeLenses, + view, + ) +import Control.Monad (filterM) +import Control.Monad.IO.Class (MonadIO (..)) +import Data.Bifunctor (first) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (listToMaybe) +import qualified Data.Set as S +import System.Directory (doesFileExist) +import System.FilePath ( + replaceExtension, + (), + ) + +import Unbound.Generics.LocallyNameless ( + Alpha, + Bind, + Name, + Subst, + bind, + ) +import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) + +import Polysemy + +import Disco.AST.Surface +import Disco.AST.Typed +import Disco.Context +import Disco.Extensions +import Disco.Names +import Disco.Pretty hiding ((<>)) +import Disco.Typecheck.Erase (erase, erasePattern) +import Disco.Typecheck.Util (TyCtx) +import Disco.Types + +import Paths_disco ------------------------------------------------------------ -- ModuleInfo and related types @@ -76,10 +88,10 @@ data Defn = Defn (Name ATerm) [Type] Type [Clause] deriving (Show, Generic, Alpha, Data, Subst Type) instance Pretty Defn where - pretty (Defn x patTys ty clauses) = vcat $ - prettyTyDecl x (foldr (:->:) ty patTys) - : - map (pretty . (x,) . eraseClause) clauses + pretty (Defn x patTys ty clauses) = + vcat $ + prettyTyDecl x (foldr (:->:) ty patTys) + : map (pretty . (x,) . eraseClause) clauses -- | A clause in a definition consists of a list of patterns (the LHS -- of the =) and a term (the RHS). For example, given the concrete @@ -89,31 +101,31 @@ type Clause = Bind [APattern] ATerm eraseClause :: Clause -> Bind [Pattern] Term eraseClause b = bind (map erasePattern ps) (erase t) - where (ps, t) = unsafeUnbind b + where + (ps, t) = unsafeUnbind b -- | Type checking a module yields a value of type ModuleInfo which contains -- mapping from terms to their relavent documenation, a mapping from terms to -- properties, and a mapping from terms to their types. data ModuleInfo = ModuleInfo - { _miName :: ModuleName - , _miImports :: Map ModuleName ModuleInfo - - -- List of names declared by the module, in the order they occur - , _miNames :: [QName Term] - , _miDocs :: Ctx Term Docs - , _miProps :: Ctx ATerm [AProperty] - , _miTys :: TyCtx - , _miTydefs :: TyDefCtx + { _miName :: ModuleName + , _miImports :: Map ModuleName ModuleInfo + , -- List of names declared by the module, in the order they occur + _miNames :: [QName Term] + , _miDocs :: Ctx Term Docs + , _miProps :: Ctx ATerm [AProperty] + , _miTys :: TyCtx + , _miTydefs :: TyDefCtx , _miTermdefs :: Ctx ATerm Defn - , _miTerms :: [(ATerm, PolyType)] - , _miExts :: ExtSet + , _miTerms :: [(ATerm, PolyType)] + , _miExts :: ExtSet } deriving (Show) makeLenses ''ModuleInfo instance Semigroup ModuleInfo where - -- | Two ModuleInfos + -- \| Two ModuleInfos -- are merged by joining their doc, type, type definition, and term -- contexts. The property context of the new module is the one -- obtained from the second module. The name of the new module is @@ -121,8 +133,8 @@ instance Semigroup ModuleInfo where -- earlier ones. Note that this function should really only be used -- for the special top-level REPL module. ModuleInfo n1 is1 ns1 d1 _ ty1 tyd1 tm1 tms1 es1 - <> ModuleInfo _ is2 ns2 d2 p2 ty2 tyd2 tm2 tms2 es2 - = ModuleInfo + <> ModuleInfo _ is2 ns2 d2 p2 ty2 tyd2 tm2 tms2 es2 = + ModuleInfo n1 (is1 <> is2) (ns1 <> ns2) @@ -178,7 +190,7 @@ data Resolver -- `:load`ed module). withStdlib :: Resolver -> Resolver withStdlib (FromDir fp) = FromDirOrStdlib fp -withStdlib r = r +withStdlib r = r -- | Given a module resolution mode and a raw module name, relavent -- directories are searched for the file containing the provided @@ -189,9 +201,9 @@ resolveModule resolver modname = do datadir <- liftIO getDataDir let searchPath = case resolver of - FromStdlib -> [(datadir, Stdlib)] - FromDir dir -> [(dir, Dir dir)] - FromCwdOrStdlib -> [(datadir, Stdlib), (".", Dir ".")] + FromStdlib -> [(datadir, Stdlib)] + FromDir dir -> [(dir, Dir dir)] + FromCwdOrStdlib -> [(datadir, Stdlib), (".", Dir ".")] FromDirOrStdlib dir -> [(datadir, Stdlib), (dir, Dir dir)] let fps = map (first ( replaceExtension modname "disco")) searchPath fexists <- liftIO $ filterM (doesFileExist . fst) fps diff --git a/src/Disco/Names.hs b/src/Disco/Names.hs index 10abec11..8836bcd0 100644 --- a/src/Disco/Names.hs +++ b/src/Disco/Names.hs @@ -1,39 +1,48 @@ -{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveDataTypeable #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Names -- Copyright : disco team and contributors -- Maintainer : byorgey@gmail.com -- -- Names for modules and identifiers. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - -module Disco.Names - ( -- * Modules and their provenance - ModuleProvenance(..), ModuleName(..) - -- * Names and their provenance - , NameProvenance(..), QName(..), isFree, localName, (.-) - -- * Name-related utilities - , fvQ, substQ, substsQ - ) where - -import Control.Lens (Traversal', filtered) -import Data.Data (Data) -import Data.Data.Lens (template) -import Data.Typeable (Typeable) -import GHC.Generics (Generic) -import Prelude hiding ((<>)) -import System.FilePath (dropExtension) -import Unbound.Generics.LocallyNameless - -import Disco.Pretty -import Disco.Types +module Disco.Names ( + -- * Modules and their provenance + ModuleProvenance (..), + ModuleName (..), + + -- * Names and their provenance + NameProvenance (..), + QName (..), + isFree, + localName, + (.-), + + -- * Name-related utilities + fvQ, + substQ, + substsQ, +) where + +import Control.Lens (Traversal', filtered) +import Data.Data (Data) +import Data.Data.Lens (template) +import Data.Typeable (Typeable) +import GHC.Generics (Generic) +import System.FilePath (dropExtension) +import Unbound.Generics.LocallyNameless +import Prelude hiding ((<>)) + +import Disco.Pretty +import Disco.Types ------------------------------------------------------------ -- Modules @@ -41,16 +50,19 @@ import Disco.Types -- | Where did a module come from? data ModuleProvenance - = Dir FilePath -- ^ From a particular directory (relative to cwd) - | Stdlib -- ^ From the standard library + = -- | From a particular directory (relative to cwd) + Dir FilePath + | -- | From the standard library + Stdlib deriving (Eq, Ord, Show, Generic, Data, Alpha, Subst Type) -- | The name of a module. data ModuleName - = REPLModule -- ^ The special top-level "module" consisting of - -- what has been entered at the REPL. - | Named ModuleProvenance String - -- ^ A named module, with its name and provenance. + = -- | The special top-level "module" consisting of + -- what has been entered at the REPL. + REPLModule + | -- | A named module, with its name and provenance. + Named ModuleProvenance String deriving (Eq, Ord, Show, Generic, Data, Alpha, Subst Type) ------------------------------------------------------------ @@ -59,19 +71,21 @@ data ModuleName -- | Where did a name come from? data NameProvenance - = LocalName -- ^ The name is locally bound - | QualifiedName ModuleName -- ^ The name is exported by the given module + = -- | The name is locally bound + LocalName + | -- | The name is exported by the given module + QualifiedName ModuleName deriving (Eq, Ord, Show, Generic, Data, Alpha, Subst Type) -- | A @QName@, or qualified name, is a 'Name' paired with its -- 'NameProvenance'. -data QName a = QName { qnameProvenance :: NameProvenance, qname :: Name a } +data QName a = QName {qnameProvenance :: NameProvenance, qname :: Name a} deriving (Eq, Ord, Show, Generic, Data, Alpha, Subst Type) -- | Does this name correspond to a free variable? isFree :: QName a -> Bool isFree (QName (QualifiedName _) _) = True -isFree (QName LocalName n) = isFreeName n +isFree (QName LocalName n) = isFreeName n -- | Create a locally bound qualified name. localName :: Name a -> QName a @@ -88,7 +102,7 @@ m .- x = QName (QualifiedName m) x -- | The @unbound-generics@ library gives us free variables for free. -- But when dealing with typed and desugared ASTs, we want all the -- free 'QName's instead of just 'Name's. -fvQ :: (Data t, Typeable e) => Traversal' t (QName e) +fvQ :: (Data t, Typeable e) => Traversal' t (QName e) fvQ = template . filtered isFree substQ :: Subst b a => QName b -> b -> a -> a @@ -102,10 +116,10 @@ substsQ = undefined ------------------------------------------------------------ instance Pretty ModuleName where - pretty REPLModule = "REPL" + pretty REPLModule = "REPL" pretty (Named (Dir _) s) = text (dropExtension s) - pretty (Named Stdlib s) = text (dropExtension s) + pretty (Named Stdlib s) = text (dropExtension s) instance Pretty (QName a) where - pretty (QName LocalName x) = pretty x + pretty (QName LocalName x) = pretty x pretty (QName (QualifiedName mn) x) = pretty mn <> "." <> pretty x diff --git a/src/Disco/Parser.hs b/src/Disco/Parser.hs index 5435259a..96d2b2d6 100644 --- a/src/Disco/Parser.hs +++ b/src/Disco/Parser.hs @@ -1,6 +1,9 @@ {-# LANGUAGE TemplateHaskell #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Parser -- Copyright : disco team and contributors @@ -10,78 +13,124 @@ -- -- Parser to convert concrete Disco syntax into an (untyped, surface -- language) AST. --- ------------------------------------------------------------------------------ - -module Disco.Parser - ( -- * Parser type and utilities - DiscoParseError(..), Parser, runParser, withExts, indented, thenIndented - - -- * Lexer - - -- ** Basic lexemes - , sc, lexeme, symbol, reservedOp - , natural, reserved, reservedWords, ident - - -- ** Punctuation - , parens, braces, angles, brackets - , semi, comma, colon, dot, pipe - , lambda - - -- * Disco parser - - -- ** Modules - , wholeModule, parseModule, parseExtName, parseTopLevel, parseDecl - , parseImport, parseModuleName - - -- ** Terms - , term, parseTerm, parseTerm', parseExpr, parseAtom - , parseContainer, parseEllipsis, parseContainerComp, parseQual - , parseLet, parseTypeOp - - -- ** Case and patterns - , parseCase, parseBranch, parseGuards, parseGuard - , parsePattern, parseAtomicPattern - - -- ** Types - , parseType, parseAtomicType - , parsePolyTy - ) - where - -import Unbound.Generics.LocallyNameless (Name, bind, embed, - fvAny, name2String, - string2Name) -import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) - -import Control.Monad.Combinators.Expr -import qualified Text.Megaparsec as MP -import Text.Megaparsec hiding (State, - runParser) -import Text.Megaparsec.Char -import qualified Text.Megaparsec.Char.Lexer as L - -import Control.Lens (makeLenses, toListOf, - use, (%=), (%~), (&), - (.=)) -import Control.Monad.State -import Data.Char (isAlpha, isDigit) -import Data.Foldable (asum) -import Data.List (find, intercalate) -import qualified Data.Map as M -import Data.Maybe (fromMaybe, isNothing) -import Data.Ratio -import Data.Set (Set) -import qualified Data.Set as S - -import Disco.AST.Surface -import Disco.Extensions -import Disco.Module -import Disco.Pretty (prettyStr) -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Types -import Polysemy (run) +module Disco.Parser ( + -- * Parser type and utilities + DiscoParseError (..), + Parser, + runParser, + withExts, + indented, + thenIndented, + + -- * Lexer + + -- ** Basic lexemes + sc, + lexeme, + symbol, + reservedOp, + natural, + reserved, + reservedWords, + ident, + + -- ** Punctuation + parens, + braces, + angles, + brackets, + semi, + comma, + colon, + dot, + pipe, + lambda, + + -- * Disco parser + + -- ** Modules + wholeModule, + parseModule, + parseExtName, + parseTopLevel, + parseDecl, + parseImport, + parseModuleName, + + -- ** Terms + term, + parseTerm, + parseTerm', + parseExpr, + parseAtom, + parseContainer, + parseEllipsis, + parseContainerComp, + parseQual, + parseLet, + parseTypeOp, + + -- ** Case and patterns + parseCase, + parseBranch, + parseGuards, + parseGuard, + parsePattern, + parseAtomicPattern, + + -- ** Types + parseType, + parseAtomicType, + parsePolyTy, +) +where + +import Unbound.Generics.LocallyNameless ( + Name, + bind, + embed, + fvAny, + name2String, + string2Name, + ) +import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) + +import Control.Monad.Combinators.Expr +import Text.Megaparsec hiding ( + State, + runParser, + ) +import qualified Text.Megaparsec as MP +import Text.Megaparsec.Char +import qualified Text.Megaparsec.Char.Lexer as L + +import Control.Lens ( + makeLenses, + toListOf, + use, + (%=), + (%~), + (&), + (.=), + ) +import Control.Monad.State +import Data.Char (isAlpha, isDigit) +import Data.Foldable (asum) +import Data.List (find, intercalate) +import qualified Data.Map as M +import Data.Maybe (fromMaybe, isNothing) +import Data.Ratio +import Data.Set (Set) +import qualified Data.Set as S + +import Disco.AST.Surface +import Disco.Extensions +import Disco.Module +import Disco.Pretty (prettyStr) +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Types +import Polysemy (run) ------------------------------------------------------------ -- Lexer @@ -91,17 +140,25 @@ import Polysemy (run) -- | Currently required indent level. data IndentMode where - NoIndent :: IndentMode -- ^ Don't require indent. - ThenIndent :: IndentMode -- ^ Parse one token without - -- indent, then switch to @Indent@. - Indent :: IndentMode -- ^ Require everything to be indented at - -- least one space. + NoIndent :: + -- | Don't require indent. + IndentMode + ThenIndent :: + -- | Parse one token without + -- indent, then switch to @Indent@. + IndentMode + Indent :: + -- | Require everything to be indented at + -- least one space. + IndentMode -- | Extra custom state for the parser. data ParserState = ParserState - { _indentMode :: IndentMode -- ^ Currently required level of indentation. - , _enabledExts :: Set Ext -- ^ Set of enabled language extensions - -- (some of which may affect parsing). + { _indentMode :: IndentMode + -- ^ Currently required level of indentation. + , _enabledExts :: Set Ext + -- ^ Set of enabled language extensions + -- (some of which may affect parsing). } makeLenses ''ParserState @@ -128,14 +185,14 @@ data DiscoParseError deriving (Show, Eq, Ord) instance ShowErrorComponent DiscoParseError where - showErrorComponent (ReservedVarName x) = "keyword \"" ++ x ++ "\" cannot be used as a variable name" + showErrorComponent (ReservedVarName x) = "keyword \"" ++ x ++ "\" cannot be used as a variable name" showErrorComponent (InvalidPattern (OT t)) = "Invalid pattern: " ++ run (prettyStr t) - showErrorComponent MissingAscr = "Variables introduced by ∀ or ∃ must have a type" - showErrorComponent MultiArgLambda = "Anonymous functions (lambdas) can only have a single argument.\nInstead of \\x, y. ... you can write \\x. \\y. ...\nhttps://disco-lang.readthedocs.io/en/latest/reference/anonymous-func.html" + showErrorComponent MissingAscr = "Variables introduced by ∀ or ∃ must have a type" + showErrorComponent MultiArgLambda = "Anonymous functions (lambdas) can only have a single argument.\nInstead of \\x, y. ... you can write \\x. \\y. ...\nhttps://disco-lang.readthedocs.io/en/latest/reference/anonymous-func.html" errorComponentLen (ReservedVarName x) = length x - errorComponentLen (InvalidPattern _) = 1 - errorComponentLen MissingAscr = 1 - errorComponentLen MultiArgLambda = 1 + errorComponentLen (InvalidPattern _) = 1 + errorComponentLen MissingAscr = 1 + errorComponentLen MultiArgLambda = 1 -- | A parser is a megaparsec parser of strings, with an extra layer -- of state to keep track of the current indentation level and @@ -175,8 +232,8 @@ requireIndent p = do a <- p indentMode .= Indent return a - Indent -> L.indentGuard sc GT pos1 >> p - NoIndent -> p + Indent -> L.indentGuard sc GT pos1 >> p + NoIndent -> p -- | Locally set the enabled extensions within a subparser. withExts :: Set Ext -> Parser a -> Parser a @@ -205,8 +262,8 @@ ensureEnabled e = do -- | Generically consume whitespace, including comments. sc :: Parser () sc = L.space space1 lineComment empty {- no block comments in disco -} - where - lineComment = L.skipLineComment "--" + where + lineComment = L.skipLineComment "--" -- | Parse a lexeme, that is, a parser followed by consuming -- whitespace. @@ -226,25 +283,25 @@ opChar :: [Char] opChar = "~!@#$%^&*-+=|<>?/\\." parens, braces, angles, brackets, bagdelims, fbrack, cbrack :: Parser a -> Parser a -parens = between (symbol "(") (symbol ")") -braces = between (symbol "{") (symbol "}") -angles = between (symbol "<") (symbol ">") -brackets = between (symbol "[") (symbol "]") +parens = between (symbol "(") (symbol ")") +braces = between (symbol "{") (symbol "}") +angles = between (symbol "<") (symbol ">") +brackets = between (symbol "[") (symbol "]") bagdelims = between (symbol "⟅") (symbol "⟆") -fbrack = between (symbol "⌊") (symbol "⌋") -cbrack = between (symbol "⌈") (symbol "⌉") +fbrack = between (symbol "⌊") (symbol "⌋") +cbrack = between (symbol "⌈") (symbol "⌉") semi, comma, colon, dot, pipe, hash :: Parser String -semi = symbol ";" -comma = symbol "," -colon = symbol ":" -dot = symbol "." -pipe = symbol "|" -hash = symbol "#" +semi = symbol ";" +comma = symbol "," +colon = symbol ":" +dot = symbol "." +pipe = symbol "|" +hash = symbol "#" -- | A literal ellipsis of two or more dots, @..@ ellipsis :: Parser String -ellipsis = label "ellipsis (..)" $ concat <$> ((:) <$> dot <*> some dot) +ellipsis = label "ellipsis (..)" $ concat <$> ((:) <$> dot <*> some dot) -- | The symbol that starts an anonymous function (either a backslash -- or a Greek λ). @@ -279,35 +336,39 @@ natural = lexeme L.decimal "natural number" -- See https://github.com/disco-lang/disco/issues/245 and Note -- [Trailing period]. decimal :: Parser Rational -decimal = lexeme (readDecimal <$> some digit <* char '.' - <*> fractionalPart - ) - where - digit = satisfy isDigit - fractionalPart = - -- either some digits optionally followed by bracketed digits... - (,) <$> some digit <*> optional (brackets (some digit)) - -- ...or just bracketed digits. +decimal = + lexeme + ( readDecimal + <$> some digit + <* char '.' + <*> fractionalPart + ) + where + digit = satisfy isDigit + fractionalPart = + -- either some digits optionally followed by bracketed digits... + (,) <$> some digit <*> optional (brackets (some digit)) + -- ...or just bracketed digits. <|> ([],) <$> (Just <$> brackets (some digit)) - readDecimal a (b, mrep) - = read a % 1 -- integer part + readDecimal a (b, mrep) = + read a % 1 -- integer part - -- next part is just b/10^n - + (if null b then 0 else read b) % (10^length b) + -- next part is just b/10^n + + (if null b then 0 else read b) % (10 ^ length b) + -- repeating part + + readRep (length b) mrep - -- repeating part - + readRep (length b) mrep + readRep _ Nothing = 0 + readRep offset (Just rep) = read rep % (10 ^ offset * (10 ^ length rep - 1)) - readRep _ Nothing = 0 - readRep offset (Just rep) = read rep % (10^offset * (10^length rep - 1)) - -- If s = 0.[rep] then 10^(length rep) * s = rep.[rep], so - -- 10^(length rep) * s - s = rep, so - -- - -- s = rep/(10^(length rep) - 1). - -- - -- We also have to divide by 10^(length b) to shift it over - -- past any non-repeating prefix. +-- If s = 0.[rep] then 10^(length rep) * s = rep.[rep], so +-- 10^(length rep) * s - s = rep, so +-- +-- s = rep/(10^(length rep) - 1). +-- +-- We also have to divide by 10^(length b) to shift it over +-- past any non-repeating prefix. -- ~~~~ Note [Trailing period] -- @@ -337,18 +398,71 @@ reserved w = (lexeme . try) $ string w *> notFollowedBy alphaNumChar -- | The list of all reserved words. reservedWords :: [String] reservedWords = - [ "unit", "true", "false", "True", "False", "let", "in", "is" - , "if", "when" - , "otherwise", "and", "or", "mod", "choose", "implies", "iff" - , "min", "max" - , "union", "∪", "intersect", "∩", "subset", "⊆", "elem", "∈" - , "enumerate", "count", "divides" - , "Void", "Unit", "Bool", "Boolean", "Proposition", "Prop", "Char" - , "Nat", "Natural", "Int", "Integer", "Frac", "Fractional", "Rational", "Fin" - , "List", "Bag", "Set", "Graph", "Map" - , "N", "Z", "F", "Q", "ℕ", "ℤ", "𝔽", "ℚ" - , "∀", "forall", "∃", "exists", "type" - , "import", "using" + [ "unit" + , "true" + , "false" + , "True" + , "False" + , "let" + , "in" + , "is" + , "if" + , "when" + , "otherwise" + , "and" + , "or" + , "mod" + , "choose" + , "implies" + , "iff" + , "min" + , "max" + , "union" + , "∪" + , "intersect" + , "∩" + , "subset" + , "⊆" + , "elem" + , "∈" + , "enumerate" + , "count" + , "divides" + , "Void" + , "Unit" + , "Bool" + , "Boolean" + , "Proposition" + , "Prop" + , "Char" + , "Nat" + , "Natural" + , "Int" + , "Integer" + , "Frac" + , "Fractional" + , "Rational" + , "Fin" + , "List" + , "Bag" + , "Set" + , "Graph" + , "Map" + , "N" + , "Z" + , "F" + , "Q" + , "ℕ" + , "ℤ" + , "𝔽" + , "ℚ" + , "∀" + , "forall" + , "∃" + , "exists" + , "type" + , "import" + , "using" ] -- | Parse an identifier, i.e. any non-reserved string beginning with @@ -356,15 +470,15 @@ reservedWords = -- underscores, and apostrophes. identifier :: Parser Char -> Parser String identifier begin = (lexeme . try) (p >>= check) "variable name" - where - p = (:) <$> begin <*> many identChar - identChar = alphaNumChar <|> oneOf "_'" - check x - | x `elem` reservedWords = do - -- back up to beginning of bad token to report correct position - updateParserState (\s -> s { stateOffset = stateOffset s - length x }) - customFailure $ ReservedVarName x - | otherwise = return x + where + p = (:) <$> begin <*> many identChar + identChar = alphaNumChar <|> oneOf "_'" + check x + | x `elem` reservedWords = do + -- back up to beginning of bad token to report correct position + updateParserState (\s -> s {stateOffset = stateOffset s - length x}) + customFailure $ ReservedVarName x + | otherwise = return x -- | Parse an 'identifier' and turn it into a 'Name'. ident :: Parser (Name Term) @@ -376,7 +490,7 @@ ident = string2Name <$> identifier letterChar -- | Results from parsing a block of top-level things. data TLResults = TLResults { _tlDecls :: [Decl] - , _tlDocs :: [(Name Term, [DocThing])] + , _tlDocs :: [(Name Term, [DocThing])] , _tlTerms :: [Term] } @@ -397,117 +511,124 @@ wholeModule = between sc eof . parseModule -- REPL, and replace them when parsing a standalone module. parseModule :: LoadingMode -> Parser Module parseModule mode = do - exts <- S.fromList <$> many parseExtension + exts <- S.fromList <$> many parseExtension let extFun = case mode of Standalone -> withExts - REPL -> withAdditionalExts + REPL -> withAdditionalExts extFun exts $ do - imports <- many parseImport + imports <- many parseImport topLevel <- many parseTopLevel let theMod = mkModule exts imports topLevel return theMod - where - groupTLs :: [DocThing] -> [TopLevel] -> TLResults - groupTLs _ [] = emptyTLResults - groupTLs revDocs (TLDoc doc : rest) - = groupTLs (doc : revDocs) rest - groupTLs revDocs (TLDecl decl@(DType (TypeDecl x _)) : rest) - = groupTLs [] rest - & tlDecls %~ (decl :) - & tlDocs %~ ((x, reverse revDocs) :) - groupTLs revDocs (TLDecl decl@(DTyDef (TypeDefn x _ _)) : rest) - = groupTLs [] rest - & tlDecls %~ (decl :) - & tlDocs %~ ((string2Name x, reverse revDocs) :) - groupTLs _ (TLDecl defn : rest) - = groupTLs [] rest - & tlDecls %~ (defn :) - groupTLs _ (TLExpr t : rest) - = groupTLs [] rest & tlTerms %~ (t:) - - defnGroups :: [Decl] -> [Decl] - defnGroups [] = [] - defnGroups (d@DType{} : ds) = d : defnGroups ds - defnGroups (d@DTyDef{} : ds) = d : defnGroups ds - defnGroups (DDefn (TermDefn x bs) : ds) = DDefn (TermDefn x (bs ++ concatMap (\(TermDefn _ cs) -> cs) grp)) : defnGroups rest - where - (grp, rest) = matchDefn ds - matchDefn :: [Decl] -> ([TermDefn], [Decl]) - matchDefn (DDefn t@(TermDefn x' _) : ds2) | x == x' = (t:ts, ds2') - where - (ts, ds2') = matchDefn ds2 - matchDefn ds2 = ([], ds2) - - mkModule exts imps tls = Module exts imps (defnGroups decls) docs terms - where - TLResults decls docs terms = groupTLs [] tls + where + groupTLs :: [DocThing] -> [TopLevel] -> TLResults + groupTLs _ [] = emptyTLResults + groupTLs revDocs (TLDoc doc : rest) = + groupTLs (doc : revDocs) rest + groupTLs revDocs (TLDecl decl@(DType (TypeDecl x _)) : rest) = + groupTLs [] rest + & tlDecls %~ (decl :) + & tlDocs %~ ((x, reverse revDocs) :) + groupTLs revDocs (TLDecl decl@(DTyDef (TypeDefn x _ _)) : rest) = + groupTLs [] rest + & tlDecls %~ (decl :) + & tlDocs %~ ((string2Name x, reverse revDocs) :) + groupTLs _ (TLDecl defn : rest) = + groupTLs [] rest + & tlDecls %~ (defn :) + groupTLs _ (TLExpr t : rest) = + groupTLs [] rest & tlTerms %~ (t :) + + defnGroups :: [Decl] -> [Decl] + defnGroups [] = [] + defnGroups (d@DType {} : ds) = d : defnGroups ds + defnGroups (d@DTyDef {} : ds) = d : defnGroups ds + defnGroups (DDefn (TermDefn x bs) : ds) = DDefn (TermDefn x (bs ++ concatMap (\(TermDefn _ cs) -> cs) grp)) : defnGroups rest + where + (grp, rest) = matchDefn ds + matchDefn :: [Decl] -> ([TermDefn], [Decl]) + matchDefn (DDefn t@(TermDefn x' _) : ds2) | x == x' = (t : ts, ds2') + where + (ts, ds2') = matchDefn ds2 + matchDefn ds2 = ([], ds2) + + mkModule exts imps tls = Module exts imps (defnGroups decls) docs terms + where + TLResults decls docs terms = groupTLs [] tls -- | Parse an extension. parseExtension :: Parser Ext -parseExtension = L.nonIndented sc $ - reserved "using" *> parseExtName +parseExtension = + L.nonIndented sc $ + reserved "using" *> parseExtName -- | Parse the name of a language extension (case-insensitive). parseExtName :: Parser Ext parseExtName = choice (map parseOneExt allExtsList) "language extension name" - where - parseOneExt ext = ext <$ lexeme (string' (show ext) :: Parser String) + where + parseOneExt ext = ext <$ lexeme (string' (show ext) :: Parser String) -- | Parse an import, of the form @import @. parseImport :: Parser String -parseImport = L.nonIndented sc $ - reserved "import" *> parseModuleName +parseImport = + L.nonIndented sc $ + reserved "import" *> parseModuleName -- | Parse the name of a module. parseModuleName :: Parser String -parseModuleName = lexeme $ - intercalate "/" <$> (some (alphaNumChar <|> oneOf "_-") `sepBy` char '/') <* optional (string ".disco") +parseModuleName = + lexeme $ + intercalate "/" <$> (some (alphaNumChar <|> oneOf "_-") `sepBy` char '/') <* optional (string ".disco") -- | Parse a top level item (either documentation or a declaration), -- which must start at the left margin. parseTopLevel :: Parser TopLevel -parseTopLevel = L.nonIndented sc $ - TLDoc <$> parseDocThing - <|> TLDecl <$> parseDecl -- See Note [Parsing definitions and top-level expressions] - <|> TLExpr <$> thenIndented parseTerm - - -- ~~~~ Note [Parsing definitions and top-level expressions] - -- - -- The beginning of a definition might look the same as an - -- expression. e.g. is f(x,y) the start of a definition of f, or an - -- expression with a function call? We used to therefore wrap - -- 'parseDecl' in 'try'. The problem is that if a definition has a - -- syntax error on the RHS, it would fail, backtrack, then try - -- parsing a top-level expression and fail when it got to the = - -- sign, giving an uninformative parse error message. - -- See https://github.com/disco-lang/disco/issues/346. - -- - -- The solution is that we now do more careful backtracking within - -- parseDecl itself: when parsing a definition, we only backtrack if - -- we don't get a complete LHS + '=' sign; once we start parsing the - -- RHS of a definition we no longer backtrack, since it can't - -- possibly be a valid top-level expression. +parseTopLevel = + L.nonIndented sc $ + TLDoc <$> parseDocThing + <|> TLDecl <$> parseDecl -- See Note [Parsing definitions and top-level expressions] + <|> TLExpr <$> thenIndented parseTerm + +-- ~~~~ Note [Parsing definitions and top-level expressions] +-- +-- The beginning of a definition might look the same as an +-- expression. e.g. is f(x,y) the start of a definition of f, or an +-- expression with a function call? We used to therefore wrap +-- 'parseDecl' in 'try'. The problem is that if a definition has a +-- syntax error on the RHS, it would fail, backtrack, then try +-- parsing a top-level expression and fail when it got to the = +-- sign, giving an uninformative parse error message. +-- See https://github.com/disco-lang/disco/issues/346. +-- +-- The solution is that we now do more careful backtracking within +-- parseDecl itself: when parsing a definition, we only backtrack if +-- we don't get a complete LHS + '=' sign; once we start parsing the +-- RHS of a definition we no longer backtrack, since it can't +-- possibly be a valid top-level expression. -- | Parse a documentation item: either a group of lines beginning -- with @|||@ (text documentation), or a group beginning with @!!!@ -- (checked examples/properties). parseDocThing :: Parser DocThing -parseDocThing - = DocString <$> some parseDocString - <|> DocProperty <$> parseProperty +parseDocThing = + DocString <$> some parseDocString + <|> DocProperty <$> parseProperty -- | Parse one line of documentation beginning with @|||@. parseDocString :: Parser String -parseDocString = label "documentation" $ L.nonIndented sc $ - string "|||" - *> takeWhileP Nothing (`elem` " \t") - *> takeWhileP Nothing (`notElem` "\r\n") <* sc - - -- Note we use string "|||" rather than symbol "|||" because we - -- don't want it to consume whitespace afterwards (in particular a - -- line with ||| by itself would cause symbol "|||" to consume the - -- newline). +parseDocString = + label "documentation" $ + L.nonIndented sc $ + string "|||" + *> takeWhileP Nothing (`elem` " \t") + *> takeWhileP Nothing (`notElem` "\r\n") + <* sc + +-- Note we use string "|||" rather than symbol "|||" because we +-- don't want it to consume whitespace afterwards (in particular a +-- line with ||| by itself would cause symbol "|||" to consume the +-- newline). -- | Parse a top-level property/unit test, which is just @!!!@ -- followed by an arbitrary term. @@ -523,20 +644,21 @@ parseDecl = try (DType <$> parseTyDecl) <|> DDefn <$> parseDefn <|> DTyDef <$> p -- | Parse a top-level type declaration of the form @x : ty@. parseTyDecl :: Parser TypeDecl -parseTyDecl = label "type declaration" $ - TypeDecl <$> ident <*> indented (colon *> parsePolyTy) +parseTyDecl = + label "type declaration" $ + TypeDecl <$> ident <*> indented (colon *> parsePolyTy) -- | Parse a definition of the form @x pat1 .. patn = t@. parseDefn :: Parser TermDefn -parseDefn = label "definition" $ - (\(x, ps) body -> TermDefn x [bind ps body]) - - -- Only backtrack if we don't get a complete 'LHS ='. Once we see - -- an = sign, commit to parsing a definition, because it can't be a - -- valid standalone expression anymore. If the RHS fails, we don't - -- want to backtrack, we just want to display the parse error. - <$> try ((,) <$> ident <*> indented (many parseAtomicPattern) <* reservedOp "=") - <*> indented parseTerm +parseDefn = + label "definition" $ + (\(x, ps) body -> TermDefn x [bind ps body]) + -- Only backtrack if we don't get a complete 'LHS ='. Once we see + -- an = sign, commit to parsing a definition, because it can't be a + -- valid standalone expression anymore. If the RHS fails, we don't + -- want to backtrack, we just want to display the parse error. + <$> try ((,) <$> ident <*> indented (many parseAtomicPattern) <* reservedOp "=") + <*> indented parseTerm -- | Parse the definition of a user-defined algebraic data type. parseTyDefn :: Parser TypeDefn @@ -556,50 +678,52 @@ term = between sc eof parseTerm -- | Parse a term, consisting of a @parseTerm'@ optionally -- followed by an ascription. parseTerm :: Parser Term -parseTerm = -- trace "parseTerm" $ +parseTerm = + -- trace "parseTerm" $ ascribe <$> parseTerm' <*> optional (label "type annotation" $ colon *> parsePolyTy) - where - ascribe t Nothing = t - ascribe t (Just ty) = TAscr t ty + where + ascribe t Nothing = t + ascribe t (Just ty) = TAscr t ty -- | Parse a non-atomic, non-ascribed term. parseTerm' :: Parser Term -parseTerm' = label "expression" $ - parseQuantified - <|> parseLet - <|> parseExpr - <|> parseAtom +parseTerm' = + label "expression" $ + parseQuantified + <|> parseLet + <|> parseExpr + <|> parseAtom -- | Parse an atomic term. parseAtom :: Parser Term -parseAtom = label "expression" $ - parseUnit - <|> TBool True <$ (reserved "true" <|> reserved "True") - <|> TBool False <$ (reserved "false" <|> reserved "False") - <|> TChar <$> lexeme (between (char '\'') (char '\'') L.charLiteral) - <|> TString <$> lexeme (char '"' >> manyTill L.charLiteral (char '"')) - <|> TWild <$ try parseWild - <|> TPrim <$> try parseStandaloneOp - - -- Note primitives are NOT reserved words, so they are just parsed - -- as identifiers. This means that it is possible to shadow a - -- primitive in a local context, as it should be. Vars are turned - -- into prims at scope-checking time: if a var is not in scope but - -- there is a prim of that name then it becomes a TPrim. See the - -- 'typecheck Infer (TVar x)' case in Disco.Typecheck. - <|> TVar <$> ident - <|> TPrim <$> (ensureEnabled Primitives *> parsePrim) - <|> TRat <$> try decimal - <|> TNat <$> natural - <|> parseTypeOp - <|> TApp (TPrim PrimFloor) . TParens <$> fbrack parseTerm - <|> TApp (TPrim PrimCeil) . TParens <$> cbrack parseTerm - <|> parseCase - <|> try parseAbs - <|> bagdelims (parseContainer BagContainer) - <|> braces (parseContainer SetContainer) - <|> brackets (parseContainer ListContainer) - <|> tuple <$> parens (parseTerm `sepBy1` comma) +parseAtom = + label "expression" $ + parseUnit + <|> TBool True <$ (reserved "true" <|> reserved "True") + <|> TBool False <$ (reserved "false" <|> reserved "False") + <|> TChar <$> lexeme (between (char '\'') (char '\'') L.charLiteral) + <|> TString <$> lexeme (char '"' >> manyTill L.charLiteral (char '"')) + <|> TWild <$ try parseWild + <|> TPrim <$> try parseStandaloneOp + -- Note primitives are NOT reserved words, so they are just parsed + -- as identifiers. This means that it is possible to shadow a + -- primitive in a local context, as it should be. Vars are turned + -- into prims at scope-checking time: if a var is not in scope but + -- there is a prim of that name then it becomes a TPrim. See the + -- 'typecheck Infer (TVar x)' case in Disco.Typecheck. + <|> TVar <$> ident + <|> TPrim <$> (ensureEnabled Primitives *> parsePrim) + <|> TRat <$> try decimal + <|> TNat <$> natural + <|> parseTypeOp + <|> TApp (TPrim PrimFloor) . TParens <$> fbrack parseTerm + <|> TApp (TPrim PrimCeil) . TParens <$> cbrack parseTerm + <|> parseCase + <|> try parseAbs + <|> bagdelims (parseContainer BagContainer) + <|> braces (parseContainer SetContainer) + <|> brackets (parseContainer ListContainer) + <|> tuple <$> parens (parseTerm `sepBy1` comma) parseAbs :: Parser Term parseAbs = TApp (TPrim PrimAbs) <$> (pipe *> parseTerm <* pipe) @@ -610,42 +734,43 @@ parseUnit = TUnit <$ (reserved "unit" <|> void (symbol "■")) -- | Parse a wildcard, which is an underscore that isn't the start of -- an identifier. parseWild :: Parser () -parseWild = (lexeme . try . void) $ - string "_" <* notFollowedBy (alphaNumChar <|> oneOf "_'") +parseWild = + (lexeme . try . void) $ + string "_" <* notFollowedBy (alphaNumChar <|> oneOf "_'") -- | Parse a standalone operator name with tildes indicating argument -- slots, e.g. ~+~ for the addition operator. parseStandaloneOp :: Parser Prim parseStandaloneOp = asum $ concatMap mkStandaloneOpParsers (concat opTable) - where - mkStandaloneOpParsers :: OpInfo -> [Parser Prim] - mkStandaloneOpParsers (OpInfo (UOpF Pre uop) syns _) - = map (\syn -> PrimUOp uop <$ try (lexeme (string syn >> char '~'))) syns - mkStandaloneOpParsers (OpInfo (UOpF Post uop) syns _) - = map (\syn -> PrimUOp uop <$ try (lexeme (char '~' >> string syn))) syns - mkStandaloneOpParsers (OpInfo (BOpF _ bop) syns _) - = map (\syn -> PrimBOp bop <$ try (lexeme (char '~' >> string syn >> char '~'))) syns - - -- XXX TODO: improve the above so it first tries to parse a ~, - -- then parses any postfix or infix thing; or else it looks for - -- a prefix thing followed by a ~. This will get rid of the - -- need for 'try' and also potentially improve error messages. - -- The below may come in useful. - - -- flatOpTable = concat opTable - - -- prefixOps = [ (uop, syns) | (OpInfo (UOpF Pre uop) syns _) <- flatOpTable ] - -- postfixOps = [ (uop, syns) | (OpInfo (UOpF Post uop) syns _) <- flatOpTable ] - -- infixOps = [ (bop, syns) | (OpInfo (BOpF _ bop) syns _) <- flatOpTable ] + where + mkStandaloneOpParsers :: OpInfo -> [Parser Prim] + mkStandaloneOpParsers (OpInfo (UOpF Pre uop) syns _) = + map (\syn -> PrimUOp uop <$ try (lexeme (string syn >> char '~'))) syns + mkStandaloneOpParsers (OpInfo (UOpF Post uop) syns _) = + map (\syn -> PrimUOp uop <$ try (lexeme (char '~' >> string syn))) syns + mkStandaloneOpParsers (OpInfo (BOpF _ bop) syns _) = + map (\syn -> PrimBOp bop <$ try (lexeme (char '~' >> string syn >> char '~'))) syns + +-- XXX TODO: improve the above so it first tries to parse a ~, +-- then parses any postfix or infix thing; or else it looks for +-- a prefix thing followed by a ~. This will get rid of the +-- need for 'try' and also potentially improve error messages. +-- The below may come in useful. + +-- flatOpTable = concat opTable + +-- prefixOps = [ (uop, syns) | (OpInfo (UOpF Pre uop) syns _) <- flatOpTable ] +-- postfixOps = [ (uop, syns) | (OpInfo (UOpF Post uop) syns _) <- flatOpTable ] +-- infixOps = [ (bop, syns) | (OpInfo (BOpF _ bop) syns _) <- flatOpTable ] -- | Parse a primitive name starting with a $. parsePrim :: Parser Prim parsePrim = do void (char '$') x <- identifier letterChar - case find ((==x) . primSyntax) primTable of + case find ((== x) . primSyntax) primTable of Just (PrimInfo p _ _) -> return p - Nothing -> fail ("Unrecognized primitive $" ++ x) + Nothing -> fail ("Unrecognized primitive $" ++ x) -- | Parse a container, like a literal list, set, bag, or a -- comprehension (not including the square or curly brackets). @@ -668,45 +793,44 @@ parsePrim = do -- -- ::= [ ',' ] '..' [ ',' ] -- @ - parseContainer :: Container -> Parser Term parseContainer c = nonEmptyContainer <|> return (TContainer c [] Nothing) + where -- Careful to do this without backtracking, since backtracking can -- lead to bad performance in certain pathological cases (for -- example, a very deeply nested list). - where - -- Any non-empty container starts with a term, followed by some - -- remainder (which could either be the rest of a literal - -- container, or a container comprehension). If there is no - -- remainder just return a "singleton" container (which could - -- include a trailing ellipsis + final term). - nonEmptyContainer = parseRepTerm >>= containerRemainder - - parseRepTerm = do - t <- parseTerm - n <- optional $ do - guard (c == BagContainer) - void hash - parseTerm - return (t, n) - - -- The remainder of a container after the first term starts with - -- either a pipe (for a comprehension) or a comma (for a literal - -- container). - containerRemainder :: (Term, Maybe Term) -> Parser Term - containerRemainder (t,n) = - (guard (isNothing n) *> parseContainerComp c t) <|> parseLitContainerRemainder t n - - parseLitContainerRemainder :: Term -> Maybe Term -> Parser Term - parseLitContainerRemainder t n = do - -- Wrapping the (',' term) production in 'try' is important: if - -- it consumes a comma but then fails when parsing a term, we - -- want to be able to backtrack so we can potentially parse an - -- ellipsis beginning with a comma. - ts <- many (try (comma *> parseRepTerm)) - e <- optional parseEllipsis - return $ TContainer c ((t,n):ts) e + -- Any non-empty container starts with a term, followed by some + -- remainder (which could either be the rest of a literal + -- container, or a container comprehension). If there is no + -- remainder just return a "singleton" container (which could + -- include a trailing ellipsis + final term). + nonEmptyContainer = parseRepTerm >>= containerRemainder + + parseRepTerm = do + t <- parseTerm + n <- optional $ do + guard (c == BagContainer) + void hash + parseTerm + return (t, n) + + -- The remainder of a container after the first term starts with + -- either a pipe (for a comprehension) or a comma (for a literal + -- container). + containerRemainder :: (Term, Maybe Term) -> Parser Term + containerRemainder (t, n) = + (guard (isNothing n) *> parseContainerComp c t) <|> parseLitContainerRemainder t n + + parseLitContainerRemainder :: Term -> Maybe Term -> Parser Term + parseLitContainerRemainder t n = do + -- Wrapping the (',' term) production in 'try' is important: if + -- it consumes a comma but then fails when parsing a term, we + -- want to be able to backtrack so we can potentially parse an + -- ellipsis beginning with a comma. + ts <- many (try (comma *> parseRepTerm)) + e <- optional parseEllipsis + return $ TContainer c ((t, n) : ts) e -- | Parse an ellipsis at the end of a literal list, of the form -- @.. t@. Any number > 1 of dots may be used, just for fun. @@ -727,12 +851,14 @@ parseContainerComp c t = do -- a guard @t@. parseQual :: Parser Qual parseQual = parseSelection <|> parseQualGuard - where - parseSelection = label "membership expression (x in ...)" $ + where + parseSelection = + label "membership expression (x in ...)" $ QBind <$> try (ident <* selector) <*> (embed <$> parseTerm) - selector = reservedOp "<-" <|> reserved "in" + selector = reservedOp "<-" <|> reserved "in" - parseQualGuard = label "boolean expression" $ + parseQualGuard = + label "boolean expression" $ QGuard . embed <$> parseTerm -- | Turn a parenthesized list of zero or more terms into the @@ -742,53 +868,57 @@ parseQual = parseSelection <|> parseQualGuard -- terms @(t1,t2,...)@ are a tuple. tuple :: [Term] -> Term tuple [x] = TParens x -tuple t = TTup t +tuple t = TTup t -- | Parse a quantified abstraction (λ, ∀, ∃). parseQuantified :: Parser Term parseQuantified = do q <- parseQuantifier TAbs q <$> (bind <$> parseArgs (q /= Lam) <*> (dot *> parseTerm)) - where - parseArgs notLam = (parsePattern notLam `sepBy1` comma) >>= checkMulti - -- ∀ and ∃ can have multiple bindings separated by commas, - -- like ∀ x:N, y:N. ... but we don't allow this for λ. - where - checkMulti :: [Pattern] -> Parser [Pattern] - checkMulti ps - | notLam = return ps - | otherwise = case ps of - [p] -> return [p] - _ -> customFailure MultiArgLambda + where + parseArgs notLam = (parsePattern notLam `sepBy1` comma) >>= checkMulti + where + -- ∀ and ∃ can have multiple bindings separated by commas, + -- like ∀ x:N, y:N. ... but we don't allow this for λ. + + checkMulti :: [Pattern] -> Parser [Pattern] + checkMulti ps + | notLam = return ps + | otherwise = case ps of + [p] -> return [p] + _ -> customFailure MultiArgLambda -- | Parse a quantifier symbol (lambda, forall, or exists). parseQuantifier :: Parser Quantifier parseQuantifier = - Lam <$ lambda - <|> All <$ forall - <|> Ex <$ exists + Lam <$ lambda + <|> All <$ forall + <|> Ex <$ exists -- | Parse a let expression (@let x1 = t1, x2 = t2, ... in t@). parseLet :: Parser Term parseLet = - TLet <$> - (reserved "let" *> - (bind - <$> (toTelescope <$> (parseBinding `sepBy` comma)) - <*> (reserved "in" *> parseTerm))) + TLet + <$> ( reserved "let" + *> ( bind + <$> (toTelescope <$> (parseBinding `sepBy` comma)) + <*> (reserved "in" *> parseTerm) + ) + ) -- | Parse a single binding (@x [ : ty ] = t@). parseBinding :: Parser Binding parseBinding = do - x <- ident + x <- ident mty <- optional (colon *> parsePolyTy) - t <- symbol "=" *> (embed <$> parseTerm) + t <- symbol "=" *> (embed <$> parseTerm) return $ Binding (embed <$> mty) x t -- | Parse a case expression. parseCase :: Parser Term -parseCase = between (symbol "{?") (symbol "?}") $ - TCase <$> parseBranch `sepBy` comma +parseCase = + between (symbol "{?") (symbol "?}") $ + TCase <$> parseBranch `sepBy` comma -- | Parse one branch of a case expression. parseBranch :: Parser Branch @@ -802,15 +932,15 @@ parseGuards = (TelEmpty <$ reserved "otherwise") <|> (toTelescope <$> many parse -- | Parse a single guard (@if@, @if ... is ...@, or @let@) parseGuard :: Parser Guard parseGuard = parseGCond <|> parseGLet - where - guardWord = reserved "if" <|> reserved "when" - parseGCond = do - guardWord - t <- parseTerm - parseGPat t <|> parseGBool t - parseGPat t = GPat (embed t) <$> (reserved "is" *> parsePattern False) - parseGBool t = pure $ GBool (embed t) - parseGLet = GLet <$> (reserved "let" *> parseBinding) + where + guardWord = reserved "if" <|> reserved "when" + parseGCond = do + guardWord + t <- parseTerm + parseGPat t <|> parseGBool t + parseGPat t = GPat (embed t) <$> (reserved "is" *> parsePattern False) + parseGBool t = pure $ GBool (embed t) + parseGLet = GLet <$> (reserved "let" *> parseBinding) -- | Parse an atomic pattern, by parsing a term and then attempting to -- convert it to a pattern. @@ -819,7 +949,7 @@ parseAtomicPattern = label "pattern" $ do t <- parseAtom case termToPattern t of Nothing -> customFailure $ InvalidPattern (OT t) - Just p -> return $ maybe p (PNonlinear p) (findDuplicatePVar p) + Just p -> return $ maybe p (PNonlinear p) (findDuplicatePVar p) -- | Parse a pattern, by parsing a term and then attempting to convert -- it to a pattern. The Bool parameter says whether to require @@ -837,261 +967,250 @@ parsePattern requireAscr = label "pattern" $ do -- a tuple with each component recursively having ascriptions? -- This is required for patterns bound by ∀ and ∃ quantifiers. hasAscr :: Pattern -> Bool -hasAscr PAscr{} = True +hasAscr PAscr {} = True hasAscr (PTup ps) = all hasAscr ps -hasAscr _ = False +hasAscr _ = False -- | Lazy monadic variant of find. findM :: Monad m => (a -> m (Maybe b)) -> [a] -> m (Maybe b) findM _ [] = return Nothing -findM p (a:as) = do +findM p (a : as) = do b <- p a case b of Just x -> return $ Just x - _ -> findM p as + _ -> findM p as -- | Does a pattern have the same variable repeated more than once? findDuplicatePVar :: Pattern -> Maybe (Name Term) findDuplicatePVar = flip evalState S.empty . go - where - go :: Pattern -> State (Set String) (Maybe (Name Term)) - go (PVar x) = do - let xName = name2String x - seen <- gets (S.member xName) - if seen - then return (Just x) - else do - modify (S.insert xName) - return Nothing - go (PAscr p _) = go p - go (PTup ps) = findM go ps - go (PInj _ p) = go p - go (PCons p1 p2) = findM go [p1,p2] - go (PList ps) = findM go ps - go (PAdd _ p _) = go p - go (PMul _ p _) = go p - go (PSub p _) = go p - go (PNeg p) = go p - go (PFrac p1 p2) = findM go [p1,p2] - go _ = return Nothing + where + go :: Pattern -> State (Set String) (Maybe (Name Term)) + go (PVar x) = do + let xName = name2String x + seen <- gets (S.member xName) + if seen + then return (Just x) + else do + modify (S.insert xName) + return Nothing + go (PAscr p _) = go p + go (PTup ps) = findM go ps + go (PInj _ p) = go p + go (PCons p1 p2) = findM go [p1, p2] + go (PList ps) = findM go ps + go (PAdd _ p _) = go p + go (PMul _ p _) = go p + go (PSub p _) = go p + go (PNeg p) = go p + go (PFrac p1 p2) = findM go [p1, p2] + go _ = return Nothing -- | Attempt converting a term to a pattern. termToPattern :: Term -> Maybe Pattern -termToPattern TWild = Just PWild -termToPattern (TVar x) = Just $ PVar x +termToPattern TWild = Just PWild +termToPattern (TVar x) = Just $ PVar x termToPattern (TParens t) = termToPattern t -termToPattern TUnit = Just PUnit -termToPattern (TBool b) = Just $ PBool b -termToPattern (TNat n) = Just $ PNat n -termToPattern (TChar c) = Just $ PChar c +termToPattern TUnit = Just PUnit +termToPattern (TBool b) = Just $ PBool b +termToPattern (TNat n) = Just $ PNat n +termToPattern (TChar c) = Just $ PChar c termToPattern (TString s) = Just $ PString s -termToPattern (TTup ts) = PTup <$> mapM termToPattern ts +termToPattern (TTup ts) = PTup <$> mapM termToPattern ts termToPattern (TApp (TVar i) t) - | i == string2Name "left" = PInj L <$> termToPattern t + | i == string2Name "left" = PInj L <$> termToPattern t | i == string2Name "right" = PInj R <$> termToPattern t -- termToPattern (TInj s t) = PInj s <$> termToPattern t termToPattern (TAscr t s) = case s of Forall (unsafeUnbind -> ([], s')) -> PAscr <$> termToPattern t <*> pure s' - _ -> Nothing - -termToPattern (TBin Cons t1 t2) - = PCons <$> termToPattern t1 <*> termToPattern t2 - -termToPattern (TBin Add t1 t2) - = case (termToPattern t1, termToPattern t2) of - (Just p, _) - | length (toListOf fvAny p) == 1 - && null (toListOf fvAny t2) - -> Just $ PAdd L p t2 - (_, Just p) - | length (toListOf fvAny p) == 1 - && null (toListOf fvAny t1) - -> Just $ PAdd R p t1 - _ -> Nothing - -- If t1 is a pattern binding one variable, and t2 has no fvs, - -- this can be a PAdd L. Also vice versa for PAdd R. - -termToPattern (TBin Mul t1 t2) - = case (termToPattern t1, termToPattern t2) of - (Just p, _) - | length (toListOf fvAny p) == 1 - && null (toListOf fvAny t2) - -> Just $ PMul L p t2 - (_, Just p) - | length (toListOf fvAny p) == 1 - && null (toListOf fvAny t1) - -> Just $ PMul R p t1 - _ -> Nothing - -- If t1 is a pattern binding one variable, and t2 has no fvs, - -- this can be a PMul L. Also vice versa for PMul R. - -termToPattern (TBin Sub t1 t2) - = case termToPattern t1 of - Just p - | length (toListOf fvAny p) == 1 - && null (toListOf fvAny t2) - -> Just $ PSub p t2 - _ -> Nothing - -- If t1 is a pattern binding one variable, and t2 has no fvs, - -- this can be a PSub. - - -- For now we don't handle the case of t - p, since it seems - -- less useful (and desugaring it would require extra code since - -- subtraction is not commutative). - -termToPattern (TBin Div t1 t2) - = PFrac <$> termToPattern t1 <*> termToPattern t2 - + _ -> Nothing +termToPattern (TBin Cons t1 t2) = + PCons <$> termToPattern t1 <*> termToPattern t2 +termToPattern (TBin Add t1 t2) = + case (termToPattern t1, termToPattern t2) of + (Just p, _) + | length (toListOf fvAny p) == 1 + && null (toListOf fvAny t2) -> + Just $ PAdd L p t2 + (_, Just p) + | length (toListOf fvAny p) == 1 + && null (toListOf fvAny t1) -> + Just $ PAdd R p t1 + _ -> Nothing +-- If t1 is a pattern binding one variable, and t2 has no fvs, +-- this can be a PAdd L. Also vice versa for PAdd R. + +termToPattern (TBin Mul t1 t2) = + case (termToPattern t1, termToPattern t2) of + (Just p, _) + | length (toListOf fvAny p) == 1 + && null (toListOf fvAny t2) -> + Just $ PMul L p t2 + (_, Just p) + | length (toListOf fvAny p) == 1 + && null (toListOf fvAny t1) -> + Just $ PMul R p t1 + _ -> Nothing +-- If t1 is a pattern binding one variable, and t2 has no fvs, +-- this can be a PMul L. Also vice versa for PMul R. + +termToPattern (TBin Sub t1 t2) = + case termToPattern t1 of + Just p + | length (toListOf fvAny p) == 1 + && null (toListOf fvAny t2) -> + Just $ PSub p t2 + _ -> Nothing +-- If t1 is a pattern binding one variable, and t2 has no fvs, +-- this can be a PSub. + +-- For now we don't handle the case of t - p, since it seems +-- less useful (and desugaring it would require extra code since +-- subtraction is not commutative). + +termToPattern (TBin Div t1 t2) = + PFrac <$> termToPattern t1 <*> termToPattern t2 termToPattern (TUn Neg t) = PNeg <$> termToPattern t - -termToPattern (TContainer ListContainer ts Nothing) - = PList <$> mapM (termToPattern . fst) ts - -termToPattern _ = Nothing +termToPattern (TContainer ListContainer ts Nothing) = + PList <$> mapM (termToPattern . fst) ts +termToPattern _ = Nothing -- | Parse an expression built out of unary and binary operators. parseExpr :: Parser Term parseExpr = fixJuxtMul . fixChains <$> (makeExprParser parseAtom table "expression") - where - table - -- Special case for function application, with highest - -- precedence. Note that we parse all juxtaposition as - -- function application first; we later go through and turn - -- some into multiplication (fixing up the precedence - -- appropriately) based on a syntactic analysis. - = [ InfixL (TApp <$ string "") ] - - -- get all other operators from the opTable + where + table = + -- Special case for function application, with highest + -- precedence. Note that we parse all juxtaposition as + -- function application first; we later go through and turn + -- some into multiplication (fixing up the precedence + -- appropriately) based on a syntactic analysis. + [InfixL (TApp <$ string "")] + -- get all other operators from the opTable : (map . concatMap) mkOpParser opTable - mkOpParser :: OpInfo -> [Operator Parser Term] - mkOpParser (OpInfo op syns _) = concatMap (withOpFixity op) syns - - -- Only parse unary operators consisting of operator symbols. - -- Alphabetic unary operators (i.e. 'not') will be parsed as - -- applications of variable names, since if they are parsed here - -- they will incorrectly parse even when they are a prefix of a - -- variable name. - withOpFixity (UOpF fx op) syn - | any isAlpha syn = [] - | otherwise = [ufxParser fx ((reservedOp syn "operator") >> return (TUn op))] - - withOpFixity (BOpF fx op) syn - = [bfxParser fx ((reservedOp syn "operator") >> return (TBin op))] - - ufxParser Pre = Prefix - ufxParser Post = Postfix - - bfxParser InL = InfixL - bfxParser InR = InfixR - bfxParser In = InfixN - - isChainable op = op `elem` [Eq, Neq, Lt, Gt, Leq, Geq, Divides] - - -- Comparison chains like 3 < x < 5 first get parsed as 3 < (x < - -- 5), which does not make sense. This function looks for such - -- nested comparison operators and turns them into a TChain. - fixChains (TUn op t) = TUn op (fixChains t) - fixChains (TBin op t1 (TBin op' t21 t22)) - | isChainable op && isChainable op' = TChain t1 (TLink op t21 : getLinks op' t22) - fixChains (TBin op t1 t2) = TBin op (fixChains t1) (fixChains t2) - fixChains (TApp t1 t2) = TApp (fixChains t1) (fixChains t2) - - -- Only recurse as long as we see TUn, TBin, or TApp which could - -- have been generated by the expression parser. If we see - -- anything else we can stop. - fixChains e = e - - getLinks op (TBin op' t1 t2) - | isChainable op' = TLink op t1 : getLinks op' t2 - getLinks op e = [TLink op (fixChains e)] - - -- Find juxtapositions (parsed as function application) which - -- syntactically have either a literal Nat or a parenthesized - -- expression containing an operator as the LHS, and turn them - -- into multiplications. Then fix up the parse tree by rotating - -- newly created multiplications up until their precedence is - -- higher than the thing above them. - - fixJuxtMul :: Term -> Term - - -- Just recurse through TUn or TBin and fix precedence on the way back up. - fixJuxtMul (TUn op t) = fixPrec $ TUn op (fixJuxtMul t) - fixJuxtMul (TBin op t1 t2) = fixPrec $ TBin op (fixJuxtMul t1) (fixJuxtMul t2) - - -- Possibly turn a TApp into a multiplication, if the LHS looks - -- like a multiplicative term. However, we must be sure to - -- *first* recursively fix the subterms (particularly the - -- left-hand one) *before* doing this analysis. See - -- . - fixJuxtMul (TApp t1 t2) - | isMultiplicativeTerm t1' = fixPrec $ TBin Mul t1' t2' - | otherwise = fixPrec $ TApp t1' t2' - where - t1' = fixJuxtMul t1 - t2' = fixJuxtMul t2 - - -- Otherwise we can stop recursing, since anything other than TUn, - -- TBin, or TApp could not have been produced by the expression - -- parser. - fixJuxtMul t = t - - -- A multiplicative term is one that looks like either a natural - -- number literal, or a unary or binary operation (optionally - -- parenthesized). For example, 3, (-2), and (x + 5) are all - -- multiplicative terms, so 3x, (-2)x, and (x + 5)x all get parsed - -- as multiplication. On the other hand, (x y) is always parsed - -- as function application, even if x and y both turn out to have - -- numeric types; a variable like x does not count as a - -- multiplicative term. Likewise, (x y) z is parsed as function - -- application, since (x y) is not a multiplicative term: it is - -- parenthezised, but contains a TApp rather than a TBin or TUn. - isMultiplicativeTerm :: Term -> Bool - isMultiplicativeTerm (TNat _) = True - isMultiplicativeTerm TUn{} = True - isMultiplicativeTerm TBin{} = True - isMultiplicativeTerm (TParens t) = isMultiplicativeTerm t - isMultiplicativeTerm _ = False - - -- Fix precedence by bubbling up any new TBin terms whose - -- precedence is less than that of the operator above them. We - -- don't worry at all about fixing associativity, just precedence. - - fixPrec :: Term -> Term - - -- e.g. 2y! --> (2@y)! --> fixup --> 2 * (y!) - fixPrec (TUn uop (TBin bop t1 t2)) - | bPrec bop < uPrec uop = case uopMap M.! uop of - OpInfo (UOpF Pre _) _ _ -> TBin bop (TUn uop t1) t2 - OpInfo (UOpF Post _) _ _ -> TBin bop t1 (TUn uop t2) - _ -> error "Impossible! In fixPrec, uopMap contained OpInfo (BOpF ...)" - - fixPrec (TBin bop1 (TBin bop2 t1 t2) t3) - | bPrec bop2 < bPrec bop1 = TBin bop2 t1 (fixPrec $ TBin bop1 t2 t3) - - -- e.g. x^2y --> x^(2@y) --> x^(2*y) --> (x^2) * y - fixPrec (TBin bop1 t1 (TBin bop2 t2 t3)) - | bPrec bop2 < bPrec bop1 = TBin bop2 (fixPrec $ TBin bop1 t1 t2) t3 - - fixPrec t = t + mkOpParser :: OpInfo -> [Operator Parser Term] + mkOpParser (OpInfo op syns _) = concatMap (withOpFixity op) syns + + -- Only parse unary operators consisting of operator symbols. + -- Alphabetic unary operators (i.e. 'not') will be parsed as + -- applications of variable names, since if they are parsed here + -- they will incorrectly parse even when they are a prefix of a + -- variable name. + withOpFixity (UOpF fx op) syn + | any isAlpha syn = [] + | otherwise = [ufxParser fx ((reservedOp syn "operator") >> return (TUn op))] + withOpFixity (BOpF fx op) syn = + [bfxParser fx ((reservedOp syn "operator") >> return (TBin op))] + + ufxParser Pre = Prefix + ufxParser Post = Postfix + + bfxParser InL = InfixL + bfxParser InR = InfixR + bfxParser In = InfixN + + isChainable op = op `elem` [Eq, Neq, Lt, Gt, Leq, Geq, Divides] + + -- Comparison chains like 3 < x < 5 first get parsed as 3 < (x < + -- 5), which does not make sense. This function looks for such + -- nested comparison operators and turns them into a TChain. + fixChains (TUn op t) = TUn op (fixChains t) + fixChains (TBin op t1 (TBin op' t21 t22)) + | isChainable op && isChainable op' = TChain t1 (TLink op t21 : getLinks op' t22) + fixChains (TBin op t1 t2) = TBin op (fixChains t1) (fixChains t2) + fixChains (TApp t1 t2) = TApp (fixChains t1) (fixChains t2) + -- Only recurse as long as we see TUn, TBin, or TApp which could + -- have been generated by the expression parser. If we see + -- anything else we can stop. + fixChains e = e + + getLinks op (TBin op' t1 t2) + | isChainable op' = TLink op t1 : getLinks op' t2 + getLinks op e = [TLink op (fixChains e)] + + -- Find juxtapositions (parsed as function application) which + -- syntactically have either a literal Nat or a parenthesized + -- expression containing an operator as the LHS, and turn them + -- into multiplications. Then fix up the parse tree by rotating + -- newly created multiplications up until their precedence is + -- higher than the thing above them. + + fixJuxtMul :: Term -> Term + + -- Just recurse through TUn or TBin and fix precedence on the way back up. + fixJuxtMul (TUn op t) = fixPrec $ TUn op (fixJuxtMul t) + fixJuxtMul (TBin op t1 t2) = fixPrec $ TBin op (fixJuxtMul t1) (fixJuxtMul t2) + -- Possibly turn a TApp into a multiplication, if the LHS looks + -- like a multiplicative term. However, we must be sure to + -- \*first* recursively fix the subterms (particularly the + -- left-hand one) *before* doing this analysis. See + -- . + fixJuxtMul (TApp t1 t2) + | isMultiplicativeTerm t1' = fixPrec $ TBin Mul t1' t2' + | otherwise = fixPrec $ TApp t1' t2' + where + t1' = fixJuxtMul t1 + t2' = fixJuxtMul t2 + + -- Otherwise we can stop recursing, since anything other than TUn, + -- TBin, or TApp could not have been produced by the expression + -- parser. + fixJuxtMul t = t + + -- A multiplicative term is one that looks like either a natural + -- number literal, or a unary or binary operation (optionally + -- parenthesized). For example, 3, (-2), and (x + 5) are all + -- multiplicative terms, so 3x, (-2)x, and (x + 5)x all get parsed + -- as multiplication. On the other hand, (x y) is always parsed + -- as function application, even if x and y both turn out to have + -- numeric types; a variable like x does not count as a + -- multiplicative term. Likewise, (x y) z is parsed as function + -- application, since (x y) is not a multiplicative term: it is + -- parenthezised, but contains a TApp rather than a TBin or TUn. + isMultiplicativeTerm :: Term -> Bool + isMultiplicativeTerm (TNat _) = True + isMultiplicativeTerm TUn {} = True + isMultiplicativeTerm TBin {} = True + isMultiplicativeTerm (TParens t) = isMultiplicativeTerm t + isMultiplicativeTerm _ = False + + -- Fix precedence by bubbling up any new TBin terms whose + -- precedence is less than that of the operator above them. We + -- don't worry at all about fixing associativity, just precedence. + + fixPrec :: Term -> Term + + -- e.g. 2y! --> (2@y)! --> fixup --> 2 * (y!) + fixPrec (TUn uop (TBin bop t1 t2)) + | bPrec bop < uPrec uop = case uopMap M.! uop of + OpInfo (UOpF Pre _) _ _ -> TBin bop (TUn uop t1) t2 + OpInfo (UOpF Post _) _ _ -> TBin bop t1 (TUn uop t2) + _ -> error "Impossible! In fixPrec, uopMap contained OpInfo (BOpF ...)" + fixPrec (TBin bop1 (TBin bop2 t1 t2) t3) + | bPrec bop2 < bPrec bop1 = TBin bop2 t1 (fixPrec $ TBin bop1 t2 t3) + -- e.g. x^2y --> x^(2@y) --> x^(2*y) --> (x^2) * y + fixPrec (TBin bop1 t1 (TBin bop2 t2 t3)) + | bPrec bop2 < bPrec bop1 = TBin bop2 (fixPrec $ TBin bop1 t1 t2) t3 + fixPrec t = t -- | Parse an atomic type. parseAtomicType :: Parser Type -parseAtomicType = label "type" $ - TyVoid <$ reserved "Void" - <|> TyUnit <$ reserved "Unit" - <|> TyBool <$ (reserved "Boolean" <|> reserved "Bool") - <|> TyProp <$ (reserved "Proposition" <|> reserved "Prop") - <|> TyC <$ reserved "Char" - -- <|> try parseTyFin - <|> TyN <$ (reserved "Natural" <|> reserved "Nat" <|> reserved "N" <|> reserved "ℕ") - <|> TyZ <$ (reserved "Integer" <|> reserved "Int" <|> reserved "Z" <|> reserved "ℤ") - <|> TyF <$ (reserved "Fractional" <|> reserved "Frac" <|> reserved "F" <|> reserved "𝔽") - <|> TyQ <$ (reserved "Rational" <|> reserved "Q" <|> reserved "ℚ") - <|> TyCon <$> parseCon <*> (fromMaybe [] <$> optional (parens (parseType `sepBy1` comma))) - <|> TyVar <$> parseTyVar - <|> parens parseType +parseAtomicType = + label "type" $ + TyVoid <$ reserved "Void" + <|> TyUnit <$ reserved "Unit" + <|> TyBool <$ (reserved "Boolean" <|> reserved "Bool") + <|> TyProp <$ (reserved "Proposition" <|> reserved "Prop") + <|> TyC <$ reserved "Char" + -- <|> try parseTyFin + <|> TyN <$ (reserved "Natural" <|> reserved "Nat" <|> reserved "N" <|> reserved "ℕ") + <|> TyZ <$ (reserved "Integer" <|> reserved "Int" <|> reserved "Z" <|> reserved "ℤ") + <|> TyF <$ (reserved "Fractional" <|> reserved "Frac" <|> reserved "F" <|> reserved "𝔽") + <|> TyQ <$ (reserved "Rational" <|> reserved "Q" <|> reserved "ℚ") + <|> TyCon <$> parseCon <*> (fromMaybe [] <$> optional (parens (parseType `sepBy1` comma))) + <|> TyVar <$> parseTyVar + <|> parens parseType -- parseTyFin :: Parser Type -- parseTyFin = TyFin <$> (reserved "Fin" *> natural) @@ -1099,15 +1218,15 @@ parseAtomicType = label "type" $ parseCon :: Parser Con parseCon = - CList <$ reserved "List" - <|> CBag <$ reserved "Bag" - <|> CSet <$ reserved "Set" - <|> CGraph <$ reserved "Graph" - <|> CMap <$ reserved "Map" - <|> CUser <$> parseTyDef + CList <$ reserved "List" + <|> CBag <$ reserved "Bag" + <|> CSet <$ reserved "Set" + <|> CGraph <$ reserved "Graph" + <|> CMap <$ reserved "Map" + <|> CUser <$> parseTyDef parseTyDef :: Parser String -parseTyDef = identifier upperChar +parseTyDef = identifier upperChar parseTyVarName :: Parser String parseTyVarName = identifier lowerChar @@ -1121,23 +1240,28 @@ parsePolyTy = closeType <$> parseType -- | Parse a type expression built out of binary operators. parseType :: Parser Type parseType = makeExprParser parseAtomicType table - where - table = [ [ infixR "*" (:*:) - , infixR "×" (:*:) ] - , [ infixR "+" (:+:) - , infixR "⊎" (:+:) - ] - , [ infixR "->" (:->:) - , infixR "→" (:->:) - ] - ] - - infixR name fun = InfixR (reservedOp name >> return fun) + where + table = + [ + [ infixR "*" (:*:) + , infixR "×" (:*:) + ] + , + [ infixR "+" (:+:) + , infixR "⊎" (:+:) + ] + , + [ infixR "->" (:->:) + , infixR "→" (:->:) + ] + ] + + infixR name fun = InfixR (reservedOp name >> return fun) parseTyOp :: Parser TyOp parseTyOp = - Enumerate <$ reserved "enumerate" - <|> Count <$ reserved "count" + Enumerate <$ reserved "enumerate" + <|> Count <$ reserved "count" parseTypeOp :: Parser Term parseTypeOp = TTyOp <$> parseTyOp <*> parseAtomicType diff --git a/src/Disco/Pretty.hs b/src/Disco/Pretty.hs index f3639f1a..981b869e 100644 --- a/src/Disco/Pretty.hs +++ b/src/Disco/Pretty.hs @@ -1,8 +1,13 @@ -{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE NoMonomorphismRestriction #-} -{-# LANGUAGE OverloadedStrings #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- TODO: the calls to 'error' should be replaced with logging/error capabilities. + -- | -- Module : Disco.Pretty -- Copyright : disco team and contributors @@ -11,40 +16,35 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Various pretty-printing facilities for disco. --- ------------------------------------------------------------------------------ - --- TODO: the calls to 'error' should be replaced with logging/error capabilities. - -module Disco.Pretty - ( module Disco.Pretty.DSL - , module Disco.Pretty - , module Disco.Pretty.Prec - , Doc - ) - where +module Disco.Pretty ( + module Disco.Pretty.DSL, + module Disco.Pretty, + module Disco.Pretty.Prec, + Doc, +) +where -import Prelude hiding ((<>)) +import Prelude hiding ((<>)) -import Data.Bifunctor -import Data.Char (isAlpha) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Ratio -import Data.Set (Set) -import qualified Data.Set as S +import Data.Bifunctor +import Data.Char (isAlpha) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Ratio +import Data.Set (Set) +import qualified Data.Set as S -import Disco.Effects.LFresh -import Polysemy +import Disco.Effects.LFresh +import Polysemy -import Polysemy.Reader +import Polysemy.Reader -import Text.PrettyPrint (Doc) -import Unbound.Generics.LocallyNameless (Name) +import Text.PrettyPrint (Doc) +import Unbound.Generics.LocallyNameless (Name) -import Disco.Pretty.DSL -import Disco.Pretty.Prec -import Disco.Syntax.Operators +import Disco.Pretty.DSL +import Disco.Pretty.Prec +import Disco.Syntax.Operators ------------------------------------------------------------ -- Utilities for handling precedence and associativity @@ -103,7 +103,7 @@ instance Pretty a => Pretty [a] where instance (Pretty k, Pretty v) => Pretty (Map k v) where pretty m = do - let es = map (\(k,v) -> pretty k <+> "->" <+> pretty v) (M.assocs m) + let es = map (\(k, v) -> pretty k <+> "->" <+> pretty v) (M.assocs m) ds <- setPA initPA $ punctuate "," es braces (hsep ds) @@ -119,13 +119,13 @@ instance Pretty (Name a) where instance Pretty TyOp where pretty = \case Enumerate -> text "enumerate" - Count -> text "count" + Count -> text "count" -- | Pretty-print a unary operator, by looking up its concrete syntax -- in the 'uopMap'. instance Pretty UOp where pretty op = case M.lookup op uopMap of - Just (OpInfo _ (syn:_) _) -> + Just (OpInfo _ (syn : _) _) -> text $ syn ++ (if all isAlpha syn then " " else "") _ -> error $ "UOp " ++ show op ++ " not in uopMap!" @@ -133,8 +133,8 @@ instance Pretty UOp where -- in the 'bopMap'. instance Pretty BOp where pretty op = case M.lookup op bopMap of - Just (OpInfo _ (syn:_) _) -> text syn - _ -> error $ "BOp " ++ show op ++ " not in bopMap!" + Just (OpInfo _ (syn : _) _) -> text syn + _ -> error $ "BOp " ++ show op ++ " not in bopMap!" -------------------------------------------------- -- Pretty-printing decimals @@ -144,19 +144,19 @@ instance Pretty BOp where -- in square brackets. prettyDecimal :: Rational -> String prettyDecimal r = printedDecimal + where + (n, d) = properFraction r :: (Integer, Rational) + (expan, len) = digitalExpansion 10 (numerator d) (denominator d) + printedDecimal + | length first102 > 101 || length first102 == 101 && last first102 /= 0 = + show n ++ "." ++ concatMap show (take 100 expan) ++ "..." + | rep == [0] = + show n ++ "." ++ (if null pre then "0" else concatMap show pre) + | otherwise = + show n ++ "." ++ concatMap show pre ++ "[" ++ concatMap show rep ++ "]" where - (n,d) = properFraction r :: (Integer, Rational) - (expan, len) = digitalExpansion 10 (numerator d) (denominator d) - printedDecimal - | length first102 > 101 || length first102 == 101 && last first102 /= 0 - = show n ++ "." ++ concatMap show (take 100 expan) ++ "..." - | rep == [0] - = show n ++ "." ++ (if null pre then "0" else concatMap show pre) - | otherwise - = show n ++ "." ++ concatMap show pre ++ "[" ++ concatMap show rep ++ "]" - where - (pre, rep) = splitAt len expan - first102 = take 102 expan + (pre, rep) = splitAt len expan + first102 = take 102 expan -- Given a list, find the indices of the list giving the first and -- second occurrence of the first element to repeat, or Nothing if @@ -166,9 +166,9 @@ findRep = findRep' M.empty 0 findRep' :: Ord a => M.Map a Int -> Int -> [a] -> ([a], Int) findRep' _ _ [] = error "Impossible. Empty list in findRep'" -findRep' prevs ix (x:xs) +findRep' prevs ix (x : xs) | x `M.member` prevs = ([], prevs M.! x) - | otherwise = first (x:) $ findRep' (M.insert x ix prevs) (ix+1) xs + | otherwise = first (x :) $ findRep' (M.insert x ix prevs) (ix + 1) xs -- | @digitalExpansion b n d@ takes the numerator and denominator of a -- fraction n/d between 0 and 1, and returns a pair of (1) a list of @@ -185,7 +185,7 @@ findRep' prevs ix (x:xs) -- looking for the first time that the remainder repeats. digitalExpansion :: Integer -> Integer -> Integer -> ([Integer], Int) digitalExpansion b n d = digits - where - longDivStep (_, r) = (b*r) `divMod` d - res = tail $ iterate longDivStep (0,n) - digits = first (map fst) (findRep res) + where + longDivStep (_, r) = (b * r) `divMod` d + res = tail $ iterate longDivStep (0, n) + digits = first (map fst) (findRep res) diff --git a/src/Disco/Pretty/DSL.hs b/src/Disco/Pretty/DSL.hs index b4c35c72..8032f150 100644 --- a/src/Disco/Pretty/DSL.hs +++ b/src/Disco/Pretty/DSL.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- +----------------------------------------------------------------------------- +{-# OPTIONS_GHC -fno-warn-orphans #-} + -- | -- Module : Disco.Pretty.DSL -- Copyright : disco team and contributors @@ -7,24 +10,19 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Adapter DSL on top of Text.PrettyPrint for Applicative pretty-printing. --- ------------------------------------------------------------------------------ - -{-# OPTIONS_GHC -fno-warn-orphans #-} - module Disco.Pretty.DSL where -import Control.Applicative hiding (empty) -import Data.String (IsString (..)) -import Prelude hiding ((<>)) +import Control.Applicative hiding (empty) +import Data.String (IsString (..)) +import Prelude hiding ((<>)) -import Polysemy -import Polysemy.Reader +import Polysemy +import Polysemy.Reader -import Text.PrettyPrint (Doc) -import qualified Text.PrettyPrint as PP +import Text.PrettyPrint (Doc) +import qualified Text.PrettyPrint as PP -import Disco.Pretty.Prec +import Disco.Pretty.Prec instance IsString (Sem r Doc) where fromString = text @@ -36,16 +34,16 @@ instance IsString (Sem r Doc) where -- operates over a generic functor/monad. vcat :: Applicative f => [f Doc] -> f Doc -vcat ds = PP.vcat <$> sequenceA ds +vcat ds = PP.vcat <$> sequenceA ds hcat :: Applicative f => [f Doc] -> f Doc -hcat ds = PP.hcat <$> sequenceA ds +hcat ds = PP.hcat <$> sequenceA ds hsep :: Applicative f => [f Doc] -> f Doc -hsep ds = PP.hsep <$> sequenceA ds +hsep ds = PP.hsep <$> sequenceA ds parens :: Functor f => f Doc -> f Doc -parens = fmap PP.parens +parens = fmap PP.parens brackets :: Functor f => f Doc -> f Doc brackets = fmap PP.brackets @@ -63,10 +61,10 @@ doubleQuotes :: Functor f => f Doc -> f Doc doubleQuotes = fmap PP.doubleQuotes text :: Applicative m => String -> m Doc -text = pure . PP.text +text = pure . PP.text integer :: Applicative m => Integer -> m Doc -integer = pure . PP.integer +integer = pure . PP.integer nest :: Functor f => Int -> f Doc -> f Doc nest n d = PP.nest n <$> d @@ -75,13 +73,13 @@ hang :: Applicative f => f Doc -> Int -> f Doc -> f Doc hang d1 n d2 = PP.hang <$> d1 <*> pure n <*> d2 empty :: Applicative m => m Doc -empty = pure PP.empty +empty = pure PP.empty (<+>) :: Applicative f => f Doc -> f Doc -> f Doc (<+>) = liftA2 (PP.<+>) (<>) :: Applicative f => f Doc -> f Doc -> f Doc -(<>) = liftA2 (PP.<>) +(<>) = liftA2 (PP.<>) ($+$) :: Applicative f => f Doc -> f Doc -> f Doc ($+$) = liftA2 (PP.$+$) diff --git a/src/Disco/Pretty/Prec.hs b/src/Disco/Pretty/Prec.hs index b11f8d2f..0df41afe 100644 --- a/src/Disco/Pretty/Prec.hs +++ b/src/Disco/Pretty/Prec.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Pretty.Prec -- Copyright : disco team and contributors @@ -7,12 +10,9 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Precedence and associativity for pretty-printing. --- ------------------------------------------------------------------------------ - module Disco.Pretty.Prec where -import Disco.Syntax.Operators +import Disco.Syntax.Operators -- Types for storing precedence + associativity together diff --git a/src/Disco/Property.hs b/src/Disco/Property.hs index 769f77f6..bdbd1092 100644 --- a/src/Disco/Property.hs +++ b/src/Disco/Property.hs @@ -1,5 +1,7 @@ - ----------------------------------------------------------------------------- +----------------------------------------------------------------------------- +{-# LANGUAGE OverloadedStrings #-} + -- | -- Module : Disco.Property -- Copyright : disco team and contributors @@ -8,42 +10,37 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Properties of disco functions. --- ------------------------------------------------------------------------------ - -{-# LANGUAGE OverloadedStrings #-} - -module Disco.Property - ( +module Disco.Property ( -- * Generation - generateSamples + generateSamples, -- * Utility - , invertMotive, invertPropResult + invertMotive, + invertPropResult, -- * Pretty-printing - , prettyTestResult - ) - where + prettyTestResult, +) +where -import Prelude hiding ((<>)) +import Prelude hiding ((<>)) -import Data.Char (toLower) +import Data.Char (toLower) import qualified Data.Enumeration.Invertible as E -import Disco.Effects.Random -import Polysemy +import Disco.Effects.Random +import Polysemy -import Disco.AST.Typed -import Disco.Effects.Input -import Disco.Effects.LFresh -import Disco.Error -import Disco.Pretty -import Disco.Syntax.Prims -import Disco.Typecheck.Erase (eraseProperty) -import Disco.Types (TyDefCtx) -import Disco.Value -import Polysemy.Reader +import Disco.AST.Typed +import Disco.Effects.Input +import Disco.Effects.LFresh +import Disco.Error +import Disco.Pretty +import Disco.Syntax.Prims +import Disco.Typecheck.Erase (eraseProperty) +import Disco.Types (TyDefCtx) +import Disco.Value +import Polysemy.Reader -- | Toggles which outcome (finding or not finding the thing being -- searched for) qualifies as success, without changing the thing @@ -56,25 +53,25 @@ invertMotive (SearchMotive (a, b)) = SearchMotive (not a, b) invertPropResult :: TestResult -> TestResult invertPropResult res@(TestResult b r env) | TestRuntimeError _ <- r = res - | otherwise = TestResult (not b) r env + | otherwise = TestResult (not b) r env randomLarge :: Member Random r => [Integer] -> Sem r [Integer] -randomLarge [] = return [] -randomLarge [_] = return [] +randomLarge [] = return [] +randomLarge [_] = return [] randomLarge (x : y : xs) = (:) <$> randomR (x, y) <*> randomLarge (y : xs) -- | Select samples from an enumeration according to a search type. Also returns -- a 'SearchType' describing the results, which may be 'Exhaustive' if the -- enumeration is no larger than the number of samples requested. generateSamples :: Member Random r => SearchType -> E.IEnumeration a -> Sem r ([a], SearchType) -generateSamples Exhaustive e = return (E.enumerate e, Exhaustive) +generateSamples Exhaustive e = return (E.enumerate e, Exhaustive) generateSamples (Randomized n m) e | E.Finite k <- E.card e, k <= n + m = return (E.enumerate e, Exhaustive) - | otherwise = do - let small = [0 .. n] - rs <- randomLarge [100, 1000, 10000, 100000, 1000000] - let samples = map (E.select e) $ small ++ rs - return (samples, Randomized n m) + | otherwise = do + let small = [0 .. n] + rs <- randomLarge [100, 1000, 10000, 100000, 1000000] + let samples = map (E.select e) $ small ++ rs + return (samples, Randomized n m) -- XXX do shrinking for randomly generated test cases? @@ -83,12 +80,15 @@ generateSamples (Randomized n m) e ------------------------------------------------------------ prettyResultCertainty :: Members '[LFresh, Reader PA] r => TestReason -> AProperty -> String -> Sem r Doc -prettyResultCertainty r prop res - = (if resultIsCertain r then "Certainly" else "Possibly") <+> text res <> ":" <+> pretty (eraseProperty prop) - -prettyTestReason - :: Members '[Input TyDefCtx, LFresh, Reader PA] r - => Bool -> AProperty -> TestReason -> Sem r Doc +prettyResultCertainty r prop res = + (if resultIsCertain r then "Certainly" else "Possibly") <+> text res <> ":" <+> pretty (eraseProperty prop) + +prettyTestReason :: + Members '[Input TyDefCtx, LFresh, Reader PA] r => + Bool -> + AProperty -> + TestReason -> + Sem r Doc prettyTestReason _ _ TestBool = empty prettyTestReason b prop (TestFound (TestResult _ tr env)) | b = prettyTestEnv "Found example:" env @@ -100,28 +100,29 @@ prettyTestReason b _ (TestNotFound (Randomized n m)) | b = "Checked" <+> text (show (n + m)) <+> "possibilities without finding a counterexample." | not b = "No example was found; checked" <+> text (show (n + m)) <+> "possibilities." prettyTestReason _ _ (TestEqual t a1 a2) = - bulletList "-" - [ "Left side: " <> prettyValue t a1 - , "Right side: " <> prettyValue t a2 - ] + bulletList + "-" + [ "Left side: " <> prettyValue t a1 + , "Right side: " <> prettyValue t a2 + ] prettyTestReason _ _ (TestLt t a1 a2) = - bulletList "-" - [ "Left side: " <> prettyValue t a1 - , "Right side: " <> prettyValue t a2 - ] + bulletList + "-" + [ "Left side: " <> prettyValue t a1 + , "Right side: " <> prettyValue t a2 + ] prettyTestReason _ _ (TestRuntimeError ee) = "Test failed with an error:" - $+$ - nest 2 (pretty (EvalErr ee)) - -- $+$ - -- prettyTestEnv "Example inputs that caused the error:" env - -- See #364 + $+$ nest 2 (pretty (EvalErr ee)) +-- \$+$ +-- prettyTestEnv "Example inputs that caused the error:" env +-- See #364 prettyTestReason b (ATApp _ (ATPrim _ (PrimBOp _)) (ATTup _ [p1, p2])) (TestBin _ tr1 tr2) = - bulletList "-" - [ "Left side:" $+$ nest 2 (prettyTestResult' b p1 tr1) - , "Right side:" $+$ nest 2 (prettyTestResult' b p2 tr2) - ] - + bulletList + "-" + [ "Left side:" $+$ nest 2 (prettyTestResult' b p1 tr1) + , "Right side:" $+$ nest 2 (prettyTestResult' b p2 tr2) + ] -- See Note [prettyTestReason fallback] prettyTestReason _ _ _ = empty @@ -145,25 +146,31 @@ prettyTestReason _ _ _ = empty -- of the test result. So we just give up and decline to print a -- reason. -prettyTestResult' - :: Members '[Input TyDefCtx, LFresh, Reader PA] r - => Bool -> AProperty -> TestResult -> Sem r Doc +prettyTestResult' :: + Members '[Input TyDefCtx, LFresh, Reader PA] r => + Bool -> + AProperty -> + TestResult -> + Sem r Doc prettyTestResult' _ prop (TestResult bool tr _) = prettyResultCertainty tr prop (map toLower (show bool)) - $+$ - prettyTestReason bool prop tr + $+$ prettyTestReason bool prop tr -prettyTestResult - :: Members '[Input TyDefCtx, LFresh, Reader PA] r - => AProperty -> TestResult -> Sem r Doc +prettyTestResult :: + Members '[Input TyDefCtx, LFresh, Reader PA] r => + AProperty -> + TestResult -> + Sem r Doc prettyTestResult prop (TestResult b r env) = prettyTestResult' b prop (TestResult b r env) -prettyTestEnv - :: Members '[Input TyDefCtx, LFresh, Reader PA] r - => String -> TestEnv -> Sem r Doc +prettyTestEnv :: + Members '[Input TyDefCtx, LFresh, Reader PA] r => + String -> + TestEnv -> + Sem r Doc prettyTestEnv _ (TestEnv []) = empty prettyTestEnv s (TestEnv vs) = text s $+$ nest 2 (vcat (map prettyBind vs)) - where - maxNameLen = maximum . map (\(n, _, _) -> length n) $ vs - prettyBind (x, ty, v) = - text x <> text (replicate (maxNameLen - length x) ' ') <+> "=" <+> prettyValue ty v + where + maxNameLen = maximum . map (\(n, _, _) -> length n) $ vs + prettyBind (x, ty, v) = + text x <> text (replicate (maxNameLen - length x) ' ') <+> "=" <+> prettyValue ty v diff --git a/src/Disco/Report.hs b/src/Disco/Report.hs index 51a1a51c..cd24929a 100644 --- a/src/Disco/Report.hs +++ b/src/Disco/Report.hs @@ -1,14 +1,5 @@ - ----------------------------------------------------------------------------- --- | --- Module : Disco.Report --- Copyright : disco team and contributors --- Maintainer : byorgey@gmail.com --- --- SPDX-License-Identifier: BSD-3-Clause --- --- XXX --- + ----------------------------------------------------------------------------- -- The benefit of having our own deeply-embedded type for pretty @@ -18,16 +9,24 @@ -- interface of the pretty-printing library currently being used, so -- that a lot of code could just be kept unchanged. +-- | +-- Module : Disco.Report +-- Copyright : disco team and contributors +-- Maintainer : byorgey@gmail.com +-- +-- SPDX-License-Identifier: BSD-3-Clause +-- +-- XXX module Disco.Report where -import Data.List (intersperse) +import Data.List (intersperse) data Report - = RTxt String - | RSeq [Report] - | RVSeq [Report] - | RList [Report] - | RNest Report + = RTxt String + | RSeq [Report] + | RVSeq [Report] + | RList [Report] + | RNest Report deriving (Show) text :: String -> Report @@ -52,4 +51,3 @@ nest :: Report -> Report nest = RNest ------------------------------------------------------------ - diff --git a/src/Disco/Subst.hs b/src/Disco/Subst.hs index d7005800..23e221cd 100644 --- a/src/Disco/Subst.hs +++ b/src/Disco/Subst.hs @@ -1,6 +1,11 @@ {-# LANGUAGE OverloadedStrings #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Subst -- Copyright : disco team and contributors @@ -8,41 +13,39 @@ -- -- The "Disco.Subst" module defines a generic type of substitutions -- that map variable names to values. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - -module Disco.Subst - ( -- * Substitutions - - Substitution(..), dom - - -- ** Constructing/destructing substitutions - - , idS, (|->), fromList, toList - - -- ** Substitution operations +module Disco.Subst ( + -- * Substitutions + Substitution (..), + dom, - , (@@), compose, applySubst, lookup + -- ** Constructing/destructing substitutions + idS, + (|->), + fromList, + toList, - ) - where + -- ** Substitution operations + (@@), + compose, + applySubst, + lookup, +) +where -import Prelude hiding (lookup) +import Prelude hiding (lookup) -import Unbound.Generics.LocallyNameless (Name, Subst, substs) +import Unbound.Generics.LocallyNameless (Name, Subst, substs) -import Data.Coerce +import Data.Coerce -import Data.Map (Map) -import qualified Data.Map as M -import Data.Set (Set) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Set (Set) -import Disco.Effects.LFresh -import Disco.Pretty -import Polysemy -import Polysemy.Reader +import Disco.Effects.LFresh +import Disco.Pretty +import Polysemy +import Polysemy.Reader -- | A value of type @Substitution a@ is a substitution which maps some set of -- names (the /domain/, see 'dom') to values of type @a@. @@ -57,7 +60,7 @@ import Polysemy.Reader -- See also "Disco.Types", which defines 'S' as an alias for -- substitutions on types (the most common kind in the disco -- codebase). -newtype Substitution a = Substitution { getSubst :: Map (Name a) a } +newtype Substitution a = Substitution {getSubst :: Map (Name a) a} deriving (Eq, Ord, Show) instance Functor Substitution where diff --git a/src/Disco/Syntax/Operators.hs b/src/Disco/Syntax/Operators.hs index 7a17c595..6d0c90ca 100644 --- a/src/Disco/Syntax/Operators.hs +++ b/src/Disco/Syntax/Operators.hs @@ -1,6 +1,12 @@ -{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveDataTypeable #-} + +----------------------------------------------------------------------------- + ----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Syntax.Operators -- Copyright : disco team and contributors @@ -8,78 +14,121 @@ -- -- Unary and binary operators along with information like precedence, -- fixity, and concrete syntax. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause +module Disco.Syntax.Operators ( + -- * Operators + UOp (..), + BOp (..), + TyOp (..), -module Disco.Syntax.Operators - ( -- * Operators - UOp(..), BOp(..), TyOp(..) + -- * Operator info + UFixity (..), + BFixity (..), + OpFixity (..), + OpInfo (..), - -- * Operator info - , UFixity(..), BFixity(..), OpFixity(..), OpInfo(..) + -- * Operator tables and lookup + opTable, + uopMap, + bopMap, + uPrec, + bPrec, + assoc, + funPrec, +) where - -- * Operator tables and lookup - , opTable, uopMap, bopMap - , uPrec, bPrec, assoc, funPrec +import Data.Data (Data) +import GHC.Generics (Generic) +import Unbound.Generics.LocallyNameless - ) where - -import Data.Data (Data) -import GHC.Generics (Generic) -import Unbound.Generics.LocallyNameless - -import Data.Map (Map, (!)) -import qualified Data.Map as M +import Data.Map (Map, (!)) +import qualified Data.Map as M ------------------------------------------------------------ -- Operators ------------------------------------------------------------ -- | Unary operators. -data UOp = Neg -- ^ Arithmetic negation (@-@) - | Not -- ^ Logical negation (@not@) - | Fact -- ^ Factorial (@!@) +data UOp + = -- | Arithmetic negation (@-@) + Neg + | -- | Logical negation (@not@) + Not + | -- | Factorial (@!@) + Fact deriving (Show, Read, Eq, Ord, Generic, Data, Alpha, Subst t) -- | Binary operators. -data BOp = Add -- ^ Addition (@+@) - | Sub -- ^ Subtraction (@-@) - | SSub -- ^ Saturating Subtraction (@.-@ / @∸@) - | Mul -- ^ Multiplication (@*@) - | Div -- ^ Division (@/@) - | Exp -- ^ Exponentiation (@^@) - | IDiv -- ^ Integer division (@//@) - | Eq -- ^ Equality test (@==@) - | Neq -- ^ Not-equal (@/=@) - | Lt -- ^ Less than (@<@) - | Gt -- ^ Greater than (@>@) - | Leq -- ^ Less than or equal (@<=@) - | Geq -- ^ Greater than or equal (@>=@) - | Min -- ^ Minimum (@min@) - | Max -- ^ Maximum (@max@) - | And -- ^ Logical and (@&&@ / @and@) - | Or -- ^ Logical or (@||@ / @or@) - | Impl -- ^ Logical implies (@->@ / @implies@) - | Iff -- ^ Logical biconditional (@<->@ / @iff@) - | Mod -- ^ Modulo (@mod@) - | Divides -- ^ Divisibility test (@|@) - | Choose -- ^ Binomial and multinomial coefficients (@choose@) - | Cons -- ^ List cons (@::@) - | CartProd -- ^ Cartesian product of sets (@**@ / @⨯@) - | Union -- ^ Union of two sets (@union@ / @∪@) - | Inter -- ^ Intersection of two sets (@intersect@ / @∩@) - | Diff -- ^ Difference between two sets (@\@) - | Elem -- ^ Element test (@∈@) - | Subset -- ^ Subset test (@⊆@) - | ShouldEq -- ^ Equality assertion (@=!=@) - | ShouldLt -- ^ Less than assertion (@!<@) +data BOp + = -- | Addition (@+@) + Add + | -- | Subtraction (@-@) + Sub + | -- | Saturating Subtraction (@.-@ / @∸@) + SSub + | -- | Multiplication (@*@) + Mul + | -- | Division (@/@) + Div + | -- | Exponentiation (@^@) + Exp + | -- | Integer division (@//@) + IDiv + | -- | Equality test (@==@) + Eq + | -- | Not-equal (@/=@) + Neq + | -- | Less than (@<@) + Lt + | -- | Greater than (@>@) + Gt + | -- | Less than or equal (@<=@) + Leq + | -- | Greater than or equal (@>=@) + Geq + | -- | Minimum (@min@) + Min + | -- | Maximum (@max@) + Max + | -- | Logical and (@&&@ / @and@) + And + | -- | Logical or (@||@ / @or@) + Or + | -- | Logical implies (@->@ / @implies@) + Impl + | -- | Logical biconditional (@<->@ / @iff@) + Iff + | -- | Modulo (@mod@) + Mod + | -- | Divisibility test (@|@) + Divides + | -- | Binomial and multinomial coefficients (@choose@) + Choose + | -- | List cons (@::@) + Cons + | -- | Cartesian product of sets (@**@ / @⨯@) + CartProd + | -- | Union of two sets (@union@ / @∪@) + Union + | -- | Intersection of two sets (@intersect@ / @∩@) + Inter + | -- | Difference between two sets (@\@) + Diff + | -- | Element test (@∈@) + Elem + | -- | Subset test (@⊆@) + Subset + | -- | Equality assertion (@=!=@) + ShouldEq + | -- | Less than assertion (@!<@) + ShouldLt deriving (Show, Read, Eq, Ord, Generic, Data, Alpha, Subst t) -- | Type operators. -data TyOp = Enumerate -- ^ List all values of a type - | Count -- ^ Count how many values there are of a type +data TyOp + = -- | List all values of a type + Enumerate + | -- | Count how many values there are of a type + Count deriving (Show, Eq, Ord, Generic, Data, Alpha, Subst t) ------------------------------------------------------------ @@ -88,34 +137,38 @@ data TyOp = Enumerate -- ^ List all values of a type -- | Fixities of unary operators (either pre- or postfix). data UFixity - = Pre -- ^ Unary prefix. - | Post -- ^ Unary postfix. + = -- | Unary prefix. + Pre + | -- | Unary postfix. + Post deriving (Eq, Ord, Enum, Bounded, Show, Generic) -- | Fixity/associativity of infix binary operators (either left, -- right, or non-associative). data BFixity - = InL -- ^ Left-associative infix. - | InR -- ^ Right-associative infix. - | In -- ^ Infix. + = -- | Left-associative infix. + InL + | -- | Right-associative infix. + InR + | -- | Infix. + In deriving (Eq, Ord, Enum, Bounded, Show, Generic) -- | Operators together with their fixity. -data OpFixity = - UOpF UFixity UOp +data OpFixity + = UOpF UFixity UOp | BOpF BFixity BOp deriving (Eq, Show, Generic) -- | An @OpInfo@ record contains information about an operator, such -- as the operator itself, its fixity, a list of concrete syntax -- representations, and a numeric precedence level. -data OpInfo = - OpInfo +data OpInfo = OpInfo { opFixity :: OpFixity - , opSyns :: [String] - , opPrec :: Int + , opSyns :: [String] + , opPrec :: Int } - deriving Show + deriving (Show) ------------------------------------------------------------ -- Operator table @@ -128,75 +181,92 @@ data OpInfo = opTable :: [[OpInfo]] opTable = assignPrecLevels - [ [ uopInfo Pre Not ["not", "¬"] - ] - , [ uopInfo Post Fact ["!"] - ] - , [ bopInfo InR Exp ["^"] - ] - , [ uopInfo Pre Neg ["-"] - ] - , [ bopInfo In Choose ["choose"] - ] - , [ bopInfo InR CartProd ["><", "⨯"] - ] - , [ bopInfo InL Union ["union", "∪"] - , bopInfo InL Inter ["intersect", "∩"] - , bopInfo InL Diff ["\\"] - ] - , [ bopInfo InL Min ["min"] - , bopInfo InL Max ["max"] - ] - , [ bopInfo InL Mul ["*"] - , bopInfo InL Div ["/"] - , bopInfo InL Mod ["mod", "%"] - , bopInfo InL IDiv ["//"] - ] - , [ bopInfo InL Add ["+"] - , bopInfo InL Sub ["-"] - , bopInfo InL SSub [".-", "∸"] - ] - , [ bopInfo InR Cons ["::"] - ] - , [ bopInfo InR Eq ["=="] - , bopInfo InR ShouldEq ["=!="] - , bopInfo InR ShouldLt ["!<"] - , bopInfo InR Neq ["/=", "≠", "!="] - , bopInfo InR Lt ["<"] - , bopInfo InR Gt [">"] - , bopInfo InR Leq ["<=", "≤", "=<"] - , bopInfo InR Geq [">=", "≥", "=>"] - , bopInfo InR Divides ["divides"] - , bopInfo InL Subset ["subset", "⊆"] - , bopInfo InL Elem ["elem", "∈"] - ] - , [ bopInfo InR And ["/\\", "and", "∧", "&&"] - ] - , [ bopInfo InR Or ["\\/", "or", "∨", "||"] - ] - , [ bopInfo InR Impl ["->", "==>", "→", "implies"] - , bopInfo InR Iff ["<->", "<==>", "↔", "iff"] + [ + [ uopInfo Pre Not ["not", "¬"] + ] + , + [ uopInfo Post Fact ["!"] + ] + , + [ bopInfo InR Exp ["^"] + ] + , + [ uopInfo Pre Neg ["-"] + ] + , + [ bopInfo In Choose ["choose"] + ] + , + [ bopInfo InR CartProd ["><", "⨯"] + ] + , + [ bopInfo InL Union ["union", "∪"] + , bopInfo InL Inter ["intersect", "∩"] + , bopInfo InL Diff ["\\"] + ] + , + [ bopInfo InL Min ["min"] + , bopInfo InL Max ["max"] + ] + , + [ bopInfo InL Mul ["*"] + , bopInfo InL Div ["/"] + , bopInfo InL Mod ["mod", "%"] + , bopInfo InL IDiv ["//"] + ] + , + [ bopInfo InL Add ["+"] + , bopInfo InL Sub ["-"] + , bopInfo InL SSub [".-", "∸"] + ] + , + [ bopInfo InR Cons ["::"] + ] + , + [ bopInfo InR Eq ["=="] + , bopInfo InR ShouldEq ["=!="] + , bopInfo InR ShouldLt ["!<"] + , bopInfo InR Neq ["/=", "≠", "!="] + , bopInfo InR Lt ["<"] + , bopInfo InR Gt [">"] + , bopInfo InR Leq ["<=", "≤", "=<"] + , bopInfo InR Geq [">=", "≥", "=>"] + , bopInfo InR Divides ["divides"] + , bopInfo InL Subset ["subset", "⊆"] + , bopInfo InL Elem ["elem", "∈"] + ] + , + [ bopInfo InR And ["/\\", "and", "∧", "&&"] + ] + , + [ bopInfo InR Or ["\\/", "or", "∨", "||"] + ] + , + [ bopInfo InR Impl ["->", "==>", "→", "implies"] + , bopInfo InR Iff ["<->", "<==>", "↔", "iff"] + ] ] - ] - where - uopInfo fx op syns = OpInfo (UOpF fx op) syns (-1) - bopInfo fx op syns = OpInfo (BOpF fx op) syns (-1) + where + uopInfo fx op syns = OpInfo (UOpF fx op) syns (-1) + bopInfo fx op syns = OpInfo (BOpF fx op) syns (-1) - -- Start at precedence level 2 so we can give level 1 to ascription, and level 0 - -- to the ambient context + parentheses etc. - assignPrecLevels table = zipWith assignPrecs (reverse [2 .. length table+1]) table - assignPrec p op = op { opPrec = p } - assignPrecs p = map (assignPrec p) + -- Start at precedence level 2 so we can give level 1 to ascription, and level 0 + -- to the ambient context + parentheses etc. + assignPrecLevels table = zipWith assignPrecs (reverse [2 .. length table + 1]) table + assignPrec p op = op {opPrec = p} + assignPrecs p = map (assignPrec p) -- | A map from all unary operators to their associated 'OpInfo' records. uopMap :: Map UOp OpInfo -uopMap = M.fromList $ - [ (op, info) | opLevel <- opTable, info@(OpInfo (UOpF _ op) _ _) <- opLevel ] +uopMap = + M.fromList $ + [(op, info) | opLevel <- opTable, info@(OpInfo (UOpF _ op) _ _) <- opLevel] -- | A map from all binary operators to their associatied 'OpInfo' records. bopMap :: Map BOp OpInfo -bopMap = M.fromList $ - [ (op, info) | opLevel <- opTable, info@(OpInfo (BOpF _ op) _ _) <- opLevel ] +bopMap = + M.fromList $ + [(op, info) | opLevel <- opTable, info@(OpInfo (BOpF _ op) _ _) <- opLevel] -- | A convenient function for looking up the precedence of a unary operator. uPrec :: UOp -> Int @@ -211,9 +281,9 @@ assoc :: BOp -> BFixity assoc op = case M.lookup op bopMap of Just (OpInfo (BOpF fx _) _ _) -> fx - _ -> error $ "BOp " ++ show op ++ " not in bopMap!" + _ -> error $ "BOp " ++ show op ++ " not in bopMap!" -- | The precedence level of function application (higher than any -- other precedence level). funPrec :: Int -funPrec = length opTable+1 +funPrec = length opTable + 1 diff --git a/src/Disco/Syntax/Prims.hs b/src/Disco/Syntax/Prims.hs index c455a3f6..29e46d6a 100644 --- a/src/Disco/Syntax/Prims.hs +++ b/src/Disco/Syntax/Prims.hs @@ -1,6 +1,12 @@ -{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveDataTypeable #-} + +----------------------------------------------------------------------------- + ----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Syntax.Prims -- Copyright : disco team and contributors @@ -8,24 +14,22 @@ -- -- Concrete syntax for the prims (i.e. built-in constants) supported -- by the language. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - -module Disco.Syntax.Prims - ( Prim(..) - , PrimInfo(..), primTable, toPrim, primMap - ) where +module Disco.Syntax.Prims ( + Prim (..), + PrimInfo (..), + primTable, + toPrim, + primMap, +) where -import GHC.Generics (Generic) -import Unbound.Generics.LocallyNameless +import GHC.Generics (Generic) +import Unbound.Generics.LocallyNameless -import Data.Map (Map) -import qualified Data.Map as M +import Data.Map (Map) +import qualified Data.Map as M -import Data.Data (Data) -import Disco.Syntax.Operators +import Data.Data (Data) +import Disco.Syntax.Operators ------------------------------------------------------------ -- Prims @@ -33,57 +37,120 @@ import Disco.Syntax.Operators -- | Primitives, /i.e./ built-in constants. data Prim where - PrimUOp :: UOp -> Prim -- ^ Unary operator - PrimBOp :: BOp -> Prim -- ^ Binary operator - - PrimLeft :: Prim -- ^ Left injection into a sum type. - PrimRight :: Prim -- ^ Right injection into a sum type. - - PrimSqrt :: Prim -- ^ Integer square root (@sqrt@) - PrimFloor :: Prim -- ^ Floor of fractional type (@floor@) - PrimCeil :: Prim -- ^ Ceiling of fractional type (@ceiling@) - PrimAbs :: Prim -- ^ Absolute value (@abs@) - - PrimPower :: Prim -- ^ Power set (XXX or bag?) - - PrimList :: Prim -- ^ Container -> list conversion - PrimBag :: Prim -- ^ Container -> bag conversion - PrimSet :: Prim -- ^ Container -> set conversion - - PrimB2C :: Prim -- ^ bag -> set of counts conversion - PrimC2B :: Prim -- ^ set of counts -> bag conversion - PrimUC2B :: Prim -- ^ unsafe set of counts -> bag conversion - -- that assumes all distinct - PrimMapToSet :: Prim -- ^ Map k v -> Set (k × v) - PrimSetToMap :: Prim -- ^ Set (k × v) -> Map k v - - PrimSummary :: Prim -- ^ Get Adjacency list of Graph - PrimVertex :: Prim -- ^ Construct a graph Vertex - PrimEmptyGraph :: Prim -- ^ Empty graph - PrimOverlay :: Prim -- ^ Overlay two Graphs - PrimConnect :: Prim -- ^ Connect Graph to another with directed edges - - PrimInsert :: Prim -- ^ Insert into map - PrimLookup :: Prim -- ^ Get value associated with key in map - - PrimEach :: Prim -- ^ Each operation for containers - PrimReduce :: Prim -- ^ Reduce operation for containers - PrimFilter :: Prim -- ^ Filter operation for containers - PrimJoin :: Prim -- ^ Monadic join for containers - PrimMerge :: Prim -- ^ Generic merge operation for bags/sets - - PrimIsPrime :: Prim -- ^ Efficient primality test - PrimFactor :: Prim -- ^ Factorization - PrimFrac :: Prim -- ^ Turn a rational into a pair (num, denom) - - PrimCrash :: Prim -- ^ Crash - - PrimUntil :: Prim -- ^ @[x, y, z .. e]@ - - PrimHolds :: Prim -- ^ Test whether a proposition holds - - PrimLookupSeq :: Prim -- ^ Lookup OEIS sequence - PrimExtendSeq :: Prim -- ^ Extend OEIS sequence + PrimUOp :: + UOp -> + -- | Unary operator + Prim + PrimBOp :: + BOp -> + -- | Binary operator + Prim + PrimLeft :: + -- | Left injection into a sum type. + Prim + PrimRight :: + -- | Right injection into a sum type. + Prim + PrimSqrt :: + -- | Integer square root (@sqrt@) + Prim + PrimFloor :: + -- | Floor of fractional type (@floor@) + Prim + PrimCeil :: + -- | Ceiling of fractional type (@ceiling@) + Prim + PrimAbs :: + -- | Absolute value (@abs@) + Prim + PrimPower :: + -- | Power set (XXX or bag?) + Prim + PrimList :: + -- | Container -> list conversion + Prim + PrimBag :: + -- | Container -> bag conversion + Prim + PrimSet :: + -- | Container -> set conversion + Prim + PrimB2C :: + -- | bag -> set of counts conversion + Prim + PrimC2B :: + -- | set of counts -> bag conversion + Prim + PrimUC2B :: + -- | unsafe set of counts -> bag conversion + -- that assumes all distinct + Prim + PrimMapToSet :: + -- | Map k v -> Set (k × v) + Prim + PrimSetToMap :: + -- | Set (k × v) -> Map k v + Prim + PrimSummary :: + -- | Get Adjacency list of Graph + Prim + PrimVertex :: + -- | Construct a graph Vertex + Prim + PrimEmptyGraph :: + -- | Empty graph + Prim + PrimOverlay :: + -- | Overlay two Graphs + Prim + PrimConnect :: + -- | Connect Graph to another with directed edges + Prim + PrimInsert :: + -- | Insert into map + Prim + PrimLookup :: + -- | Get value associated with key in map + Prim + PrimEach :: + -- | Each operation for containers + Prim + PrimReduce :: + -- | Reduce operation for containers + Prim + PrimFilter :: + -- | Filter operation for containers + Prim + PrimJoin :: + -- | Monadic join for containers + Prim + PrimMerge :: + -- | Generic merge operation for bags/sets + Prim + PrimIsPrime :: + -- | Efficient primality test + Prim + PrimFactor :: + -- | Factorization + Prim + PrimFrac :: + -- | Turn a rational into a pair (num, denom) + Prim + PrimCrash :: + -- | Crash + Prim + PrimUntil :: + -- | @[x, y, z .. e]@ + Prim + PrimHolds :: + -- | Test whether a proposition holds + Prim + PrimLookupSeq :: + -- | Lookup OEIS sequence + Prim + PrimExtendSeq :: + -- | Extend OEIS sequence + Prim deriving (Show, Read, Eq, Ord, Generic, Alpha, Subst t, Data) ------------------------------------------------------------ @@ -96,84 +163,72 @@ data Prim where -- the basic language. Unexposed prims can only be referenced by -- enabling the Primitives language extension and prefixing their -- name by @$@. -data PrimInfo = - PrimInfo - { thePrim :: Prim - , primSyntax :: String +data PrimInfo = PrimInfo + { thePrim :: Prim + , primSyntax :: String , primExposed :: Bool - -- Is the prim available in the normal syntax of the language? - -- - -- primExposed = True means that the bare primSyntax can be used - -- in the surface syntax, and the prim will be pretty-printed as - -- the primSyntax. - -- - -- primExposed = False means that the only way to enter it is to - -- enable the Primitives language extension and write a $ - -- followed by the primSyntax. The prim will be pretty-printed with a $ - -- prefix. - -- - -- In no case is a prim a reserved word. + -- Is the prim available in the normal syntax of the language? + -- + -- primExposed = True means that the bare primSyntax can be used + -- in the surface syntax, and the prim will be pretty-printed as + -- the primSyntax. + -- + -- primExposed = False means that the only way to enter it is to + -- enable the Primitives language extension and write a $ + -- followed by the primSyntax. The prim will be pretty-printed with a $ + -- prefix. + -- + -- In no case is a prim a reserved word. } -- | A table containing a 'PrimInfo' record for every non-operator -- 'Prim' recognized by the language. primTable :: [PrimInfo] primTable = - [ PrimInfo PrimLeft "left" True - , PrimInfo PrimRight "right" True - - , PrimInfo (PrimUOp Not) "not" True - , PrimInfo PrimSqrt "sqrt" True - , PrimInfo PrimFloor "floor" True - , PrimInfo PrimCeil "ceiling" True - , PrimInfo PrimAbs "abs" True - - , PrimInfo PrimPower "power" True - - , PrimInfo PrimList "list" True - , PrimInfo PrimBag "bag" True - , PrimInfo PrimSet "set" True - - , PrimInfo PrimB2C "bagCounts" True - , PrimInfo PrimC2B "bagFromCounts" True - , PrimInfo PrimUC2B "unsafeBagFromCounts" False - , PrimInfo PrimMapToSet "mapToSet" True - , PrimInfo PrimSetToMap "map" True - - , PrimInfo PrimSummary "summary" True - , PrimInfo PrimVertex "vertex" True - , PrimInfo PrimEmptyGraph "emptyGraph" True - , PrimInfo PrimOverlay "overlay" True - , PrimInfo PrimConnect "connect" True - - , PrimInfo PrimInsert "insert" True - , PrimInfo PrimLookup "lookup" True - - , PrimInfo PrimEach "each" True - , PrimInfo PrimReduce "reduce" True - , PrimInfo PrimFilter "filter" True - , PrimInfo PrimJoin "join" False - , PrimInfo PrimMerge "merge" False - - , PrimInfo PrimIsPrime "isPrime" False - , PrimInfo PrimFactor "factor" False - , PrimInfo PrimFrac "frac" False - - , PrimInfo PrimCrash "crash" False - - , PrimInfo PrimUntil "until" False - - , PrimInfo PrimHolds "holds" True - + [ PrimInfo PrimLeft "left" True + , PrimInfo PrimRight "right" True + , PrimInfo (PrimUOp Not) "not" True + , PrimInfo PrimSqrt "sqrt" True + , PrimInfo PrimFloor "floor" True + , PrimInfo PrimCeil "ceiling" True + , PrimInfo PrimAbs "abs" True + , PrimInfo PrimPower "power" True + , PrimInfo PrimList "list" True + , PrimInfo PrimBag "bag" True + , PrimInfo PrimSet "set" True + , PrimInfo PrimB2C "bagCounts" True + , PrimInfo PrimC2B "bagFromCounts" True + , PrimInfo PrimUC2B "unsafeBagFromCounts" False + , PrimInfo PrimMapToSet "mapToSet" True + , PrimInfo PrimSetToMap "map" True + , PrimInfo PrimSummary "summary" True + , PrimInfo PrimVertex "vertex" True + , PrimInfo PrimEmptyGraph "emptyGraph" True + , PrimInfo PrimOverlay "overlay" True + , PrimInfo PrimConnect "connect" True + , PrimInfo PrimInsert "insert" True + , PrimInfo PrimLookup "lookup" True + , PrimInfo PrimEach "each" True + , PrimInfo PrimReduce "reduce" True + , PrimInfo PrimFilter "filter" True + , PrimInfo PrimJoin "join" False + , PrimInfo PrimMerge "merge" False + , PrimInfo PrimIsPrime "isPrime" False + , PrimInfo PrimFactor "factor" False + , PrimInfo PrimFrac "frac" False + , PrimInfo PrimCrash "crash" False + , PrimInfo PrimUntil "until" False + , PrimInfo PrimHolds "holds" True , PrimInfo PrimLookupSeq "lookupSequence" False , PrimInfo PrimExtendSeq "extendSequence" False ] -- | Find any exposed prims with the given name. toPrim :: String -> [Prim] -toPrim x = [ p | PrimInfo p syn True <- primTable, syn == x ] +toPrim x = [p | PrimInfo p syn True <- primTable, syn == x] -- | A convenient map from each 'Prim' to its info record. primMap :: Map Prim PrimInfo -primMap = M.fromList $ - [ (p, pinfo) | pinfo@(PrimInfo p _ _) <- primTable ] +primMap = + M.fromList $ + [(p, pinfo) | pinfo@(PrimInfo p _ _) <- primTable] diff --git a/src/Disco/Typecheck.hs b/src/Disco/Typecheck.hs index 0a30955f..a67f25c4 100644 --- a/src/Disco/Typecheck.hs +++ b/src/Disco/Typecheck.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NondecreasingIndentation #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Typecheck -- Copyright : disco team and contributors @@ -12,55 +15,58 @@ -- -- Typecheck the Disco surface language and transform it into a -- type-annotated AST. --- ------------------------------------------------------------------------------ - module Disco.Typecheck where -import Control.Arrow ((&&&)) -import Control.Lens ((^..)) -import Control.Monad.Except -import Control.Monad.Trans.Maybe -import Data.Bifunctor (first) -import Data.Coerce -import qualified Data.Foldable as F -import Data.List (group, sort) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Maybe (isJust) -import Data.Set (Set) -import qualified Data.Set as S -import Prelude as P hiding (lookup) - -import Unbound.Generics.LocallyNameless (Alpha, Bind, Name, - bind, embed, - name2String, - string2Name, substs, - unembed) -import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) - -import Disco.Effects.Fresh -import Polysemy hiding (embed) -import Polysemy.Error -import Polysemy.Output -import Polysemy.Reader -import Polysemy.Writer - -import Disco.AST.Surface -import Disco.AST.Typed -import Disco.Context hiding (filter) -import qualified Disco.Context as Ctx -import Disco.Messages -import Disco.Module -import Disco.Names -import Disco.Subst (applySubst) -import qualified Disco.Subst as Subst -import Disco.Syntax.Operators -import Disco.Syntax.Prims -import Disco.Typecheck.Constraints -import Disco.Typecheck.Util -import Disco.Types -import Disco.Types.Rules +import Control.Arrow ((&&&)) +import Control.Lens ((^..)) +import Control.Monad.Except +import Control.Monad.Trans.Maybe +import Data.Bifunctor (first) +import Data.Coerce +import qualified Data.Foldable as F +import Data.List (group, sort) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (isJust) +import Data.Set (Set) +import qualified Data.Set as S +import Prelude as P hiding (lookup) + +import Unbound.Generics.LocallyNameless ( + Alpha, + Bind, + Name, + bind, + embed, + name2String, + string2Name, + substs, + unembed, + ) +import Unbound.Generics.LocallyNameless.Unsafe (unsafeUnbind) + +import Disco.Effects.Fresh +import Polysemy hiding (embed) +import Polysemy.Error +import Polysemy.Output +import Polysemy.Reader +import Polysemy.Writer + +import Disco.AST.Surface +import Disco.AST.Typed +import Disco.Context hiding (filter) +import qualified Disco.Context as Ctx +import Disco.Messages +import Disco.Module +import Disco.Names +import Disco.Subst (applySubst) +import qualified Disco.Subst as Subst +import Disco.Syntax.Operators +import Disco.Syntax.Prims +import Disco.Typecheck.Constraints +import Disco.Typecheck.Util +import Disco.Types +import Disco.Types.Rules ------------------------------------------------------------ -- Container utilities @@ -71,8 +77,8 @@ containerTy c ty = TyCon (containerToCon c) [ty] containerToCon :: Container -> Con containerToCon ListContainer = CList -containerToCon BagContainer = CBag -containerToCon SetContainer = CSet +containerToCon BagContainer = CBag +containerToCon SetContainer = CSet ------------------------------------------------------------ -- Telescopes @@ -82,19 +88,21 @@ containerToCon SetContainer = CSet -- each item along with a context of variables it binds; each such -- context is then added to the overall context when inferring -- subsequent items in the telescope. -inferTelescope - :: (Alpha b, Alpha tyb, Member (Reader TyCtx) r) - => (b -> Sem r (tyb, TyCtx)) -> Telescope b -> Sem r (Telescope tyb, TyCtx) +inferTelescope :: + (Alpha b, Alpha tyb, Member (Reader TyCtx) r) => + (b -> Sem r (tyb, TyCtx)) -> + Telescope b -> + Sem r (Telescope tyb, TyCtx) inferTelescope inferOne tel = do (tel1, ctx) <- go (fromTelescope tel) return (toTelescope tel1, ctx) - where - go [] = return ([], emptyCtx) - go (b:bs) = do - (tyb, ctx) <- inferOne b - extends ctx $ do + where + go [] = return ([], emptyCtx) + go (b : bs) = do + (tyb, ctx) <- inferOne b + extends ctx $ do (tybs, ctx') <- go bs - return (tyb:tybs, ctx <> ctx') + return (tyb : tybs, ctx <> ctx') ------------------------------------------------------------ -- Modules @@ -105,9 +113,12 @@ inferTelescope inferOne tel = do -- on success. This function does not handle imports at all; any -- imports should already be checked and passed in as the second -- argument. -checkModule - :: Members '[Output Message, Reader TyCtx, Reader TyDefCtx, Error LocTCError, Fresh] r - => ModuleName -> Map ModuleName ModuleInfo -> Module -> Sem r ModuleInfo +checkModule :: + Members '[Output Message, Reader TyCtx, Reader TyDefCtx, Error LocTCError, Fresh] r => + ModuleName -> + Map ModuleName ModuleInfo -> + Module -> + Sem r ModuleInfo checkModule name imports (Module es _ m docs terms) = do let (typeDecls, defns, tydefs) = partitionDecls m importTyCtx = mconcat (imports ^.. traverse . miTys) @@ -116,7 +127,7 @@ checkModule name imports (Module es _ m docs terms) = do importTyDefnCtx = M.unions (imports ^.. traverse . miTydefs) tyDefnCtx <- mapError noLoc $ makeTyDefnCtx tydefs withTyDefns (tyDefnCtx `M.union` importTyDefnCtx) $ do - tyCtx <- mapError noLoc $ makeTyCtx name typeDecls + tyCtx <- mapError noLoc $ makeTyCtx name typeDecls extends importTyCtx $ extends tyCtx $ do mapM_ (checkTyDefn name) tydefs adefns <- mapM (checkDefn name) defns @@ -124,16 +135,17 @@ checkModule name imports (Module es _ m docs terms) = do docCtx = ctxForModule name docs dups = filterDups . map getDefnName $ adefns case dups of - (x:_) -> throw $ noLoc $ DuplicateDefns (coerce x) + (x : _) -> throw $ noLoc $ DuplicateDefns (coerce x) [] -> do - aprops <- mapError noLoc $ checkProperties docCtx -- XXX location? - aterms <- mapError noLoc $ mapM inferTop terms -- XXX location? + aprops <- mapError noLoc $ checkProperties docCtx -- XXX location? + aterms <- mapError noLoc $ mapM inferTop terms -- XXX location? return $ ModuleInfo name imports (map ((name .-) . getDeclName) typeDecls) docCtx aprops tyCtx tyDefnCtx defnCtx aterms es - where getDefnName :: Defn -> Name ATerm - getDefnName (Defn n _ _ _) = n + where + getDefnName :: Defn -> Name ATerm + getDefnName (Defn n _ _ _) = n - getDeclName :: TypeDecl -> Name Term - getDeclName (TypeDecl n _) = n + getDeclName :: TypeDecl -> Name Term + getDeclName (TypeDecl n _) = n -------------------------------------------------- -- Type definitions @@ -148,17 +160,16 @@ makeTyDefnCtx tydefs = do newNames = map (\(TypeDefn x _ _) -> x) tydefs dups = filterDups $ newNames ++ oldNames - let convert (TypeDefn x args body) - = (x, TyDefBody args (flip substs body . zip (map string2Name args))) + let convert (TypeDefn x args body) = + (x, TyDefBody args (flip substs body . zip (map string2Name args))) case dups of - (x:_) -> throw (DuplicateTyDefns x) - [] -> return . M.fromList $ map convert tydefs + (x : _) -> throw (DuplicateTyDefns x) + [] -> return . M.fromList $ map convert tydefs -- | Check the validity of a type definition. checkTyDefn :: Members '[Reader TyDefCtx, Error LocTCError] r => ModuleName -> TypeDefn -> Sem r () checkTyDefn name defn@(TypeDefn x args body) = mapError (LocTCError (Just (name .- string2Name x))) $ do - -- First, make sure the body is a valid type, i.e. everything inside -- it is well-kinded. checkTypeValid body @@ -195,20 +206,19 @@ checkCyclicTy (TyUser name args) set = do False -> do ty <- lookupTyDefn name args checkCyclicTy ty (S.insert name set) - checkCyclicTy _ set = return set -- | Ensure that a type definition does not use any unbound type -- variables or undefined types. checkUnboundVars :: Members '[Reader TyDefCtx, Error TCError] r => TypeDefn -> Sem r () checkUnboundVars (TypeDefn _ args body) = go body - where - go (TyAtom (AVar (U x))) - | name2String x `elem` args = return () - | otherwise = throw $ UnboundTyVar x - go (TyAtom _) = return () - go (TyUser name tys) = lookupTyDefn name tys >> mapM_ go tys - go (TyCon _ tys) = mapM_ go tys + where + go (TyAtom (AVar (U x))) + | name2String x `elem` args = return () + | otherwise = throw $ UnboundTyVar x + go (TyAtom _) = return () + go (TyUser name tys) = lookupTyDefn name tys >> mapM_ go tys + go (TyCon _ tys) = mapM_ go tys -- | Check for polymorphic recursion: starting from a user-defined -- type, keep expanding its definition recursively, ensuring that @@ -216,20 +226,20 @@ checkUnboundVars (TypeDefn _ args body) = go body -- as arguments. checkPolyRec :: Member (Error TCError) r => TypeDefn -> Sem r () checkPolyRec (TypeDefn name args body) = go body - where - go (TyCon (CUser x) tys) - | x == name && not (all isTyVar tys) = + where + go (TyCon (CUser x) tys) + | x == name && not (all isTyVar tys) = throw $ NoPolyRec name args tys - | otherwise = return () - go (TyCon _ tys) = mapM_ go tys - go _ = return () + | otherwise = return () + go (TyCon _ tys) = mapM_ go tys + go _ = return () -- | Keep only the duplicate elements from a list. -- -- >>> filterDups [1,3,2,1,1,4,2] -- [1,2] filterDups :: Ord a => [a] -> [a] -filterDups = map head . filter ((>1) . length) . group . sort +filterDups = map head . filter ((> 1) . length) . group . sort -------------------------------------------------- -- Type declarations @@ -242,12 +252,12 @@ makeTyCtx :: Members '[Reader TyDefCtx, Error TCError] r => ModuleName -> [TypeD makeTyCtx name decls = do let dups = filterDups . map (\(TypeDecl x _) -> x) $ decls case dups of - (x:_) -> throw (DuplicateDecls x) - [] -> do + (x : _) -> throw (DuplicateDecls x) + [] -> do checkCtx declCtx return declCtx - where - declCtx = ctxForModule name $ map (\(TypeDecl x ty) -> (x,ty)) decls + where + declCtx = ctxForModule name $ map (\(TypeDecl x ty) -> (x, ty)) decls -- | Check that all the types in a context are valid. checkCtx :: Members '[Reader TyDefCtx, Error TCError] r => TyCtx -> Sem r () @@ -257,18 +267,19 @@ checkCtx = mapM_ checkPolyTyValid . Ctx.elems -- Top-level definitions -- | Type check a top-level definition in the given module. -checkDefn - :: Members '[Reader TyCtx, Reader TyDefCtx, Error LocTCError, Fresh, Output Message] r - => ModuleName -> TermDefn -> Sem r Defn +checkDefn :: + Members '[Reader TyCtx, Reader TyDefCtx, Error LocTCError, Fresh, Output Message] r => + ModuleName -> + TermDefn -> + Sem r Defn checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $ do - -- Check that all clauses have the same number of patterns checkNumPats clauses -- Get the declared type signature of x Forall sig <- lookup (name .- x) >>= maybe (throw $ NoType x) return - -- If x isn't in the context, it's because no type was declared for it, so - -- throw an error. + -- If x isn't in the context, it's because no type was declared for it, so + -- throw an error. (nms, ty) <- unbind sig -- Try to decompose the type into a chain of arrows like pty1 -> @@ -281,59 +292,65 @@ checkDefn name (TermDefn x clauses) = mapError (LocTCError (Just (name .- x))) $ return (aclauses, ty) return $ applySubst theta (Defn (coerce x) patTys bodyTy acs) - where - numPats = length . fst . unsafeUnbind - - checkNumPats [] = return () -- This can't happen, but meh - checkNumPats [_] = return () - checkNumPats (c:cs) - | all ((==0) . numPats) (c:cs) = throw (DuplicateDefns x) - | not (all ((== numPats c) . numPats) cs) = throw NumPatterns - -- XXX more info, this error actually means # of - -- patterns don't match across different clauses - | otherwise = return () - - -- | Check a clause of a definition against a list of pattern types and a body type. - checkClause - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => [Type] -> Type -> Bind [Pattern] Term -> Sem r Clause - checkClause patTys bodyTy clause = do - (pats, body) <- unbind clause - - -- At this point we know that every clause has the same number of patterns, - -- which is the same as the length of the list patTys. So we can just use - -- zipWithM to check all the patterns. - (ctxs, aps) <- unzip <$> zipWithM checkPattern pats patTys - at <- extends (mconcat ctxs) $ check body bodyTy - return $ bind aps at - - -- Decompose a type that must be of the form t1 -> t2 -> ... -> tn -> t{n+1}. - decomposeDefnTy :: Members '[Reader TyDefCtx, Error TCError] r => Int -> Type -> Sem r ([Type], Type) - decomposeDefnTy 0 ty = return ([], ty) - decomposeDefnTy n (TyUser tyName args) = lookupTyDefn tyName args >>= decomposeDefnTy n - decomposeDefnTy n (ty1 :->: ty2) = first (ty1:) <$> decomposeDefnTy (n-1) ty2 - decomposeDefnTy _n _ty = throw NumPatterns - -- XXX include more info. More argument patterns than arrows in the type. + where + numPats = length . fst . unsafeUnbind + + checkNumPats [] = return () -- This can't happen, but meh + checkNumPats [_] = return () + checkNumPats (c : cs) + | all ((== 0) . numPats) (c : cs) = throw (DuplicateDefns x) + | not (all ((== numPats c) . numPats) cs) = throw NumPatterns + -- XXX more info, this error actually means # of + -- patterns don't match across different clauses + | otherwise = return () + + -- \| Check a clause of a definition against a list of pattern types and a body type. + checkClause :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + [Type] -> + Type -> + Bind [Pattern] Term -> + Sem r Clause + checkClause patTys bodyTy clause = do + (pats, body) <- unbind clause + + -- At this point we know that every clause has the same number of patterns, + -- which is the same as the length of the list patTys. So we can just use + -- zipWithM to check all the patterns. + (ctxs, aps) <- unzip <$> zipWithM checkPattern pats patTys + at <- extends (mconcat ctxs) $ check body bodyTy + return $ bind aps at + + -- Decompose a type that must be of the form t1 -> t2 -> ... -> tn -> t{n+1}. + decomposeDefnTy :: Members '[Reader TyDefCtx, Error TCError] r => Int -> Type -> Sem r ([Type], Type) + decomposeDefnTy 0 ty = return ([], ty) + decomposeDefnTy n (TyUser tyName args) = lookupTyDefn tyName args >>= decomposeDefnTy n + decomposeDefnTy n (ty1 :->: ty2) = first (ty1 :) <$> decomposeDefnTy (n - 1) ty2 + decomposeDefnTy _n _ty = throw NumPatterns + +-- XXX include more info. More argument patterns than arrows in the type. -------------------------------------------------- -- Properties -- | Given a context mapping names to documentation, extract the -- properties attached to each name and typecheck them. -checkProperties - :: Members '[Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh, Output Message] r - => Ctx Term Docs -> Sem r (Ctx ATerm [AProperty]) +checkProperties :: + Members '[Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh, Output Message] r => + Ctx Term Docs -> + Sem r (Ctx ATerm [AProperty]) checkProperties docs = Ctx.coerceKeys . Ctx.filter (not . P.null) <$> (traverse . traverse) checkProperty properties - where - properties :: Ctx Term [Property] - properties = fmap (\ds -> [p | DocProperty p <- ds]) docs + where + properties :: Ctx Term [Property] + properties = fmap (\ds -> [p | DocProperty p <- ds]) docs -- | Check the types of the terms embedded in a property. -checkProperty - :: Members '[Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh, Output Message] r - => Property -> Sem r AProperty +checkProperty :: + Members '[Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh, Output Message] r => + Property -> + Sem r AProperty checkProperty prop = do (at, theta) <- solve $ check prop TyProp -- XXX do we need to default container variables here? @@ -357,24 +374,25 @@ checkPolyTyValid (Forall b) = do -- applied. But we do need to check that every type is applied to -- the correct number of arguments. checkTypeValid :: Members '[Reader TyDefCtx, Error TCError] r => Type -> Sem r () -checkTypeValid (TyAtom _) = return () +checkTypeValid (TyAtom _) = return () checkTypeValid (TyCon c tys) = do k <- conArity c - if | n < k -> throw (NotEnoughArgs c) - | n > k -> throw (TooManyArgs c) - | otherwise -> mapM_ checkTypeValid tys - where - n = length tys + if + | n < k -> throw (NotEnoughArgs c) + | n > k -> throw (TooManyArgs c) + | otherwise -> mapM_ checkTypeValid tys + where + n = length tys conArity :: Members '[Reader TyDefCtx, Error TCError] r => Con -> Sem r Int conArity (CContainer _) = return 1 conArity CGraph = return 1 -conArity (CUser name) = do +conArity (CUser name) = do d <- ask @TyDefCtx case M.lookup name d of - Nothing -> throw (NotTyDef name) + Nothing -> throw (NotTyDef name) Just (TyDefBody as _) -> return (length as) -conArity _ = return 2 -- (->, *, +, map) +conArity _ = return 2 -- (->, *, +, map) -------------------------------------------------- -- Checking modes @@ -384,28 +402,32 @@ conArity _ = return 2 -- (->, *, +, map) -- are trying to synthesize a valid type for a term; checking mode -- means we are trying to show that a term has a given type. data Mode = Infer | Check Type - deriving Show + deriving (Show) -- | Check that a term has the given type. Either throws an error, or -- returns the term annotated with types for all subterms. -- -- This function is provided for convenience; it simply calls -- 'typecheck' with an appropriate 'Mode'. -check - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Term -> Type -> Sem r ATerm +check :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Term -> + Type -> + Sem r ATerm check t ty = typecheck (Check ty) t -- | Check that a term has the given polymorphic type. -checkPolyTy - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Term -> PolyType -> Sem r ATerm +checkPolyTy :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Term -> + PolyType -> + Sem r ATerm checkPolyTy t (Forall sig) = do (as, tau) <- unbind sig (at, cst) <- withConstraint $ check t tau case as of [] -> constraint cst - _ -> constraint $ CAll (bind as cst) + _ -> constraint $ CAll (bind as cst) return at -- | Infer the type of a term. If it succeeds, it returns the term @@ -413,19 +435,20 @@ checkPolyTy t (Forall sig) = do -- -- This function is provided for convenience; it simply calls -- 'typecheck' with an appropriate 'Mode'. -infer - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Term -> Sem r ATerm +infer :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Term -> + Sem r ATerm infer = typecheck Infer -- | Top-level type inference algorithm: infer a (polymorphic) type -- for a term by running type inference, solving the resulting -- constraints, and quantifying over any remaining type variables. -inferTop - :: Members '[Output Message, Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r - => Term -> Sem r (ATerm, PolyType) +inferTop :: + Members '[Output Message, Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r => + Term -> + Sem r (ATerm, PolyType) inferTop t = do - -- Run inference on the term and try to solve the resulting -- constraints. (at, theta) <- solve $ infer t @@ -433,7 +456,7 @@ inferTop t = do debug "Final annotated term (before substitution and container monomorphizing):" debugPretty at - -- Apply the resulting substitution. + -- Apply the resulting substitution. let at' = applySubst theta at -- Find any remaining container variables. @@ -449,9 +472,11 @@ inferTop t = do -- | Top-level type checking algorithm: check that a term has a given -- polymorphic type by running type checking and solving the -- resulting constraints. -checkTop - :: Members '[Output Message, Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r - => Term -> PolyType -> Sem r ATerm +checkTop :: + Members '[Output Message, Reader TyCtx, Reader TyDefCtx, Error TCError, Fresh] r => + Term -> + PolyType -> + Sem r ATerm checkTop t ty = do (at, theta) <- solve $ checkPolyTy t ty return $ applySubst theta at @@ -465,10 +490,11 @@ checkTop t ty = do -- takes a 'Mode'. This cuts down on code duplication in many -- cases, and allows all the checking and inference code related to -- a given AST node to be placed together. -typecheck - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Mode -> Term -> Sem r ATerm - +typecheck :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Mode -> + Term -> + Sem r ATerm -- ~~~~ Note [Pattern coverage] -- In several places we have clauses like -- @@ -497,14 +523,12 @@ typecheck -- This case has to be first, so in all other cases we know the type -- will not be a TyUser. typecheck (Check (TyUser name args)) t = lookupTyDefn name args >>= check t - -------------------------------------------------- -- Parens -- Recurse through parens; they are not represented explicitly in the -- resulting ATerm. typecheck mode (TParens t) = typecheck mode t - -------------------------------------------------- -- Variables @@ -512,37 +536,36 @@ typecheck mode (TParens t) = typecheck mode t -- checking case; checking the type of a variable will fall through to -- this case. typecheck Infer (TVar x) = do - -- Pick the first method that succeeds; if none do, throw an unbound -- variable error. mt <- runMaybeT . F.asum . map MaybeT $ [tryLocal, tryModule, tryPrim] maybe (throw (Unbound x)) return mt - where - -- 1. See if the variable name is bound locally. - tryLocal = do - mty <- Ctx.lookup (localName x) - case mty of - Just (Forall sig) -> do - (_, ty) <- unbind sig - return . Just $ ATVar ty (localName (coerce x)) - Nothing -> return Nothing - - -- 2. See if the variable name is bound in some in-scope module, - -- throwing an ambiguity error if it is bound in multiple modules. - tryModule = do - bs <- Ctx.lookupNonLocal x - case bs of - [(m,Forall sig)] -> do - (_, ty) <- unbind sig - return . Just $ ATVar ty (m .- coerce x) - [] -> return Nothing - _ -> throw $ Ambiguous x (map fst bs) - - -- 3. See if we should convert it to a primitive. - tryPrim = - case toPrim (name2String x) of - (prim:_) -> Just <$> typecheck Infer (TPrim prim) - _ -> return Nothing + where + -- 1. See if the variable name is bound locally. + tryLocal = do + mty <- Ctx.lookup (localName x) + case mty of + Just (Forall sig) -> do + (_, ty) <- unbind sig + return . Just $ ATVar ty (localName (coerce x)) + Nothing -> return Nothing + + -- 2. See if the variable name is bound in some in-scope module, + -- throwing an ambiguity error if it is bound in multiple modules. + tryModule = do + bs <- Ctx.lookupNonLocal x + case bs of + [(m, Forall sig)] -> do + (_, ty) <- unbind sig + return . Just $ ATVar ty (m .- coerce x) + [] -> return Nothing + _ -> throw $ Ambiguous x (map fst bs) + + -- 3. See if we should convert it to a primitive. + tryPrim = + case toPrim (name2String x) of + (prim : _) -> Just <$> typecheck Infer (TPrim prim) + _ -> return Nothing -------------------------------------------------- -- Primitives @@ -550,396 +573,371 @@ typecheck Infer (TVar x) = do typecheck Infer (TPrim prim) = do ty <- inferPrim prim return $ ATPrim ty prim - - where - inferPrim :: Members '[Writer Constraint, Fresh] r => Prim -> Sem r Type - - ---------------------------------------- - -- Left/right - - inferPrim PrimLeft = do - a <- freshTy - b <- freshTy - return $ a :->: (a :+: b) - - inferPrim PrimRight = do - a <- freshTy - b <- freshTy - return $ b :->: (a :+: b) - - ---------------------------------------- - -- Logic - - inferPrim (PrimBOp op) | op `elem` [And, Or, Impl, Iff] = do - a <- freshTy - constraint $ CQual (bopQual op) a - return $ a :*: a :->: a - - -- See Note [Pattern coverage] ----------------------------- - inferPrim (PrimBOp And) = error "inferPrim And should be unreachable" - inferPrim (PrimBOp Or) = error "inferPrim Or should be unreachable" - inferPrim (PrimBOp Impl) = error "inferPrim Impl should be unreachable" - inferPrim (PrimBOp Iff) = error "inferPrim Iff should be unreachable" - ------------------------------------------------------------ - - inferPrim (PrimUOp Not) = do - a <- freshTy - constraint $ CQual QBool a - return $ a :->: a - - ---------------------------------------- - -- Container conversion - - inferPrim conv | conv `elem` [PrimList, PrimBag, PrimSet] = do - c <- freshAtom -- make a unification variable for the container type - a <- freshTy -- make a unification variable for the element type - - -- converting to a set or bag requires being able to sort the elements - when (conv /= PrimList) $ constraint $ CQual QCmp a - - return $ TyContainer c a :->: primCtrCon conv a - - where - primCtrCon PrimList = TyList - primCtrCon PrimBag = TyBag - primCtrCon _ = TySet - - -- See Note [Pattern coverage] ----------------------------- - inferPrim PrimList = error "inferPrim PrimList should be unreachable" - inferPrim PrimBag = error "inferPrim PrimBag should be unreachable" - inferPrim PrimSet = error "inferPrim PrimSet should be unreachable" - ------------------------------------------------------------ - - inferPrim PrimB2C = do - a <- freshTy - return $ TyBag a :->: TySet (a :*: TyN) - - inferPrim PrimC2B = do - a <- freshTy - c <- freshAtom - constraint $ CQual QCmp a - return $ TyContainer c (a :*: TyN) :->: TyBag a - - inferPrim PrimUC2B = do - a <- freshTy - c <- freshAtom - return $ TyContainer c (a :*: TyN) :->: TyBag a - - inferPrim PrimMapToSet = do - k <- freshTy - v <- freshTy - constraint $ CQual QSimple k - return $ TyMap k v :->: TySet (k :*: v) - - inferPrim PrimSetToMap = do - k <- freshTy - v <- freshTy - constraint $ CQual QSimple k - return $ TySet (k :*: v) :->: TyMap k v - - inferPrim PrimSummary = do - a <- freshTy - constraint $ CQual QSimple a - return $ TyGraph a :->: TyMap a (TySet a) - - inferPrim PrimVertex = do - a <- freshTy - constraint $ CQual QSimple a - return $ a :->: TyGraph a - - inferPrim PrimEmptyGraph = do - a <- freshTy - constraint $ CQual QSimple a - return $ TyGraph a - - inferPrim PrimOverlay = do - a <- freshTy - constraint $ CQual QSimple a - return $ TyGraph a :*: TyGraph a :->: TyGraph a - - inferPrim PrimConnect = do - a <- freshTy - constraint $ CQual QSimple a - return $ TyGraph a :*: TyGraph a :->: TyGraph a - - inferPrim PrimInsert = do - a <- freshTy - b <- freshTy - constraint $ CQual QSimple a - return $ a :*: b :*: TyMap a b :->: TyMap a b - - inferPrim PrimLookup = do - a <- freshTy - b <- freshTy - constraint $ CQual QSimple a - return $ a :*: TyMap a b :->: (TyUnit :+: b) - ---------------------------------------- - -- Container primitives - - inferPrim (PrimBOp Cons) = do - a <- freshTy - return $ a :*: TyList a :->: TyList a - - -- XXX see https://github.com/disco-lang/disco/issues/160 - -- each : (a -> b) × c a -> c b - inferPrim PrimEach = do - c <- freshAtom - a <- freshTy - b <- freshTy - return $ (a :->: b) :*: TyContainer c a :->: TyContainer c b - - -- XXX should eventually be (a * a -> a) * c a -> a, - -- with a check that the function has the right properties. - -- reduce : (a * a -> a) * a * c a -> a - inferPrim PrimReduce = do - c <- freshAtom - a <- freshTy - return $ (a :*: a :->: a) :*: a :*: TyContainer c a :->: a - - -- filter : (a -> Bool) × c a -> c a - inferPrim PrimFilter = do - c <- freshAtom - a <- freshTy - return $ (a :->: TyBool) :*: TyContainer c a :->: TyContainer c a - - -- join : c (c a) -> c a - inferPrim PrimJoin = do - c <- freshAtom - a <- freshTy - return $ TyContainer c (TyContainer c a) :->: TyContainer c a - - -- merge : (N × N -> N) × c a × c a -> c a (c = bag or set) - inferPrim PrimMerge = do - c <- freshAtom - a <- freshTy - constraint $ COr + where + inferPrim :: Members '[Writer Constraint, Fresh] r => Prim -> Sem r Type + + ---------------------------------------- + -- Left/right + + inferPrim PrimLeft = do + a <- freshTy + b <- freshTy + return $ a :->: (a :+: b) + inferPrim PrimRight = do + a <- freshTy + b <- freshTy + return $ b :->: (a :+: b) + + ---------------------------------------- + -- Logic + + inferPrim (PrimBOp op) | op `elem` [And, Or, Impl, Iff] = do + a <- freshTy + constraint $ CQual (bopQual op) a + return $ a :*: a :->: a + + -- See Note [Pattern coverage] ----------------------------- + inferPrim (PrimBOp And) = error "inferPrim And should be unreachable" + inferPrim (PrimBOp Or) = error "inferPrim Or should be unreachable" + inferPrim (PrimBOp Impl) = error "inferPrim Impl should be unreachable" + inferPrim (PrimBOp Iff) = error "inferPrim Iff should be unreachable" + ------------------------------------------------------------ + + inferPrim (PrimUOp Not) = do + a <- freshTy + constraint $ CQual QBool a + return $ a :->: a + + ---------------------------------------- + -- Container conversion + + inferPrim conv | conv `elem` [PrimList, PrimBag, PrimSet] = do + c <- freshAtom -- make a unification variable for the container type + a <- freshTy -- make a unification variable for the element type + + -- converting to a set or bag requires being able to sort the elements + when (conv /= PrimList) $ constraint $ CQual QCmp a + + return $ TyContainer c a :->: primCtrCon conv a + where + primCtrCon PrimList = TyList + primCtrCon PrimBag = TyBag + primCtrCon _ = TySet + + -- See Note [Pattern coverage] ----------------------------- + inferPrim PrimList = error "inferPrim PrimList should be unreachable" + inferPrim PrimBag = error "inferPrim PrimBag should be unreachable" + inferPrim PrimSet = error "inferPrim PrimSet should be unreachable" + ------------------------------------------------------------ + + inferPrim PrimB2C = do + a <- freshTy + return $ TyBag a :->: TySet (a :*: TyN) + inferPrim PrimC2B = do + a <- freshTy + c <- freshAtom + constraint $ CQual QCmp a + return $ TyContainer c (a :*: TyN) :->: TyBag a + inferPrim PrimUC2B = do + a <- freshTy + c <- freshAtom + return $ TyContainer c (a :*: TyN) :->: TyBag a + inferPrim PrimMapToSet = do + k <- freshTy + v <- freshTy + constraint $ CQual QSimple k + return $ TyMap k v :->: TySet (k :*: v) + inferPrim PrimSetToMap = do + k <- freshTy + v <- freshTy + constraint $ CQual QSimple k + return $ TySet (k :*: v) :->: TyMap k v + inferPrim PrimSummary = do + a <- freshTy + constraint $ CQual QSimple a + return $ TyGraph a :->: TyMap a (TySet a) + inferPrim PrimVertex = do + a <- freshTy + constraint $ CQual QSimple a + return $ a :->: TyGraph a + inferPrim PrimEmptyGraph = do + a <- freshTy + constraint $ CQual QSimple a + return $ TyGraph a + inferPrim PrimOverlay = do + a <- freshTy + constraint $ CQual QSimple a + return $ TyGraph a :*: TyGraph a :->: TyGraph a + inferPrim PrimConnect = do + a <- freshTy + constraint $ CQual QSimple a + return $ TyGraph a :*: TyGraph a :->: TyGraph a + inferPrim PrimInsert = do + a <- freshTy + b <- freshTy + constraint $ CQual QSimple a + return $ a :*: b :*: TyMap a b :->: TyMap a b + inferPrim PrimLookup = do + a <- freshTy + b <- freshTy + constraint $ CQual QSimple a + return $ a :*: TyMap a b :->: (TyUnit :+: b) + ---------------------------------------- + -- Container primitives + + inferPrim (PrimBOp Cons) = do + a <- freshTy + return $ a :*: TyList a :->: TyList a + + -- XXX see https://github.com/disco-lang/disco/issues/160 + -- each : (a -> b) × c a -> c b + inferPrim PrimEach = do + c <- freshAtom + a <- freshTy + b <- freshTy + return $ (a :->: b) :*: TyContainer c a :->: TyContainer c b + + -- XXX should eventually be (a * a -> a) * c a -> a, + -- with a check that the function has the right properties. + -- reduce : (a * a -> a) * a * c a -> a + inferPrim PrimReduce = do + c <- freshAtom + a <- freshTy + return $ (a :*: a :->: a) :*: a :*: TyContainer c a :->: a + + -- filter : (a -> Bool) × c a -> c a + inferPrim PrimFilter = do + c <- freshAtom + a <- freshTy + return $ (a :->: TyBool) :*: TyContainer c a :->: TyContainer c a + + -- join : c (c a) -> c a + inferPrim PrimJoin = do + c <- freshAtom + a <- freshTy + return $ TyContainer c (TyContainer c a) :->: TyContainer c a + + -- merge : (N × N -> N) × c a × c a -> c a (c = bag or set) + inferPrim PrimMerge = do + c <- freshAtom + a <- freshTy + constraint $ + COr [ CEq (TyAtom (ABase CtrBag)) (TyAtom c) , CEq (TyAtom (ABase CtrSet)) (TyAtom c) ] - let ca = TyContainer c a - return $ (TyN :*: TyN :->: TyN) :*: ca :*: ca :->: ca - - inferPrim (PrimBOp CartProd) = do - a <- freshTy - b <- freshTy - c <- freshAtom - return $ TyContainer c a :*: TyContainer c b :->: TyContainer c (a :*: b) - - inferPrim (PrimBOp setOp) | setOp `elem` [Union, Inter, Diff, Subset] = do - a <- freshTy - c <- freshAtom - constraint $ COr + let ca = TyContainer c a + return $ (TyN :*: TyN :->: TyN) :*: ca :*: ca :->: ca + inferPrim (PrimBOp CartProd) = do + a <- freshTy + b <- freshTy + c <- freshAtom + return $ TyContainer c a :*: TyContainer c b :->: TyContainer c (a :*: b) + inferPrim (PrimBOp setOp) | setOp `elem` [Union, Inter, Diff, Subset] = do + a <- freshTy + c <- freshAtom + constraint $ + COr [ CEq (TyAtom (ABase CtrBag)) (TyAtom c) , CEq (TyAtom (ABase CtrSet)) (TyAtom c) ] - let ca = TyContainer c a - let resTy = case setOp of {Subset -> TyBool; _ -> ca} - return $ ca :*: ca :->: resTy - - -- See Note [Pattern coverage] ----------------------------- - inferPrim (PrimBOp Union) = error "inferPrim Union should be unreachable" - inferPrim (PrimBOp Inter) = error "inferPrim Inter should be unreachable" - inferPrim (PrimBOp Diff) = error "inferPrim Diff should be unreachable" - inferPrim (PrimBOp Subset) = error "inferPrim Subset should be unreachable" - ------------------------------------------------------------ - - inferPrim (PrimBOp Elem) = do - a <- freshTy - c <- freshAtom - - constraint $ CQual QCmp a - - return $ a :*: TyContainer c a :->: TyBool - - ---------------------------------------- - -- Arithmetic - - inferPrim (PrimBOp IDiv) = do - a <- freshTy - resTy <- cInt a - return $ a :*: a :->: resTy - - inferPrim (PrimBOp Mod) = do - a <- freshTy - constraint $ CSub a TyZ - return $ a :*: a :->: a - - inferPrim (PrimBOp op) | op `elem` [Add, Mul, Sub, Div, SSub] = do - a <- freshTy - constraint $ CQual (bopQual op) a - return $ a :*: a :->: a - - -- See Note [Pattern coverage] ----------------------------- - inferPrim (PrimBOp Add ) = error "inferPrim Add should be unreachable" - inferPrim (PrimBOp Mul ) = error "inferPrim Mul should be unreachable" - inferPrim (PrimBOp Sub ) = error "inferPrim Sub should be unreachable" - inferPrim (PrimBOp Div ) = error "inferPrim Div should be unreachable" - inferPrim (PrimBOp SSub) = error "inferPrim SSub should be unreachable" - ------------------------------------------------------------ - - inferPrim (PrimUOp Neg) = do - a <- freshTy - constraint $ CQual QSub a - return $ a :->: a - - inferPrim (PrimBOp Exp) = do - a <- freshTy - b <- freshTy - resTy <- cExp a b - return $ a :*: b :->: resTy - - ---------------------------------------- - -- Number theory - - inferPrim PrimIsPrime = return $ TyN :->: TyBool - inferPrim PrimFactor = return $ TyN :->: TyBag TyN - - inferPrim PrimFrac = return $ TyQ :->: (TyZ :*: TyN) - - inferPrim (PrimBOp Divides) = do - a <- freshTy - constraint $ CQual QNum a - return $ a :*: a :->: TyBool - - ---------------------------------------- - -- Choose - - -- For now, a simple typing rule for multinomial coefficients that - -- requires everything to be Nat. However, they can be extended to - -- handle negative or fractional arguments. - inferPrim (PrimBOp Choose) = do - b <- freshTy - - -- b can be either Nat (a binomial coefficient) - -- or a list of Nat (a multinomial coefficient). - constraint $ COr [CEq b TyN, CEq b (TyList TyN)] - return $ TyN :*: b :->: TyN - - ---------------------------------------- - -- Ellipses - - -- Actually 'until' supports more types than this, e.g. Q instead - -- of N, but this is good enough. This case is here just for - -- completeness---in case someone enables primitives and uses it - -- directly---but typically 'until' is generated only during - -- desugaring of a container with ellipsis, after typechecking, in - -- which case it can be assigned a more appropriate type directly. - - inferPrim PrimUntil = return $ TyN :*: TyList TyN :->: TyList TyN - - ---------------------------------------- - -- Crash - - inferPrim PrimCrash = do - a <- freshTy - return $ TyString :->: a - - ---------------------------------------- - -- Propositions - - -- 'holds' converts a Prop into a Bool (but might not terminate). - inferPrim PrimHolds = return $ TyProp :->: TyBool - - -- An equality assertion =!= is just like a comparison ==, except - -- the result is a Prop. - inferPrim (PrimBOp ShouldEq) = do - ty <- freshTy - constraint $ CQual QCmp ty - return $ ty :*: ty :->: TyProp - - inferPrim (PrimBOp ShouldLt) = do - ty <- freshTy - constraint $ CQual QCmp ty - return $ ty :*: ty :->: TyProp - - ---------------------------------------- - -- Comparisons - - -- Infer the type of a comparison. A comparison always has type - -- Bool, but we have to make sure the subterms have compatible - -- types. We also generate a QCmp qualifier, for two reasons: - -- one, we need to know whether e.g. a comparison was done at a - -- certain type, so we can decide whether the type is allowed to - -- be completely polymorphic or not. Also, comparison of Props is - -- not allowed. - inferPrim (PrimBOp op) | op `elem` [Eq, Neq, Lt, Gt, Leq, Geq] = do - ty <- freshTy - constraint $ CQual QCmp ty - return $ ty :*: ty :->: TyBool - - -- See Note [Pattern coverage] ----------------------------- - inferPrim (PrimBOp Eq) = error "inferPrim Eq should be unreachable" - inferPrim (PrimBOp Neq) = error "inferPrim Neq should be unreachable" - inferPrim (PrimBOp Lt) = error "inferPrim Lt should be unreachable" - inferPrim (PrimBOp Gt) = error "inferPrim Gt should be unreachable" - inferPrim (PrimBOp Leq) = error "inferPrim Leq should be unreachable" - inferPrim (PrimBOp Geq) = error "inferPrim Geq should be unreachable" - ------------------------------------------------------------ - - inferPrim (PrimBOp op) | op `elem` [Min, Max] = do - ty <- freshTy - constraint $ CQual QCmp ty - return $ ty :*: ty :->: ty - - -- See Note [Pattern coverage] ----------------------------- - inferPrim (PrimBOp Min) = error "inferPrim Min should be unreachable" - inferPrim (PrimBOp Max) = error "inferPrim Max should be unreachable" - ------------------------------------------------------------ - - ---------------------------------------- - -- Special arithmetic functions: fact, sqrt, floor, ceil, abs - - inferPrim (PrimUOp Fact) = return $ TyN :->: TyN - inferPrim PrimSqrt = return $ TyN :->: TyN - - inferPrim p | p `elem` [PrimFloor, PrimCeil] = do - argTy <- freshTy - resTy <- cInt argTy - return $ argTy :->: resTy - - -- See Note [Pattern coverage] ----------------------------- - inferPrim PrimFloor = error "inferPrim Floor should be unreachable" - inferPrim PrimCeil = error "inferPrim Ceil should be unreachable" - ------------------------------------------------------------ - - inferPrim PrimAbs = do - argTy <- freshTy - resTy <- freshTy - cAbs argTy resTy `cOr` cSize argTy resTy - return $ argTy :->: resTy - - ---------------------------------------- - -- power set/bag - - inferPrim PrimPower = do - a <- freshTy - c <- freshAtom - - constraint $ CQual QCmp a - constraint $ COr + let ca = TyContainer c a + let resTy = case setOp of Subset -> TyBool; _ -> ca + return $ ca :*: ca :->: resTy + + -- See Note [Pattern coverage] ----------------------------- + inferPrim (PrimBOp Union) = error "inferPrim Union should be unreachable" + inferPrim (PrimBOp Inter) = error "inferPrim Inter should be unreachable" + inferPrim (PrimBOp Diff) = error "inferPrim Diff should be unreachable" + inferPrim (PrimBOp Subset) = error "inferPrim Subset should be unreachable" + ------------------------------------------------------------ + + inferPrim (PrimBOp Elem) = do + a <- freshTy + c <- freshAtom + + constraint $ CQual QCmp a + + return $ a :*: TyContainer c a :->: TyBool + + ---------------------------------------- + -- Arithmetic + + inferPrim (PrimBOp IDiv) = do + a <- freshTy + resTy <- cInt a + return $ a :*: a :->: resTy + inferPrim (PrimBOp Mod) = do + a <- freshTy + constraint $ CSub a TyZ + return $ a :*: a :->: a + inferPrim (PrimBOp op) | op `elem` [Add, Mul, Sub, Div, SSub] = do + a <- freshTy + constraint $ CQual (bopQual op) a + return $ a :*: a :->: a + + -- See Note [Pattern coverage] ----------------------------- + inferPrim (PrimBOp Add) = error "inferPrim Add should be unreachable" + inferPrim (PrimBOp Mul) = error "inferPrim Mul should be unreachable" + inferPrim (PrimBOp Sub) = error "inferPrim Sub should be unreachable" + inferPrim (PrimBOp Div) = error "inferPrim Div should be unreachable" + inferPrim (PrimBOp SSub) = error "inferPrim SSub should be unreachable" + ------------------------------------------------------------ + + inferPrim (PrimUOp Neg) = do + a <- freshTy + constraint $ CQual QSub a + return $ a :->: a + inferPrim (PrimBOp Exp) = do + a <- freshTy + b <- freshTy + resTy <- cExp a b + return $ a :*: b :->: resTy + + ---------------------------------------- + -- Number theory + + inferPrim PrimIsPrime = return $ TyN :->: TyBool + inferPrim PrimFactor = return $ TyN :->: TyBag TyN + inferPrim PrimFrac = return $ TyQ :->: (TyZ :*: TyN) + inferPrim (PrimBOp Divides) = do + a <- freshTy + constraint $ CQual QNum a + return $ a :*: a :->: TyBool + + ---------------------------------------- + -- Choose + + -- For now, a simple typing rule for multinomial coefficients that + -- requires everything to be Nat. However, they can be extended to + -- handle negative or fractional arguments. + inferPrim (PrimBOp Choose) = do + b <- freshTy + + -- b can be either Nat (a binomial coefficient) + -- or a list of Nat (a multinomial coefficient). + constraint $ COr [CEq b TyN, CEq b (TyList TyN)] + return $ TyN :*: b :->: TyN + + ---------------------------------------- + -- Ellipses + + -- Actually 'until' supports more types than this, e.g. Q instead + -- of N, but this is good enough. This case is here just for + -- completeness---in case someone enables primitives and uses it + -- directly---but typically 'until' is generated only during + -- desugaring of a container with ellipsis, after typechecking, in + -- which case it can be assigned a more appropriate type directly. + + inferPrim PrimUntil = return $ TyN :*: TyList TyN :->: TyList TyN + ---------------------------------------- + -- Crash + + inferPrim PrimCrash = do + a <- freshTy + return $ TyString :->: a + + ---------------------------------------- + -- Propositions + + -- 'holds' converts a Prop into a Bool (but might not terminate). + inferPrim PrimHolds = return $ TyProp :->: TyBool + -- An equality assertion =!= is just like a comparison ==, except + -- the result is a Prop. + inferPrim (PrimBOp ShouldEq) = do + ty <- freshTy + constraint $ CQual QCmp ty + return $ ty :*: ty :->: TyProp + inferPrim (PrimBOp ShouldLt) = do + ty <- freshTy + constraint $ CQual QCmp ty + return $ ty :*: ty :->: TyProp + + ---------------------------------------- + -- Comparisons + + -- Infer the type of a comparison. A comparison always has type + -- Bool, but we have to make sure the subterms have compatible + -- types. We also generate a QCmp qualifier, for two reasons: + -- one, we need to know whether e.g. a comparison was done at a + -- certain type, so we can decide whether the type is allowed to + -- be completely polymorphic or not. Also, comparison of Props is + -- not allowed. + inferPrim (PrimBOp op) | op `elem` [Eq, Neq, Lt, Gt, Leq, Geq] = do + ty <- freshTy + constraint $ CQual QCmp ty + return $ ty :*: ty :->: TyBool + + -- See Note [Pattern coverage] ----------------------------- + inferPrim (PrimBOp Eq) = error "inferPrim Eq should be unreachable" + inferPrim (PrimBOp Neq) = error "inferPrim Neq should be unreachable" + inferPrim (PrimBOp Lt) = error "inferPrim Lt should be unreachable" + inferPrim (PrimBOp Gt) = error "inferPrim Gt should be unreachable" + inferPrim (PrimBOp Leq) = error "inferPrim Leq should be unreachable" + inferPrim (PrimBOp Geq) = error "inferPrim Geq should be unreachable" + ------------------------------------------------------------ + + inferPrim (PrimBOp op) | op `elem` [Min, Max] = do + ty <- freshTy + constraint $ CQual QCmp ty + return $ ty :*: ty :->: ty + + -- See Note [Pattern coverage] ----------------------------- + inferPrim (PrimBOp Min) = error "inferPrim Min should be unreachable" + inferPrim (PrimBOp Max) = error "inferPrim Max should be unreachable" + ------------------------------------------------------------ + + ---------------------------------------- + -- Special arithmetic functions: fact, sqrt, floor, ceil, abs + + inferPrim (PrimUOp Fact) = return $ TyN :->: TyN + inferPrim PrimSqrt = return $ TyN :->: TyN + inferPrim p | p `elem` [PrimFloor, PrimCeil] = do + argTy <- freshTy + resTy <- cInt argTy + return $ argTy :->: resTy + + -- See Note [Pattern coverage] ----------------------------- + inferPrim PrimFloor = error "inferPrim Floor should be unreachable" + inferPrim PrimCeil = error "inferPrim Ceil should be unreachable" + ------------------------------------------------------------ + + inferPrim PrimAbs = do + argTy <- freshTy + resTy <- freshTy + cAbs argTy resTy `cOr` cSize argTy resTy + return $ argTy :->: resTy + + ---------------------------------------- + -- power set/bag + + inferPrim PrimPower = do + a <- freshTy + c <- freshAtom + + constraint $ CQual QCmp a + constraint $ + COr [ CEq (TyAtom (ABase CtrSet)) (TyAtom c) , CEq (TyAtom (ABase CtrBag)) (TyAtom c) ] - return $ TyContainer c a :->: TyContainer c (TyContainer c a) - - inferPrim PrimLookupSeq = return $ TyList TyN :->: (TyUnit :+: TyString) - inferPrim PrimExtendSeq = return $ TyList TyN :->: TyList TyN + return $ TyContainer c a :->: TyContainer c (TyContainer c a) + inferPrim PrimLookupSeq = return $ TyList TyN :->: (TyUnit :+: TyString) + inferPrim PrimExtendSeq = return $ TyList TyN :->: TyList TyN -------------------------------------------------- -- Base types -- A few trivial cases for base types. -typecheck Infer TUnit = return ATUnit -typecheck Infer (TBool b) = return $ ATBool TyBool b -typecheck Infer (TChar c) = return $ ATChar c -typecheck Infer (TString cs) = return $ ATString cs +typecheck Infer TUnit = return ATUnit +typecheck Infer (TBool b) = return $ ATBool TyBool b +typecheck Infer (TChar c) = return $ ATChar c +typecheck Infer (TString cs) = return $ ATString cs -- typecheck (Check (TyFin n)) (TNat x) = return $ ATNat (TyFin n) x -typecheck Infer (TNat n) = return $ ATNat TyN n -typecheck Infer (TRat r) = return $ ATRat r - -typecheck _ TWild = throw NoTWild - +typecheck Infer (TNat n) = return $ ATNat TyN n +typecheck Infer (TRat r) = return $ ATRat r +typecheck _ TWild = throw NoTWild -------------------------------------------------- -- Abstractions (lambdas and quantifiers) @@ -976,50 +974,48 @@ typecheck (Check checkTy) tm@(TAbs Lam body) = do -- types for all the arguments. extends ctx $ ATAbs Lam checkTy <$> (bind (coerce typedArgs) <$> check t resTy) - - where - - -- Given the patterns and their optional type annotations in the - -- head of a lambda (e.g. @x (y:Z) (f : N -> N) -> ...@), and the - -- type at which we are checking the lambda, ensure that: - -- - -- - The type is of the form @ty1 -> ty2 -> ... -> resTy@ and - -- there are enough @ty1@, @ty2@, ... to match all the arguments. - -- - Each pattern successfully checks at its corresponding type. - -- - -- If it succeeds, return a context binding variables to their - -- types (as determined by the patterns and the input types) which - -- we can use to extend when checking the body, a list of the typed - -- patterns, and the result type of the function. - checkArgs - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => [Pattern] -> Type -> Term -> Sem r (TyCtx, [APattern], Type) - - -- If we're all out of arguments, the remaining checking type is the - -- result, and there are no variables to bind in the context. - checkArgs [] ty _ = return (emptyCtx, [], ty) - - -- Take the next pattern and its annotation; the checking type must - -- be a function type ty1 -> ty2. - checkArgs (p : args) ty term = do - - -- Ensure that ty is a function type - (ty1, ty2) <- ensureConstr2 CArr ty (Left term) - - -- Check the argument pattern against the function domain. - (pCtx, pTyped) <- checkPattern p ty1 - - -- Check the rest of the arguments under the type ty2, returning a - -- context with the rest of the arguments and the final result type. - (ctx, typedArgs, resTy) <- checkArgs args ty2 term - - -- Pass the result type through, and put the pattern-bound variables - -- in the returned context. - return (pCtx <> ctx, pTyped : typedArgs, resTy) + where + -- Given the patterns and their optional type annotations in the + -- head of a lambda (e.g. @x (y:Z) (f : N -> N) -> ...@), and the + -- type at which we are checking the lambda, ensure that: + -- + -- - The type is of the form @ty1 -> ty2 -> ... -> resTy@ and + -- there are enough @ty1@, @ty2@, ... to match all the arguments. + -- - Each pattern successfully checks at its corresponding type. + -- + -- If it succeeds, return a context binding variables to their + -- types (as determined by the patterns and the input types) which + -- we can use to extend when checking the body, a list of the typed + -- patterns, and the result type of the function. + checkArgs :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + [Pattern] -> + Type -> + Term -> + Sem r (TyCtx, [APattern], Type) + + -- If we're all out of arguments, the remaining checking type is the + -- result, and there are no variables to bind in the context. + checkArgs [] ty _ = return (emptyCtx, [], ty) + -- Take the next pattern and its annotation; the checking type must + -- be a function type ty1 -> ty2. + checkArgs (p : args) ty term = do + -- Ensure that ty is a function type + (ty1, ty2) <- ensureConstr2 CArr ty (Left term) + + -- Check the argument pattern against the function domain. + (pCtx, pTyped) <- checkPattern p ty1 + + -- Check the rest of the arguments under the type ty2, returning a + -- context with the rest of the arguments and the final result type. + (ctx, typedArgs, resTy) <- checkArgs args ty2 term + + -- Pass the result type through, and put the pattern-bound variables + -- in the returned context. + return (pCtx <> ctx, pTyped : typedArgs, resTy) -- In inference mode, we handle lambdas as well as quantifiers (∀, ∃). -typecheck Infer (TAbs q lam) = do - +typecheck Infer (TAbs q lam) = do -- Open it and get the argument patterns with any type annotations. (args, t) <- unbind lam @@ -1037,7 +1033,8 @@ typecheck Infer (TAbs q lam) = do -- concrete type from annotations inside tuples. forM_ (map getType typedPats) $ \ty -> unless (isSearchable ty) $ - throw $ NoSearch ty + throw $ + NoSearch ty -- Extend the context with the given arguments, and then do -- something appropriate depending on the quantifier. @@ -1051,19 +1048,21 @@ typecheck Infer (TAbs q lam) = do -- For other quantifiers, check that the body has type Prop, -- and return Prop. - _ -> do -- ∀, ∃ + _ -> do + -- ∀, ∃ at <- check t TyProp return $ ATAbs q TyProp (bind typedPats at) - where - getAscrOrFresh - :: Members '[Reader TyDefCtx, Error TCError, Fresh] r - => Pattern -> Sem r Type - getAscrOrFresh (PAscr _ ty) = checkTypeValid ty >> pure ty - getAscrOrFresh _ = freshTy - - -- mkFunTy [ty1, ..., tyn] out = ty1 -> (ty2 -> ... (tyn -> out)) - mkFunTy :: [Type] -> Type -> Type - mkFunTy tys out = foldr (:->:) out tys + where + getAscrOrFresh :: + Members '[Reader TyDefCtx, Error TCError, Fresh] r => + Pattern -> + Sem r Type + getAscrOrFresh (PAscr _ ty) = checkTypeValid ty >> pure ty + getAscrOrFresh _ = freshTy + + -- mkFunTy [ty1, ..., tyn] out = ty1 -> (ty2 -> ... (tyn -> out)) + mkFunTy :: [Type] -> Type -> Type + mkFunTy tys out = foldr (:->:) out tys -------------------------------------------------- -- Application @@ -1071,7 +1070,7 @@ typecheck Infer (TAbs q lam) = do -- Infer the type of a function application by inferring the function -- type and then checking the argument type. We don't need a checking -- case because checking mode doesn't help. -typecheck Infer (TApp t t') = do +typecheck Infer (TApp t t') = do at <- infer t let ty = getType at (ty1, ty2) <- ensureConstr2 CArr ty (Left t) @@ -1082,34 +1081,37 @@ typecheck Infer (TApp t t') = do -- Check/infer the type of a tuple. typecheck mode1 (TTup tup) = uncurry ATTup <$> typecheckTuple mode1 tup - where - typecheckTuple - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Mode -> [Term] -> Sem r (Type, [ATerm]) - typecheckTuple _ [] = error "Impossible! typecheckTuple []" - typecheckTuple mode [t] = (getType &&& (:[])) <$> typecheck mode t - typecheckTuple mode (t:ts) = do - (m,ms) <- ensureConstrMode2 CProd mode (Left $ TTup (t:ts)) - at <- typecheck m t - (ty, ats) <- typecheckTuple ms ts - return (getType at :*: ty, at : ats) + where + typecheckTuple :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Mode -> + [Term] -> + Sem r (Type, [ATerm]) + typecheckTuple _ [] = error "Impossible! typecheckTuple []" + typecheckTuple mode [t] = (getType &&& (: [])) <$> typecheck mode t + typecheckTuple mode (t : ts) = do + (m, ms) <- ensureConstrMode2 CProd mode (Left $ TTup (t : ts)) + at <- typecheck m t + (ty, ats) <- typecheckTuple ms ts + return (getType at :*: ty, at : ats) ---------------------------------------- -- Comparison chain typecheck Infer (TChain t ls) = ATChain TyBool <$> infer t <*> inferChain t ls - - where - inferChain - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Term -> [Link] -> Sem r [ALink] - inferChain _ [] = return [] - inferChain t1 (TLink op t2 : links) = do - at2 <- infer t2 - _ <- check (TBin op t1 t2) TyBool - atl <- inferChain t2 links - return $ ATLink op at2 : atl + where + inferChain :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Term -> + [Link] -> + Sem r [ALink] + inferChain _ [] = return [] + inferChain t1 (TLink op t2 : links) = do + at2 <- infer t2 + _ <- check (TBin op t1 t2) TyBool + atl <- inferChain t2 links + return $ ATLink op at2 : atl ---------------------------------------- -- Type operations @@ -1117,8 +1119,7 @@ typecheck Infer (TChain t ls) = typecheck Infer (TTyOp Enumerate t) = do checkTypeValid t return $ ATTyOp (TyList t) Enumerate t - -typecheck Infer (TTyOp Count t) = do +typecheck Infer (TTyOp Count t) = do checkTypeValid t return $ ATTyOp (TyUnit :+: TyN) Count t @@ -1126,14 +1127,14 @@ typecheck Infer (TTyOp Count t) = do -- Containers -- Literal containers, including ellipses -typecheck mode t@(TContainer c xs ell) = do +typecheck mode t@(TContainer c xs ell) = do eltMode <- ensureConstrMode1 (containerToCon c) mode (Left t) - axns <- mapM (\(x,n) -> (,) <$> typecheck eltMode x <*> traverse (`check` TyN) n) xs - aell <- typecheckEllipsis eltMode ell + axns <- mapM (\(x, n) -> (,) <$> typecheck eltMode x <*> traverse (`check` TyN) n) xs + aell <- typecheckEllipsis eltMode ell resTy <- case mode of Infer -> do - let tys = [ getType at | Just (Until at) <- [aell] ] ++ map (getType . fst) axns - tyv <- freshTy + let tys = [getType at | Just (Until at) <- [aell]] ++ map (getType . fst) axns + tyv <- freshTy constraints $ map (`CSub` tyv) tys return $ containerTy c tyv Check ty -> return ty @@ -1143,13 +1144,14 @@ typecheck mode t@(TContainer c xs ell) = do when (c /= ListContainer && not (P.null xs)) $ constraint $ CQual QCmp eltTy when (isJust ell) $ constraint $ CQual QEnum eltTy return $ ATContainer resTy c axns aell - - where - typecheckEllipsis - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Mode -> Maybe (Ellipsis Term) -> Sem r (Maybe (Ellipsis ATerm)) - typecheckEllipsis _ Nothing = return Nothing - typecheckEllipsis m (Just (Until tm)) = Just . Until <$> typecheck m tm + where + typecheckEllipsis :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Mode -> + Maybe (Ellipsis Term) -> + Sem r (Maybe (Ellipsis ATerm)) + typecheckEllipsis _ Nothing = return Nothing + typecheckEllipsis m (Just (Until tm)) = Just . Until <$> typecheck m tm -- ~~~~ Note [Container literal constraints] -- @@ -1191,27 +1193,26 @@ typecheck mode t@(TContainer c xs ell) = do -- Container comprehensions typecheck mode tcc@(TContainerComp c bqt) = do eltMode <- ensureConstrMode1 (containerToCon c) mode (Left tcc) - (qs, t) <- unbind bqt + (qs, t) <- unbind bqt (aqs, cx) <- inferTelescope inferQual qs extends cx $ do at <- typecheck eltMode t let resTy = case mode of - Infer -> containerTy c (getType at) + Infer -> containerTy c (getType at) Check ty -> ty return $ ATContainerComp resTy c (bind aqs at) - - where - inferQual - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Qual -> Sem r (AQual, TyCtx) - inferQual (QBind x (unembed -> t)) = do - at <- infer t - ty <- ensureConstr1 (containerToCon c) (getType at) (Left t) - return (AQBind (coerce x) (embed at), singleCtx (localName x) (toPolyType ty)) - - inferQual (QGuard (unembed -> t)) = do - at <- check t TyBool - return (AQGuard (embed at), emptyCtx) + where + inferQual :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Qual -> + Sem r (AQual, TyCtx) + inferQual (QBind x (unembed -> t)) = do + at <- infer t + ty <- ensureConstr1 (containerToCon c) (getType at) (Left t) + return (AQBind (coerce x) (embed at), singleCtx (localName x) (toPolyType ty)) + inferQual (QGuard (unembed -> t)) = do + at <- check t TyBool + return (AQGuard (embed at), emptyCtx) -------------------------------------------------- -- Let @@ -1227,69 +1228,69 @@ typecheck mode (TLet l) = do extends ctx $ do at2 <- typecheck mode t2 return $ ATLet (getType at2) (bind as at2) - - where - - -- Infer the type of a binding (@x [: ty] = t@), returning a - -- type-annotated binding along with a (singleton) context for the - -- bound variable. The optional type annotation on the variable - -- determines whether we use inference or checking mode for the - -- body. - inferBinding - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Binding -> Sem r (ABinding, TyCtx) - inferBinding (Binding mty x (unembed -> t)) = do - at <- case mty of - Just (unembed -> ty) -> checkPolyTy t ty - Nothing -> infer t - return (ABinding mty (coerce x) (embed at), singleCtx (localName x) (toPolyType $ getType at)) + where + -- Infer the type of a binding (@x [: ty] = t@), returning a + -- type-annotated binding along with a (singleton) context for the + -- bound variable. The optional type annotation on the variable + -- determines whether we use inference or checking mode for the + -- body. + inferBinding :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Binding -> + Sem r (ABinding, TyCtx) + inferBinding (Binding mty x (unembed -> t)) = do + at <- case mty of + Just (unembed -> ty) -> checkPolyTy t ty + Nothing -> infer t + return (ABinding mty (coerce x) (embed at), singleCtx (localName x) (toPolyType $ getType at)) -------------------------------------------------- -- Case -- Check/infer a case expression. -typecheck _ (TCase []) = throw EmptyCase +typecheck _ (TCase []) = throw EmptyCase typecheck mode (TCase bs) = do bs' <- mapM typecheckBranch bs resTy <- case mode of Check ty -> return ty - Infer -> do + Infer -> do x <- freshTy constraints $ map ((`CSub` x) . getType) bs' return x return $ ATCase resTy bs' - - where - typecheckBranch - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Branch -> Sem r ABranch - typecheckBranch b = do - (gs, t) <- unbind b - (ags, ctx) <- inferTelescope inferGuard gs - extends ctx $ - bind ags <$> typecheck mode t - - -- Infer the type of a guard, returning the type-annotated guard - -- along with a context of types for any variables bound by the - -- guard. - inferGuard - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Guard -> Sem r (AGuard, TyCtx) - inferGuard (GBool (unembed -> t)) = do - at <- check t TyBool - return (AGBool (embed at), emptyCtx) - inferGuard (GPat (unembed -> t) p) = do - at <- infer t - (ctx, apt) <- checkPattern p (getType at) - return (AGPat (embed at) apt, ctx) - inferGuard (GLet (Binding mty x (unembed -> t))) = do - at <- case mty of - Just (unembed -> ty) -> checkPolyTy t ty - Nothing -> infer t - return - ( AGLet (ABinding mty (coerce x) (embed at)) - , singleCtx (localName x) (toPolyType (getType at)) - ) + where + typecheckBranch :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Branch -> + Sem r ABranch + typecheckBranch b = do + (gs, t) <- unbind b + (ags, ctx) <- inferTelescope inferGuard gs + extends ctx $ + bind ags <$> typecheck mode t + + -- Infer the type of a guard, returning the type-annotated guard + -- along with a context of types for any variables bound by the + -- guard. + inferGuard :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Guard -> + Sem r (AGuard, TyCtx) + inferGuard (GBool (unembed -> t)) = do + at <- check t TyBool + return (AGBool (embed at), emptyCtx) + inferGuard (GPat (unembed -> t) p) = do + at <- infer t + (ctx, apt) <- checkPattern p (getType at) + return (AGPat (embed at) apt, ctx) + inferGuard (GLet (Binding mty x (unembed -> t))) = do + at <- case mty of + Just (unembed -> ty) -> checkPolyTy t ty + Nothing -> infer t + return + ( AGLet (ABinding mty (coerce x) (embed at)) + , singleCtx (localName x) (toPolyType (getType at)) + ) -------------------------------------------------- -- Type ascription @@ -1297,7 +1298,6 @@ typecheck mode (TCase bs) = do -- Ascriptions are what let us flip from inference mode into -- checking mode. typecheck Infer (TAscr t ty) = checkPolyTyValid ty >> checkPolyTy t ty - -------------------------------------------------- -- Inference fallback @@ -1316,18 +1316,15 @@ typecheck (Check ty) t = do -- | Check that a pattern has the given type, and return a context of -- pattern variables bound in the pattern along with their types. -checkPattern - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Pattern -> Type -> Sem r (TyCtx, APattern) - +checkPattern :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Pattern -> + Type -> + Sem r (TyCtx, APattern) checkPattern (PNonlinear p x) _ = throw $ NonlinearPattern p x - checkPattern p (TyUser name args) = lookupTyDefn name args >>= checkPattern p - checkPattern (PVar x) ty = return (singleCtx (localName x) (toPolyType ty), APVar ty (coerce x)) - -checkPattern PWild ty = return (emptyCtx, APWild ty) - +checkPattern PWild ty = return (emptyCtx, APWild ty) checkPattern (PAscr p ty1) ty2 = do -- We have a pattern that promises to match ty1 and someone is asking -- us if it can also match ty2. So we just have to ensure what we're @@ -1335,47 +1332,43 @@ checkPattern (PAscr p ty1) ty2 = do constraint $ CSub ty2 ty1 -- ... and then make sure the pattern can actually match what it promised to. checkPattern p ty1 - checkPattern PUnit ty = do ensureEq ty TyUnit return (emptyCtx, APUnit) - checkPattern (PBool b) ty = do ensureEq ty TyBool return (emptyCtx, APBool b) - checkPattern (PChar c) ty = do ensureEq ty TyC return (emptyCtx, APChar c) - checkPattern (PString s) ty = do ensureEq ty TyString return (emptyCtx, APString s) - checkPattern (PTup tup) tupTy = do listCtxtAps <- checkTuplePat tup tupTy let (ctxs, aps) = unzip listCtxtAps return (mconcat ctxs, APTup (foldr1 (:*:) (map getType aps)) aps) - - where - checkTuplePat - :: Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => [Pattern] -> Type -> Sem r [(TyCtx, APattern)] - checkTuplePat [] _ = error "Impossible! checkTuplePat []" - checkTuplePat [p] ty = do -- (:[]) <$> check t ty - (ctx, apt) <- checkPattern p ty - return [(ctx, apt)] - checkTuplePat (p:ps) ty = do - (ty1, ty2) <- ensureConstr2 CProd ty (Right $ PTup (p:ps)) - (ctx, apt) <- checkPattern p ty1 - rest <- checkTuplePat ps ty2 - return ((ctx, apt) : rest) - -checkPattern p@(PInj L pat) ty = do + where + checkTuplePat :: + Members '[Reader TyCtx, Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + [Pattern] -> + Type -> + Sem r [(TyCtx, APattern)] + checkTuplePat [] _ = error "Impossible! checkTuplePat []" + checkTuplePat [p] ty = do + -- (:[]) <$> check t ty + (ctx, apt) <- checkPattern p ty + return [(ctx, apt)] + checkTuplePat (p : ps) ty = do + (ty1, ty2) <- ensureConstr2 CProd ty (Right $ PTup (p : ps)) + (ctx, apt) <- checkPattern p ty1 + rest <- checkTuplePat ps ty2 + return ((ctx, apt) : rest) +checkPattern p@(PInj L pat) ty = do (ty1, ty2) <- ensureConstr2 CSum ty (Right p) (ctx, apt) <- checkPattern pat ty1 return (ctx, APInj (ty1 :+: ty2) L apt) -checkPattern p@(PInj R pat) ty = do +checkPattern p@(PInj R pat) ty = do (ty1, ty2) <- ensureConstr2 CSum ty (Right p) (ctx, apt) <- checkPattern pat ty2 return (ctx, APInj (ty1 :+: ty2) R apt) @@ -1396,46 +1389,39 @@ checkPattern p@(PInj R pat) ty = do -- false -- checkPattern (PNat n) (TyFin m) = return (emptyCtx, APNat (TyFin m) n) -checkPattern (PNat n) ty = do +checkPattern (PNat n) ty = do constraint $ CSub TyN ty return (emptyCtx, APNat ty n) - checkPattern p@(PCons p1 p2) ty = do tyl <- ensureConstr1 CList ty (Right p) (ctx1, ap1) <- checkPattern p1 tyl (ctx2, ap2) <- checkPattern p2 (TyList tyl) return (ctx1 <> ctx2, APCons (TyList tyl) ap1 ap2) - checkPattern p@(PList ps) ty = do tyl <- ensureConstr1 CList ty (Right p) listCtxtAps <- mapM (`checkPattern` tyl) ps let (ctxs, aps) = unzip listCtxtAps return (mconcat ctxs, APList (TyList tyl) aps) - checkPattern (PAdd s p t) ty = do constraint $ CQual QNum ty (ctx, apt) <- checkPattern p ty at <- check t ty return (ctx, APAdd ty s apt at) - checkPattern (PMul s p t) ty = do constraint $ CQual QNum ty (ctx, apt) <- checkPattern p ty at <- check t ty return (ctx, APMul ty s apt at) - checkPattern (PSub p t) ty = do constraint $ CQual QNum ty (ctx, apt) <- checkPattern p ty at <- check t ty return (ctx, APSub ty apt at) - checkPattern (PNeg p) ty = do constraint $ CQual QSub ty tyInner <- cPos ty (ctx, apt) <- checkPattern p tyInner return (ctx, APNeg ty apt) - checkPattern (PFrac p q) ty = do constraint $ CQual QDiv ty tyP <- cInt ty @@ -1469,12 +1455,11 @@ cSize argTy resTy = do -- appropriate constraints. cPos :: Members '[Writer Constraint, Fresh] r => Type -> Sem r Type cPos ty = do - constraint $ CQual QNum ty -- The input type has to be numeric. + constraint $ CQual QNum ty -- The input type has to be numeric. case ty of -- If the input type is a concrete base type, we can just -- compute the correct output type. TyAtom (ABase b) -> return $ TyAtom (ABase (pos b)) - -- Otherwise, generate a fresh type variable for the output type -- along with some constraints. _ -> do @@ -1482,16 +1467,17 @@ cPos ty = do -- Valid types for absolute value are Z -> N, Q -> F, or T -> T -- (e.g. Z5 -> Z5). - constraint $ COr - [ cAnd [CSub ty TyZ, CSub TyN res] - , cAnd [CSub ty TyQ, CSub TyF res] - , CEq ty res - ] + constraint $ + COr + [ cAnd [CSub ty TyZ, CSub TyN res] + , cAnd [CSub ty TyQ, CSub TyF res] + , CEq ty res + ] return res - where - pos Z = N - pos Q = F - pos t = t + where + pos Z = N + pos Q = F + pos t = t -- | Given an input type @ty@, return a type which represents the -- output type of the floor or ceiling functions, and generate @@ -1503,7 +1489,6 @@ cInt ty = do -- If the input type is a concrete base type, we can just -- compute the correct output type. TyAtom (ABase b) -> return $ TyAtom (ABase (int b)) - -- Otherwise, generate a fresh type variable for the output type -- along with some constraints. _ -> do @@ -1511,17 +1496,17 @@ cInt ty = do -- Valid types for absolute value are F -> N, Q -> Z, or T -> T -- (e.g. Z5 -> Z5). - constraint $ COr - [ cAnd [CSub ty TyF, CSub TyN res] - , cAnd [CSub ty TyQ, CSub TyZ res] - , CEq ty res - ] + constraint $ + COr + [ cAnd [CSub ty TyF, CSub TyN res] + , cAnd [CSub ty TyQ, CSub TyZ res] + , CEq ty res + ] return res - - where - int F = N - int Q = Z - int t = t + where + int F = N + int Q = Z + int t = t -- | Given input types to the exponentiation operator, return a type -- which represents the output type, and generate appropriate @@ -1535,7 +1520,6 @@ cExp ty1 TyN = do -- a function to find a supertype of a given type that satisfies QDiv. cExp ty1 ty2 = do - -- Create a fresh type variable to represent the result type. The -- base type has to be a subtype. resTy <- freshTy @@ -1544,10 +1528,11 @@ cExp ty1 ty2 = do -- Either the exponent type is N, in which case the result type has -- to support multiplication, or else the exponent is Z, in which -- case the result type also has to support division. - constraint $ COr - [ cAnd [CQual QNum resTy, CEq ty2 TyN] - , cAnd [CQual QDiv resTy, CEq ty2 TyZ] - ] + constraint $ + COr + [ cAnd [CQual QNum resTy, CEq ty2 TyN] + , cAnd [CQual QDiv resTy, CEq ty2 TyZ] + ] return resTy ------------------------------------------------------------ @@ -1571,107 +1556,143 @@ getEltTy c ty = do -- type variable's outermost constructor matches the provided -- constructor, and a list of fresh type variables is returned whose -- count matches the arity of the provided constructor. -ensureConstr - :: forall r. Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Con -> Type -> Either Term Pattern -> Sem r [Type] +ensureConstr :: + forall r. + Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Con -> + Type -> + Either Term Pattern -> + Sem r [Type] ensureConstr c ty targ = matchConTy c ty - where - matchConTy :: Con -> Type -> Sem r [Type] - - -- expand type definitions lazily - matchConTy c1 (TyUser name args) = lookupTyDefn name args >>= matchConTy c1 - - matchConTy c1 (TyCon c2 tys) = do - matchCon c1 c2 - return tys - - matchConTy c1 tyv@(TyAtom (AVar (U _))) = do - tyvs <- mapM (const freshTy) (arity c1) - constraint $ CEq tyv (TyCon c1 tyvs) - return tyvs - - matchConTy _ _ = matchError - - -- | Check whether two constructors match, which could include - -- unifying container variables if we are matching two container - -- types; otherwise, simply ensure that the constructors are - -- equal. Throw a 'matchError' if they do not match. - matchCon :: Con -> Con -> Sem r () - matchCon c1 c2 | c1 == c2 = return () - matchCon (CContainer v@(AVar (U _))) (CContainer ctr2) = - constraint $ CEq (TyAtom v) (TyAtom ctr2) - matchCon (CContainer ctr1) (CContainer v@(AVar (U _))) = - constraint $ CEq (TyAtom ctr1) (TyAtom v) - matchCon _ _ = matchError - - matchError :: Sem r a - matchError = case targ of - Left term -> throw (NotCon c term ty) - Right pat -> throw (PatternType c pat ty) + where + matchConTy :: Con -> Type -> Sem r [Type] + + -- expand type definitions lazily + matchConTy c1 (TyUser name args) = lookupTyDefn name args >>= matchConTy c1 + matchConTy c1 (TyCon c2 tys) = do + matchCon c1 c2 + return tys + matchConTy c1 tyv@(TyAtom (AVar (U _))) = do + tyvs <- mapM (const freshTy) (arity c1) + constraint $ CEq tyv (TyCon c1 tyvs) + return tyvs + matchConTy _ _ = matchError + + -- \| Check whether two constructors match, which could include + -- unifying container variables if we are matching two container + -- types; otherwise, simply ensure that the constructors are + -- equal. Throw a 'matchError' if they do not match. + matchCon :: Con -> Con -> Sem r () + matchCon c1 c2 | c1 == c2 = return () + matchCon (CContainer v@(AVar (U _))) (CContainer ctr2) = + constraint $ CEq (TyAtom v) (TyAtom ctr2) + matchCon (CContainer ctr1) (CContainer v@(AVar (U _))) = + constraint $ CEq (TyAtom ctr1) (TyAtom v) + matchCon _ _ = matchError + + matchError :: Sem r a + matchError = case targ of + Left term -> throw (NotCon c term ty) + Right pat -> throw (PatternType c pat ty) -- | A variant of ensureConstr that expects to get exactly one -- argument type out, and throws an error if we get any other -- number. -ensureConstr1 - :: Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Con -> Type -> Either Term Pattern -> Sem r Type +ensureConstr1 :: + Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Con -> + Type -> + Either Term Pattern -> + Sem r Type ensureConstr1 c ty targ = do tys <- ensureConstr c ty targ case tys of [ty1] -> return ty1 - _ -> error $ - "Impossible! Wrong number of arg types in ensureConstr1 " ++ show c ++ " " - ++ show ty ++ ": " ++ show tys + _ -> + error $ + "Impossible! Wrong number of arg types in ensureConstr1 " + ++ show c + ++ " " + ++ show ty + ++ ": " + ++ show tys -- | A variant of ensureConstr that expects to get exactly two -- argument types out, and throws an error if we get any other -- number. -ensureConstr2 - :: Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Con -> Type -> Either Term Pattern -> Sem r (Type, Type) -ensureConstr2 c ty targ = do +ensureConstr2 :: + Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Con -> + Type -> + Either Term Pattern -> + Sem r (Type, Type) +ensureConstr2 c ty targ = do tys <- ensureConstr c ty targ case tys of [ty1, ty2] -> return (ty1, ty2) - _ -> error $ - "Impossible! Wrong number of arg types in ensureConstr2 " ++ show c ++ " " - ++ show ty ++ ": " ++ show tys + _ -> + error $ + "Impossible! Wrong number of arg types in ensureConstr2 " + ++ show c + ++ " " + ++ show ty + ++ ": " + ++ show tys -- | A variant of 'ensureConstr' that works on 'Mode's instead of -- 'Type's. Behaves similarly to 'ensureConstr' if the 'Mode' is -- 'Check'; otherwise it generates an appropriate number of copies -- of 'Infer'. -ensureConstrMode - :: Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Con -> Mode -> Either Term Pattern -> Sem r [Mode] -ensureConstrMode c Infer _ = return $ map (const Infer) (arity c) +ensureConstrMode :: + Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Con -> + Mode -> + Either Term Pattern -> + Sem r [Mode] +ensureConstrMode c Infer _ = return $ map (const Infer) (arity c) ensureConstrMode c (Check ty) tp = map Check <$> ensureConstr c ty tp -- | A variant of 'ensureConstrMode' that expects to get a single -- 'Mode' and throws an error if it encounters any other number. -ensureConstrMode1 - :: Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Con -> Mode -> Either Term Pattern -> Sem r Mode +ensureConstrMode1 :: + Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Con -> + Mode -> + Either Term Pattern -> + Sem r Mode ensureConstrMode1 c m targ = do ms <- ensureConstrMode c m targ case ms of [m1] -> return m1 - _ -> error $ - "Impossible! Wrong number of arg types in ensureConstrMode1 " ++ show c ++ " " - ++ show m ++ ": " ++ show ms + _ -> + error $ + "Impossible! Wrong number of arg types in ensureConstrMode1 " + ++ show c + ++ " " + ++ show m + ++ ": " + ++ show ms -- | A variant of 'ensureConstrMode' that expects to get two 'Mode's -- and throws an error if it encounters any other number. -ensureConstrMode2 - :: Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r - => Con -> Mode -> Either Term Pattern -> Sem r (Mode, Mode) +ensureConstrMode2 :: + Members '[Reader TyDefCtx, Writer Constraint, Error TCError, Fresh] r => + Con -> + Mode -> + Either Term Pattern -> + Sem r (Mode, Mode) ensureConstrMode2 c m targ = do ms <- ensureConstrMode c m targ case ms of [m1, m2] -> return (m1, m2) - _ -> error $ - "Impossible! Wrong number of arg types in ensureConstrMode2 " ++ show c ++ " " - ++ show m ++ ": " ++ show ms + _ -> + error $ + "Impossible! Wrong number of arg types in ensureConstrMode2 " + ++ show c + ++ " " + ++ show m + ++ ": " + ++ show ms -- | Ensure that two types are equal: -- 1. Do nothing if they are literally equal @@ -1679,4 +1700,4 @@ ensureConstrMode2 c m targ = do ensureEq :: Member (Writer Constraint) r => Type -> Type -> Sem r () ensureEq ty1 ty2 | ty1 == ty2 = return () - | otherwise = constraint $ CEq ty1 ty2 + | otherwise = constraint $ CEq ty1 ty2 diff --git a/src/Disco/Typecheck/Constraints.hs b/src/Disco/Typecheck/Constraints.hs index 7495731e..ea669e98 100644 --- a/src/Disco/Typecheck/Constraints.hs +++ b/src/Disco/Typecheck/Constraints.hs @@ -1,7 +1,10 @@ -{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE OverloadedStrings #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Typecheck.Constraints -- Copyright : disco team and contributors @@ -10,73 +13,69 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Constraints generated by type inference & checking. --- ------------------------------------------------------------------------------ - -module Disco.Typecheck.Constraints - ( Constraint(..) - , cAnd - ) - where +module Disco.Typecheck.Constraints ( + Constraint (..), + cAnd, +) +where -import qualified Data.List.NonEmpty as NE -import Data.Semigroup -import GHC.Generics (Generic) -import Unbound.Generics.LocallyNameless hiding (lunbind) +import qualified Data.List.NonEmpty as NE +import Data.Semigroup +import GHC.Generics (Generic) +import Unbound.Generics.LocallyNameless hiding (lunbind) -import Disco.Effects.LFresh +import Disco.Effects.LFresh -import Disco.Pretty hiding ((<>)) -import Disco.Syntax.Operators (BFixity (In, InL, InR)) -import Disco.Types -import Disco.Types.Rules +import Disco.Pretty hiding ((<>)) +import Disco.Syntax.Operators (BFixity (In, InL, InR)) +import Disco.Types +import Disco.Types.Rules -- | Constraints are generated as a result of type inference and checking. -- These constraints are accumulated during the inference and checking phase -- and are subsequently solved by the constraint solver. data Constraint where - CSub :: Type -> Type -> Constraint - CEq :: Type -> Type -> Constraint - CQual :: Qualifier -> Type -> Constraint - CAnd :: [Constraint] -> Constraint - CTrue :: Constraint - COr :: [Constraint] -> Constraint - CAll :: Bind [Name Type] Constraint -> Constraint - + CSub :: Type -> Type -> Constraint + CEq :: Type -> Type -> Constraint + CQual :: Qualifier -> Type -> Constraint + CAnd :: [Constraint] -> Constraint + CTrue :: Constraint + COr :: [Constraint] -> Constraint + CAll :: Bind [Name Type] Constraint -> Constraint deriving (Show, Generic, Alpha, Subst Type) instance Pretty Constraint where pretty = \case - CSub ty1 ty2 -> withPA (PA 4 In) $ lt (pretty ty1) <+> "<:" <+> rt (pretty ty2) - CEq ty1 ty2 -> withPA (PA 4 In) $ lt (pretty ty1) <+> "=" <+> rt (pretty ty2) - CQual q ty -> withPA (PA 10 InL) $ lt (pretty q) <+> rt (pretty ty) - CAnd [c] -> pretty c - -- Use rt for both, since we don't need to print parens for /\ at all - CAnd (c:cs) -> withPA (PA 3 InR) $ rt (pretty c) <+> "/\\" <+> rt (pretty (CAnd cs)) - CAnd [] -> "True" - CTrue -> "True" - COr [c] -> pretty c - COr (c:cs) -> withPA (PA 2 InR) $ lt (pretty c) <+> "\\/" <+> rt (pretty (COr cs)) - COr [] -> "False" - CAll b -> lunbind b $ \(xs, c) -> + CSub ty1 ty2 -> withPA (PA 4 In) $ lt (pretty ty1) <+> "<:" <+> rt (pretty ty2) + CEq ty1 ty2 -> withPA (PA 4 In) $ lt (pretty ty1) <+> "=" <+> rt (pretty ty2) + CQual q ty -> withPA (PA 10 InL) $ lt (pretty q) <+> rt (pretty ty) + CAnd [c] -> pretty c + -- Use rt for both, since we don't need to print parens for /\ at all + CAnd (c : cs) -> withPA (PA 3 InR) $ rt (pretty c) <+> "/\\" <+> rt (pretty (CAnd cs)) + CAnd [] -> "True" + CTrue -> "True" + COr [c] -> pretty c + COr (c : cs) -> withPA (PA 2 InR) $ lt (pretty c) <+> "\\/" <+> rt (pretty (COr cs)) + COr [] -> "False" + CAll b -> lunbind b $ \(xs, c) -> "∀" <+> intercalate "," (map pretty xs) <> "." <+> pretty c -- A helper function for creating a single constraint from a list of constraints. cAnd :: [Constraint] -> Constraint cAnd cs = case filter nontrivial cs of - [] -> CTrue + [] -> CTrue [c] -> c cs' -> CAnd cs' - where - nontrivial CTrue = False - nontrivial _ = True + where + nontrivial CTrue = False + nontrivial _ = True instance Semigroup Constraint where - c1 <> c2 = cAnd [c1,c2] - sconcat = cAnd . NE.toList - stimes = stimesIdempotent + c1 <> c2 = cAnd [c1, c2] + sconcat = cAnd . NE.toList + stimes = stimesIdempotent instance Monoid Constraint where - mempty = CTrue + mempty = CTrue mappend = (<>) mconcat = cAnd diff --git a/src/Disco/Typecheck/Erase.hs b/src/Disco/Typecheck/Erase.hs index 3794889c..93815ae4 100644 --- a/src/Disco/Typecheck/Erase.hs +++ b/src/Disco/Typecheck/Erase.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Typecheck.Erase -- Copyright : (c) 2016 disco team (see LICENSE) @@ -7,82 +10,83 @@ -- -- Typecheck the Disco surface language and transform it into a -- type-annotated AST. --- ------------------------------------------------------------------------------ - module Disco.Typecheck.Erase where -import Unbound.Generics.LocallyNameless -import Unbound.Generics.LocallyNameless.Unsafe +import Unbound.Generics.LocallyNameless +import Unbound.Generics.LocallyNameless.Unsafe -import Control.Arrow ((***)) -import Data.Coerce +import Control.Arrow ((***)) +import Data.Coerce -import Disco.AST.Desugared -import Disco.AST.Surface -import Disco.AST.Typed -import Disco.Names (QName (..)) +import Disco.AST.Desugared +import Disco.AST.Surface +import Disco.AST.Typed +import Disco.Names (QName (..)) -- | Erase all the type annotations from a term. erase :: ATerm -> Term erase (ATVar _ (QName _ x)) = TVar (coerce x) -erase (ATPrim _ x) = TPrim x -erase (ATLet _ bs) = TLet $ bind (mapTelescope eraseBinding tel) (erase at) - where (tel,at) = unsafeUnbind bs -erase ATUnit = TUnit -erase (ATBool _ b) = TBool b -erase (ATChar c) = TChar c -erase (ATString s) = TString s -erase (ATNat _ i) = TNat i -erase (ATRat r) = TRat r -erase (ATAbs q _ b) = TAbs q $ bind (map erasePattern x) (erase at) - where (x,at) = unsafeUnbind b -erase (ATApp _ t1 t2) = TApp (erase t1) (erase t2) -erase (ATTup _ ats) = TTup (map erase ats) -erase (ATCase _ brs) = TCase (map eraseBranch brs) -erase (ATChain _ at lnks) = TChain (erase at) (map eraseLink lnks) -erase (ATTyOp _ op ty) = TTyOp op ty -erase (ATContainer _ c ats aell) = TContainer c (map (erase *** fmap erase) ats) ((fmap . fmap) erase aell) -erase (ATContainerComp _ c b) = TContainerComp c $ bind (mapTelescope eraseQual tel) (erase at) - where (tel,at) = unsafeUnbind b -erase (ATTest _ x) = erase x +erase (ATPrim _ x) = TPrim x +erase (ATLet _ bs) = TLet $ bind (mapTelescope eraseBinding tel) (erase at) + where + (tel, at) = unsafeUnbind bs +erase ATUnit = TUnit +erase (ATBool _ b) = TBool b +erase (ATChar c) = TChar c +erase (ATString s) = TString s +erase (ATNat _ i) = TNat i +erase (ATRat r) = TRat r +erase (ATAbs q _ b) = TAbs q $ bind (map erasePattern x) (erase at) + where + (x, at) = unsafeUnbind b +erase (ATApp _ t1 t2) = TApp (erase t1) (erase t2) +erase (ATTup _ ats) = TTup (map erase ats) +erase (ATCase _ brs) = TCase (map eraseBranch brs) +erase (ATChain _ at lnks) = TChain (erase at) (map eraseLink lnks) +erase (ATTyOp _ op ty) = TTyOp op ty +erase (ATContainer _ c ats aell) = TContainer c (map (erase *** fmap erase) ats) ((fmap . fmap) erase aell) +erase (ATContainerComp _ c b) = TContainerComp c $ bind (mapTelescope eraseQual tel) (erase at) + where + (tel, at) = unsafeUnbind b +erase (ATTest _ x) = erase x eraseBinding :: ABinding -> Binding eraseBinding (ABinding mty x (unembed -> at)) = Binding mty (coerce x) (embed (erase at)) erasePattern :: APattern -> Pattern -erasePattern (APVar _ n) = PVar (coerce n) -erasePattern (APWild _) = PWild -erasePattern APUnit = PUnit -erasePattern (APBool b) = PBool b -erasePattern (APChar c) = PChar c -erasePattern (APString s) = PString s -erasePattern (APTup _ alp) = PTup $ map erasePattern alp -erasePattern (APInj _ s apt) = PInj s (erasePattern apt) -erasePattern (APNat _ n) = PNat n +erasePattern (APVar _ n) = PVar (coerce n) +erasePattern (APWild _) = PWild +erasePattern APUnit = PUnit +erasePattern (APBool b) = PBool b +erasePattern (APChar c) = PChar c +erasePattern (APString s) = PString s +erasePattern (APTup _ alp) = PTup $ map erasePattern alp +erasePattern (APInj _ s apt) = PInj s (erasePattern apt) +erasePattern (APNat _ n) = PNat n erasePattern (APCons _ ap1 ap2) = PCons (erasePattern ap1) (erasePattern ap2) -erasePattern (APList _ alp) = PList $ map erasePattern alp -erasePattern (APAdd _ s p t) = PAdd s (erasePattern p) (erase t) -erasePattern (APMul _ s p t) = PMul s (erasePattern p) (erase t) -erasePattern (APSub _ p t) = PSub (erasePattern p) (erase t) -erasePattern (APNeg _ p) = PNeg (erasePattern p) -erasePattern (APFrac _ p1 p2) = PFrac (erasePattern p1) (erasePattern p2) +erasePattern (APList _ alp) = PList $ map erasePattern alp +erasePattern (APAdd _ s p t) = PAdd s (erasePattern p) (erase t) +erasePattern (APMul _ s p t) = PMul s (erasePattern p) (erase t) +erasePattern (APSub _ p t) = PSub (erasePattern p) (erase t) +erasePattern (APNeg _ p) = PNeg (erasePattern p) +erasePattern (APFrac _ p1 p2) = PFrac (erasePattern p1) (erasePattern p2) eraseBranch :: ABranch -> Branch eraseBranch b = bind (mapTelescope eraseGuard tel) (erase at) - where (tel,at) = unsafeUnbind b + where + (tel, at) = unsafeUnbind b eraseGuard :: AGuard -> Guard -eraseGuard (AGBool (unembed -> at)) = GBool (embed (erase at)) +eraseGuard (AGBool (unembed -> at)) = GBool (embed (erase at)) eraseGuard (AGPat (unembed -> at) p) = GPat (embed (erase at)) (erasePattern p) -eraseGuard (AGLet b) = GLet (eraseBinding b) +eraseGuard (AGLet b) = GLet (eraseBinding b) eraseLink :: ALink -> Link eraseLink (ATLink bop at) = TLink bop (erase at) eraseQual :: AQual -> Qual eraseQual (AQBind x (unembed -> at)) = QBind (coerce x) (embed (erase at)) -eraseQual (AQGuard (unembed -> at)) = QGuard (embed (erase at)) +eraseQual (AQGuard (unembed -> at)) = QGuard (embed (erase at)) eraseProperty :: AProperty -> Property eraseProperty = erase @@ -92,32 +96,33 @@ eraseProperty = erase eraseDTerm :: DTerm -> Term eraseDTerm (DTVar _ (QName _ x)) = TVar (coerce x) -eraseDTerm (DTPrim _ x) = TPrim x -eraseDTerm DTUnit = TUnit -eraseDTerm (DTBool _ b) = TBool b -eraseDTerm (DTChar c) = TChar c -eraseDTerm (DTNat _ n) = TNat n -eraseDTerm (DTRat r) = TRat r -eraseDTerm (DTAbs q _ b) = TAbs q $ bind [PVar . coerce $ x] (eraseDTerm dt) - where (x, dt) = unsafeUnbind b -eraseDTerm (DTApp _ d1 d2) = TApp (eraseDTerm d1) (eraseDTerm d2) +eraseDTerm (DTPrim _ x) = TPrim x +eraseDTerm DTUnit = TUnit +eraseDTerm (DTBool _ b) = TBool b +eraseDTerm (DTChar c) = TChar c +eraseDTerm (DTNat _ n) = TNat n +eraseDTerm (DTRat r) = TRat r +eraseDTerm (DTAbs q _ b) = TAbs q $ bind [PVar . coerce $ x] (eraseDTerm dt) + where + (x, dt) = unsafeUnbind b +eraseDTerm (DTApp _ d1 d2) = TApp (eraseDTerm d1) (eraseDTerm d2) eraseDTerm (DTPair _ d1 d2) = TTup [eraseDTerm d1, eraseDTerm d2] -eraseDTerm (DTCase _ bs) = TCase (map eraseDBranch bs) +eraseDTerm (DTCase _ bs) = TCase (map eraseDBranch bs) eraseDTerm (DTTyOp _ op ty) = TTyOp op ty -eraseDTerm (DTNil _) = TList [] Nothing -eraseDTerm (DTTest _ x) = eraseDTerm x +eraseDTerm (DTNil _) = TList [] Nothing +eraseDTerm (DTTest _ x) = eraseDTerm x eraseDBranch :: DBranch -> Branch eraseDBranch b = bind (mapTelescope eraseDGuard tel) (eraseDTerm d) - where - (tel, d) = unsafeUnbind b + where + (tel, d) = unsafeUnbind b eraseDGuard :: DGuard -> Guard eraseDGuard (DGPat (unembed -> d) p) = GPat (embed (eraseDTerm d)) (eraseDPattern p) eraseDPattern :: DPattern -> Pattern -eraseDPattern (DPVar _ x) = PVar (coerce x) -eraseDPattern (DPWild _) = PWild -eraseDPattern DPUnit = PUnit -eraseDPattern (DPPair _ x1 x2) = PTup (map (PVar . coerce) [x1,x2]) -eraseDPattern (DPInj _ s x) = PInj s (PVar (coerce x)) +eraseDPattern (DPVar _ x) = PVar (coerce x) +eraseDPattern (DPWild _) = PWild +eraseDPattern DPUnit = PUnit +eraseDPattern (DPPair _ x1 x2) = PTup (map (PVar . coerce) [x1, x2]) +eraseDPattern (DPInj _ s x) = PInj s (PVar (coerce x)) diff --git a/src/Disco/Typecheck/Graph.hs b/src/Disco/Typecheck/Graph.hs index 20960e86..7b385f22 100644 --- a/src/Disco/Typecheck/Graph.hs +++ b/src/Disco/Typecheck/Graph.hs @@ -1,6 +1,9 @@ {-# LANGUAGE OverloadedStrings #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Typecheck.Graph -- Copyright : disco team and contributors @@ -11,65 +14,67 @@ -- A thin layer on top of graphs from the @fgl@ package, which -- allows dealing with vertices by label instead of by integer -- @Node@ values. ------------------------------------------------------------------------------ - module Disco.Typecheck.Graph where -import Prelude hiding (map, (<>)) -import qualified Prelude as P +import Prelude hiding (map, (<>)) +import qualified Prelude as P -import Control.Arrow ((&&&)) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Maybe (fromJust, isJust, mapMaybe) -import Data.Set (Set) -import qualified Data.Set as S -import Data.Tuple (swap) +import Control.Arrow ((&&&)) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (fromJust, isJust, mapMaybe) +import Data.Set (Set) +import qualified Data.Set as S +import Data.Tuple (swap) -import qualified Data.Graph.Inductive.Graph as G -import Data.Graph.Inductive.PatriciaTree (Gr) -import qualified Data.Graph.Inductive.Query.DFS as G (components, - condensation, topsort') +import qualified Data.Graph.Inductive.Graph as G +import Data.Graph.Inductive.PatriciaTree (Gr) +import qualified Data.Graph.Inductive.Query.DFS as G ( + components, + condensation, + topsort', + ) -import Disco.Pretty -import Disco.Util ((!)) +import Disco.Pretty +import Disco.Util ((!)) -- | Directed graphs, with vertices labelled by @a@ and unlabelled -- edges. data Graph a = G (Gr a ()) (Map a G.Node) (Map G.Node a) - deriving Show + deriving (Show) instance Pretty a => Pretty (Graph a) where pretty (G g _ _) = parens (prettyVertices <> ", " <> prettyEdges) + where -- (V = {(0, x), (1, N)}, E = {0 -> 1, 2 -> 3}) - where - vs = G.labNodes g - es = G.labEdges g - prettyVertex (n,a) = parens (text (show n) <> ", " <> pretty a) - prettyVertices = "V = " <> braces (intercalate "," (P.map prettyVertex vs)) - prettyEdge (v1,v2,_) = text (show v1) <+> "->" <+> text (show v2) - prettyEdges = "E = " <> braces (intercalate "," (P.map prettyEdge es)) + vs = G.labNodes g + es = G.labEdges g + + prettyVertex (n, a) = parens (text (show n) <> ", " <> pretty a) + prettyVertices = "V = " <> braces (intercalate "," (P.map prettyVertex vs)) + prettyEdge (v1, v2, _) = text (show v1) <+> "->" <+> text (show v2) + prettyEdges = "E = " <> braces (intercalate "," (P.map prettyEdge es)) -- | Create a graph with the given set of vertices and directed edges. -- If any edges refer to vertices that are not in the given vertex -- set, they will simply be dropped. -mkGraph :: (Show a, Ord a) => Set a -> Set (a,a) -> Graph a +mkGraph :: (Show a, Ord a) => Set a -> Set (a, a) -> Graph a mkGraph vs es = G (G.mkGraph vs' es') a2n n2a - where - vs' = zip [0..] (S.toList vs) - n2a = M.fromList vs' - a2n = M.fromList . P.map swap $ vs' - es' = mapMaybe mkEdge (S.toList es) - mkEdge (a1,a2) = (,,) <$> M.lookup a1 a2n <*> M.lookup a2 a2n <*> pure () + where + vs' = zip [0 ..] (S.toList vs) + n2a = M.fromList vs' + a2n = M.fromList . P.map swap $ vs' + es' = mapMaybe mkEdge (S.toList es) + mkEdge (a1, a2) = (,,) <$> M.lookup a1 a2n <*> M.lookup a2 a2n <*> pure () -- | Return the set of vertices (nodes) of a graph. nodes :: Graph a -> Set a nodes (G _ m _) = M.keysSet m -- | Return the set of directed edges of a graph. -edges :: Ord a => Graph a -> Set (a,a) -edges (G g _ m) = S.fromList $ P.map (\(n1,n2,()) -> (m ! n1, m ! n2)) (G.labEdges g) +edges :: Ord a => Graph a -> Set (a, a) +edges (G g _ m) = S.fromList $ P.map (\(n1, n2, ()) -> (m ! n1, m ! n2)) (G.labEdges g) -- | Map a function over all the vertices of a graph. @Graph@ is not -- a @Functor@ instance because of the @Ord@ constraint on @b@. @@ -79,8 +84,8 @@ map f (G g m1 m2) = G (G.nmap f g) (M.mapKeys f m1) (M.map f m2) -- | Delete a vertex. delete :: (Show a, Ord a) => a -> Graph a -> Graph a delete a (G g a2n n2a) = G (G.delNode n g) (M.delete a a2n) (M.delete n n2a) - where - n = a2n ! a + where + n = a2n ! a -- | The @condensation@ of a graph is the graph of its strongly -- connected components, /i.e./ each strongly connected component is @@ -90,11 +95,11 @@ delete a (G g a2n n2a) = G (G.delNode n g) (M.delete a a2n) (M.delete n n2a) -- component A to any vertex in component B in the original graph. condensation :: Ord a => Graph a -> Graph (Set a) condensation (G g _ n2a) = G g' as2n n2as - where - g' = G.nmap (S.fromList . P.map (n2a !)) (G.condensation g) - vs' = G.labNodes g' - n2as = M.fromList vs' - as2n = M.fromList . P.map swap $ vs' + where + g' = G.nmap (S.fromList . P.map (n2a !)) (G.condensation g) + vs' = G.labNodes g' + n2as = M.fromList vs' + as2n = M.fromList . P.map swap $ vs' -- | Get a list of the weakly connected components of a graph, -- providing the set of vertices in each. Equivalently, return the @@ -116,7 +121,7 @@ topsort (G g _a2n _n2a) = G.topsort' g sequenceGraph :: Ord a => Graph (Maybe a) -> Maybe (Graph a) sequenceGraph g = case all isJust (nodes g) of False -> Nothing - True -> Just $ map fromJust g + True -> Just $ map fromJust g -- | Get a list of all the /successors/ of a given node in the graph, -- /i.e./ all the nodes reachable from the given node by a directed @@ -140,16 +145,16 @@ pre (G g a2n n2a) = P.map (n2a !) . G.pre g . (a2n !) -- but much more efficient. cessors :: (Show a, Ord a) => Graph a -> (Map a (Set a), Map a (Set a)) cessors g@(G gg _ _) = (succs, preds) - where - as = G.topsort' gg - succs = foldr collectSuccs M.empty as -- build successors map - collectSuccs a m = M.insert a succsSet m - where - ss = suc g a - succsSet = S.fromList ss `S.union` S.unions (P.map (m !) ss) - - preds = foldr collectPreds M.empty (reverse as) -- build predecessors map - collectPreds a m = M.insert a predsSet m - where - ss = pre g a - predsSet = S.fromList ss `S.union` S.unions (P.map (m !) ss) + where + as = G.topsort' gg + succs = foldr collectSuccs M.empty as -- build successors map + collectSuccs a m = M.insert a succsSet m + where + ss = suc g a + succsSet = S.fromList ss `S.union` S.unions (P.map (m !) ss) + + preds = foldr collectPreds M.empty (reverse as) -- build predecessors map + collectPreds a m = M.insert a predsSet m + where + ss = pre g a + predsSet = S.fromList ss `S.union` S.unions (P.map (m !) ss) diff --git a/src/Disco/Typecheck/Solve.hs b/src/Disco/Typecheck/Solve.hs index 35e3f4a2..5de75a3e 100644 --- a/src/Disco/Typecheck/Solve.hs +++ b/src/Disco/Typecheck/Solve.hs @@ -1,8 +1,11 @@ -{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TemplateHaskell #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Typecheck.Solve -- Copyright : disco team and contributors @@ -12,51 +15,62 @@ -- -- Constraint solver for the constraints generated during type -- checking/inference. ------------------------------------------------------------------------------ - module Disco.Typecheck.Solve where -import Unbound.Generics.LocallyNameless (Alpha, Name, Subst, fv, - name2Integer, string2Name, - substs) - -import Data.Coerce -import GHC.Generics (Generic) - -import Control.Arrow ((&&&), (***)) -import Control.Lens hiding (use, (%=), (.=)) -import Control.Monad (unless, zipWithM) -import Data.Bifunctor (first, second) -import Data.Either (partitionEithers) -import Data.List (find, foldl', intersect, - partition) -import Data.Map (Map, (!)) -import qualified Data.Map as M -import Data.Maybe (fromJust, fromMaybe, - mapMaybe) -import Data.Monoid (First (..)) -import Data.Set (Set) -import qualified Data.Set as S -import Data.Tuple - -import Disco.Effects.Fresh -import Disco.Effects.State -import Polysemy -import Polysemy.Error -import Polysemy.Input -import Polysemy.Output - -import Disco.Messages -import Disco.Pretty hiding ((<>)) -import Disco.Subst -import qualified Disco.Subst as Subst -import Disco.Typecheck.Constraints -import Disco.Typecheck.Graph (Graph) -import qualified Disco.Typecheck.Graph as G -import Disco.Typecheck.Unify -import Disco.Types -import Disco.Types.Qualifiers -import Disco.Types.Rules +import Unbound.Generics.LocallyNameless ( + Alpha, + Name, + Subst, + fv, + name2Integer, + string2Name, + substs, + ) + +import Data.Coerce +import GHC.Generics (Generic) + +import Control.Arrow ((&&&), (***)) +import Control.Lens hiding (use, (%=), (.=)) +import Control.Monad (unless, zipWithM) +import Data.Bifunctor (first, second) +import Data.Either (partitionEithers) +import Data.List ( + find, + foldl', + intersect, + partition, + ) +import Data.Map (Map, (!)) +import qualified Data.Map as M +import Data.Maybe ( + fromJust, + fromMaybe, + mapMaybe, + ) +import Data.Monoid (First (..)) +import Data.Set (Set) +import qualified Data.Set as S +import Data.Tuple + +import Disco.Effects.Fresh +import Disco.Effects.State +import Polysemy +import Polysemy.Error +import Polysemy.Input +import Polysemy.Output + +import Disco.Messages +import Disco.Pretty hiding ((<>)) +import Disco.Subst +import qualified Disco.Subst as Subst +import Disco.Typecheck.Constraints +import Disco.Typecheck.Graph (Graph) +import qualified Disco.Typecheck.Graph as G +import Disco.Typecheck.Unify +import Disco.Types +import Disco.Types.Qualifiers +import Disco.Types.Rules -------------------------------------------------- -- Solver errors @@ -65,11 +79,11 @@ import Disco.Types.Rules -- process. data SolveError where NoWeakUnifier :: SolveError - NoUnify :: SolveError - UnqualBase :: Qualifier -> BaseTy -> SolveError - Unqual :: Qualifier -> Type -> SolveError - QualSkolem :: Qualifier -> Name Type -> SolveError - deriving Show + NoUnify :: SolveError + UnqualBase :: Qualifier -> BaseTy -> SolveError + Unqual :: Qualifier -> Type -> SolveError + QualSkolem :: Qualifier -> Name Type -> SolveError + deriving (Show) instance Semigroup SolveError where e <> _ = e @@ -87,16 +101,16 @@ filterErrors :: Member (Error e) r => [Sem r a] -> Sem r [a] filterErrors ms = do es <- mapM try ms case partitionEithers es of - (e:_, []) -> throw e - (_, as) -> return as + (e : _, []) -> throw e + (_, as) -> return as -- | A variant of 'asum' which picks the first action that succeeds, -- or re-throws the error of the last one if none of them -- do. Precondition: the list must not be empty. asum' :: Member (Error e) r => [Sem r a] -> Sem r a -asum' [] = error "Impossible: asum' []" -asum' [m] = m -asum' (m:ms) = m `catch` (\_ -> asum' ms) +asum' [] = error "Impossible: asum' []" +asum' [m] = m +asum' (m : ms) = m `catch` (\_ -> asum' ms) -------------------------------------------------- -- Simple constraints @@ -122,8 +136,10 @@ instance Pretty SimpleConstraint where -- | Information about a particular type variable. More information -- may be added in the future (e.g. polarity). data TyVarInfo = TVI - { _tyVarIlk :: First Ilk -- ^ The ilk (unification or skolem) of the variable, if known - , _tyVarSort :: Sort -- ^ The sort (set of qualifiers) of the type variable. + { _tyVarIlk :: First Ilk + -- ^ The ilk (unification or skolem) of the variable, if known + , _tyVarSort :: Sort + -- ^ The sort (set of qualifiers) of the type variable. } deriving (Show) @@ -144,7 +160,7 @@ instance Semigroup TyVarInfo where -- | A 'TyVarInfoMap' records what we know about each type variable; -- it is a mapping from type variable names to 'TyVarInfo' records. -newtype TyVarInfoMap = VM { unVM :: Map (Name Type) TyVarInfo } +newtype TyVarInfoMap = VM {unVM :: Map (Name Type) TyVarInfo} deriving (Show) instance Pretty TyVarInfoMap where @@ -154,7 +170,8 @@ instance Pretty TyVarInfoMap where -- underlying 'Map'. onVM :: (Map (Name Type) TyVarInfo -> Map (Name Type) TyVarInfo) -> - TyVarInfoMap -> TyVarInfoMap + TyVarInfoMap -> + TyVarInfoMap onVM f (VM m) = VM (f m) -- | Look up a given variable name in a 'TyVarInfoMap'. @@ -178,7 +195,7 @@ instance Semigroup TyVarInfoMap where VM sm1 <> VM sm2 = VM (M.unionWith (<>) sm1 sm2) instance Monoid TyVarInfoMap where - mempty = VM M.empty + mempty = VM M.empty mappend = (<>) -- | Get the sort of a particular variable recorded in a @@ -209,18 +226,19 @@ extendSort x s = onVM (at x . _Just . tyVarSort %~ (`S.union` s)) -- expansion of recursive types can stop when encountering a -- previously seen constraint. data SimplifyState = SS - { _ssVarMap :: TyVarInfoMap + { _ssVarMap :: TyVarInfoMap , _ssConstraints :: [SimpleConstraint] - , _ssSubst :: S - , _ssSeen :: Set SimpleConstraint + , _ssSubst :: S + , _ssSeen :: Set SimpleConstraint } makeLenses ''SimplifyState lkup :: (Ord k, Show k, Show (Map k a)) => String -> Map k a -> k -> a lkup messg m k = fromMaybe (error errMsg) (M.lookup k m) - where - errMsg = unlines + where + errMsg = + unlines [ "Key lookup error:" , " Key = " ++ show k , " Map = " ++ show m @@ -230,11 +248,11 @@ lkup messg m k = fromMaybe (error errMsg) (M.lookup k m) -------------------------------------------------- -- Top-level solver algorithm -solveConstraint - :: Members '[Fresh, Error SolveError, Output Message, Input TyDefCtx] r - => Constraint -> Sem r S +solveConstraint :: + Members '[Fresh, Error SolveError, Output Message, Input TyDefCtx] r => + Constraint -> + Sem r S solveConstraint c = do - -- Step 1. Open foralls (instantiating with skolem variables) and -- collect wanted qualifiers; also expand disjunctions. Result in a -- list of possible constraint sets; each one consists of equational @@ -252,11 +270,12 @@ solveConstraint c = do -- a solution. asum' (map (uncurry solveConstraintChoice) qcList) -solveConstraintChoice - :: Members '[Fresh, Error SolveError, Output Message, Input TyDefCtx] r - => TyVarInfoMap -> [SimpleConstraint] -> Sem r S +solveConstraintChoice :: + Members '[Fresh, Error SolveError, Output Message, Input TyDefCtx] r => + TyVarInfoMap -> + [SimpleConstraint] -> + Sem r S solveConstraintChoice quals cs = do - debugPretty quals debug $ vcat (map pretty' cs) @@ -265,8 +284,8 @@ solveConstraintChoice quals cs = do -- Step 2. Check for weak unification to ensure termination. (a la -- Traytel et al). - let toEqn (t1 :<: t2) = (t1,t2) - toEqn (t1 :=: t2) = (t1,t2) + let toEqn (t1 :<: t2) = (t1, t2) + toEqn (t1 :=: t2) = (t1, t2) _ <- note NoWeakUnifier $ weakUnify tyDefns (map toEqn cs) -- Step 3. Simplify constraints, resulting in a set of atomic @@ -280,7 +299,7 @@ solveConstraintChoice quals cs = do debug "Done running simplifier. Results:" debugPretty vm - debug $ vcat $ map (pretty' . (\(x,y) -> TyAtom x :<: TyAtom y)) atoms + debug $ vcat $ map (pretty' . (\(x, y) -> TyAtom x :<: TyAtom y)) atoms debugPretty theta_simp -- Step 4. Turn the atomic constraints into a directed constraint @@ -294,7 +313,7 @@ solveConstraintChoice quals cs = do -- extract them and include them in the constraint graph as isolated -- vertices let mkAVar (v, First (Just Skolem)) = AVar (S v) - mkAVar (v, _ ) = AVar (U v) + mkAVar (v, _) = AVar (U v) vars = S.fromList . map (mkAVar . second (view tyVarIlk)) . M.assocs . unVM $ vm g = mkConstraintGraph vars atoms @@ -315,7 +334,6 @@ solveConstraintChoice quals cs = do -- We don't need to ensure that theta_skolem respects sorts since -- checkSkolems will only unify skolem vars with unsorted variables. - -- Step 6. Eliminate cycles from the graph, turning each strongly -- connected component into a single node, unifying all the atoms in -- each component. @@ -329,11 +347,11 @@ solveConstraintChoice quals cs = do debugPretty theta_cyc -- Check that the resulting substitution respects sorts... - let sortOK (x, TyAtom (ABase ty)) = hasSort ty (getSort vm x) + let sortOK (x, TyAtom (ABase ty)) = hasSort ty (getSort vm x) sortOK (_, TyAtom (AVar (U _))) = True - sortOK p = error $ "Impossible! sortOK " ++ show p - unless (all sortOK (Subst.toList theta_cyc)) - $ throw NoUnify + sortOK p = error $ "Impossible! sortOK " ++ show p + unless (all sortOK (Subst.toList theta_cyc)) $ + throw NoUnify -- ... and update the sort map if we unified any type variables. -- Just make sure that if theta_cyc maps x |-> y, then y picks up @@ -343,10 +361,10 @@ solveConstraintChoice quals cs = do debugPretty vm let vm' = foldr updateVarMap vm (Subst.toList theta_cyc) - where - updateVarMap :: (Name Type, Type) -> TyVarInfoMap -> TyVarInfoMap - updateVarMap (x, TyAtom (AVar (U y))) vmm = extendSort y (getSort vmm x) vmm - updateVarMap _ vmm = vmm + where + updateVarMap :: (Name Type, Type) -> TyVarInfoMap -> TyVarInfoMap + updateVarMap (x, TyAtom (AVar (U y))) vmm = extendSort y (getSort vmm x) vmm + updateVarMap _ vmm = vmm debug "Updated sort map:" debugPretty vm @@ -366,11 +384,10 @@ solveConstraintChoice quals cs = do -- predecessor base types in the graph; then unify all the type -- variables in any remaining weakly connected components. - debug "------------------------------" debug "Solving for type variables..." - theta_sol <- solveGraph vm' g''' + theta_sol <- solveGraph vm' g''' debugPretty theta_sol debug "------------------------------" @@ -381,75 +398,79 @@ solveConstraintChoice quals cs = do return theta_final - -------------------------------------------------- -- Step 1. Constraint decomposition. -decomposeConstraint - :: Members '[Fresh, Error SolveError, Input TyDefCtx] r - => Constraint -> Sem r [(TyVarInfoMap, [SimpleConstraint])] +decomposeConstraint :: + Members '[Fresh, Error SolveError, Input TyDefCtx] r => + Constraint -> + Sem r [(TyVarInfoMap, [SimpleConstraint])] decomposeConstraint (CSub t1 t2) = return [(mempty, [t1 :<: t2])] -decomposeConstraint (CEq t1 t2) = return [(mempty, [t1 :=: t2])] -decomposeConstraint (CQual q ty) = (:[]) . (, []) <$> decomposeQual ty q -decomposeConstraint (CAnd cs) = map mconcat . sequence <$> mapM decomposeConstraint cs -decomposeConstraint CTrue = return [mempty] -decomposeConstraint (CAll ty) = do +decomposeConstraint (CEq t1 t2) = return [(mempty, [t1 :=: t2])] +decomposeConstraint (CQual q ty) = (: []) . (,[]) <$> decomposeQual ty q +decomposeConstraint (CAnd cs) = map mconcat . sequence <$> mapM decomposeConstraint cs +decomposeConstraint CTrue = return [mempty] +decomposeConstraint (CAll ty) = do (vars, c) <- unbind ty let c' = substs (mkSkolems vars) c (map . first . addSkolems) vars <$> decomposeConstraint c' - - where - mkSkolems :: [Name Type] -> [(Name Type, Type)] - mkSkolems = map (id &&& TySkolem) - -decomposeConstraint (COr cs) = concat <$> filterErrors (map decomposeConstraint cs) - -decomposeQual - :: Members '[Fresh, Error SolveError, Input TyDefCtx] r - => Type -> Qualifier -> Sem r TyVarInfoMap + where + mkSkolems :: [Name Type] -> [(Name Type, Type)] + mkSkolems = map (id &&& TySkolem) +decomposeConstraint (COr cs) = concat <$> filterErrors (map decomposeConstraint cs) + +decomposeQual :: + Members '[Fresh, Error SolveError, Input TyDefCtx] r => + Type -> + Qualifier -> + Sem r TyVarInfoMap decomposeQual = go S.empty - where - go :: Members '[Fresh, Error SolveError, Input TyDefCtx] r - => Set (String, [Type], Qualifier) -> Type -> Qualifier -> Sem r TyVarInfoMap - - -- For a type atom, call out to checkQual. - go _ (TyAtom a) q = checkQual q a - - -- Coinductively check user-defined types for a qualifier. Keep - -- track of a set of user-defined types and qualifiers we have - -- seen. Every time we encounter a new one, add it to the set and - -- recurse on its unfolding. If we ever encounter one we have - -- already seen, we can assume by coinduction that the qualifier - -- is satisfied. - go seen (TyCon (CUser t) tys) q = do - case (t, tys, q) `S.member` seen of - True -> return mempty - False -> do - tyDefns <- input @TyDefCtx - case M.lookup t tyDefns of - Nothing -> error $ show t ++ " not in ty defn map!!" - Just (TyDefBody _ body) -> do - let ty' = body tys - go (S.insert (t, tys, q) seen) ty' q - - -- If we have a container type where the container is still a variable, - -- just replace it with List for the purposes of generating constraints--- - -- all containers (lists, bags, sets) have the same qualifier rules. - go seen (TyCon (CContainer (AVar _)) tys) q = go seen (TyCon CList tys) q - - -- Otherwise, decompose a type constructor according to the qualRules. - go seen ty@(TyCon c tys) q = case qualRules c q of - Nothing -> throw $ Unqual q ty - Just qs -> mconcat <$> zipWithM (maybe (return mempty) . go seen) tys qs - -checkQual - :: Members '[Fresh, Error SolveError] r - => Qualifier -> Atom -> Sem r TyVarInfoMap + where + go :: + Members '[Fresh, Error SolveError, Input TyDefCtx] r => + Set (String, [Type], Qualifier) -> + Type -> + Qualifier -> + Sem r TyVarInfoMap + + -- For a type atom, call out to checkQual. + go _ (TyAtom a) q = checkQual q a + -- Coinductively check user-defined types for a qualifier. Keep + -- track of a set of user-defined types and qualifiers we have + -- seen. Every time we encounter a new one, add it to the set and + -- recurse on its unfolding. If we ever encounter one we have + -- already seen, we can assume by coinduction that the qualifier + -- is satisfied. + go seen (TyCon (CUser t) tys) q = do + case (t, tys, q) `S.member` seen of + True -> return mempty + False -> do + tyDefns <- input @TyDefCtx + case M.lookup t tyDefns of + Nothing -> error $ show t ++ " not in ty defn map!!" + Just (TyDefBody _ body) -> do + let ty' = body tys + go (S.insert (t, tys, q) seen) ty' q + + -- If we have a container type where the container is still a variable, + -- just replace it with List for the purposes of generating constraints--- + -- all containers (lists, bags, sets) have the same qualifier rules. + go seen (TyCon (CContainer (AVar _)) tys) q = go seen (TyCon CList tys) q + -- Otherwise, decompose a type constructor according to the qualRules. + go seen ty@(TyCon c tys) q = case qualRules c q of + Nothing -> throw $ Unqual q ty + Just qs -> mconcat <$> zipWithM (maybe (return mempty) . go seen) tys qs + +checkQual :: + Members '[Fresh, Error SolveError] r => + Qualifier -> + Atom -> + Sem r TyVarInfoMap checkQual q (AVar (U v)) = return . VM . M.singleton v $ mkTVI Unification (S.singleton q) checkQual q (AVar (S v)) = throw $ QualSkolem q v -checkQual q (ABase bty) = +checkQual q (ABase bty) = case hasQual bty q of - True -> return mempty + True -> return mempty False -> throw $ UnqualBase q bty -------------------------------------------------- @@ -465,250 +486,249 @@ checkQual q (ABase bty) = -- After this step, the remaining constraints will all be atomic -- constraints, that is, only of the form (v1 <: v2), (v <: b), or -- (b <: v), where v is a type variable and b is a base type. - -simplify - :: Members '[Error SolveError, Output Message, Input TyDefCtx] r - => TyVarInfoMap -> [SimpleConstraint] -> Sem r (TyVarInfoMap, [(Atom, Atom)], S) -simplify origVM cs - = (\(SS vm' cs' s' _) -> (vm', map extractAtoms cs', s')) - -- contFreshMT :: Monad m => FreshMT m a -> Integer -> m a - -- "Run a FreshMT computation given a starting index for fresh name generation." - <$> runFresh' n (execState (SS origVM cs idS S.empty) simplify') - where - - fvNums :: Alpha a => [a] -> [Integer] - fvNums = map (name2Integer :: Name Type -> Integer) . toListOf fv - - -- Find first unused integer in constraint free vars and sort map - -- domain, and use it to start the fresh var generation, so we don't - -- generate any "fresh" names that interfere with existing names - n1 = maximum0 . fvNums $ cs - n = succ . maximum . (n1:) . fvNums . M.keys . unVM $ origVM - - maximum0 [] = 0 - maximum0 xs = maximum xs - - -- Extract the type atoms from an atomic constraint. - extractAtoms :: SimpleConstraint -> (Atom, Atom) - extractAtoms (TyAtom a1 :<: TyAtom a2) = (a1, a2) - extractAtoms c = error $ "Impossible: simplify left non-atomic or non-subtype constraint " ++ show c - - -- Iterate picking one simplifiable constraint and simplifying it - -- until none are left. - simplify' - :: Members '[State SimplifyState, Fresh, Error SolveError, Output Message, Input TyDefCtx] r - => Sem r () - simplify' = do - -- q <- gets fst - -- debug (pretty q) - -- debug "" - - mc <- pickSimplifiable - case mc of - Nothing -> return () - Just s -> do - - debug $ "Simplifying:" <+> pretty' s - - simplifyOne s - simplify' - - -- Pick out one simplifiable constraint, removing it from the list - -- of constraints in the state. Return Nothing if no more - -- constraints can be simplified. - pickSimplifiable - :: Members '[State SimplifyState, Fresh, Error SolveError] r - => Sem r (Maybe SimpleConstraint) - pickSimplifiable = do - curCs <- use ssConstraints - case pick simplifiable curCs of - Nothing -> return Nothing - Just (a,as) -> do - ssConstraints .= as - return (Just a) - - -- Pick the first element from a list satisfying the given - -- predicate, returning the element and the list with the element - -- removed. - pick :: (a -> Bool) -> [a] -> Maybe (a,[a]) - pick _ [] = Nothing - pick p (a:as) - | p a = Just (a,as) - | otherwise = second (a:) <$> pick p as - - -- Check if a constraint can be simplified. An equality - -- constraint can always be "simplified" via unification. A - -- subtyping constraint can be simplified if either it involves a - -- type constructor (in which case we can decompose it), or if it - -- involves two base types (in which case it can be removed if the - -- relationship holds). - simplifiable :: SimpleConstraint -> Bool - simplifiable (_ :=: _) = True - simplifiable (TyCon {} :<: TyCon {}) = True - simplifiable (TyVar {} :<: TyCon {}) = True - simplifiable (TyCon {} :<: TyVar {}) = True - simplifiable (TyCon (CUser _) _ :<: _) = True - simplifiable (_ :<: TyCon (CUser _) _) = True - simplifiable (TyAtom (ABase _) :<: TyAtom (ABase _)) = True - - simplifiable _ = False - - -- Simplify the given simplifiable constraint. If the constraint - -- has already been seen before (due to expansion of a recursive - -- type), just throw it away and stop. - simplifyOne - :: Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r - => SimpleConstraint -> Sem r () - simplifyOne c = do - seen <- use ssSeen - case c `S.member` seen of - True -> return () - False -> do - ssSeen %= S.insert c - simplifyOne' c - - simplifyOne' - :: Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r - => SimpleConstraint -> Sem r () - - -- If we have an equality constraint, run unification on it. The - -- resulting substitution is applied to the remaining constraints - -- as well as prepended to the current substitution. - - simplifyOne' (ty1 :=: ty2) = do - tyDefns <- input @TyDefCtx - case unify tyDefns [(ty1, ty2)] of - Nothing -> throw NoUnify - Just s' -> extendSubst s' - - -- If we see a constraint of the form (T <: a), where T is a - -- user-defined type and a is a type variable, then just turn it - -- into an equality (T = a). This is sound but probably not - -- complete. The alternative seems quite complicated, possibly - -- even undecidable. See https://github.com/disco-lang/disco/issues/207 . - simplifyOne' (ty1@(TyCon (CUser _) _) :<: ty2@TyVar{}) - = simplifyOne' (ty1 :=: ty2) - - -- Otherwise, expand the user-defined type and continue. - simplifyOne' (TyCon (CUser t) ts :<: ty2) = do - tyDefns <- input @TyDefCtx - case M.lookup t tyDefns of - Nothing -> error $ show t ++ " not in ty defn map!" - Just (TyDefBody _ body) -> - ssConstraints %= ((body ts :<: ty2) :) - - -- Turn a <: T into a = T. See comment above. - simplifyOne' (ty1@TyVar{} :<: ty2@(TyCon (CUser _) _)) - = simplifyOne' (ty1 :=: ty2) - - simplifyOne' (ty1 :<: TyCon (CUser t) ts) = do - tyDefns <- input @TyDefCtx - case M.lookup t tyDefns of - Nothing -> error $ show t ++ " not in ty defn map!" - Just (TyDefBody _ body) -> - ssConstraints %= ((ty1 :<: body ts) :) - - -- Given a subtyping constraint between two type constructors, - -- decompose it if the constructors are the same (or fail if they - -- aren't), taking into account the variance of each argument to - -- the constructor. Container types are a special case; - -- recursively generate a subtyping constraint for their - -- constructors as well. - simplifyOne' (TyCon c1@(CContainer ctr1) tys1 :<: TyCon (CContainer ctr2) tys2) = - ssConstraints %= - (( (TyAtom ctr1 :<: TyAtom ctr2) - : zipWith3 variance (arity c1) tys1 tys2 +simplify :: + Members '[Error SolveError, Output Message, Input TyDefCtx] r => + TyVarInfoMap -> + [SimpleConstraint] -> + Sem r (TyVarInfoMap, [(Atom, Atom)], S) +simplify origVM cs = + (\(SS vm' cs' s' _) -> (vm', map extractAtoms cs', s')) + -- contFreshMT :: Monad m => FreshMT m a -> Integer -> m a + -- "Run a FreshMT computation given a starting index for fresh name generation." + <$> runFresh' n (execState (SS origVM cs idS S.empty) simplify') + where + fvNums :: Alpha a => [a] -> [Integer] + fvNums = map (name2Integer :: Name Type -> Integer) . toListOf fv + + -- Find first unused integer in constraint free vars and sort map + -- domain, and use it to start the fresh var generation, so we don't + -- generate any "fresh" names that interfere with existing names + n1 = maximum0 . fvNums $ cs + n = succ . maximum . (n1 :) . fvNums . M.keys . unVM $ origVM + + maximum0 [] = 0 + maximum0 xs = maximum xs + + -- Extract the type atoms from an atomic constraint. + extractAtoms :: SimpleConstraint -> (Atom, Atom) + extractAtoms (TyAtom a1 :<: TyAtom a2) = (a1, a2) + extractAtoms c = error $ "Impossible: simplify left non-atomic or non-subtype constraint " ++ show c + + -- Iterate picking one simplifiable constraint and simplifying it + -- until none are left. + simplify' :: + Members '[State SimplifyState, Fresh, Error SolveError, Output Message, Input TyDefCtx] r => + Sem r () + simplify' = do + -- q <- gets fst + -- debug (pretty q) + -- debug "" + + mc <- pickSimplifiable + case mc of + Nothing -> return () + Just s -> do + debug $ "Simplifying:" <+> pretty' s + + simplifyOne s + simplify' + + -- Pick out one simplifiable constraint, removing it from the list + -- of constraints in the state. Return Nothing if no more + -- constraints can be simplified. + pickSimplifiable :: + Members '[State SimplifyState, Fresh, Error SolveError] r => + Sem r (Maybe SimpleConstraint) + pickSimplifiable = do + curCs <- use ssConstraints + case pick simplifiable curCs of + Nothing -> return Nothing + Just (a, as) -> do + ssConstraints .= as + return (Just a) + + -- Pick the first element from a list satisfying the given + -- predicate, returning the element and the list with the element + -- removed. + pick :: (a -> Bool) -> [a] -> Maybe (a, [a]) + pick _ [] = Nothing + pick p (a : as) + | p a = Just (a, as) + | otherwise = second (a :) <$> pick p as + + -- Check if a constraint can be simplified. An equality + -- constraint can always be "simplified" via unification. A + -- subtyping constraint can be simplified if either it involves a + -- type constructor (in which case we can decompose it), or if it + -- involves two base types (in which case it can be removed if the + -- relationship holds). + simplifiable :: SimpleConstraint -> Bool + simplifiable (_ :=: _) = True + simplifiable (TyCon {} :<: TyCon {}) = True + simplifiable (TyVar {} :<: TyCon {}) = True + simplifiable (TyCon {} :<: TyVar {}) = True + simplifiable (TyCon (CUser _) _ :<: _) = True + simplifiable (_ :<: TyCon (CUser _) _) = True + simplifiable (TyAtom (ABase _) :<: TyAtom (ABase _)) = True + simplifiable _ = False + + -- Simplify the given simplifiable constraint. If the constraint + -- has already been seen before (due to expansion of a recursive + -- type), just throw it away and stop. + simplifyOne :: + Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r => + SimpleConstraint -> + Sem r () + simplifyOne c = do + seen <- use ssSeen + case c `S.member` seen of + True -> return () + False -> do + ssSeen %= S.insert c + simplifyOne' c + + simplifyOne' :: + Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r => + SimpleConstraint -> + Sem r () + + -- If we have an equality constraint, run unification on it. The + -- resulting substitution is applied to the remaining constraints + -- as well as prepended to the current substitution. + + simplifyOne' (ty1 :=: ty2) = do + tyDefns <- input @TyDefCtx + case unify tyDefns [(ty1, ty2)] of + Nothing -> throw NoUnify + Just s' -> extendSubst s' + + -- If we see a constraint of the form (T <: a), where T is a + -- user-defined type and a is a type variable, then just turn it + -- into an equality (T = a). This is sound but probably not + -- complete. The alternative seems quite complicated, possibly + -- even undecidable. See https://github.com/disco-lang/disco/issues/207 . + simplifyOne' (ty1@(TyCon (CUser _) _) :<: ty2@TyVar {}) = + simplifyOne' (ty1 :=: ty2) + -- Otherwise, expand the user-defined type and continue. + simplifyOne' (TyCon (CUser t) ts :<: ty2) = do + tyDefns <- input @TyDefCtx + case M.lookup t tyDefns of + Nothing -> error $ show t ++ " not in ty defn map!" + Just (TyDefBody _ body) -> + ssConstraints %= ((body ts :<: ty2) :) + + -- Turn a <: T into a = T. See comment above. + simplifyOne' (ty1@TyVar {} :<: ty2@(TyCon (CUser _) _)) = + simplifyOne' (ty1 :=: ty2) + simplifyOne' (ty1 :<: TyCon (CUser t) ts) = do + tyDefns <- input @TyDefCtx + case M.lookup t tyDefns of + Nothing -> error $ show t ++ " not in ty defn map!" + Just (TyDefBody _ body) -> + ssConstraints %= ((ty1 :<: body ts) :) + + -- Given a subtyping constraint between two type constructors, + -- decompose it if the constructors are the same (or fail if they + -- aren't), taking into account the variance of each argument to + -- the constructor. Container types are a special case; + -- recursively generate a subtyping constraint for their + -- constructors as well. + simplifyOne' (TyCon c1@(CContainer ctr1) tys1 :<: TyCon (CContainer ctr2) tys2) = + ssConstraints + %= ( ( (TyAtom ctr1 :<: TyAtom ctr2) + : zipWith3 variance (arity c1) tys1 tys2 + ) + ++ ) - ++) - - simplifyOne' (TyCon c1 tys1 :<: TyCon c2 tys2) - | c1 /= c2 = throw NoUnify - | otherwise = - ssConstraints %= (zipWith3 variance (arity c1) tys1 tys2 ++) - - -- Given a subtyping constraint between a variable and a type - -- constructor, expand the variable into the same constructor - -- applied to fresh type variables. - simplifyOne' con@(TyVar a :<: TyCon c _) = expandStruct a c con - simplifyOne' con@(TyCon c _ :<: TyVar a ) = expandStruct a c con - - -- Given a subtyping constraint between two base types, just check - -- whether the first is indeed a subtype of the second. (Note - -- that we only pattern match here on type atoms, which could - -- include variables, but this will only ever get called if - -- 'simplifiable' was true, which checks that both are base - -- types.) - simplifyOne' (TyAtom (ABase b1) :<: TyAtom (ABase b2)) = do - case isSubB b1 b2 of - True -> return () - False -> throw NoUnify - - simplifyOne' (s :<: t) = - error $ "Impossible! simplifyOne' " ++ show s ++ " :<: " ++ show t - - expandStruct - :: Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r - => Name Type -> Con -> SimpleConstraint -> Sem r () - expandStruct a c con = do - as <- mapM (const (TyVar <$> fresh (string2Name "a"))) (arity c) - let s' = a |-> TyCon c as - ssConstraints %= (con:) - extendSubst s' - - -- 1. compose s' with current subst - -- 2. apply s' to constraints - -- 3. apply s' to qualifier map and decompose - extendSubst - :: Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r - => S -> Sem r () - extendSubst s' = do - ssSubst %= (s'@@) - ssConstraints %= applySubst s' - substVarMap s' - - substVarMap - :: Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r - => S -> Sem r () - substVarMap s' = do - vm <- use ssVarMap - - -- 1. Get quals for each var in domain of s' and match them with - -- the types being substituted for those vars. - - let tySorts :: [(Type, Sort)] - tySorts = map (second (view tyVarSort)) . mapMaybe (traverse (`lookupVM` vm) . swap) $ Subst.toList s' - - tyQualList :: [(Type, Qualifier)] - tyQualList = concatMap (sequenceA . second S.toList) tySorts - - -- 2. Decompose the resulting qualifier constraints - - vm' <- mconcat <$> mapM (uncurry decomposeQual) tyQualList - - -- 3. delete domain of s' from vm and merge in decomposed quals. - - ssVarMap .= vm' <> foldl' (flip deleteVM) vm (dom s') - - -- The above works even when unifying two variables. Say we have - -- the TyVarInfoMap - -- - -- a |-> {add} - -- b |-> {sub} - -- - -- and we get back theta = [a |-> b]. The domain of theta - -- consists solely of a, so we look up a in the TyVarInfoMap and get - -- {add}. We therefore generate the constraint 'add (theta a)' - -- = 'add b' which can't be decomposed at all, and hence yields - -- the TyVarInfoMap {b |-> {add}}. We then delete a from the - -- original TyVarInfoMap and merge the result with the new TyVarInfoMap, - -- yielding {b |-> {sub,add}}. - - - -- Create a subtyping constraint based on the variance of a type - -- constructor argument position: in the usual order for - -- covariant, and reversed for contravariant. - variance Co ty1 ty2 = ty1 :<: ty2 - variance Contra ty1 ty2 = ty2 :<: ty1 + simplifyOne' (TyCon c1 tys1 :<: TyCon c2 tys2) + | c1 /= c2 = throw NoUnify + | otherwise = + ssConstraints %= (zipWith3 variance (arity c1) tys1 tys2 ++) + -- Given a subtyping constraint between a variable and a type + -- constructor, expand the variable into the same constructor + -- applied to fresh type variables. + simplifyOne' con@(TyVar a :<: TyCon c _) = expandStruct a c con + simplifyOne' con@(TyCon c _ :<: TyVar a) = expandStruct a c con + -- Given a subtyping constraint between two base types, just check + -- whether the first is indeed a subtype of the second. (Note + -- that we only pattern match here on type atoms, which could + -- include variables, but this will only ever get called if + -- 'simplifiable' was true, which checks that both are base + -- types.) + simplifyOne' (TyAtom (ABase b1) :<: TyAtom (ABase b2)) = do + case isSubB b1 b2 of + True -> return () + False -> throw NoUnify + simplifyOne' (s :<: t) = + error $ "Impossible! simplifyOne' " ++ show s ++ " :<: " ++ show t + + expandStruct :: + Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r => + Name Type -> + Con -> + SimpleConstraint -> + Sem r () + expandStruct a c con = do + as <- mapM (const (TyVar <$> fresh (string2Name "a"))) (arity c) + let s' = a |-> TyCon c as + ssConstraints %= (con :) + extendSubst s' + + -- 1. compose s' with current subst + -- 2. apply s' to constraints + -- 3. apply s' to qualifier map and decompose + extendSubst :: + Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r => + S -> + Sem r () + extendSubst s' = do + ssSubst %= (s' @@) + ssConstraints %= applySubst s' + substVarMap s' + + substVarMap :: + Members '[State SimplifyState, Fresh, Error SolveError, Input TyDefCtx] r => + S -> + Sem r () + substVarMap s' = do + vm <- use ssVarMap + + -- 1. Get quals for each var in domain of s' and match them with + -- the types being substituted for those vars. + + let tySorts :: [(Type, Sort)] + tySorts = map (second (view tyVarSort)) . mapMaybe (traverse (`lookupVM` vm) . swap) $ Subst.toList s' + + tyQualList :: [(Type, Qualifier)] + tyQualList = concatMap (sequenceA . second S.toList) tySorts + + -- 2. Decompose the resulting qualifier constraints + + vm' <- mconcat <$> mapM (uncurry decomposeQual) tyQualList + + -- 3. delete domain of s' from vm and merge in decomposed quals. + + ssVarMap .= vm' <> foldl' (flip deleteVM) vm (dom s') + + -- The above works even when unifying two variables. Say we have + -- the TyVarInfoMap + -- + -- a |-> {add} + -- b |-> {sub} + -- + -- and we get back theta = [a |-> b]. The domain of theta + -- consists solely of a, so we look up a in the TyVarInfoMap and get + -- {add}. We therefore generate the constraint 'add (theta a)' + -- = 'add b' which can't be decomposed at all, and hence yields + -- the TyVarInfoMap {b |-> {add}}. We then delete a from the + -- original TyVarInfoMap and merge the result with the new TyVarInfoMap, + -- yielding {b |-> {sub,add}}. + + -- Create a subtyping constraint based on the variance of a type + -- constructor argument position: in the usual order for + -- covariant, and reversed for contravariant. + variance Co ty1 ty2 = ty1 :<: ty2 + variance Contra ty1 ty2 = ty2 :<: ty1 -------------------------------------------------- -- Step 4: Build constraint graph @@ -718,8 +738,8 @@ simplify origVM cs -- corresponding constraint graph. mkConstraintGraph :: (Show a, Ord a) => Set a -> [(a, a)] -> Graph a mkConstraintGraph as cs = G.mkGraph nodes (S.fromList cs) - where - nodes = as `S.union` S.fromList (cs ^.. traverse . each) + where + nodes = as `S.union` S.fromList (cs ^.. traverse . each) -------------------------------------------------- -- Step 5: Check skolems @@ -730,18 +750,24 @@ mkConstraintGraph as cs = G.mkGraph nodes (S.fromList cs) -- If there are any WCCs with a single skolem, no base types, and -- only unsorted variables, just unify them all with the skolem and -- remove those components. -checkSkolems - :: Members '[Error SolveError, Output Message, Input TyDefCtx] r - => TyVarInfoMap -> Graph Atom -> Sem r (Graph UAtom, S) +checkSkolems :: + Members '[Error SolveError, Output Message, Input TyDefCtx] r => + TyVarInfoMap -> + Graph Atom -> + Sem r (Graph UAtom, S) checkSkolems vm graph = do let skolemWCCs :: [Set Atom] skolemWCCs = filter (any isSkolem) $ G.wcc graph - ok wcc = S.size (S.filter isSkolem wcc) <= 1 - && all (\case { ABase _ -> False - ; AVar (S _) -> True - ; AVar (U v) -> maybe True (S.null . view tyVarSort) (lookupVM v vm) }) - wcc + ok wcc = + S.size (S.filter isSkolem wcc) <= 1 + && all + ( \case + ABase _ -> False + AVar (S _) -> True + AVar (U v) -> maybe True (S.null . view tyVarSort) (lookupVM v vm) + ) + wcc (good, bad) = partition ok skolemWCCs @@ -751,28 +777,30 @@ checkSkolems vm graph = do -- (1) delete them from the graph -- (2) unify them all with the skolem unifyWCCs graph idS good - - where - noSkolems :: Atom -> UAtom - noSkolems (ABase b) = UB b - noSkolems (AVar (U v)) = UV v - noSkolems (AVar (S v)) = error $ "Skolem " ++ show v ++ " remaining in noSkolems" - - unifyWCCs - :: Members '[Error SolveError, Output Message, Input TyDefCtx] r - => Graph Atom -> S -> [Set Atom] -> Sem r (Graph UAtom, S) - unifyWCCs g s [] = return (G.map noSkolems g, s) - unifyWCCs g s (u:us) = do - debug $ "Unifying" <+> pretty' (u:us) <> "..." - - tyDefns <- input @TyDefCtx - - let g' = foldl' (flip G.delete) g u - - ms' = unifyAtoms tyDefns (S.toList u) - case ms' of - Nothing -> throw NoUnify - Just s' -> unifyWCCs g' (atomToTypeSubst s' @@ s) us + where + noSkolems :: Atom -> UAtom + noSkolems (ABase b) = UB b + noSkolems (AVar (U v)) = UV v + noSkolems (AVar (S v)) = error $ "Skolem " ++ show v ++ " remaining in noSkolems" + + unifyWCCs :: + Members '[Error SolveError, Output Message, Input TyDefCtx] r => + Graph Atom -> + S -> + [Set Atom] -> + Sem r (Graph UAtom, S) + unifyWCCs g s [] = return (G.map noSkolems g, s) + unifyWCCs g s (u : us) = do + debug $ "Unifying" <+> pretty' (u : us) <> "..." + + tyDefns <- input @TyDefCtx + + let g' = foldl' (flip G.delete) g u + + ms' = unifyAtoms tyDefns (S.toList u) + case ms' of + Nothing -> throw NoUnify + Just s' -> unifyWCCs g' (atomToTypeSubst s' @@ s) us -------------------------------------------------- -- Step 6: Eliminate cycles @@ -789,40 +817,44 @@ checkSkolems vm graph = do -- Of course, this step can fail if the types in a SCC are not -- unifiable. If it succeeds, it returns the collapsed graph (which -- is now guaranteed to be acyclic, i.e. a DAG) and a substitution. -elimCycles - :: Members '[Error SolveError] r - => TyDefCtx -> Graph UAtom -> Sem r (Graph UAtom, S) +elimCycles :: + Members '[Error SolveError] r => + TyDefCtx -> + Graph UAtom -> + Sem r (Graph UAtom, S) elimCycles tyDefns = elimCyclesGen uatomToTypeSubst (unifyUAtoms tyDefns) -elimCyclesGen - :: forall a b r. (Subst a a, Ord a, Members '[Error SolveError] r) - => (Substitution a -> Substitution b) -> ([a] -> Maybe (Substitution a)) - -> Graph a -> Sem r (Graph a, Substitution b) -elimCyclesGen genSubst genUnify g - = note NoUnify - $ (G.map fst &&& (genSubst . compose . S.map snd . G.nodes)) <$> g' - where - - g' :: Maybe (Graph (a, Substitution a)) - g' = G.sequenceGraph $ G.map unifySCC (G.condensation g) - - unifySCC :: Set a -> Maybe (a, Substitution a) - unifySCC uatoms = case S.toList uatoms of - [] -> error "Impossible! unifySCC on the empty set" - as@(a:_) -> (flip applySubst a &&& id) <$> genUnify as +elimCyclesGen :: + forall a b r. + (Subst a a, Ord a, Members '[Error SolveError] r) => + (Substitution a -> Substitution b) -> + ([a] -> Maybe (Substitution a)) -> + Graph a -> + Sem r (Graph a, Substitution b) +elimCyclesGen genSubst genUnify g = + note NoUnify $ + (G.map fst &&& (genSubst . compose . S.map snd . G.nodes)) <$> g' + where + g' :: Maybe (Graph (a, Substitution a)) + g' = G.sequenceGraph $ G.map unifySCC (G.condensation g) + + unifySCC :: Set a -> Maybe (a, Substitution a) + unifySCC uatoms = case S.toList uatoms of + [] -> error "Impossible! unifySCC on the empty set" + as@(a : _) -> (flip applySubst a &&& id) <$> genUnify as ------------------------------------------------------------ -- Step 6a: check base type edges ------------------------------------------------------------ isBaseEdge :: (UAtom, UAtom) -> Either (BaseTy, BaseTy) (UAtom, UAtom) -isBaseEdge (UB b1, UB b2) = Left (b1,b2) -isBaseEdge e = Right e +isBaseEdge (UB b1, UB b2) = Left (b1, b2) +isBaseEdge e = Right e checkBaseEdge :: Members '[Error SolveError] r => (BaseTy, BaseTy) -> Sem r () checkBaseEdge (b1, b2) | isSubB b1 b2 = return () - | otherwise = throw NoUnify + | otherwise = throw NoUnify checkBaseEdges :: Members '[Error SolveError] r => Graph UAtom -> Sem r (Graph UAtom) checkBaseEdges g = do @@ -839,22 +871,22 @@ checkBaseEdges g = do -- successors, but not both). data Rels = Rels { baseRels :: Set BaseTy - , varRels :: Set (Name Type) + , varRels :: Set (Name Type) } deriving (Show, Eq) -- | A RelMap associates each variable to its sets of base type and -- variable predecessors and successors in the constraint graph. -newtype RelMap = RelMap { unRelMap :: Map (Name Type, Dir) Rels} +newtype RelMap = RelMap {unRelMap :: Map (Name Type, Dir) Rels} instance Pretty RelMap where pretty (RelMap rm) = vcat (map prettyVar byVar) - where - vars = S.map fst (M.keysSet rm) - byVar = map (\x -> (rm!(x,SubTy), x, rm!(x,SuperTy))) (S.toList vars) + where + vars = S.map fst (M.keysSet rm) + byVar = map (\x -> (rm ! (x, SubTy), x, rm ! (x, SuperTy))) (S.toList vars) - prettyVar (subs, x, sups) = hsep [prettyRel subs, "<:", pretty x, "<:", prettyRel sups] - prettyRel rs = pretty (baseRels rs) <> ", " <> pretty (varRels rs) + prettyVar (subs, x, sups) = hsep [prettyRel subs, "<:", pretty x, "<:", prettyRel sups] + prettyRel rs = pretty (baseRels rs) <> ", " <> pretty (varRels rs) -- | Modify a @RelMap@ to record the fact that we have solved for a -- type variable. In particular, delete the variable from the @@ -862,57 +894,57 @@ instance Pretty RelMap where -- other variable to remove this variable and add the base type we -- chose for it. substRel :: Name Type -> BaseTy -> RelMap -> RelMap -substRel x ty - = RelMap - . M.delete (x,SuperTy) - . M.delete (x,SubTy) - . M.map (\r@(Rels b v) -> if x `S.member` v then Rels (S.insert ty b) (S.delete x v) else r) - . unRelMap +substRel x ty = + RelMap + . M.delete (x, SuperTy) + . M.delete (x, SubTy) + . M.map (\r@(Rels b v) -> if x `S.member` v then Rels (S.insert ty b) (S.delete x v) else r) + . unRelMap -- | Essentially dirtypesBySort vm rm dir t s x finds all the -- dir-types (sub- or super-) of t which have sort s, relative to -- the variables in x. This is \overbar{T}_S^X (resp. \underbar...) -- from Traytel et al. dirtypesBySort :: TyVarInfoMap -> RelMap -> Dir -> BaseTy -> Sort -> Set (Name Type) -> [BaseTy] -dirtypesBySort vm (RelMap relMap) dir t s x - - -- Keep only those supertypes t' of t - = keep (dirtypes dir t) $ \t' -> - -- which have the right sort, and such that - hasSort t' s && - +dirtypesBySort vm (RelMap relMap) dir t s x = + -- Keep only those supertypes t' of t + keep (dirtypes dir t) $ \t' -> + -- which have the right sort, and such that + hasSort t' s + && -- for all variables beta \in x, - forAll x (\beta -> - - -- there is at least one type t'' which is a subtype of t' - -- which would be a valid solution for beta, that is, - exists (dirtypes (other dir) t') $ \t'' -> - - -- t'' has the sort beta is supposed to have, and - hasSort t'' (getSort vm beta) && - - -- t'' is a supertype of every base type predecessor of beta. - forAll (baseRels (lkup "dirtypesBySort, beta rel" relMap (beta, other dir))) - (isDirB dir t'')) - - -- The above comments are written assuming dir = Super; of course, - -- if dir = Sub then just swap "super" and "sub" everywhere. - - where - forAll, exists :: Foldable t => t a -> (a -> Bool) -> Bool - forAll = flip all - exists = flip any - keep = flip filter + forAll + x + ( \beta -> + -- there is at least one type t'' which is a subtype of t' + -- which would be a valid solution for beta, that is, + exists (dirtypes (other dir) t') $ \t'' -> + -- t'' has the sort beta is supposed to have, and + hasSort t'' (getSort vm beta) + && + -- t'' is a supertype of every base type predecessor of beta. + forAll + (baseRels (lkup "dirtypesBySort, beta rel" relMap (beta, other dir))) + (isDirB dir t'') + ) + where + -- The above comments are written assuming dir = Super; of course, + -- if dir = Sub then just swap "super" and "sub" everywhere. + + forAll, exists :: Foldable t => t a -> (a -> Bool) -> Bool + forAll = flip all + exists = flip any + keep = flip filter -- | Sort-aware infimum or supremum. limBySort :: TyVarInfoMap -> RelMap -> Dir -> [BaseTy] -> Sort -> Set (Name Type) -> Maybe BaseTy -limBySort vm rm dir ts s x - = (\is -> find (\lim -> all (\u -> isDirB dir u lim) is) is) - . isects - . map (\t -> dirtypesBySort vm rm dir t s x) - $ ts - where - isects = foldr1 intersect +limBySort vm rm dir ts s x = + (\is -> find (\lim -> all (\u -> isDirB dir u lim) is) is) + . isects + . map (\t -> dirtypesBySort vm rm dir t s x) + $ ts + where + isects = foldr1 intersect lubBySort, glbBySort :: TyVarInfoMap -> RelMap -> [BaseTy] -> Sort -> Set (Name Type) -> Maybe BaseTy lubBySort vm rm = limBySort vm rm SuperTy @@ -931,77 +963,85 @@ glbBySort vm rm = limBySort vm rm SubTy -- complete algorithm. We choose to assign it the sup of its -- predecessors in this case, since it seems nice to default to -- "simpler" types lower down in the subtyping chain. -solveGraph - :: Members '[Fresh, Error SolveError, Output Message] r - => TyVarInfoMap -> Graph UAtom -> Sem r S +solveGraph :: + Members '[Fresh, Error SolveError, Output Message] r => + TyVarInfoMap -> + Graph UAtom -> + Sem r S solveGraph vm g = atomToTypeSubst . unifyWCC <$> go topRelMap - where - unifyWCC :: Substitution BaseTy -> Substitution Atom - unifyWCC s = compose (map mkEquateSubst wccVarGroups) @@ fmap ABase s - where - wccVarGroups :: [Set (Name Type)] - wccVarGroups = map (S.map getVar) . filter (all uisVar) . applySubst s $ G.wcc g - getVar (UV v) = v - getVar (UB b) = error - $ "Impossible! Base type " ++ show b ++ " in solveGraph.getVar" - - mkEquateSubst :: Set (Name Type) -> Substitution Atom - mkEquateSubst = mkEquations . S.toList - - mkEquations (a:as) = Subst.fromList . map (\v -> (coerce v, AVar (U a))) $ as - mkEquations [] = error "Impossible! Empty set of names in mkEquateSubst" - - -- After picking concrete base types for all the type - -- variables we can, the only thing possibly remaining in - -- the graph are components containing only type variables - -- and no base types. It is sound, and simplifies the - -- generated types considerably, to simply unify any type - -- variables which are related by subtyping constraints. - -- That is, we collect all the type variables in each - -- weakly connected component and unify them. - -- - -- As an example where this final step makes a difference, - -- consider a term like @\x. (\y.y) x@. A fresh type - -- variable is generated for the type of @x@, and another - -- for the type of @y@; the application of @(\y.y)@ to @x@ - -- induces a subtyping constraint between the two type - -- variables. The most general type would be something - -- like @forall a b. (a <: b) => a -> b@, but we want to - -- avoid generating unnecessary subtyping constraints (the - -- type system might not even support subtyping qualifiers - -- like this). Instead, we unify the two type variables - -- and the resulting type is @forall a. a -> a@. - - -- Get the successor and predecessor sets for all the type variables. - topRelMap :: RelMap - topRelMap - = RelMap - . M.map (uncurry Rels . (S.fromAscList *** S.fromAscList) - . partitionEithers . map uatomToEither . S.toList) + where + unifyWCC :: Substitution BaseTy -> Substitution Atom + unifyWCC s = compose (map mkEquateSubst wccVarGroups) @@ fmap ABase s + where + wccVarGroups :: [Set (Name Type)] + wccVarGroups = map (S.map getVar) . filter (all uisVar) . applySubst s $ G.wcc g + getVar (UV v) = v + getVar (UB b) = + error $ + "Impossible! Base type " ++ show b ++ " in solveGraph.getVar" + + mkEquateSubst :: Set (Name Type) -> Substitution Atom + mkEquateSubst = mkEquations . S.toList + + mkEquations (a : as) = Subst.fromList . map (\v -> (coerce v, AVar (U a))) $ as + mkEquations [] = error "Impossible! Empty set of names in mkEquateSubst" + + -- After picking concrete base types for all the type + -- variables we can, the only thing possibly remaining in + -- the graph are components containing only type variables + -- and no base types. It is sound, and simplifies the + -- generated types considerably, to simply unify any type + -- variables which are related by subtyping constraints. + -- That is, we collect all the type variables in each + -- weakly connected component and unify them. + -- + -- As an example where this final step makes a difference, + -- consider a term like @\x. (\y.y) x@. A fresh type + -- variable is generated for the type of @x@, and another + -- for the type of @y@; the application of @(\y.y)@ to @x@ + -- induces a subtyping constraint between the two type + -- variables. The most general type would be something + -- like @forall a b. (a <: b) => a -> b@, but we want to + -- avoid generating unnecessary subtyping constraints (the + -- type system might not even support subtyping qualifiers + -- like this). Instead, we unify the two type variables + -- and the resulting type is @forall a. a -> a@. + + -- Get the successor and predecessor sets for all the type variables. + topRelMap :: RelMap + topRelMap = + RelMap + . M.map + ( uncurry Rels + . (S.fromAscList *** S.fromAscList) + . partitionEithers + . map uatomToEither + . S.toList + ) $ M.mapKeys (,SuperTy) subMap `M.union` M.mapKeys (,SubTy) superMap - subMap, superMap :: Map (Name Type) (Set UAtom) - (subMap, superMap) = (onlyVars *** onlyVars) $ G.cessors g - - onlyVars :: Map UAtom (Set UAtom) -> Map (Name Type) (Set UAtom) - onlyVars = M.mapKeys fromVar . M.filterWithKey (\a _ -> uisVar a) - where - fromVar (UV x) = x - fromVar _ = error "Impossible! UB but uisVar." - - go - :: Members '[Fresh, Error SolveError, Output Message] r - => RelMap -> Sem r (Substitution BaseTy) - go relMap@(RelMap rm) = debugPretty relMap >> case as of - + subMap, superMap :: Map (Name Type) (Set UAtom) + (subMap, superMap) = (onlyVars *** onlyVars) $ G.cessors g + + onlyVars :: Map UAtom (Set UAtom) -> Map (Name Type) (Set UAtom) + onlyVars = M.mapKeys fromVar . M.filterWithKey (\a _ -> uisVar a) + where + fromVar (UV x) = x + fromVar _ = error "Impossible! UB but uisVar." + + go :: + Members '[Fresh, Error SolveError, Output Message] r => + RelMap -> + Sem r (Substitution BaseTy) + go relMap@(RelMap rm) = + debugPretty relMap >> case as of -- No variables left that have base type constraints. - [] -> return idS - + [] -> return idS -- Solve one variable at a time. See below. - (a:_) -> do + (a : _) -> do debug $ "Solving for" <+> pretty' a case solveVar a of - Nothing -> do + Nothing -> do debug $ "Couldn't solve for" <+> pretty' a throw NoUnify @@ -1019,103 +1059,124 @@ solveGraph vm g = atomToTypeSubst . unifyWCC <$> go topRelMap Just s -> do debugPretty s (@@ s) <$> go (substRel a (fromJust $ Subst.lookup (coerce a) s) relMap) - - where - -- NOTE we can't solve a bunch in parallel! Might end up - -- assigning them conflicting solutions if some depend on - -- others. For example, consider the situation - -- - -- Z - -- | - -- a3 - -- / \ - -- a1 N + where + -- NOTE we can't solve a bunch in parallel! Might end up + -- assigning them conflicting solutions if some depend on + -- others. For example, consider the situation + -- + -- Z + -- | + -- a3 + -- / \ + -- a1 N + -- + -- If we try to solve in parallel we will end up assigning a1 + -- -> Z (since it only has base types as an upper bound) and + -- a3 -> N (since it has both upper and lower bounds, and by + -- default we pick the lower bound), but this is wrong since + -- we should have a1 < a3. + -- + -- If instead we solve them one at a time, we could e.g. first + -- solve a1 -> Z, and then we would find a3 -> Z as well. + -- Alternately, if we first solve a3 -> N then we will have a1 + -- -> N as well. Both are acceptable. + -- + -- In fact, this exact graph comes from (^x.x+1) which was + -- erroneously being inferred to have type Z -> N when I first + -- wrote the code. + + -- Get only the variables we can solve on this pass, which + -- have base types in their predecessor or successor set. If + -- there are no such variables, then start picking any + -- remaining variables with a sort and pick types for them + -- (disco doesn't have qualified polymorphism so we can't just + -- leave them). + asBase = + map fst + . filter (not . S.null . baseRels . lkup "solveGraph.go.as" rm) + $ M.keys rm + as = case asBase of + [] -> filter ((/= topSort) . getSort vm) . map fst $ M.keys rm + _ -> asBase + + -- Solve for a variable, failing if it has no solution, otherwise returning + -- a substitution for it. + solveVar :: Name Type -> Maybe (Substitution BaseTy) + solveVar v = + case ((v, SuperTy), (v, SubTy)) & over both (S.toList . baseRels . lkup "solveGraph.solveVar" rm) of + -- No sub- or supertypes; the only way this can happen is + -- if it has a nontrivial sort. -- - -- If we try to solve in parallel we will end up assigning a1 - -- -> Z (since it only has base types as an upper bound) and - -- a3 -> N (since it has both upper and lower bounds, and by - -- default we pick the lower bound), but this is wrong since - -- we should have a1 < a3. + -- Traytel et al. don't seem to have a rule saying what to + -- do in this case (see Fig. 16 on p. 16 of their long + -- version). We used to just pick a type that inhabits + -- the sort, but this is wrong; see + -- https://github.com/disco-lang/disco/issues/192. -- - -- If instead we solve them one at a time, we could e.g. first - -- solve a1 -> Z, and then we would find a3 -> Z as well. - -- Alternately, if we first solve a3 -> N then we will have a1 - -- -> N as well. Both are acceptable. + -- If the sort is 'bool', we'll pick the Boolean base + -- type, since there are no other sorts which could cause + -- a conflict as in #192. -- - -- In fact, this exact graph comes from (^x.x+1) which was - -- erroneously being inferred to have type Z -> N when I first - -- wrote the code. - - -- Get only the variables we can solve on this pass, which - -- have base types in their predecessor or successor set. If - -- there are no such variables, then start picking any - -- remaining variables with a sort and pick types for them - -- (disco doesn't have qualified polymorphism so we can't just - -- leave them). - asBase - = map fst - . filter (not . S.null . baseRels . lkup "solveGraph.go.as" rm) - $ M.keys rm - as = case asBase of - [] -> filter ((/= topSort) . getSort vm) . map fst $ M.keys rm - _ -> asBase - - -- Solve for a variable, failing if it has no solution, otherwise returning - -- a substitution for it. - solveVar :: Name Type -> Maybe (Substitution BaseTy) - solveVar v = - case ((v,SuperTy), (v,SubTy)) & over both (S.toList . baseRels . lkup "solveGraph.solveVar" rm) of - -- No sub- or supertypes; the only way this can happen is - -- if it has a nontrivial sort. - -- - -- Traytel et al. don't seem to have a rule saying what to - -- do in this case (see Fig. 16 on p. 16 of their long - -- version). We used to just pick a type that inhabits - -- the sort, but this is wrong; see - -- https://github.com/disco-lang/disco/issues/192. - -- - -- If the sort is 'bool', we'll pick the Boolean base - -- type, since there are no other sorts which could cause - -- a conflict as in #192. - -- - -- Otherwise, we assume that any situation in which we - -- have no base sub- or supertypes but we do have - -- nontrivial sorts means that we are dealing with numeric - -- types; so we can just call N a base subtype and go from - -- there. - - ([], []) -> - if getSort vm v == S.fromList [QBool] - then Just (coerce v |-> B) - else - -- Debug.trace (show v ++ " has no sub- or supertypes. Assuming N as a subtype.") - (coerce v |->) <$> lubBySort vm relMap [N] (getSort vm v) - (varRels (lkup "solveVar none, rels" rm (v,SubTy))) - - -- Only supertypes. Just assign a to their inf, if one exists. - (bsupers, []) -> - -- Debug.trace (show v ++ " has only supertypes (" ++ show bsupers ++ ")") $ - (coerce v |->) <$> glbBySort vm relMap bsupers (getSort vm v) - (varRels (lkup "solveVar bsupers, rels" rm (v,SuperTy))) - - -- Only subtypes. Just assign a to their sup. - ([], bsubs) -> - -- Debug.trace (show v ++ " has only subtypes (" ++ show bsubs ++ ")") $ - -- Debug.trace ("sortmap: " ++ show vm) $ - -- Debug.trace ("relmap: " ++ show relMap) $ - -- Debug.trace ("sort for " ++ show v ++ ": " ++ show (getSort vm v)) $ - -- Debug.trace ("relvars: " ++ show (varRels (relMap ! (v,SubTy)))) $ - (coerce v |->) <$> lubBySort vm relMap bsubs (getSort vm v) - (varRels (lkup "solveVar bsubs, rels" rm (v,SubTy))) - - -- Both successors and predecessors. Both must have a - -- valid bound, and the bounds must not overlap. Assign a - -- to the sup of its predecessors. - (bsupers, bsubs) -> do - ub <- glbBySort vm relMap bsupers (getSort vm v) - (varRels (rm ! (v,SuperTy))) - lb <- lubBySort vm relMap bsubs (getSort vm v) - (varRels (rm ! (v,SubTy))) - case isSubB lb ub of - True -> Just (coerce v |-> lb) - False -> Nothing + -- Otherwise, we assume that any situation in which we + -- have no base sub- or supertypes but we do have + -- nontrivial sorts means that we are dealing with numeric + -- types; so we can just call N a base subtype and go from + -- there. + + ([], []) -> + if getSort vm v == S.fromList [QBool] + then Just (coerce v |-> B) + else -- Debug.trace (show v ++ " has no sub- or supertypes. Assuming N as a subtype.") + + (coerce v |->) + <$> lubBySort + vm + relMap + [N] + (getSort vm v) + (varRels (lkup "solveVar none, rels" rm (v, SubTy))) + -- Only supertypes. Just assign a to their inf, if one exists. + (bsupers, []) -> + -- Debug.trace (show v ++ " has only supertypes (" ++ show bsupers ++ ")") $ + (coerce v |->) + <$> glbBySort + vm + relMap + bsupers + (getSort vm v) + (varRels (lkup "solveVar bsupers, rels" rm (v, SuperTy))) + -- Only subtypes. Just assign a to their sup. + ([], bsubs) -> + -- Debug.trace (show v ++ " has only subtypes (" ++ show bsubs ++ ")") $ + -- Debug.trace ("sortmap: " ++ show vm) $ + -- Debug.trace ("relmap: " ++ show relMap) $ + -- Debug.trace ("sort for " ++ show v ++ ": " ++ show (getSort vm v)) $ + -- Debug.trace ("relvars: " ++ show (varRels (relMap ! (v,SubTy)))) $ + (coerce v |->) + <$> lubBySort + vm + relMap + bsubs + (getSort vm v) + (varRels (lkup "solveVar bsubs, rels" rm (v, SubTy))) + -- Both successors and predecessors. Both must have a + -- valid bound, and the bounds must not overlap. Assign a + -- to the sup of its predecessors. + (bsupers, bsubs) -> do + ub <- + glbBySort + vm + relMap + bsupers + (getSort vm v) + (varRels (rm ! (v, SuperTy))) + lb <- + lubBySort + vm + relMap + bsubs + (getSort vm v) + (varRels (rm ! (v, SubTy))) + case isSubB lb ub of + True -> Just (coerce v |-> lb) + False -> Nothing diff --git a/src/Disco/Typecheck/Unify.hs b/src/Disco/Typecheck/Unify.hs index 9abbf7dc..b8e10a23 100644 --- a/src/Disco/Typecheck/Unify.hs +++ b/src/Disco/Typecheck/Unify.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Typecheck.Unify -- Copyright : disco team and contributors @@ -7,21 +10,18 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Unification. --- ------------------------------------------------------------------------------ - module Disco.Typecheck.Unify where -import Unbound.Generics.LocallyNameless (Name, fv) +import Unbound.Generics.LocallyNameless (Name, fv) -import Control.Lens (anyOf) -import Control.Monad.State -import qualified Data.Map as M -import Data.Set (Set) -import qualified Data.Set as S +import Control.Lens (anyOf) +import Control.Monad.State +import qualified Data.Map as M +import Data.Set (Set) +import qualified Data.Set as S -import Disco.Subst -import Disco.Types +import Disco.Subst +import Disco.Types -- XXX todo: might be better if unification took sorts into account -- directly. As it is, however, I think it works properly; @@ -51,92 +51,89 @@ weakUnify = unify' (\_ _ -> True) -- | Given a list of equations between types, return a substitution -- which makes all the equations satisfied (or fail if it is not -- possible), up to the given comparison on base types. -unify' :: (BaseTy -> BaseTy -> Bool) -> TyDefCtx - -> [(Type, Type)] -> Maybe S +unify' :: + (BaseTy -> BaseTy -> Bool) -> + TyDefCtx -> + [(Type, Type)] -> + Maybe S unify' baseEq tyDefns eqs = evalStateT (go eqs) S.empty - where - go :: [(Type, Type)] -> StateT (Set (Type,Type)) Maybe S - go [] = return idS - go (e:es) = do - u <- unifyOne e - case u of - Left sub -> (@@ sub) <$> go (applySubst sub es) - Right newEs -> go (newEs ++ es) - - unifyOne :: (Type, Type) -> StateT (Set (Type,Type)) Maybe (Either S [(Type, Type)]) - unifyOne pair = do - seen <- get - case pair `S.member` seen of - True -> return $ Left idS - False -> unifyOne' pair - - unifyOne' :: (Type, Type) -> StateT (Set (Type,Type)) Maybe (Either S [(Type, Type)]) - - unifyOne' (ty1, ty2) - | ty1 == ty2 = return $ Left idS - - unifyOne' (TyVar x, ty2) - | occurs x ty2 = mzero - | otherwise = return $ Left (x |-> ty2) - unifyOne' (ty1, x@(TyVar _)) - = unifyOne (x, ty1) - - -- At this point we know ty2 isn't the same skolem nor a unification variable. - -- Skolems don't unify with anything. - unifyOne' (TySkolem _, _) = mzero - unifyOne' (_, TySkolem _) = mzero - - -- Unify two container types: unify the container descriptors as - -- well as the type arguments - unifyOne' p@(TyCon (CContainer ctr1) tys1, TyCon (CContainer ctr2) tys2) = do - modify (S.insert p) - return $ Right ((TyAtom ctr1, TyAtom ctr2) : zip tys1 tys2) - - -- If one of the types to be unified is a user-defined type, - -- unfold its definition before continuing the matching - unifyOne' p@(TyCon (CUser t) tys1, ty2) = do - modify (S.insert p) - case M.lookup t tyDefns of - Nothing -> mzero - Just (TyDefBody _ body) -> return $ Right [(body tys1, ty2)] - - unifyOne' p@(ty1, TyCon (CUser t) tys2) = do - modify (S.insert p) - case M.lookup t tyDefns of - Nothing -> mzero - Just (TyDefBody _ body) -> return $ Right [(ty1, body tys2)] - - -- Unify any other pair of type constructor applications: the type - -- constructors must match exactly - unifyOne' p@(TyCon c1 tys1, TyCon c2 tys2) - | c1 == c2 = do - modify (S.insert p) - return $ Right (zip tys1 tys2) - | otherwise = mzero - unifyOne' (TyAtom (ABase b1), TyAtom (ABase b2)) - | baseEq b1 b2 = return $ Left idS - | otherwise = mzero - unifyOne' _ = mzero -- Atom = Cons - + where + go :: [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S + go [] = return idS + go (e : es) = do + u <- unifyOne e + case u of + Left sub -> (@@ sub) <$> go (applySubst sub es) + Right newEs -> go (newEs ++ es) + + unifyOne :: (Type, Type) -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]) + unifyOne pair = do + seen <- get + case pair `S.member` seen of + True -> return $ Left idS + False -> unifyOne' pair + + unifyOne' :: (Type, Type) -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]) + + unifyOne' (ty1, ty2) + | ty1 == ty2 = return $ Left idS + unifyOne' (TyVar x, ty2) + | occurs x ty2 = mzero + | otherwise = return $ Left (x |-> ty2) + unifyOne' (ty1, x@(TyVar _)) = + unifyOne (x, ty1) + -- At this point we know ty2 isn't the same skolem nor a unification variable. + -- Skolems don't unify with anything. + unifyOne' (TySkolem _, _) = mzero + unifyOne' (_, TySkolem _) = mzero + -- Unify two container types: unify the container descriptors as + -- well as the type arguments + unifyOne' p@(TyCon (CContainer ctr1) tys1, TyCon (CContainer ctr2) tys2) = do + modify (S.insert p) + return $ Right ((TyAtom ctr1, TyAtom ctr2) : zip tys1 tys2) + + -- If one of the types to be unified is a user-defined type, + -- unfold its definition before continuing the matching + unifyOne' p@(TyCon (CUser t) tys1, ty2) = do + modify (S.insert p) + case M.lookup t tyDefns of + Nothing -> mzero + Just (TyDefBody _ body) -> return $ Right [(body tys1, ty2)] + unifyOne' p@(ty1, TyCon (CUser t) tys2) = do + modify (S.insert p) + case M.lookup t tyDefns of + Nothing -> mzero + Just (TyDefBody _ body) -> return $ Right [(ty1, body tys2)] + + -- Unify any other pair of type constructor applications: the type + -- constructors must match exactly + unifyOne' p@(TyCon c1 tys1, TyCon c2 tys2) + | c1 == c2 = do + modify (S.insert p) + return $ Right (zip tys1 tys2) + | otherwise = mzero + unifyOne' (TyAtom (ABase b1), TyAtom (ABase b2)) + | baseEq b1 b2 = return $ Left idS + | otherwise = mzero + unifyOne' _ = mzero -- Atom = Cons equate :: TyDefCtx -> [Type] -> Maybe S equate tyDefns tys = unify tyDefns eqns - where - eqns = zip tys (tail tys) + where + eqns = zip tys (tail tys) occurs :: Name Type -> Type -> Bool -occurs x = anyOf fv (==x) - +occurs x = anyOf fv (== x) unifyAtoms :: TyDefCtx -> [Atom] -> Maybe (Substitution Atom) unifyAtoms tyDefns = fmap (fmap fromTyAtom) . equate tyDefns . map TyAtom - where - fromTyAtom (TyAtom a) = a - fromTyAtom _ = error "fromTyAtom on non-TyAtom!" + where + fromTyAtom (TyAtom a) = a + fromTyAtom _ = error "fromTyAtom on non-TyAtom!" unifyUAtoms :: TyDefCtx -> [UAtom] -> Maybe (Substitution UAtom) unifyUAtoms tyDefns = fmap (fmap fromTyAtom) . equate tyDefns . map (TyAtom . uatomToAtom) - where - fromTyAtom (TyAtom (ABase b)) = UB b - fromTyAtom (TyAtom (AVar (U v))) = UV v - fromTyAtom _ = error "fromTyAtom on wrong thing!" + where + fromTyAtom (TyAtom (ABase b)) = UB b + fromTyAtom (TyAtom (AVar (U v))) = UV v + fromTyAtom _ = error "fromTyAtom on wrong thing!" diff --git a/src/Disco/Typecheck/Util.hs b/src/Disco/Typecheck/Util.hs index 468b26a2..1c70dfa0 100644 --- a/src/Disco/Typecheck/Util.hs +++ b/src/Disco/Typecheck/Util.hs @@ -1,5 +1,7 @@ +----------------------------------------------------------------------------- ----------------------------------------------------------------------------- + -- | -- Module : Disco.Typecheck.Util -- Copyright : (c) 2016 disco team (see LICENSE) @@ -8,30 +10,27 @@ -- -- Definition of type contexts, type errors, and various utilities -- used during type checking. --- ------------------------------------------------------------------------------ - module Disco.Typecheck.Util where -import Disco.Effects.Fresh -import Polysemy -import Polysemy.Error -import Polysemy.Output -import Polysemy.Reader -import Polysemy.Writer -import Unbound.Generics.LocallyNameless (Name, bind, string2Name) - -import qualified Data.Map as M -import Data.Tuple (swap) -import Prelude hiding (lookup) - -import Disco.AST.Surface -import Disco.Context -import Disco.Messages -import Disco.Names (ModuleName, QName) -import Disco.Typecheck.Constraints -import Disco.Typecheck.Solve -import Disco.Types +import Disco.Effects.Fresh +import Polysemy +import Polysemy.Error +import Polysemy.Output +import Polysemy.Reader +import Polysemy.Writer +import Unbound.Generics.LocallyNameless (Name, bind, string2Name) + +import qualified Data.Map as M +import Data.Tuple (swap) +import Prelude hiding (lookup) + +import Disco.AST.Surface +import Disco.Context +import Disco.Messages +import Disco.Names (ModuleName, QName) +import Disco.Typecheck.Constraints +import Disco.Typecheck.Solve +import Disco.Types ------------------------------------------------------------ -- Contexts @@ -47,7 +46,7 @@ type TyCtx = Ctx Term PolyType -- | A typechecking error, wrapped up together with the name of the -- thing that was being checked when the error occurred. data LocTCError = LocTCError (Maybe (QName Term)) TCError - deriving Show + deriving (Show) -- | Wrap a @TCError@ into a @LocTCError@ with no explicit provenance -- information. @@ -56,39 +55,60 @@ noLoc = LocTCError Nothing -- | Potential typechecking errors. data TCError - = Unbound (Name Term) -- ^ Encountered an unbound variable - | Ambiguous (Name Term) [ModuleName] -- ^ Encountered an ambiguous name. - | NoType (Name Term) -- ^ No type is specified for a definition - | NotCon Con Term Type -- ^ The type of the term should have an - -- outermost constructor matching Con, but - -- it has type 'Type' instead - | EmptyCase -- ^ Case analyses cannot be empty. - | PatternType Con Pattern Type -- ^ The given pattern should have the type, but it doesn't. - -- instead it has a kind of type given by the Con. - | DuplicateDecls (Name Term) -- ^ Duplicate declarations. - | DuplicateDefns (Name Term) -- ^ Duplicate definitions. - | DuplicateTyDefns String -- ^ Duplicate type definitions. - | CyclicTyDef String -- ^ Cyclic type definition. - | NumPatterns -- ^ # of patterns does not match type in definition - | NonlinearPattern Pattern (Name Term) -- ^ Duplicate variable in a pattern - | NoSearch Type -- ^ Type can't be quantified over. - | Unsolvable SolveError -- ^ The constraint solver couldn't find a solution. - | NotTyDef String -- ^ An undefined type name was used. - | NoTWild -- ^ Wildcards are not allowed in terms. - | NotEnoughArgs Con -- ^ Not enough arguments provided to type constructor. - | TooManyArgs Con -- ^ Too many arguments provided to type constructor. - | UnboundTyVar (Name Type) -- ^ Unbound type variable - | NoPolyRec String [String] [Type] -- ^ Polymorphic recursion is not allowed - | NoError -- ^ Not an error. The identity of the - -- @Monoid TCError@ instance. - deriving Show + = -- | Encountered an unbound variable + Unbound (Name Term) + | -- | Encountered an ambiguous name. + Ambiguous (Name Term) [ModuleName] + | -- | No type is specified for a definition + NoType (Name Term) + | -- | The type of the term should have an + -- outermost constructor matching Con, but + -- it has type 'Type' instead + NotCon Con Term Type + | -- | Case analyses cannot be empty. + EmptyCase + | -- | The given pattern should have the type, but it doesn't. + -- instead it has a kind of type given by the Con. + PatternType Con Pattern Type + | -- | Duplicate declarations. + DuplicateDecls (Name Term) + | -- | Duplicate definitions. + DuplicateDefns (Name Term) + | -- | Duplicate type definitions. + DuplicateTyDefns String + | -- | Cyclic type definition. + CyclicTyDef String + | -- | # of patterns does not match type in definition + NumPatterns + | -- | Duplicate variable in a pattern + NonlinearPattern Pattern (Name Term) + | -- | Type can't be quantified over. + NoSearch Type + | -- | The constraint solver couldn't find a solution. + Unsolvable SolveError + | -- | An undefined type name was used. + NotTyDef String + | -- | Wildcards are not allowed in terms. + NoTWild + | -- | Not enough arguments provided to type constructor. + NotEnoughArgs Con + | -- | Too many arguments provided to type constructor. + TooManyArgs Con + | -- | Unbound type variable + UnboundTyVar (Name Type) + | -- | Polymorphic recursion is not allowed + NoPolyRec String [String] [Type] + | -- | Not an error. The identity of the + -- @Monoid TCError@ instance. + NoError + deriving (Show) instance Semigroup TCError where _ <> r = r -- | 'TCError' is a monoid where we simply discard the first error. instance Monoid TCError where - mempty = NoError + mempty = NoError mappend = (<>) ------------------------------------------------------------ @@ -127,14 +147,15 @@ withConstraint = fmap swap . runWriter -- | Run a computation and solve its generated constraint, returning -- the resulting substitution (or failing with an error). Note that -- this locally dispatches the constraint writer effect. -solve - :: Members '[Reader TyDefCtx, Error TCError, Output Message] r - => Sem (Writer Constraint ': r) a -> Sem r (a, S) +solve :: + Members '[Reader TyDefCtx, Error TCError, Output Message] r => + Sem (Writer Constraint ': r) a -> + Sem r (a, S) solve m = do (a, c) <- withConstraint m res <- runSolve . inputToReader . solveConstraint $ c case res of - Left e -> throw (Unsolvable e) + Left e -> throw (Unsolvable e) Right s -> return (a, s) ------------------------------------------------------------ @@ -144,12 +165,14 @@ solve m = do -- | Look up the definition of a named type. Throw a 'NotTyDef' error -- if it is not found. lookupTyDefn :: - Members '[Reader TyDefCtx, Error TCError] r - => String -> [Type] -> Sem r Type + Members '[Reader TyDefCtx, Error TCError] r => + String -> + [Type] -> + Sem r Type lookupTyDefn x args = do d <- ask @TyDefCtx case M.lookup x d of - Nothing -> throw (NotTyDef x) + Nothing -> throw (NotTyDef x) Just (TyDefBody _ body) -> return $ body args -- | Run a subcomputation with an extended type definition context. diff --git a/src/Disco/Types.hs b/src/Disco/Types.hs index d846f140..a2af6843 100644 --- a/src/Disco/Types.hs +++ b/src/Disco/Types.hs @@ -1,12 +1,16 @@ -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveDataTypeable #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE UndecidableInstances #-} - {-# OPTIONS_GHC -fno-warn-orphans #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Types -- Copyright : disco team and contributors @@ -14,104 +18,111 @@ -- -- The "Disco.Types" module defines the set of types used in the disco -- language type system, along with various utility functions. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - -module Disco.Types - ( - -- * Disco language types - -- ** Atomic types - - BaseTy(..), isCtr, Var(..), Ilk(..), pattern U, pattern S - , Atom(..) - , isVar, isBase, isSkolem - , UAtom(..), uisVar, uatomToAtom, uatomToEither - - -- ** Type constructors - - , Con(..) - , pattern CList, pattern CBag, pattern CSet - - -- ** Types - - , Type(..) - - , pattern TyVar - , pattern TySkolem - , pattern TyVoid - , pattern TyUnit - , pattern TyBool - , pattern TyProp - , pattern TyN - , pattern TyZ - , pattern TyF - , pattern TyQ - , pattern TyC - -- , pattern TyFin - , pattern (:->:) - , pattern (:*:) - , pattern (:+:) - , pattern TyList - , pattern TyBag - , pattern TySet - , pattern TyGraph - , pattern TyMap - , pattern TyContainer - , pattern TyUser - , pattern TyString - - -- ** Quantified types - - , PolyType(..) - , toPolyType, closeType - - -- * Type predicates - - , isNumTy, isEmptyTy, isFiniteTy, isSearchable - - -- * Type substitutions - - , Substitution, atomToTypeSubst, uatomToTypeSubst - - -- * Strictness - , Strictness(..), strictness - - -- * Utilities - , isTyVar - , containerVars - , countType - , unpair - , S - , TyDefBody(..) - , TyDefCtx - - -- * HasType class - , HasType(..) - ) - where - -import Data.Coerce -import Data.Data (Data) -import Disco.Data () -import GHC.Generics (Generic) -import Unbound.Generics.LocallyNameless hiding (lunbind) - -import Control.Lens (toListOf) -import Data.List (nub) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Set (Set) -import qualified Data.Set as S -import Data.Void -import Math.Combinatorics.Exact.Binomial (choose) - -import Disco.Effects.LFresh - -import Disco.Pretty hiding ((<>)) -import Disco.Subst (Substitution) -import Disco.Types.Qualifiers +module Disco.Types ( + -- * Disco language types + + -- ** Atomic types + BaseTy (..), + isCtr, + Var (..), + Ilk (..), + pattern U, + pattern S, + Atom (..), + isVar, + isBase, + isSkolem, + UAtom (..), + uisVar, + uatomToAtom, + uatomToEither, + + -- ** Type constructors + Con (..), + pattern CList, + pattern CBag, + pattern CSet, + + -- ** Types + Type (..), + pattern TyVar, + pattern TySkolem, + pattern TyVoid, + pattern TyUnit, + pattern TyBool, + pattern TyProp, + pattern TyN, + pattern TyZ, + pattern TyF, + pattern TyQ, + pattern TyC, + -- , pattern TyFin + pattern (:->:), + pattern (:*:), + pattern (:+:), + pattern TyList, + pattern TyBag, + pattern TySet, + pattern TyGraph, + pattern TyMap, + pattern TyContainer, + pattern TyUser, + pattern TyString, + + -- ** Quantified types + PolyType (..), + toPolyType, + closeType, + + -- * Type predicates + isNumTy, + isEmptyTy, + isFiniteTy, + isSearchable, + + -- * Type substitutions + Substitution, + atomToTypeSubst, + uatomToTypeSubst, + + -- * Strictness + Strictness (..), + strictness, + + -- * Utilities + isTyVar, + containerVars, + countType, + unpair, + S, + TyDefBody (..), + TyDefCtx, + + -- * HasType class + HasType (..), +) +where + +import Data.Coerce +import Data.Data (Data) +import Disco.Data () +import GHC.Generics (Generic) +import Unbound.Generics.LocallyNameless hiding (lunbind) + +import Control.Lens (toListOf) +import Data.List (nub) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Set (Set) +import qualified Data.Set as S +import Data.Void +import Math.Combinatorics.Exact.Binomial (choose) + +import Disco.Effects.LFresh + +import Disco.Pretty hiding ((<>)) +import Disco.Subst (Substitution) +import Disco.Types.Qualifiers -------------------------------------------------- -- Disco types @@ -123,34 +134,24 @@ import Disco.Types.Qualifiers -- | Base types are the built-in types which form the basis of the -- disco type system, out of which more complex types can be built. data BaseTy where - -- | The void type, with no inhabitants. Void :: BaseTy - -- | The unit type, with one inhabitant. Unit :: BaseTy - -- | Booleans. - B :: BaseTy - + B :: BaseTy -- | Propositions. - P :: BaseTy - + P :: BaseTy -- | Natural numbers. - N :: BaseTy - + N :: BaseTy -- | Integers. - Z :: BaseTy - + Z :: BaseTy -- | Fractionals (i.e. nonnegative rationals). - F :: BaseTy - + F :: BaseTy -- | Rationals. - Q :: BaseTy - + Q :: BaseTy -- | Unicode characters. - C :: BaseTy - + C :: BaseTy -- Finite types. The single argument is a natural number defining -- the exact number of inhabitants. -- Fin :: Integer -> BaseTy @@ -161,29 +162,26 @@ data BaseTy where -- particular this allows us to reuse all the existing constraint -- solving machinery for container subtyping. CtrSet :: BaseTy - -- | Bag container type. CtrBag :: BaseTy - -- | List container type. CtrList :: BaseTy - deriving (Show, Eq, Ord, Generic, Data, Alpha, Subst BaseTy, Subst Atom, Subst UAtom, Subst Type) instance Pretty BaseTy where pretty = \case - Void -> text "Void" - Unit -> text "Unit" - B -> text "Bool" - P -> text "Prop" - N -> text "ℕ" - Z -> text "ℤ" - Q -> text "ℚ" - F -> text "𝔽" - C -> text "Char" + Void -> text "Void" + Unit -> text "Unit" + B -> text "Bool" + P -> text "Prop" + N -> text "ℕ" + Z -> text "ℤ" + Q -> text "ℚ" + F -> text "𝔽" + C -> text "Char" CtrList -> text "List" - CtrBag -> text "Bag" - CtrSet -> text "Set" + CtrBag -> text "Bag" + CtrSet -> text "Set" -- | Test whether a 'BaseTy' is a container (set, bag, or list). isCtr :: BaseTy -> Bool @@ -218,7 +216,7 @@ data Ilk = Skolem | Unification instance Pretty Ilk where pretty = \case - Skolem -> "S" + Skolem -> "S" Unification -> "U" -- | 'Var' represents /type variables/, that is, variables which stand @@ -247,24 +245,24 @@ pattern S v = V Skolem v -- simplification step, we want to be able to work with collections -- of constraints that are guaranteed to contain only atomic types. data Atom where - AVar :: Var -> Atom + AVar :: Var -> Atom ABase :: BaseTy -> Atom deriving (Show, Eq, Ord, Generic, Data, Alpha, Subst Type) instance Subst Atom Atom where isvar (AVar (U x)) = Just (SubstName (coerce x)) - isvar _ = Nothing + isvar _ = Nothing instance Pretty Atom where pretty = \case AVar (U v) -> pretty v AVar (S v) -> text "$" <> pretty v - ABase b -> pretty b + ABase b -> pretty b -- | Is this atomic type a variable? isVar :: Atom -> Bool isVar (AVar _) = True -isVar _ = False +isVar _ = False -- | Is this atomic type a base type? isBase :: Atom -> Bool @@ -273,7 +271,7 @@ isBase = not . isVar -- | Is this atomic type a skolem variable? isSkolem :: Atom -> Bool isSkolem (AVar (S _)) = True -isSkolem _ = False +isSkolem _ = False -- | /Unifiable/ atomic types are the same as atomic types but without -- skolem variables. Hence, a unifiable atomic type is either a base @@ -287,13 +285,13 @@ isSkolem _ = False -- these things, we can just focus on base types and unification -- variables. data UAtom where - UB :: BaseTy -> UAtom + UB :: BaseTy -> UAtom UV :: Name Type -> UAtom deriving (Show, Eq, Ord, Generic, Alpha, Subst BaseTy) instance Subst UAtom UAtom where isvar (UV x) = Just (SubstName (coerce x)) - isvar _ = Nothing + isvar _ = Nothing instance Pretty UAtom where pretty (UB b) = pretty b @@ -302,7 +300,7 @@ instance Pretty UAtom where -- | Is this unifiable atomic type a (unification) variable? uisVar :: UAtom -> Bool uisVar (UV _) = True -uisVar _ = False +uisVar _ = False -- | Convert a unifiable atomic type into a regular atomic type. uatomToAtom :: UAtom -> Atom @@ -322,12 +320,11 @@ uatomToEither (UV v) = Right v -- argument types. data Con where -- | Function type constructor, @T1 -> T2@. - CArr :: Con + CArr :: Con -- | Product type constructor, @T1 * T2@. CProd :: Con -- | Sum type constructor, @T1 + T2@. - CSum :: Con - + CSum :: Con -- | Container type (list, bag, or set) constructor. Note this -- looks like it could contain any 'Atom', but it will only ever -- contain either a type variable or a 'CtrList', 'CtrBag', or @@ -335,29 +332,24 @@ data Con where -- -- See also 'CList', 'CBag', and 'CSet'. CContainer :: Atom -> Con - - -- | Key value maps, Map k v CMap :: Con - -- | Graph constructor, Graph a CGraph :: Con - -- | The name of a user defined algebraic datatype. CUser :: String -> Con - deriving (Show, Eq, Ord, Generic, Data, Alpha) instance Pretty Con where pretty = \case - CMap -> text "Map" - CGraph -> text "Graph" - CUser s -> text s - CList -> text "List" - CBag -> text "Bag" - CSet -> text "Set" + CMap -> text "Map" + CGraph -> text "Graph" + CUser s -> text s + CList -> text "List" + CBag -> text "Bag" + CSet -> text "Set" CContainer v -> pretty v - c -> error $ "Impossible: got Con " ++ show c ++ " in pretty @Con" + c -> error $ "Impossible: got Con " ++ show c ++ " in pretty @Con" -- | 'CList' is provided for convenience; it represents a list type -- constructor (/i.e./ @List a@). @@ -395,49 +387,49 @@ pattern CSet = CContainer (ABase CtrSet) -- pattern-match on types when convenient. For example, using these -- synonyms the foregoing example can be written @TyN :->: TyVar a@. data Type where - -- | Atomic types (variables and base types), /e.g./ @N@, @Bool@, /etc./ TyAtom :: Atom -> Type - -- | Application of a type constructor to type arguments, /e.g./ @N -- -> Bool@ is the application of the arrow type constructor to the -- arguments @N@ and @Bool@. - TyCon :: Con -> [Type] -> Type - + TyCon :: Con -> [Type] -> Type deriving (Show, Eq, Ord, Generic, Data, Alpha) instance Pretty Type where - pretty (TyAtom a) = pretty a - pretty (ty1 :->: ty2) = withPA tarrPA $ - lt (pretty ty1) <+> text "→" <+> rt (pretty ty2) - pretty (ty1 :*: ty2) = withPA tmulPA $ - lt (pretty ty1) <+> text "×" <+> rt (pretty ty2) - pretty (ty1 :+: ty2) = withPA taddPA $ - lt (pretty ty1) <+> text "+" <+> rt (pretty ty2) - pretty (TyCon c []) = pretty c - pretty (TyCon c tys) = do + pretty (TyAtom a) = pretty a + pretty (ty1 :->: ty2) = + withPA tarrPA $ + lt (pretty ty1) <+> text "→" <+> rt (pretty ty2) + pretty (ty1 :*: ty2) = + withPA tmulPA $ + lt (pretty ty1) <+> text "×" <+> rt (pretty ty2) + pretty (ty1 :+: ty2) = + withPA taddPA $ + lt (pretty ty1) <+> text "+" <+> rt (pretty ty2) + pretty (TyCon c []) = pretty c + pretty (TyCon c tys) = do ds <- setPA initPA $ punctuate (text ",") (map pretty tys) pretty c <> parens (hsep ds) instance Subst Type Qualifier instance Subst Type Rational where subst _ _ = id - substs _ = id + substs _ = id instance Subst Type Void where subst _ _ = id - substs _ = id + substs _ = id instance Subst Type Con where - isCoerceVar (CContainer (AVar (U x))) - = Just (SubstCoerce x substCtrTy) - where - substCtrTy (TyAtom a) = Just (CContainer a) - substCtrTy _ = Nothing - isCoerceVar _ = Nothing + isCoerceVar (CContainer (AVar (U x))) = + Just (SubstCoerce x substCtrTy) + where + substCtrTy (TyAtom a) = Just (CContainer a) + substCtrTy _ = Nothing + isCoerceVar _ = Nothing instance Subst Type Type where isvar (TyAtom (AVar (U x))) = Just (SubstName x) - isvar _ = Nothing + isvar _ = Nothing -pattern TyVar :: Name Type -> Type +pattern TyVar :: Name Type -> Type pattern TyVar v = TyAtom (AVar (U v)) pattern TySkolem :: Name Type -> Type @@ -470,7 +462,6 @@ pattern TyQ = TyAtom (ABase Q) pattern TyC :: Type pattern TyC = TyAtom (ABase C) - -- pattern TyFin :: Integer -> Type -- pattern TyFin n = TyAtom (ABase (Fin n)) @@ -492,7 +483,7 @@ pattern (:+:) ty1 ty2 = TyCon CSum [ty1, ty2] pattern TyList :: Type -> Type pattern TyList elTy = TyCon CList [elTy] -pattern TyBag :: Type -> Type +pattern TyBag :: Type -> Type pattern TyBag elTy = TyCon CBag [elTy] pattern TySet :: Type -> Type @@ -515,21 +506,40 @@ pattern TyString :: Type pattern TyString = TyList TyC {-# COMPLETE - TyVar, TySkolem, TyVoid, TyUnit, TyBool, TyProp, TyN, TyZ, TyF, TyQ, TyC, - (:->:), (:*:), (:+:), TyList, TyBag, TySet, TyGraph, TyMap, TyUser #-} + TyVar + , TySkolem + , TyVoid + , TyUnit + , TyBool + , TyProp + , TyN + , TyZ + , TyF + , TyQ + , TyC + , (:->:) + , (:*:) + , (:+:) + , TyList + , TyBag + , TySet + , TyGraph + , TyMap + , TyUser + #-} -- | Is this a type variable? isTyVar :: Type -> Bool isTyVar (TyAtom (AVar _)) = True -isTyVar _ = False +isTyVar _ = False -- orphans instance (Ord a, Subst t a) => Subst t (Set a) where subst x t = S.map (subst x t) - substs s = S.map (substs s) + substs s = S.map (substs s) instance (Ord k, Subst t a) => Subst t (Map k a) where subst x t = M.map (subst x t) - substs s = M.map (substs s) + substs s = M.map (substs s) -- | The definition of a user-defined type contains: -- @@ -555,14 +565,13 @@ type TyDefCtx = M.Map String TyDefBody -- | Pretty-print a type definition. instance Pretty (String, TyDefBody) where - - pretty (tyName, TyDefBody ps body) - = "type" <+> (text tyName <> prettyArgs ps) <+> text "=" <+> pretty (body (map (TyVar . string2Name) ps)) - where - prettyArgs [] = empty - prettyArgs _ = do - ds <- punctuate (text ",") (map text ps) - parens (hsep ds) + pretty (tyName, TyDefBody ps body) = + "type" <+> (text tyName <> prettyArgs ps) <+> text "=" <+> pretty (body (map (TyVar . string2Name) ps)) + where + prettyArgs [] = empty + prettyArgs _ = do + ds <- punctuate (text ",") (map text ps) + parens (hsep ds) --------------------------------- -- Universally quantified types @@ -597,52 +606,50 @@ closeType ty = Forall (bind (nub $ toListOf fv ty) ty) -- | Compute the number of inhabitants of a type. @Nothing@ means the -- type is countably infinite. countType :: Type -> Maybe Integer -countType TyVoid = Just 0 -countType TyUnit = Just 1 -countType TyBool = Just 2 +countType TyVoid = Just 0 +countType TyUnit = Just 1 +countType TyBool = Just 2 -- countType (TyFin n) = Just n -countType TyC = Just (17 * 2^(16 :: Integer)) +countType TyC = Just (17 * 2 ^ (16 :: Integer)) countType (ty1 :+: ty2) = (+) <$> countType ty1 <*> countType ty2 countType (ty1 :*: ty2) - | isEmptyTy ty1 = Just 0 - | isEmptyTy ty2 = Just 0 - | otherwise = (*) <$> countType ty1 <*> countType ty2 + | isEmptyTy ty1 = Just 0 + | isEmptyTy ty2 = Just 0 + | otherwise = (*) <$> countType ty1 <*> countType ty2 countType (ty1 :->: ty2) = case (countType ty1, countType ty2) of (Just 0, _) -> Just 1 (_, Just 0) -> Just 0 (_, Just 1) -> Just 1 - (c1, c2) -> (^) <$> c2 <*> c1 + (c1, c2) -> (^) <$> c2 <*> c1 countType (TyList ty) - | isEmptyTy ty = Just 1 - | otherwise = Nothing + | isEmptyTy ty = Just 1 + | otherwise = Nothing countType (TyBag ty) - | isEmptyTy ty = Just 1 - | otherwise = Nothing -countType (TySet ty) = (2^) <$> countType ty - - -- t = number of elements in vertex type. - -- n = number of vertices in the graph. - -- For each n in [0..t], we can choose which n values to use for the - -- vertices; then for each ordered pair of vertices (u,v) - -- (including the possibility that u = v), we choose whether or - -- not there is a directed edge u -> v. - -- - -- https://oeis.org/A135748 - -countType (TyGraph ty) = - (\t -> sum $ map (\n -> (t `choose` n) * 2^(n*n)) [0 .. t]) <$> - countType ty + | isEmptyTy ty = Just 1 + | otherwise = Nothing +countType (TySet ty) = (2 ^) <$> countType ty +-- t = number of elements in vertex type. +-- n = number of vertices in the graph. +-- For each n in [0..t], we can choose which n values to use for the +-- vertices; then for each ordered pair of vertices (u,v) +-- (including the possibility that u = v), we choose whether or +-- not there is a directed edge u -> v. +-- +-- https://oeis.org/A135748 +countType (TyGraph ty) = + (\t -> sum $ map (\n -> (t `choose` n) * 2 ^ (n * n)) [0 .. t]) + <$> countType ty countType (TyMap tyKey tyValue) - | isEmptyTy tyKey = Just 1 -- If we can't have any keys or values, - | isEmptyTy tyValue = Just 1 -- only option is empty map - | otherwise = (\k v -> (v+1) ^ k) <$> countType tyKey <*> countType tyValue - -- (v+1)^k since for each key, we can choose among v values to associate with it, - -- or we can choose to not have the key in the map. + | isEmptyTy tyKey = Just 1 -- If we can't have any keys or values, + | isEmptyTy tyValue = Just 1 -- only option is empty map + | otherwise = (\k v -> (v + 1) ^ k) <$> countType tyKey <*> countType tyValue +-- (v+1)^k since for each key, we can choose among v values to associate with it, +-- or we can choose to not have the key in the map. -- All other types are infinite. (TyN, TyZ, TyQ, TyF) -countType _ = Nothing +countType _ = Nothing -------------------------------------------------- -- Type predicates @@ -651,19 +658,19 @@ countType _ = Nothing -- | Check whether a type is a numeric type (@N@, @Z@, @F@, @Q@, or @Zn@). isNumTy :: Type -> Bool -- isNumTy (TyFin _) = True -isNumTy ty = ty `elem` [TyN, TyZ, TyF, TyQ] +isNumTy ty = ty `elem` [TyN, TyZ, TyF, TyQ] -- | Decide whether a type is empty, /i.e./ uninhabited. isEmptyTy :: Type -> Bool isEmptyTy ty | Just 0 <- countType ty = True - | otherwise = False + | otherwise = False -- | Decide whether a type is finite. isFiniteTy :: Type -> Bool isFiniteTy ty | Just _ <- countType ty = True - | otherwise = False + | otherwise = False -- XXX coinductively check whether user-defined types are searchable -- e.g. L = Unit + N * L ought to be searchable. @@ -671,16 +678,16 @@ isFiniteTy ty -- | Decide whether a type is searchable, i.e. effectively enumerable. isSearchable :: Type -> Bool -isSearchable TyProp = False +isSearchable TyProp = False isSearchable ty - | isNumTy ty = True - | isFiniteTy ty = True -isSearchable (TyList ty) = isSearchable ty -isSearchable (TySet ty) = isSearchable ty -isSearchable (ty1 :+: ty2) = isSearchable ty1 && isSearchable ty2 -isSearchable (ty1 :*: ty2) = isSearchable ty1 && isSearchable ty2 + | isNumTy ty = True + | isFiniteTy ty = True +isSearchable (TyList ty) = isSearchable ty +isSearchable (TySet ty) = isSearchable ty +isSearchable (ty1 :+: ty2) = isSearchable ty1 && isSearchable ty2 +isSearchable (ty1 :*: ty2) = isSearchable ty1 && isSearchable ty2 isSearchable (ty1 :->: ty2) = isFiniteTy ty1 && isSearchable ty2 -isSearchable _ = False +isSearchable _ = False -------------------------------------------------- -- Strictness @@ -695,7 +702,7 @@ data Strictness = Strict | Lazy strictness :: Type -> Strictness strictness ty | isNumTy ty = Strict - | otherwise = Lazy + | otherwise = Lazy -------------------------------------------------- -- Utilities @@ -705,7 +712,7 @@ strictness ty -- types. unpair :: Type -> [Type] unpair (ty1 :*: ty2) = ty1 : unpair ty2 -unpair ty = [ty] +unpair ty = [ty] -- | Define @S@ as a substitution on types (the most common kind) -- for convenience. @@ -722,8 +729,8 @@ uatomToTypeSubst = atomToTypeSubst . fmap uatomToAtom -- | Return a set of all the free container variables in a type. containerVars :: Type -> Set (Name Type) -containerVars (TyCon (CContainer (AVar (U x))) tys) - = x `S.insert` foldMap containerVars tys +containerVars (TyCon (CContainer (AVar (U x))) tys) = + x `S.insert` foldMap containerVars tys containerVars (TyCon _ tys) = foldMap containerVars tys containerVars _ = S.empty @@ -733,7 +740,6 @@ containerVars _ = S.empty -- | A type class for things whose type can be extracted or set. class HasType t where - -- | Get the type of a thing. getType :: t -> Type @@ -741,40 +747,3 @@ class HasType t where -- implementation is for 'setType' to do nothing. setType :: Type -> t -> t setType _ = id - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/src/Disco/Types/Qualifiers.hs b/src/Disco/Types/Qualifiers.hs index 0f4a1d07..7e98068f 100644 --- a/src/Disco/Types/Qualifiers.hs +++ b/src/Disco/Types/Qualifiers.hs @@ -1,28 +1,28 @@ -{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE OverloadedStrings #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Types.Qualifiers -- Copyright : disco team and contributors -- Maintainer : byorgey@gmail.com -- -- Type qualifiers and sorts. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - module Disco.Types.Qualifiers where -import GHC.Generics -import Unbound.Generics.LocallyNameless +import GHC.Generics +import Unbound.Generics.LocallyNameless -import Data.Set (Set) -import qualified Data.Set as S +import Data.Set (Set) +import qualified Data.Set as S -import Disco.Pretty -import Disco.Syntax.Operators +import Disco.Pretty +import Disco.Syntax.Operators ------------------------------------------------------------ -- Qualifiers @@ -44,25 +44,33 @@ import Disco.Syntax.Operators -- These qualifiers can appear in a 'CQual' constraint; see -- "Disco.Typecheck.Constraint". data Qualifier - = QNum -- ^ Numeric, i.e. a semiring supporting + and * - | QSub -- ^ Subtractive, i.e. supports - - | QDiv -- ^ Divisive, i.e. supports / - | QCmp -- ^ Comparable, i.e. supports decidable ordering/comparison (see Note [QCmp]) - | QEnum -- ^ Enumerable, i.e. supports ellipsis notation [x .. y] - | QBool -- ^ Boolean, i.e. supports and, or, not (Bool or Prop) - | QBasic -- ^ Things that do not involve Prop. - | QSimple -- ^ Things for which we can derive a *Haskell* Ord instance + = -- | Numeric, i.e. a semiring supporting + and * + QNum + | -- | Subtractive, i.e. supports - + QSub + | -- | Divisive, i.e. supports / + QDiv + | -- | Comparable, i.e. supports decidable ordering/comparison (see Note [QCmp]) + QCmp + | -- | Enumerable, i.e. supports ellipsis notation [x .. y] + QEnum + | -- | Boolean, i.e. supports and, or, not (Bool or Prop) + QBool + | -- | Things that do not involve Prop. + QBasic + | -- | Things for which we can derive a *Haskell* Ord instance + QSimple deriving (Show, Eq, Ord, Generic, Alpha) instance Pretty Qualifier where pretty = \case - QNum -> "num" - QSub -> "sub" - QDiv -> "div" - QCmp -> "cmp" - QEnum -> "enum" - QBool -> "bool" - QBasic -> "basic" + QNum -> "num" + QSub -> "sub" + QDiv -> "div" + QCmp -> "cmp" + QEnum -> "enum" + QBool -> "bool" + QBasic -> "basic" QSimple -> "simple" -- ~~~~ Note [QCmp] @@ -71,7 +79,9 @@ instance Pretty Qualifier where -- comparisons at runtime any more, if we disallow functions from -- being QCmp. With the switch to eager semantics + disallowing -- function comparison, it's now the case that QCmp should mean --- *decidable* (terminating) comparison. + +-- * decidable* (terminating) comparison. + -- -- It used to be the case that every type in disco supported -- (semi-decidable) linear ordering, so in one sense the QCmp @@ -93,16 +103,16 @@ instance Pretty Qualifier where -- | A helper function that returns the appropriate qualifier for a -- binary arithmetic operation. bopQual :: BOp -> Qualifier -bopQual Add = QNum -bopQual Mul = QNum -bopQual Div = QDiv -bopQual Sub = QSub +bopQual Add = QNum +bopQual Mul = QNum +bopQual Div = QDiv +bopQual Sub = QSub bopQual SSub = QNum -bopQual And = QBool -bopQual Or = QBool +bopQual And = QBool +bopQual Or = QBool bopQual Impl = QBool -bopQual Iff = QBool -bopQual _ = error "No qualifier for binary operation" +bopQual Iff = QBool +bopQual _ = error "No qualifier for binary operation" ------------------------------------------------------------ -- Sorts diff --git a/src/Disco/Types/Rules.hs b/src/Disco/Types/Rules.hs index 97e59f43..27768dcc 100644 --- a/src/Disco/Types/Rules.hs +++ b/src/Disco/Types/Rules.hs @@ -1,4 +1,9 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + +-- SPDX-License-Identifier: BSD-3-Clause + -- | -- Module : Disco.Types.Rules -- Copyright : disco team and contributors @@ -6,45 +11,46 @@ -- -- "Disco.Types.Rules" defines some generic rules about arity, -- subtyping, and sorts for disco base types. --- ------------------------------------------------------------------------------ - --- SPDX-License-Identifier: BSD-3-Clause - -module Disco.Types.Rules - ( -- * Arity - - Variance(..), arity - - -- * Qualifiers - , Qualifier(..), bopQual - - -- * Sorts - , Sort, topSort - - -- * Subtyping rules - - , Dir(..), other - - , isSubA, isSubB, isDirB - , supertypes, subtypes, dirtypes - - -- * Qualifier and sort rules - - , hasQual, hasSort - , qualRules, sortRules - , pickSortBaseTy - ) - where - -import Control.Monad ((>=>)) -import Data.List (foldl') -import Data.Map (Map) -import qualified Data.Map as M -import qualified Data.Set as S - -import Disco.Types -import Disco.Types.Qualifiers +module Disco.Types.Rules ( + -- * Arity + Variance (..), + arity, + + -- * Qualifiers + Qualifier (..), + bopQual, + + -- * Sorts + Sort, + topSort, + + -- * Subtyping rules + Dir (..), + other, + isSubA, + isSubB, + isDirB, + supertypes, + subtypes, + dirtypes, + + -- * Qualifier and sort rules + hasQual, + hasSort, + qualRules, + sortRules, + pickSortBaseTy, +) +where + +import Control.Monad ((>=>)) +import Data.List (foldl') +import Data.Map (Map) +import qualified Data.Map as M +import qualified Data.Set as S + +import Disco.Types +import Disco.Types.Qualifiers ------------------------------------------------------------ -- Arity @@ -65,15 +71,16 @@ data Variance = Co | Contra -- That is, @S1 -> T1 <: S2 -> T2@ (@<:@ means "is a subtype of") if -- and only if @S2 <: S1@ and @T1 <: T2@. arity :: Con -> [Variance] -arity CArr = [Contra, Co] -arity CProd = [Co, Co] -arity CSum = [Co, Co] +arity CArr = [Contra, Co] +arity CProd = [Co, Co] +arity CSum = [Co, Co] arity (CContainer _) = [Co] -arity CMap = [Contra, Co] -arity CGraph = [Co] -arity (CUser _) = error "Impossible! arity CUser" - -- CUsers should always be replaced by their definitions before arity - -- is called. +arity CMap = [Contra, Co] +arity CGraph = [Co] +arity (CUser _) = error "Impossible! arity CUser" + +-- CUsers should always be replaced by their definitions before arity +-- is called. ------------------------------------------------------------ -- Subtyping rules @@ -86,7 +93,7 @@ data Dir = SubTy | SuperTy -- | Swap directions. other :: Dir -> Dir -other SubTy = SuperTy +other SubTy = SuperTy other SuperTy = SubTy -------------------------------------------------- @@ -96,45 +103,45 @@ other SuperTy = SubTy -- @True@ if either they are equal, or if they are base types and -- 'isSubB' returns true. isSubA :: Atom -> Atom -> Bool -isSubA a1 a2 | a1 == a2 = True +isSubA a1 a2 | a1 == a2 = True isSubA (ABase t1) (ABase t2) = isSubB t1 t2 -isSubA _ _ = False +isSubA _ _ = False -- | Check whether one base type is a subtype of another. isSubB :: BaseTy -> BaseTy -> Bool isSubB b1 b2 | b1 == b2 = True -isSubB N Z = True -isSubB N F = True -isSubB N Q = True -isSubB Z Q = True -isSubB F Q = True -isSubB B P = True -isSubB _ _ = False +isSubB N Z = True +isSubB N F = True +isSubB N Q = True +isSubB Z Q = True +isSubB F Q = True +isSubB B P = True +isSubB _ _ = False -- | Check whether one base type is a sub- or supertype of another. isDirB :: Dir -> BaseTy -> BaseTy -> Bool -isDirB SubTy b1 b2 = isSubB b1 b2 +isDirB SubTy b1 b2 = isSubB b1 b2 isDirB SuperTy b1 b2 = isSubB b2 b1 -- | List all the supertypes of a given base type. supertypes :: BaseTy -> [BaseTy] -supertypes N = [N, Z, F, Q] -supertypes Z = [Z, Q] -supertypes F = [F, Q] -supertypes B = [B, P] +supertypes N = [N, Z, F, Q] +supertypes Z = [Z, Q] +supertypes F = [F, Q] +supertypes B = [B, P] supertypes ty = [ty] -- | List all the subtypes of a given base type. subtypes :: BaseTy -> [BaseTy] -subtypes Q = [Q, F, Z, N] -subtypes F = [F, N] -subtypes Z = [Z, N] -subtypes P = [P, B] +subtypes Q = [Q, F, Z, N] +subtypes F = [F, N] +subtypes Z = [Z, N] +subtypes P = [P, B] subtypes ty = [ty] -- | List all the sub- or supertypes of a given base type. dirtypes :: Dir -> BaseTy -> [BaseTy] -dirtypes SubTy = subtypes +dirtypes SubTy = subtypes dirtypes SuperTy = supertypes ------------------------------------------------------------ @@ -143,19 +150,19 @@ dirtypes SuperTy = supertypes -- | Check whether a given base type satisfies a qualifier. hasQual :: BaseTy -> Qualifier -> Bool -hasQual P QCmp = False -- can't compare Props -hasQual _ QCmp = True -hasQual P QBasic = False -hasQual _ QBasic = True -hasQual P QSimple = False -hasQual _ QSimple = True +hasQual P QCmp = False -- can't compare Props +hasQual _ QCmp = True +hasQual P QBasic = False +hasQual _ QBasic = True +hasQual P QSimple = False +hasQual _ QSimple = True -- hasQual (Fin _) q | q `elem` [QNum, QSub, QEnum] = True -- hasQual (Fin n) QDiv = isPrime n -hasQual b QNum = b `elem` [N, Z, F, Q] -hasQual b QSub = b `elem` [Z, Q] -hasQual b QDiv = b `elem` [F, Q] -hasQual b QEnum = b `elem` [N, Z, F, Q, C] -hasQual b QBool = b `elem` [B, P] +hasQual b QNum = b `elem` [N, Z, F, Q] +hasQual b QSub = b `elem` [Z, Q] +hasQual b QDiv = b `elem` [F, Q] +hasQual b QEnum = b `elem` [N, Z, F, Q, C] +hasQual b QBool = b `elem` [B, P] -- | Check whether a base type has a certain sort, which simply -- amounts to whether it satisfies every qualifier in the sort. @@ -177,57 +184,65 @@ hasSort = all . hasQual -- set of qualifiers (i.e. a general sort) on a type argument. In -- that case one would just have to encode 'sortRules' directly. qualRulesMap :: Map Con (Map Qualifier [Maybe Qualifier]) -qualRulesMap = M.fromList - [ CProd ==> M.fromList - [ QCmp ==> [Just QCmp, Just QCmp], - QSimple ==> [Just QSimple, Just QSimple] - ] - , CSum ==> M.fromList - [ QCmp ==> [Just QCmp, Just QCmp], - QSimple ==> [Just QSimple, Just QSimple] - ] - , CList ==> M.fromList - [ QCmp ==> [Just QCmp], - QSimple ==> [Just QSimple] +qualRulesMap = + M.fromList + [ CProd + ==> M.fromList + [ QCmp ==> [Just QCmp, Just QCmp] + , QSimple ==> [Just QSimple, Just QSimple] + ] + , CSum + ==> M.fromList + [ QCmp ==> [Just QCmp, Just QCmp] + , QSimple ==> [Just QSimple, Just QSimple] + ] + , CList + ==> M.fromList + [ QCmp ==> [Just QCmp] + , QSimple ==> [Just QSimple] + ] + , CBag + ==> M.fromList + [ QCmp ==> [Just QCmp] + , QSimple ==> [Just QSimple] + ] + , CSet + ==> M.fromList + [ QCmp ==> [Just QCmp] + , QSimple ==> [Just QSimple] + ] + , CGraph + ==> M.fromList + [ QCmp ==> [Just QCmp] + , QNum ==> [Nothing] + ] + , CMap + ==> M.fromList + [ QCmp ==> [Just QCmp, Just QCmp] + ] ] - , CBag ==> M.fromList - [ QCmp ==> [Just QCmp], - QSimple ==> [Just QSimple] - ] - , CSet ==> M.fromList - [ QCmp ==> [Just QCmp], - QSimple ==> [Just QSimple] - ] - , CGraph ==> M.fromList - [ QCmp ==> [Just QCmp], - QNum ==> [Nothing] - ] - , CMap ==> M.fromList - [ QCmp ==> [Just QCmp, Just QCmp] - ] - ] - where - (==>) :: a -> b -> (a,b) - (==>) = (,) - - -- We could (theoretically) make graphs and maps also be simple values if we require the map's values are also simple. - - -- Eventually we can easily imagine adding an opt-in mode where - -- numeric operations can be used on pairs and functions, then the - -- qualRules would become dependent on what language extension/mode - -- was chosen. For example we could have rules like - -- - -- [ CArr ==> M.fromList - -- [ QNum ==> [Nothing, Just QNum] -- (a -> b) can be +, * iff b can - -- , QSub ==> [Nothing, Just QSub] -- ditto for subtraction - -- , QDiv ==> [Nothing, Just QDiv] -- and division - -- ] - -- , CProd ==> M.fromList - -- [ QNum ==> [Just QNum, Just QNum] -- (a,b) can be +, * iff a and b can - -- , QSub ==> [Just QSub, Just QSub] -- etc. - -- , QDiv ==> [Just QDiv, Just QDiv] - -- ] - -- ] + where + (==>) :: a -> b -> (a, b) + (==>) = (,) + +-- We could (theoretically) make graphs and maps also be simple values if we require the map's values are also simple. + +-- Eventually we can easily imagine adding an opt-in mode where +-- numeric operations can be used on pairs and functions, then the +-- qualRules would become dependent on what language extension/mode +-- was chosen. For example we could have rules like +-- +-- [ CArr ==> M.fromList +-- [ QNum ==> [Nothing, Just QNum] -- (a -> b) can be +, * iff b can +-- , QSub ==> [Nothing, Just QSub] -- ditto for subtraction +-- , QDiv ==> [Nothing, Just QDiv] -- and division +-- ] +-- , CProd ==> M.fromList +-- [ QNum ==> [Just QNum, Just QNum] -- (a,b) can be +, * iff a and b can +-- , QSub ==> [Just QSub, Just QSub] -- etc. +-- , QDiv ==> [Just QDiv, Just QDiv] +-- ] +-- ] -- | Given a constructor T and a qualifier we want to hold of a type T -- t1 t2 ..., return a list of qualifiers that need to hold of t1, @@ -236,7 +251,7 @@ qualRules :: Con -> Qualifier -> Maybe [Maybe Qualifier] -- T t1 t2 ... is basic (contains no Prop) iff t1, t2 ... all are. qualRules c QBasic = Just (map (const (Just QBasic)) (arity c)) -- Otherwise, just look up in the qualRulesMap. -qualRules c q = (M.lookup c >=> M.lookup q) qualRulesMap +qualRules c q = (M.lookup c >=> M.lookup q) qualRulesMap -- | @sortRules T s = [s1, ..., sn]@ means that sort @s@ holds of -- type @(T t1 ... tn)@ if and only if @s1 t1 /\ ... /\ sn tn@. @@ -259,12 +274,12 @@ sortRules c s = do -- | Pick a base type (generally the "simplest") that satisfies a given sort. pickSortBaseTy :: Sort -> BaseTy pickSortBaseTy s - | QDiv `S.member` s && QSub `S.member` s = Q - | QDiv `S.member` s = F - | QSub `S.member` s = Z - | QNum `S.member` s = N - | QCmp `S.member` s = N - | QEnum `S.member` s = N - | QBool `S.member` s = B + | QDiv `S.member` s && QSub `S.member` s = Q + | QDiv `S.member` s = F + | QSub `S.member` s = Z + | QNum `S.member` s = N + | QCmp `S.member` s = N + | QEnum `S.member` s = N + | QBool `S.member` s = B | QSimple `S.member` s = N - | otherwise = Unit + | otherwise = Unit diff --git a/src/Disco/Util.hs b/src/Disco/Util.hs index df6fea02..97681a8a 100644 --- a/src/Disco/Util.hs +++ b/src/Disco/Util.hs @@ -1,4 +1,7 @@ ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Util -- Copyright : disco team and contributors @@ -7,9 +10,6 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Miscellaneous utilities. --- ------------------------------------------------------------------------------ - module Disco.Util where import qualified Data.Map as M @@ -18,7 +18,7 @@ infixr 1 ==> -- | A synonym for pairing which makes convenient syntax for -- constructing literal maps via M.fromList. -(==>) :: a -> b -> (a,b) +(==>) :: a -> b -> (a, b) (==>) = (,) for :: [a] -> (a -> b) -> [b] @@ -27,4 +27,4 @@ for = flip map (!) :: (Show k, Ord k) => M.Map k v -> k -> v m ! k = case M.lookup k m of Nothing -> error $ "key " ++ show k ++ " is not an element in the map" - Just v -> v + Just v -> v diff --git a/src/Disco/Value.hs b/src/Disco/Value.hs index ab6f9ef7..21893e02 100644 --- a/src/Disco/Value.hs +++ b/src/Disco/Value.hs @@ -1,10 +1,13 @@ -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} ----------------------------------------------------------------------------- + +----------------------------------------------------------------------------- + -- | -- Module : Disco.Value -- Copyright : disco team and contributors @@ -13,75 +16,94 @@ -- SPDX-License-Identifier: BSD-3-Clause -- -- Disco runtime values and environments. --- ------------------------------------------------------------------------------ - -module Disco.Value - ( -- * Values - - Value(.., VNil, VCons, VFun) - , SimpleValue(..) - , toSimpleValue, fromSimpleValue - - -- ** Conversion - - , ratv, vrat - , intv, vint - , charv, vchar - , enumv - , pairv, vpair - , listv, vlist - - -- * Props & testing - , ValProp(..), TestResult(..), TestReason_(..), TestReason - , SearchType(..), SearchMotive(.., SMExists, SMForall) - , TestVars(..), TestEnv(..), emptyTestEnv, getTestEnv, extendPropEnv, extendResultEnv - , testIsOk, testIsError, testReason, testEnv, resultIsCertain - - , LOp(..), interpLOp +module Disco.Value ( + -- * Values + Value (.., VNil, VCons, VFun), + SimpleValue (..), + toSimpleValue, + fromSimpleValue, + + -- ** Conversion + ratv, + vrat, + intv, + vint, + charv, + vchar, + enumv, + pairv, + vpair, + listv, + vlist, + + -- * Props & testing + ValProp (..), + TestResult (..), + TestReason_ (..), + TestReason, + SearchType (..), + SearchMotive (.., SMExists, SMForall), + TestVars (..), + TestEnv (..), + emptyTestEnv, + getTestEnv, + extendPropEnv, + extendResultEnv, + testIsOk, + testIsError, + testReason, + testEnv, + resultIsCertain, + LOp (..), + interpLOp, -- * Environments - - , Env + Env, -- * Memory - , Cell(..), Mem, emptyMem, allocate, allocateRec, lkup, set + Cell (..), + Mem, + emptyMem, + allocate, + allocateRec, + lkup, + set, -- * Pretty-printing - - , prettyValue', prettyValue - ) where - -import Prelude hiding ((<>)) -import qualified Prelude as P - -import Control.Monad (forM) -import Data.Bifunctor (first) -import Data.Char (chr, ord, toLower) -import Data.IntMap (IntMap) -import qualified Data.IntMap as IM -import Data.List (foldl') -import Data.Map (Map) -import qualified Data.Map as M -import Data.Ratio - -import Algebra.Graph (Graph, foldg) - -import Disco.AST.Core -import Disco.AST.Generic (Side (..)) -import Disco.Context as Ctx -import Disco.Error -import Disco.Names -import Disco.Pretty -import Disco.Syntax.Operators (BOp (Add, Mul)) -import Disco.Types - -import Disco.Effects.LFresh -import Polysemy -import Polysemy.Input -import Polysemy.Reader -import Polysemy.State -import Unbound.Generics.LocallyNameless (Name) + prettyValue', + prettyValue, +) where + +import Prelude hiding ((<>)) +import qualified Prelude as P + +import Control.Monad (forM) +import Data.Bifunctor (first) +import Data.Char (chr, ord, toLower) +import Data.IntMap (IntMap) +import qualified Data.IntMap as IM +import Data.List (foldl') +import Data.Map (Map) +import qualified Data.Map as M +import Data.Ratio + +import Algebra.Graph (Graph, foldg) + +import Disco.AST.Core +import Disco.AST.Generic (Side (..)) +import Disco.Context as Ctx +import Disco.Error +import Disco.Names +import Disco.Pretty +import Disco.Syntax.Operators (BOp (Add, Mul)) +import Disco.Types + +import Disco.Effects.LFresh +import Polysemy +import Polysemy.Input +import Polysemy.Reader +import Polysemy.State +import Unbound.Generics.LocallyNameless (Name) ------------------------------------------------------------ -- Value type @@ -90,37 +112,28 @@ import Unbound.Generics.LocallyNameless (Name) -- | Different types of values which can result from the evaluation -- process. data Value where - -- | A numeric value, which also carries a flag saying how -- fractional values should be diplayed. - VNum :: RationalDisplay -> Rational -> Value - + VNum :: RationalDisplay -> Rational -> Value -- | A built-in function constant. - VConst :: Op -> Value - + VConst :: Op -> Value -- | An injection into a sum type. - VInj :: Side -> Value -> Value - + VInj :: Side -> Value -> Value -- | The unit value. - VUnit :: Value - + VUnit :: Value -- | A pair of values. - VPair :: Value -> Value -> Value - + VPair :: Value -> Value -> Value -- | A closure, i.e. a function body together with its -- environment. - VClo :: Env -> [Name Core] -> Core -> Value - + VClo :: Env -> [Name Core] -> Core -> Value -- | A disco type can be a value. For now, there are only a very -- limited number of places this could ever show up (in -- particular, as an argument to @enumerate@ or @count@). - VType :: Type -> Value - + VType :: Type -> Value -- | A reference, i.e. a pointer to a memory cell. This is used to -- implement (optional, user-requested) laziness as well as -- recursion. - VRef :: Int -> Value - + VRef :: Int -> Value -- | A literal function value. @VFun@ is only used when -- enumerating function values in order to decide comparisons at -- higher-order function types. For example, in order to @@ -131,30 +144,25 @@ data Value where -- We assume that all @VFun@ values are /strict/, that is, their -- arguments should be fully evaluated to RNF before being -- passed to the function. - VFun_ :: ValFun -> Value - + VFun_ :: ValFun -> Value -- | A proposition. - VProp :: ValProp -> Value - + VProp :: ValProp -> Value -- | A literal bag, containing a finite list of (perhaps only -- partially evaluated) values, each paired with a count. This is -- also used to represent sets (with the invariant that all counts -- are equal to 1). VBag :: [(Value, Integer)] -> Value - -- | A graph, stored using an algebraic repesentation. VGraph :: Graph SimpleValue -> Value - -- | A map from keys to values. Differs from functions because we can -- actually construct the set of entries, while functions only have this -- property when the key type is finite. VMap :: Map SimpleValue Value -> Value - - deriving Show + deriving (Show) -- | Convenient pattern for the empty list. pattern VNil :: Value -pattern VNil = VInj L VUnit +pattern VNil = VInj L VUnit -- | Convenient pattern for list cons. pattern VCons :: Value -> Value -> Value @@ -170,31 +178,31 @@ pattern VCons h t = VInj R (VPair h t) -- only reason for actually doing this would be constructing graphs -- of graphs or maps of maps, or the like. data SimpleValue where - SNum :: RationalDisplay -> Rational -> SimpleValue - SUnit :: SimpleValue - SInj :: Side -> SimpleValue -> SimpleValue - SPair :: SimpleValue -> SimpleValue -> SimpleValue - SBag :: [(SimpleValue, Integer)] -> SimpleValue - SType :: Type -> SimpleValue + SNum :: RationalDisplay -> Rational -> SimpleValue + SUnit :: SimpleValue + SInj :: Side -> SimpleValue -> SimpleValue + SPair :: SimpleValue -> SimpleValue -> SimpleValue + SBag :: [(SimpleValue, Integer)] -> SimpleValue + SType :: Type -> SimpleValue deriving (Show, Eq, Ord) toSimpleValue :: Value -> SimpleValue toSimpleValue = \case - VNum d n -> SNum d n - VUnit -> SUnit - VInj s v1 -> SInj s (toSimpleValue v1) + VNum d n -> SNum d n + VUnit -> SUnit + VInj s v1 -> SInj s (toSimpleValue v1) VPair v1 v2 -> SPair (toSimpleValue v1) (toSimpleValue v2) - VBag bs -> SBag (map (first toSimpleValue) bs) - VType t -> SType t - t -> error $ "A non-simple value was passed as simple: " ++ show t + VBag bs -> SBag (map (first toSimpleValue) bs) + VType t -> SType t + t -> error $ "A non-simple value was passed as simple: " ++ show t fromSimpleValue :: SimpleValue -> Value -fromSimpleValue (SNum d n) = VNum d n -fromSimpleValue SUnit = VUnit -fromSimpleValue (SInj s v) = VInj s (fromSimpleValue v) +fromSimpleValue (SNum d n) = VNum d n +fromSimpleValue SUnit = VUnit +fromSimpleValue (SInj s v) = VInj s (fromSimpleValue v) fromSimpleValue (SPair v1 v2) = VPair (fromSimpleValue v1) (fromSimpleValue v2) -fromSimpleValue (SBag bs) = VBag $ map (first fromSimpleValue) bs -fromSimpleValue (SType t) = VType t +fromSimpleValue (SBag bs) = VBag $ map (first fromSimpleValue) bs +fromSimpleValue (SType t) = VType t -- | A @ValFun@ is just a Haskell function @Value -> Value@. It is a -- @newtype@ just so we can have a custom @Show@ instance for it and @@ -220,7 +228,7 @@ ratv = VNum mempty vrat :: Value -> Rational vrat (VNum _ r) = r -vrat v = error $ "vrat " ++ show v +vrat v = error $ "vrat " ++ show v -- | A convenience function for creating a default @VNum@ value with a -- default (@Fractional@) flag. @@ -229,7 +237,7 @@ intv = ratv . (% 1) vint :: Value -> Integer vint (VNum _ n) = numerator n -vint v = error $ "vint " ++ show v +vint v = error $ "vint " ++ show v vchar :: Value -> Char vchar = chr . fromIntegral . vint @@ -242,34 +250,33 @@ charv = intv . fromIntegral . ord enumv :: Enum e => e -> Value enumv e = VInj (toEnum $ fromEnum e) VUnit -pairv :: (a -> Value) -> (b -> Value) -> (a,b) -> Value -pairv av bv (a,b) = VPair (av a) (bv b) +pairv :: (a -> Value) -> (b -> Value) -> (a, b) -> Value +pairv av bv (a, b) = VPair (av a) (bv b) -vpair :: (Value -> a) -> (Value -> b) -> Value -> (a,b) +vpair :: (Value -> a) -> (Value -> b) -> Value -> (a, b) vpair va vb (VPair a b) = (va a, vb b) -vpair _ _ v = error $ "vpair " ++ show v +vpair _ _ v = error $ "vpair " ++ show v listv :: (a -> Value) -> [a] -> Value -listv _ [] = VNil -listv eltv (a:as) = VCons (eltv a) (listv eltv as) +listv _ [] = VNil +listv eltv (a : as) = VCons (eltv a) (listv eltv as) vlist :: (Value -> a) -> Value -> [a] -vlist _ VNil = [] +vlist _ VNil = [] vlist velt (VCons v vs) = velt v : vlist velt vs -vlist _ v = error $ "vlist " ++ show v - +vlist _ v = error $ "vlist " ++ show v ------------------------------------------------------------ -- Propositions ------------------------------------------------------------ data SearchType - = Exhaustive - -- ^ All possibilities were checked. - | Randomized Integer Integer - -- ^ A number of small cases were checked exhaustively and + = -- | All possibilities were checked. + Exhaustive + | -- | A number of small cases were checked exhaustively and -- then a number of additional cases were checked at random. - deriving Show + Randomized Integer Integer + deriving (Show) -- | The answer (success or failure) we're searching for, and -- the result (success or failure) we return when we find it. @@ -278,7 +285,7 @@ data SearchType -- @(True, True)@ corresponds to "exists". The other values -- arise from negations. newtype SearchMotive = SearchMotive (Bool, Bool) - deriving Show + deriving (Show) pattern SMForall :: SearchMotive pattern SMForall = SearchMotive (False, False) @@ -302,7 +309,7 @@ getTestEnv :: TestVars -> Env -> Either EvalError TestEnv getTestEnv (TestVars tvs) e = fmap TestEnv . forM tvs $ \(s, ty, name) -> do let value = Ctx.lookup' (localName name) e case value of - Just v -> return (s, ty, v) + Just v -> return (s, ty, v) Nothing -> Left (UnboundPanic name) -- | Binary logical operators. @@ -312,43 +319,43 @@ interpLOp :: LOp -> Bool -> Bool -> Bool interpLOp LAnd = (&&) interpLOp LOr = (||) interpLOp LImpl = (==>) - where - True ==> False = False - _ ==> _ = True + where + True ==> False = False + _ ==> _ = True -- | The possible outcomes of a property test, parametrized over -- the type of values. A @TestReason@ explains why a proposition -- succeeded or failed. data TestReason_ a - = TestBool - -- ^ The prop evaluated to a boolean. - | TestEqual Type a a - -- ^ The test was an equality test. Records the values being + = -- | The prop evaluated to a boolean. + TestBool + | -- | The test was an equality test. Records the values being -- compared and also their type (which is needed for printing). - | TestLt Type a a - -- ^ The test was a less than test. Records the values being + TestEqual Type a a + | -- | The test was a less than test. Records the values being -- compared and also their type (which is needed for printing). - | TestNotFound SearchType - -- ^ The search didn't find any examples/counterexamples. - | TestFound TestResult - -- ^ The search found an example/counterexample. - | TestBin LOp TestResult TestResult - -- ^ A binary logical operator was used to combine the given two results. - | TestRuntimeError EvalError - -- ^ The prop failed at runtime. This is always a failure, no + TestLt Type a a + | -- | The search didn't find any examples/counterexamples. + TestNotFound SearchType + | -- | The search found an example/counterexample. + TestFound TestResult + | -- | A binary logical operator was used to combine the given two results. + TestBin LOp TestResult TestResult + | -- | The prop failed at runtime. This is always a failure, no -- matter which quantifiers or negations it's under. + TestRuntimeError EvalError deriving (Show, Functor, Foldable, Traversable) type TestReason = TestReason_ Value -- | The possible outcomes of a proposition. data TestResult = TestResult Bool TestReason TestEnv - deriving Show + deriving (Show) -- | Whether the property test resulted in a runtime error. testIsError :: TestResult -> Bool testIsError (TestResult _ (TestRuntimeError _) _) = True -testIsError _ = False +testIsError _ = False -- | Whether the property test resulted in success. testIsOk :: TestResult -> Bool @@ -365,38 +372,38 @@ testIsCertain :: TestResult -> Bool testIsCertain (TestResult _ r _) = resultIsCertain r resultIsCertain :: TestReason -> Bool -resultIsCertain TestBool = True -resultIsCertain TestEqual {} = True -resultIsCertain TestLt {} = True -resultIsCertain (TestNotFound Exhaustive) = True +resultIsCertain TestBool = True +resultIsCertain TestEqual {} = True +resultIsCertain TestLt {} = True +resultIsCertain (TestNotFound Exhaustive) = True resultIsCertain (TestNotFound (Randomized _ _)) = False -resultIsCertain (TestFound r) = testIsCertain r -resultIsCertain (TestRuntimeError _) = True +resultIsCertain (TestFound r) = testIsCertain r +resultIsCertain (TestRuntimeError _) = True resultIsCertain (TestBin op tr1 tr2) - | c1 && c2 = True - | c1 && ((op == LOr) == ok1) = True + | c1 && c2 = True + | c1 && ((op == LOr) == ok1) = True | c2 && ((op /= LAnd) == ok2) = True - | otherwise = False - where - c1 = testIsCertain tr1 - c2 = testIsCertain tr2 - ok1 = testIsOk tr1 - ok2 = testIsOk tr2 + | otherwise = False + where + c1 = testIsCertain tr1 + c2 = testIsCertain tr2 + ok1 = testIsOk tr1 + ok2 = testIsOk tr2 -- | A @ValProp@ is the normal form of a Disco value of type @Prop@. data ValProp - = VPDone TestResult - -- ^ A prop that has already either succeeded or failed. - | VPSearch SearchMotive [Type] Value TestEnv - -- ^ A pending search. - | VPBin LOp ValProp ValProp - -- ^ A binary logical operator combining two prop values. - deriving Show + = -- | A prop that has already either succeeded or failed. + VPDone TestResult + | -- | A pending search. + VPSearch SearchMotive [Type] Value TestEnv + | -- | A binary logical operator combining two prop values. + VPBin LOp ValProp ValProp + deriving (Show) extendPropEnv :: TestEnv -> ValProp -> ValProp extendPropEnv g (VPDone (TestResult b r e)) = VPDone (TestResult b r (g P.<> e)) -extendPropEnv g (VPSearch sm tys v e) = VPSearch sm tys v (g P.<> e) -extendPropEnv g (VPBin op vp1 vp2) = VPBin op (extendPropEnv g vp1) (extendPropEnv g vp2) +extendPropEnv g (VPSearch sm tys v e) = VPSearch sm tys v (g P.<> e) +extendPropEnv g (VPBin op vp1 vp2) = VPBin op (extendPropEnv g vp1) (extendPropEnv g vp2) extendResultEnv :: TestEnv -> TestResult -> TestResult extendResultEnv g (TestResult b r e) = TestResult b r (g P.<> e) @@ -406,15 +413,16 @@ extendResultEnv g (TestResult b r e) = TestResult b r (g P.<> e) ------------------------------------------------------------ -- | An environment is a mapping from names to values. -type Env = Ctx Core Value +type Env = Ctx Core Value ------------------------------------------------------------ -- Memory ------------------------------------------------------------ -- | 'Mem' represents a memory, containing 'Cell's -data Mem = Mem { next :: Int, mu :: IntMap Cell } deriving Show -data Cell = Blackhole | E Env Core | V Value deriving Show +data Mem = Mem {next :: Int, mu :: IntMap Cell} deriving (Show) + +data Cell = Blackhole | E Env Core | V Value deriving (Show) emptyMem :: Mem emptyMem = Mem 0 IM.empty @@ -425,7 +433,7 @@ emptyMem = Mem 0 IM.empty allocate :: Members '[State Mem] r => Env -> Core -> Sem r Int allocate e t = do Mem n m <- get - put $ Mem (n+1) (IM.insert n (E e t) m) + put $ Mem (n + 1) (IM.insert n (E e t) m) return n -- | Allocate new memory cells for a group of mutually recursive @@ -434,11 +442,11 @@ allocateRec :: Members '[State Mem] r => Env -> [(QName Core, Core)] -> Sem r [I allocateRec e bs = do Mem n m <- get let newRefs = zip [n ..] bs - e' = foldl' (flip (\(i,(x,_)) -> Ctx.insert x (VRef i))) e newRefs - m' = foldl' (flip (\(i,(_,c)) -> IM.insert i (E e' c))) m newRefs + e' = foldl' (flip (\(i, (x, _)) -> Ctx.insert x (VRef i))) e newRefs + m' = foldl' (flip (\(i, (_, c)) -> IM.insert i (E e' c))) m newRefs n' = n + length bs put $ Mem n' m' - return [n .. n'-1] + return [n .. n' - 1] -- | Look up the cell at a given index. lkup :: Members '[State Mem] r => Int -> Sem r (Maybe Cell) @@ -456,56 +464,47 @@ prettyValue' :: Member (Input TyDefCtx) r => Type -> Value -> Sem r Doc prettyValue' ty v = runLFresh . runReader initPA $ prettyValue ty v prettyValue :: Members '[Input TyDefCtx, LFresh, Reader PA] r => Type -> Value -> Sem r Doc - -- Lazily expand any user-defined types prettyValue (TyUser x args) v = do tydefs <- input - let (TyDefBody _ body) = tydefs M.! x -- This can't fail if typechecking succeeded + let (TyDefBody _ body) = tydefs M.! x -- This can't fail if typechecking succeeded prettyValue (body args) v - -prettyValue _ VUnit = "■" -prettyValue TyProp _ = prettyPlaceholder TyProp -prettyValue TyBool (VInj s _) = text $ map toLower (show (s == R)) +prettyValue _ VUnit = "■" +prettyValue TyProp _ = prettyPlaceholder TyProp +prettyValue TyBool (VInj s _) = text $ map toLower (show (s == R)) prettyValue TyBool v = error $ "Non-VInj passed with Bool type to prettyValue: " ++ show v -prettyValue TyC (vchar -> c) = text (show c) +prettyValue TyC (vchar -> c) = text (show c) prettyValue (TyList TyC) (vlist vchar -> cs) = doubleQuotes . text . concatMap prettyChar $ cs - where - prettyChar = drop 1 . reverse . drop 1 . reverse . show . (:[]) -prettyValue (TyList ty) (vlist id -> xs) = do + where + prettyChar = drop 1 . reverse . drop 1 . reverse . show . (: []) +prettyValue (TyList ty) (vlist id -> xs) = do ds <- punctuate (text ",") (map (prettyValue ty) xs) brackets (hsep ds) - -prettyValue ty@(_ :*: _) v = parens (prettyTuple ty v) - -prettyValue (ty1 :+: _) (VInj L v) = "left" <> prettyVP ty1 v -prettyValue (_ :+: ty2) (VInj R v) = "right" <> prettyVP ty2 v +prettyValue ty@(_ :*: _) v = parens (prettyTuple ty v) +prettyValue (ty1 :+: _) (VInj L v) = "left" <> prettyVP ty1 v +prettyValue (_ :+: ty2) (VInj R v) = "right" <> prettyVP ty2 v prettyValue (_ :+: _) v = error $ "Non-VInj passed with sum type to prettyValue: " ++ show v - prettyValue _ (VNum d r) - | denominator r == 1 = text $ show (numerator r) - | otherwise = text $ case d of + | denominator r == 1 = text $ show (numerator r) + | otherwise = text $ case d of Fraction -> show (numerator r) ++ "/" ++ show (denominator r) - Decimal -> prettyDecimal r - -prettyValue ty@(_ :->: _) _ = prettyPlaceholder ty - -prettyValue (TySet ty) (VBag xs) = braces $ prettySequence ty "," (map fst xs) + Decimal -> prettyDecimal r +prettyValue ty@(_ :->: _) _ = prettyPlaceholder ty +prettyValue (TySet ty) (VBag xs) = braces $ prettySequence ty "," (map fst xs) prettyValue (TySet _) v = error $ "Non-VBag passed with Set type to prettyValue: " ++ show v -prettyValue (TyBag ty) (VBag xs) = prettyBag ty xs +prettyValue (TyBag ty) (VBag xs) = prettyBag ty xs prettyValue (TyBag _) v = error $ "Non-VBag passed with Bag type to prettyValue: " ++ show v - -prettyValue (TyMap tyK tyV) (VMap m) = +prettyValue (TyMap tyK tyV) (VMap m) = "map" <> parens (braces (prettySequence (tyK :*: tyV) "," (assocsToValues m))) - where - assocsToValues = map (\(k,v) -> VPair (fromSimpleValue k) v) . M.assocs + where + assocsToValues = map (\(k, v) -> VPair (fromSimpleValue k) v) . M.assocs prettyValue (TyMap _ _) v = error $ "Non-map value with map type passed to prettyValue: " ++ show v - -prettyValue (TyGraph ty) (VGraph g) = +prettyValue (TyGraph ty) (VGraph g) = foldg "emptyGraph" (("vertex" <>) . prettyVP ty . fromSimpleValue) @@ -514,36 +513,34 @@ prettyValue (TyGraph ty) (VGraph g) = g prettyValue (TyGraph _) v = error $ "Non-graph value with graph type passed to prettyValue: " ++ show v - -prettyValue ty@TyAtom{} v = +prettyValue ty@TyAtom {} v = error $ "Invalid atomic type passed to prettyValue: " ++ show ty ++ " " ++ show v - -prettyValue ty@TyCon{} v = +prettyValue ty@TyCon {} v = error $ "Invalid type constructor passed to prettyValue: " ++ show ty ++ " " ++ show v -- | Pretty-print a value with guaranteed parentheses. Do nothing for -- tuples; add an extra set of parens for other values. prettyVP :: Members '[Input TyDefCtx, LFresh, Reader PA] r => Type -> Value -> Sem r Doc prettyVP ty@(_ :*: _) = prettyValue ty -prettyVP ty = parens . prettyValue ty +prettyVP ty = parens . prettyValue ty prettyPlaceholder :: Members '[Reader PA, LFresh] r => Type -> Sem r Doc prettyPlaceholder ty = "<" <> pretty ty <> ">" prettyTuple :: Members '[Input TyDefCtx, LFresh, Reader PA] r => Type -> Value -> Sem r Doc prettyTuple (ty1 :*: ty2) (VPair v1 v2) = prettyValue ty1 v1 <> "," <+> prettyTuple ty2 v2 -prettyTuple ty v = prettyValue ty v +prettyTuple ty v = prettyValue ty v -- | 'prettySequence' pretty-prints a lists of values separated by a delimiter. prettySequence :: Members '[Input TyDefCtx, LFresh, Reader PA] r => Type -> Doc -> [Value] -> Sem r Doc prettySequence ty del vs = hsep =<< punctuate (return del) (map (prettyValue ty) vs) -- | Pretty-print a literal bag value. -prettyBag :: Members '[Input TyDefCtx, LFresh, Reader PA] r => Type -> [(Value,Integer)] -> Sem r Doc +prettyBag :: Members '[Input TyDefCtx, LFresh, Reader PA] r => Type -> [(Value, Integer)] -> Sem r Doc prettyBag _ [] = bag empty prettyBag ty vs - | all ((==1) . snd) vs = bag $ prettySequence ty "," (map fst vs) - | otherwise = bag $ hsep =<< punctuate (return ",") (map prettyCount vs) - where - prettyCount (v,1) = prettyValue ty v - prettyCount (v,n) = prettyValue ty v <+> "#" <+> text (show n) + | all ((== 1) . snd) vs = bag $ prettySequence ty "," (map fst vs) + | otherwise = bag $ hsep =<< punctuate (return ",") (map prettyCount vs) + where + prettyCount (v, 1) = prettyValue ty v + prettyCount (v, n) = prettyValue ty v <+> "#" <+> text (show n)