From c39a45ca17a83ded024c59eab02eb7387ea1ace1 Mon Sep 17 00:00:00 2001 From: Josef Svenningsson Date: Mon, 29 Jan 2024 10:12:53 -0800 Subject: [PATCH] Modernize the typechecker monad and make it abstract Summary: This is a no-op diff that just reorganizes the code a bit. The typechecker monad is made abstract. This allows for removinging the specialized `collect` function in favour an `Applicative` instance. Similarly, `orT` and `emptyT` are replaced by an `Alternative` instance. I've also removed `mapT` and `mapE` as they were just `traverse`. Reviewed By: simonmar Differential Revision: D53135425 fbshipit-source-id: a050c73729ef51ad56ea85126e8dbf0cc15a9583 --- compiler/Thrift/Compiler/Typechecker.hs | 98 ++++++++----------- compiler/Thrift/Compiler/Typechecker/Monad.hs | 86 ++++++++-------- .../Thrift/Compiler/Plugins/Haskell.hs | 3 +- 3 files changed, 84 insertions(+), 103 deletions(-) diff --git a/compiler/Thrift/Compiler/Typechecker.hs b/compiler/Thrift/Compiler/Typechecker.hs index 934edea3..551058d2 100644 --- a/compiler/Thrift/Compiler/Typechecker.hs +++ b/compiler/Thrift/Compiler/Typechecker.hs @@ -6,7 +6,7 @@ LICENSE file in the root directory of this source tree. -} -{-# LANGUAGE TypeOperators, NamedFieldPuns #-} +{-# LANGUAGE TypeOperators, NamedFieldPuns, ApplicativeDo #-} module Thrift.Compiler.Typechecker ( typecheck , typecheckConst, eqOrAlias @@ -15,14 +15,12 @@ module Thrift.Compiler.Typechecker ) where import Prelude hiding (Enum) -import Data.Either ( rights ) import Data.List hiding (uncons) import Data.Maybe import Data.Some import Data.Text.Encoding hiding (Some) import Data.Type.Equality import Control.Monad -import Control.Monad.Trans.Reader import Data.Graph import Data.Text (Text) import GHC.TypeLits hiding (TypeError) @@ -122,9 +120,9 @@ typecheckModule opts@Options{..} progs tf@ThriftFile{..} = do imap = mkEnumInt opts dEnums (smap, umap, cmap, servMap) <- (,,,) <$> mkSchemaMap (thriftName, opts) importMap tmap dStructs - `collect` mkUnionMap (thriftName, opts) importMap tmap dUnions - `collect` mkConstMap (thriftName, opts) importMap tmap thriftDeclsNew - `collect` mkServiceMap (thriftName, opts) importMap dServs + <*> mkUnionMap (thriftName, opts) importMap tmap dUnions + <*> mkConstMap (thriftName, opts) importMap tmap thriftDeclsNew + <*> mkServiceMap (thriftName, opts) importMap dServs -- Build the Env let env = Env { typeMap = tmap @@ -139,8 +137,8 @@ typecheckModule opts@Options{..} progs tf@ThriftFile{..} = do , envName = thriftName -- for weird 'mkThriftName' case } -- Typecheck the rest of the things - headers <- runTypechecker env $ mapT typecheckHeader thriftHeaders - decls <- runTypechecker env $ mapT resolveDecl thriftDeclsNew + headers <- runTypechecker env $ traverse typecheckHeader thriftHeaders + decls <- runTypechecker env $ traverse resolveDecl thriftDeclsNew let prog = Program { progName = thriftName , progHSName = renamedModule @@ -264,7 +262,7 @@ type ModuleMap = Map.Map FilePath (ThriftFile SpliceFile Loc) -- Topologically sort the modules so that we can typecheck them. We will start -- at the leaves and work our way up to the original module sortModules :: ModuleMap -> Either [TypeError l] [ThriftFile SpliceFile Loc] -sortModules moduleMap = mapE getVertex sccs +sortModules moduleMap = traverse getVertex sccs where sccs = stronglyConnComp graph graph = map mkVertex (Map.toList moduleMap) @@ -550,7 +548,7 @@ resolveFields -> [Parsed (Field u)] -> TC l [Field u 'Resolved l Loc] resolveFields sname as fs = - (mapT (resolveField as sname) =<< fields) <* + (traverse (resolveField as sname) =<< fields) <* -- Check for duplicate field ids foldM checkId Set.empty fs where @@ -651,7 +649,7 @@ resolveUnion u@Union{..} = do where resolveAlts :: Bool -> [Parsed UnionAlt] -> TC l [UnionAlt 'Resolved l Loc] resolveAlts optsLenient alts = - mapT resolveAlt alts + traverse resolveAlt alts -- Check for duplicate field ids <* foldM checkId Set.empty alts <* checkEmpty @@ -742,8 +740,8 @@ resolveService :: Typecheckable l => Parsed Service -> Typechecked l Service resolveService s@Service{..} = do (super, stmts, sAnns) <- (,,) <$> sequence (resolveSuper <$> serviceSuper) - `collect` mapT resolveStmt serviceStmts - `collect` resolveStructuredAnns serviceSAnns + <*> traverse resolveStmt serviceStmts + <*> resolveStructuredAnns serviceSAnns Env{..} <- ask pure Service { serviceResolvedName = renameService options s @@ -770,8 +768,8 @@ resolveInteraction :: Typecheckable l => Parsed Interaction -> Typechecked l Int resolveInteraction Interaction{..} = do (super, funs, sAnns) <- (,,) <$> sequence (resolveSuper <$> interactionSuper) - `collect` mapT resolveFunction interactionFunctions - `collect` resolveStructuredAnns interactionSAnns + <*> traverse resolveFunction interactionFunctions + <*> resolveStructuredAnns interactionSAnns Env{..} <- ask pure Interaction { interactionResolvedName = interactionName @@ -791,10 +789,10 @@ resolveFunction f@Function{..} = do (rtype, ftype, args, excepts, sAnns) <- (,,,,) <$> sequence (resolveFunctionTypeTy funType) - `collect` resolveFunctionType funName annNoPriorities funType - `collect` resolveFields funName annNoPriorities funArgs - `collect` resolveFields funName annNoPriorities funExceptions - `collect` resolveStructuredAnns funSAnns + <*> resolveFunctionType funName annNoPriorities funType + <*> resolveFields funName annNoPriorities funArgs + <*> resolveFields funName annNoPriorities funExceptions + <*> resolveStructuredAnns funSAnns Env{..} <- ask pure $ Function { funResolvedName = renameFunction options f @@ -997,19 +995,6 @@ getEnumType opts@Options{..} enum@Enum{..} = case enumFlavourTag opts enum of -- Typecheck Constants --------------------------------------------------------- --- | See T45688659 for the weird tale of how badly-typed keys are getting --- ignored. Try to be similar here, in weird mode, by dropping errors. -mapTWeird :: (a -> TC l b) -> [a] -> TC l [b] -mapTWeird f xs = do - Options{optsLenient} <- asks options - if not optsLenient then - mapT f xs - else - ReaderT $ \env -> Right (mapEWeird (runTypechecker env . f) xs) - where - mapEWeird :: (a -> Either [e] b) -> [a] -> [b] - mapEWeird ff = rights . map ff - typecheckConst :: Typecheckable l => Type l t @@ -1042,7 +1027,7 @@ typecheckConst TBytes (UntypedConst _ (StringConst s _)) = -- Recursive Types typecheckConst (TList u) (UntypedConst _ ListConst{..}) = - Literal . List <$> mapT (typecheckConst u . leElem) lvElems + Literal . List <$> traverse (typecheckConst u . leElem) lvElems typecheckConst ty@(TList _) c@(UntypedConst l MapConst{mvElems=[]}) = do Options{optsLenient} <- asks options if optsLenient then @@ -1050,7 +1035,7 @@ typecheckConst ty@(TList _) c@(UntypedConst l MapConst{mvElems=[]}) = do else typeError (lLocation l) (LiteralMismatch ty c) typecheckConst (TSet u) (UntypedConst _ ListConst{..}) = - Literal . Set <$> mapT (typecheckConst u . leElem) lvElems + Literal . Set <$> traverse (typecheckConst u . leElem) lvElems typecheckConst ty@(TSet _) c@(UntypedConst l MapConst{mvElems=[]}) = do Options{optsLenient} <- asks options if optsLenient then @@ -1058,15 +1043,16 @@ typecheckConst ty@(TSet _) c@(UntypedConst l MapConst{mvElems=[]}) = do else typeError (lLocation l) (LiteralMismatch ty c) typecheckConst (THashSet u) (UntypedConst _ ListConst{..}) = - Literal . HashSet <$> mapT (typecheckConst u . leElem) lvElems + Literal . HashSet <$> traverse (typecheckConst u . leElem) lvElems typecheckConst ty@(THashSet _) c@(UntypedConst l MapConst{mvElems=[]}) = do Options{optsLenient} <- asks options if optsLenient then return $ Literal $ HashSet [] -- weird files use the wrong empty brackets else typeError (lLocation l) (LiteralMismatch ty c) -typecheckConst (TMap kt vt) (UntypedConst _ MapConst{..}) = - Literal . Map <$> mapTWeird tcConsts mvElems +typecheckConst (TMap kt vt) (UntypedConst _ MapConst{..}) = do + Options{optsLenient} <- asks options + Literal . Map <$> traverseWeird optsLenient tcConsts mvElems where tcConsts ListElem{leElem=MapPair{..}} = (,) <$> typecheckConst kt mpKey @@ -1078,7 +1064,7 @@ typecheckConst ty@(TMap _ _) c@(UntypedConst l ListConst{lvElems=[]}) = do else typeError (lLocation l) (LiteralMismatch ty c) typecheckConst (THashMap kt vt) (UntypedConst _ MapConst{..}) = - Literal . HashMap <$> mapT tcConsts mvElems + Literal . HashMap <$> traverse tcConsts mvElems where tcConsts ListElem{leElem=MapPair{..}} = (,) <$> typecheckConst kt mpKey @@ -1115,8 +1101,8 @@ typecheckConst newt@(TNewtype name ty _loc) val@(UntypedConst Located{..} c) = -- however IdConsts can also be enums, so we have to check for this case -- too IdConst ident -> - typecheckPseudoEnum newt lLocation name ident `orT` - typecheckIdent newt lLocation ident `orT` + typecheckPseudoEnum newt lLocation name ident <|> + typecheckIdent newt lLocation ident <|> (liftNew =<< typecheckConst ty val) _ -> liftNew =<< typecheckConst ty val where @@ -1181,18 +1167,18 @@ typecheckConst Nothing -> typeError lLocation $ TypeMismatch tyTop tyAnn _ -> typeError lLocation $ InvalidUnion n $ length svElems --- Identifiers (typecheckIdentNum is not first parameter to the orT) +-- Identifiers (typecheckIdentNum is not first parameter to <|>) -- These identified are permitted to be enums in lenient mode typecheckConst ty (UntypedConst Located{..} (IdConst ident)) = do Env{..} <- ask typecheckIdent ty lLocation ident - `orT` typecheckIdentNum ty lLocation ident - `orT` case ty of + <|> typecheckIdentNum ty lLocation ident + <|> case ty of I8 | optsLenient options -> typecheckEnumAsInt lLocation ident I16 | optsLenient options -> typecheckEnumAsInt lLocation ident I32 | optsLenient options -> typecheckEnumAsInt lLocation ident I64 | optsLenient options -> typecheckEnumAsInt lLocation ident - _ -> emptyT + _ -> empty -- Special Types typecheckConst (TSpecial ty) val = typecheckSpecialConst ty val @@ -1218,7 +1204,7 @@ typecheckEnumAsInt loc ident = do else Just $ Text.drop 1 v where (_nm, v) = Text.breakOn "." text case enumValue of - Nothing -> emptyT + Nothing -> empty Just k -> do i <- lookupEnumInt k loc literal $ fromIntegral i @@ -1259,12 +1245,12 @@ typecheckPseudoEnum ty loc n@Name{..} ident = do case thistrueTy of Some trueTy -> case trueTy `eqOrAlias` ty of Just Refl -> pure $ Identifier renamed ty locDefined - Nothing -> emptyT - _ -> emptyT + Nothing -> empty + _ -> empty -- | Handle weird case of enum to int casting, dispatch when 'ty' is int like. -- --- Do not use as first parameter to 'orT' +-- Do not use as first parameter to '<|>' typecheckIdentNum :: (Typecheckable l) => Type l t @@ -1276,14 +1262,14 @@ typecheckIdentNum ty loc ident = case ty of I16 -> typecheckEnumInt loc ident I32 -> typecheckEnumInt loc ident I64 -> typecheckEnumInt loc ident - _ -> emptyT + _ -> empty -- | This is used for int-like constants that might, werdly, get their value -- from an enum value (implicit casting is somewhay popular). Use the -- usual typecheckIdent before attempting the enum lookup (made up rule that -- seems to work). -- --- Do not use as first parameter to 'orT' +-- Do not use as first parameter to '<|>' typecheckEnumInt :: (Typecheckable l, Num t) => Loc @@ -1291,7 +1277,7 @@ typecheckEnumInt -> TC l (TypedConst l t) typecheckEnumInt loc ident = do Env{..} <- ask - if not (optsLenient options) then emptyT else do + if not (optsLenient options) then empty else do name <- mkThriftName ident i <- lookupEnumInt name loc literal $ fromIntegral i @@ -1407,7 +1393,7 @@ getAliasedType ty = ty mkFieldMap :: [ListElem MapPair Loc] -> TC l (Map.Map Text (UntypedConst Loc)) -mkFieldMap = fmap Map.fromList . mapT getName +mkFieldMap = fmap Map.fromList . traverse getName where getName ListElem{ leElem = MapPair{..} } = case mpKey of UntypedConst _ (StringConst s _) -> pure (s, mpVal) @@ -1416,7 +1402,7 @@ mkFieldMap = fmap Map.fromList . mapT getName mkStructFieldMap :: [ListElem StructPair Loc] -> TC l (Map.Map Text (UntypedConst Loc)) -mkStructFieldMap = fmap Map.fromList . mapT getName +mkStructFieldMap = fmap Map.fromList . traverse getName where getName ListElem{ leElem = StructPair{..} } = pure (spKey, spVal) @@ -1562,7 +1548,7 @@ mkTypemap (thriftName, opts@Options{..}) imap = -- Topologically sort the Decls based on dependencies -- This function will fail if there is a cycle sortDecls :: [Parsed Decl] -> Either [TypeError l] [Parsed Decl] -sortDecls decls = mapE getVertex sccs +sortDecls decls = traverse getVertex sccs where sccs = stronglyConnComp graph graph = mapMaybe mkVertex decls @@ -1691,7 +1677,7 @@ mkSchema Struct{..} = buildSchema structMembers buildSchema (field@Field{..} : fields) = do (rty, tschema) <- (,) <$> resolveAnnotatedType fieldType - `collect` buildSchema fields + <*> buildSchema fields opts@Options{..} <- options <$> ask let renamed = Text.unpack $ renameField opts (getAnns structAnns) structName field @@ -1813,7 +1799,7 @@ mkServiceMap (thriftName, opts@Options{..}) imap = sortServices :: [Parsed Service] -> Either [TypeError l] [Parsed Service] -sortServices services = mapE getVertex sccs +sortServices services = traverse getVertex sccs where sccs = stronglyConnComp graph graph = map mkVertex services diff --git a/compiler/Thrift/Compiler/Typechecker/Monad.hs b/compiler/Thrift/Compiler/Typechecker/Monad.hs index 190e4623..9ecc3382 100644 --- a/compiler/Thrift/Compiler/Typechecker/Monad.hs +++ b/compiler/Thrift/Compiler/Typechecker/Monad.hs @@ -10,7 +10,8 @@ {-# LANGUAGE PolyKinds #-} module Thrift.Compiler.Typechecker.Monad ( TC, Typechecked - , typeError, runTypechecker, mapT, mapE, emptyT, orT, ErrCollectable(..) + , typeError, runTypechecker, ask, asks, traverseWeird + , Alternative(..) , lookupType, lookupSchema, lookupUnion, lookupEnum, lookupConst , lookupService, lookupEnumInt , TypeError(..), ErrorMsg(..), AnnotationPlacement(..) @@ -18,11 +19,12 @@ module Thrift.Compiler.Typechecker.Monad ) where import Prelude hiding (Enum) +import Control.Applicative +import Data.Either ( rights ) import Data.Int (Int32) import Data.Some import Data.Text (Text) -import Control.Monad.Trans.Class -import Control.Monad.Trans.Reader +import Control.Monad.Reader import qualified Data.Map.Strict as Map import qualified Data.Set as Set @@ -33,38 +35,36 @@ import Thrift.Compiler.Types -- The typechecking monad is a reader transformed either monad. This allows us -- to keep track of state and errors -type TC l = ReaderT (Env l) (Either [TypeError l]) +newtype TC l a = TC (ReaderT (Env l) (Either [TypeError l]) a) + deriving (Functor,Monad, MonadReader (Env l)) + +instance Applicative (TC l) where + pure a = TC (pure a) + -- | Special Applicative instance which allows us to collect + -- type errors from *both* parts of the computation in the + -- event that they both fail. + TC (ReaderT rf) <*> TC (ReaderT rx) = + TC $ ReaderT $ \env -> rf env `collect` rx env + where + collect (Left e1) (Left e2) = Left $ e1 <> e2 + collect (Left e) (Right _) = Left e + collect (Right _) (Left e) = Left e + collect (Right f) (Right x) = Right $ f x type Typechecked (l :: *) (t :: Status -> * -> * -> *) = TC l (t 'Resolved l Loc) typeError :: Loc -> ErrorMsg l -> TC l a -typeError loc msg = lift $ Left [TypeError loc msg] +typeError loc msg = TC $ lift $ Left [TypeError loc msg] runTypechecker :: Env l -> TC l a -> Either [TypeError l] a -runTypechecker = flip runReaderT - --- | map a typechecking computation over a list and collect all of the type --- errors. This differs from mapM and mapA because those will terminate with --- the first type error -mapT :: (a -> TC l b) -> [a] -> TC l [b] -mapT f xs = ReaderT $ \env -> mapE (runTypechecker env . f) xs - -mapE :: (a -> Either [e] b) -> [a] -> Either [e] [b] -mapE f = foldr (\x xs -> (:) <$> f x `collect` xs) (Right []) - --- | This is the identity for 'orT', and morally @empty@ from Alternative --- --- This produces an empty list of errors, so should never be used as the --- first parater of 'orT'. -emptyT :: TC l a -emptyT = lift (Left []) - --- | Try to run the first typechecker computation, but use the second if it --- fails. If both fail then use both errors. See 'emptyT' for identity. --- Note: this is associative and morally @<|>@ from Alternative -orT :: TC l a -> TC l a -> TC l a -orT t1 t2 = do +runTypechecker env (TC t) = runReaderT t env + +instance Alternative (TC l) where + -- This produces an empty list of errors, so should never be used as the + -- first parater of '<|>'. + empty = TC $ lift (Left []) + t1 <|> t2 = TC $ do env <- ask lift $ case runTypechecker env t1 of Left err1 -> case runTypechecker env t2 of @@ -72,21 +72,17 @@ orT t1 t2 = do x2@Right{} -> x2 x1@Right{} -> x1 --- | Similar to <*>, but doesn't obey the monad law that ap == (<*>). This --- allows us to collect type errors from *both* parts of the computation in the --- event that they both fail. (<*>) will only take the first failure it finds. -class ErrCollectable m where - collect :: m (a -> b) -> m a -> m b -infixl 4 `collect` - -instance Monoid e => ErrCollectable (Either e) where - collect (Left e1) (Left e2) = Left $ e1 <> e2 - collect (Left e) (Right _) = Left e - collect (Right _) (Left e) = Left e - collect (Right f) (Right x) = Right $ f x - -instance ErrCollectable m => ErrCollectable (ReaderT a m) where - collect (ReaderT f) (ReaderT x) = ReaderT $ \env -> f env `collect` x env +-- | See T45688659 for the weird tale of how badly-typed keys are getting +-- ignored. Try to be similar here, in weird mode, by dropping errors. +traverseWeird :: Bool -> (a -> TC l b) -> [a] -> TC l [b] +traverseWeird optsLenient f xs = do + if not optsLenient then + traverse f xs + else + TC $ ReaderT $ \env -> Right (mapEWeird (runTypechecker env . f) xs) + where + mapEWeird :: (a -> Either [e] b) -> [a] -> [b] + mapEWeird ff = rights . map ff -- Lookup Functions ------------------------------------------------------------ @@ -129,7 +125,7 @@ envLookup -> Loc -> TC l u envLookup getMap mkError name loc = do - env <- ask + env <- TC ask case doLookup env name of Nothing -> typeError loc $ mkError name Just u -> pure u @@ -145,7 +141,7 @@ envCtxLookup -> Loc -> TC l u envCtxLookup getMap mkError name loc = do - env <- ask + env <- TC ask case doLookup env name of Nothing -> typeError loc $ mkError name Just u -> pure u diff --git a/compiler/plugins/Thrift/Compiler/Plugins/Haskell.hs b/compiler/plugins/Thrift/Compiler/Plugins/Haskell.hs index 2325e2a5..e3532bff 100644 --- a/compiler/plugins/Thrift/Compiler/Plugins/Haskell.hs +++ b/compiler/plugins/Thrift/Compiler/Plugins/Haskell.hs @@ -15,7 +15,6 @@ module Thrift.Compiler.Plugins.Haskell , toCamel ) where -import Control.Monad.Trans.Reader import Data.ByteString (ByteString) import qualified Data.Foldable as Foldable import qualified Data.Map as Map @@ -165,7 +164,7 @@ instance Typecheckable Haskell where typecheckSpecialConst HsByteString (UntypedConst _ (StringConst s _)) = pure $ Literal $ Text.encodeUtf8 s typecheckSpecialConst (HsVector _ u) (UntypedConst _ ListConst{..}) = - Literal . List <$> mapT (typecheckConst u . leElem) lvElems + Literal . List <$> traverse (typecheckConst u . leElem) lvElems typecheckSpecialConst ty val@(UntypedConst Located{..} _) = typeError lLocation $ LiteralMismatch (TSpecial ty) val