Skip to content

Commit

Permalink
Modernize the typechecker monad and make it abstract
Browse files Browse the repository at this point in the history
Summary:
This is a no-op diff that just reorganizes the code a bit.

The typechecker monad is made abstract. This allows for removinging the specialized `collect` function in favour an `Applicative` instance. Similarly, `orT` and `emptyT` are replaced by an `Alternative` instance.

I've also removed `mapT` and `mapE` as they were just `traverse`.

Reviewed By: simonmar

Differential Revision: D53135425

fbshipit-source-id: a050c73729ef51ad56ea85126e8dbf0cc15a9583
  • Loading branch information
Josef Svenningsson authored and facebook-github-bot committed Jan 29, 2024
1 parent a80c3d5 commit c39a45c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 103 deletions.
98 changes: 42 additions & 56 deletions compiler/Thrift/Compiler/Typechecker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
LICENSE file in the root directory of this source tree.
-}

{-# LANGUAGE TypeOperators, NamedFieldPuns #-}
{-# LANGUAGE TypeOperators, NamedFieldPuns, ApplicativeDo #-}
module Thrift.Compiler.Typechecker
( typecheck
, typecheckConst, eqOrAlias
Expand All @@ -15,14 +15,12 @@ module Thrift.Compiler.Typechecker
) where

import Prelude hiding (Enum)
import Data.Either ( rights )
import Data.List hiding (uncons)
import Data.Maybe
import Data.Some
import Data.Text.Encoding hiding (Some)
import Data.Type.Equality
import Control.Monad
import Control.Monad.Trans.Reader
import Data.Graph
import Data.Text (Text)
import GHC.TypeLits hiding (TypeError)
Expand Down Expand Up @@ -122,9 +120,9 @@ typecheckModule opts@Options{..} progs tf@ThriftFile{..} = do
imap = mkEnumInt opts dEnums
(smap, umap, cmap, servMap)
<- (,,,) <$> mkSchemaMap (thriftName, opts) importMap tmap dStructs
`collect` mkUnionMap (thriftName, opts) importMap tmap dUnions
`collect` mkConstMap (thriftName, opts) importMap tmap thriftDeclsNew
`collect` mkServiceMap (thriftName, opts) importMap dServs
<*> mkUnionMap (thriftName, opts) importMap tmap dUnions
<*> mkConstMap (thriftName, opts) importMap tmap thriftDeclsNew
<*> mkServiceMap (thriftName, opts) importMap dServs

-- Build the Env
let env = Env { typeMap = tmap
Expand All @@ -139,8 +137,8 @@ typecheckModule opts@Options{..} progs tf@ThriftFile{..} = do
, envName = thriftName -- for weird 'mkThriftName' case
}
-- Typecheck the rest of the things
headers <- runTypechecker env $ mapT typecheckHeader thriftHeaders
decls <- runTypechecker env $ mapT resolveDecl thriftDeclsNew
headers <- runTypechecker env $ traverse typecheckHeader thriftHeaders
decls <- runTypechecker env $ traverse resolveDecl thriftDeclsNew
let prog = Program
{ progName = thriftName
, progHSName = renamedModule
Expand Down Expand Up @@ -264,7 +262,7 @@ type ModuleMap = Map.Map FilePath (ThriftFile SpliceFile Loc)
-- Topologically sort the modules so that we can typecheck them. We will start
-- at the leaves and work our way up to the original module
sortModules :: ModuleMap -> Either [TypeError l] [ThriftFile SpliceFile Loc]
sortModules moduleMap = mapE getVertex sccs
sortModules moduleMap = traverse getVertex sccs
where
sccs = stronglyConnComp graph
graph = map mkVertex (Map.toList moduleMap)
Expand Down Expand Up @@ -550,7 +548,7 @@ resolveFields
-> [Parsed (Field u)]
-> TC l [Field u 'Resolved l Loc]
resolveFields sname as fs =
(mapT (resolveField as sname) =<< fields) <*
(traverse (resolveField as sname) =<< fields) <*
-- Check for duplicate field ids
foldM checkId Set.empty fs
where
Expand Down Expand Up @@ -651,7 +649,7 @@ resolveUnion u@Union{..} = do
where
resolveAlts :: Bool -> [Parsed UnionAlt] -> TC l [UnionAlt 'Resolved l Loc]
resolveAlts optsLenient alts =
mapT resolveAlt alts
traverse resolveAlt alts
-- Check for duplicate field ids
<* foldM checkId Set.empty alts
<* checkEmpty
Expand Down Expand Up @@ -742,8 +740,8 @@ resolveService :: Typecheckable l => Parsed Service -> Typechecked l Service
resolveService s@Service{..} = do
(super, stmts, sAnns)
<- (,,) <$> sequence (resolveSuper <$> serviceSuper)
`collect` mapT resolveStmt serviceStmts
`collect` resolveStructuredAnns serviceSAnns
<*> traverse resolveStmt serviceStmts
<*> resolveStructuredAnns serviceSAnns
Env{..} <- ask
pure Service
{ serviceResolvedName = renameService options s
Expand All @@ -770,8 +768,8 @@ resolveInteraction :: Typecheckable l => Parsed Interaction -> Typechecked l Int
resolveInteraction Interaction{..} = do
(super, funs, sAnns)
<- (,,) <$> sequence (resolveSuper <$> interactionSuper)
`collect` mapT resolveFunction interactionFunctions
`collect` resolveStructuredAnns interactionSAnns
<*> traverse resolveFunction interactionFunctions
<*> resolveStructuredAnns interactionSAnns
Env{..} <- ask
pure Interaction
{ interactionResolvedName = interactionName
Expand All @@ -791,10 +789,10 @@ resolveFunction f@Function{..} = do
(rtype, ftype, args, excepts, sAnns)
<- (,,,,) <$>
sequence (resolveFunctionTypeTy funType)
`collect` resolveFunctionType funName annNoPriorities funType
`collect` resolveFields funName annNoPriorities funArgs
`collect` resolveFields funName annNoPriorities funExceptions
`collect` resolveStructuredAnns funSAnns
<*> resolveFunctionType funName annNoPriorities funType
<*> resolveFields funName annNoPriorities funArgs
<*> resolveFields funName annNoPriorities funExceptions
<*> resolveStructuredAnns funSAnns
Env{..} <- ask
pure $ Function
{ funResolvedName = renameFunction options f
Expand Down Expand Up @@ -997,19 +995,6 @@ getEnumType opts@Options{..} enum@Enum{..} = case enumFlavourTag opts enum of

-- Typecheck Constants ---------------------------------------------------------

-- | See T45688659 for the weird tale of how badly-typed keys are getting
-- ignored. Try to be similar here, in weird mode, by dropping errors.
mapTWeird :: (a -> TC l b) -> [a] -> TC l [b]
mapTWeird f xs = do
Options{optsLenient} <- asks options
if not optsLenient then
mapT f xs
else
ReaderT $ \env -> Right (mapEWeird (runTypechecker env . f) xs)
where
mapEWeird :: (a -> Either [e] b) -> [a] -> [b]
mapEWeird ff = rights . map ff

typecheckConst
:: Typecheckable l
=> Type l t
Expand Down Expand Up @@ -1042,31 +1027,32 @@ typecheckConst TBytes (UntypedConst _ (StringConst s _)) =

-- Recursive Types
typecheckConst (TList u) (UntypedConst _ ListConst{..}) =
Literal . List <$> mapT (typecheckConst u . leElem) lvElems
Literal . List <$> traverse (typecheckConst u . leElem) lvElems
typecheckConst ty@(TList _) c@(UntypedConst l MapConst{mvElems=[]}) = do
Options{optsLenient} <- asks options
if optsLenient then
return $ Literal $ List [] -- weird files use the wrong empty brackets, sigh
else
typeError (lLocation l) (LiteralMismatch ty c)
typecheckConst (TSet u) (UntypedConst _ ListConst{..}) =
Literal . Set <$> mapT (typecheckConst u . leElem) lvElems
Literal . Set <$> traverse (typecheckConst u . leElem) lvElems
typecheckConst ty@(TSet _) c@(UntypedConst l MapConst{mvElems=[]}) = do
Options{optsLenient} <- asks options
if optsLenient then
return $ Literal $ Set [] -- weird files use the wrong empty brackets, sigh
else
typeError (lLocation l) (LiteralMismatch ty c)
typecheckConst (THashSet u) (UntypedConst _ ListConst{..}) =
Literal . HashSet <$> mapT (typecheckConst u . leElem) lvElems
Literal . HashSet <$> traverse (typecheckConst u . leElem) lvElems
typecheckConst ty@(THashSet _) c@(UntypedConst l MapConst{mvElems=[]}) = do
Options{optsLenient} <- asks options
if optsLenient then
return $ Literal $ HashSet [] -- weird files use the wrong empty brackets
else
typeError (lLocation l) (LiteralMismatch ty c)
typecheckConst (TMap kt vt) (UntypedConst _ MapConst{..}) =
Literal . Map <$> mapTWeird tcConsts mvElems
typecheckConst (TMap kt vt) (UntypedConst _ MapConst{..}) = do
Options{optsLenient} <- asks options
Literal . Map <$> traverseWeird optsLenient tcConsts mvElems
where
tcConsts ListElem{leElem=MapPair{..}} = (,)
<$> typecheckConst kt mpKey
Expand All @@ -1078,7 +1064,7 @@ typecheckConst ty@(TMap _ _) c@(UntypedConst l ListConst{lvElems=[]}) = do
else
typeError (lLocation l) (LiteralMismatch ty c)
typecheckConst (THashMap kt vt) (UntypedConst _ MapConst{..}) =
Literal . HashMap <$> mapT tcConsts mvElems
Literal . HashMap <$> traverse tcConsts mvElems
where
tcConsts ListElem{leElem=MapPair{..}} = (,)
<$> typecheckConst kt mpKey
Expand Down Expand Up @@ -1115,8 +1101,8 @@ typecheckConst newt@(TNewtype name ty _loc) val@(UntypedConst Located{..} c) =
-- however IdConsts can also be enums, so we have to check for this case
-- too
IdConst ident ->
typecheckPseudoEnum newt lLocation name ident `orT`
typecheckIdent newt lLocation ident `orT`
typecheckPseudoEnum newt lLocation name ident <|>
typecheckIdent newt lLocation ident <|>
(liftNew =<< typecheckConst ty val)
_ -> liftNew =<< typecheckConst ty val
where
Expand Down Expand Up @@ -1181,18 +1167,18 @@ typecheckConst
Nothing -> typeError lLocation $ TypeMismatch tyTop tyAnn
_ -> typeError lLocation $ InvalidUnion n $ length svElems

-- Identifiers (typecheckIdentNum is not first parameter to the orT)
-- Identifiers (typecheckIdentNum is not first parameter to <|>)
-- These identified are permitted to be enums in lenient mode
typecheckConst ty (UntypedConst Located{..} (IdConst ident)) = do
Env{..} <- ask
typecheckIdent ty lLocation ident
`orT` typecheckIdentNum ty lLocation ident
`orT` case ty of
<|> typecheckIdentNum ty lLocation ident
<|> case ty of
I8 | optsLenient options -> typecheckEnumAsInt lLocation ident
I16 | optsLenient options -> typecheckEnumAsInt lLocation ident
I32 | optsLenient options -> typecheckEnumAsInt lLocation ident
I64 | optsLenient options -> typecheckEnumAsInt lLocation ident
_ -> emptyT
_ -> empty

-- Special Types
typecheckConst (TSpecial ty) val = typecheckSpecialConst ty val
Expand All @@ -1218,7 +1204,7 @@ typecheckEnumAsInt loc ident = do
else Just $ Text.drop 1 v
where (_nm, v) = Text.breakOn "." text
case enumValue of
Nothing -> emptyT
Nothing -> empty
Just k -> do
i <- lookupEnumInt k loc
literal $ fromIntegral i
Expand Down Expand Up @@ -1259,12 +1245,12 @@ typecheckPseudoEnum ty loc n@Name{..} ident = do
case thistrueTy of
Some trueTy -> case trueTy `eqOrAlias` ty of
Just Refl -> pure $ Identifier renamed ty locDefined
Nothing -> emptyT
_ -> emptyT
Nothing -> empty
_ -> empty

-- | Handle weird case of enum to int casting, dispatch when 'ty' is int like.
--
-- Do not use as first parameter to 'orT'
-- Do not use as first parameter to '<|>'
typecheckIdentNum
:: (Typecheckable l)
=> Type l t
Expand All @@ -1276,22 +1262,22 @@ typecheckIdentNum ty loc ident = case ty of
I16 -> typecheckEnumInt loc ident
I32 -> typecheckEnumInt loc ident
I64 -> typecheckEnumInt loc ident
_ -> emptyT
_ -> empty

-- | This is used for int-like constants that might, werdly, get their value
-- from an enum value (implicit casting is somewhay popular). Use the
-- usual typecheckIdent before attempting the enum lookup (made up rule that
-- seems to work).
--
-- Do not use as first parameter to 'orT'
-- Do not use as first parameter to '<|>'
typecheckEnumInt
:: (Typecheckable l, Num t)
=> Loc
-> Text
-> TC l (TypedConst l t)
typecheckEnumInt loc ident = do
Env{..} <- ask
if not (optsLenient options) then emptyT else do
if not (optsLenient options) then empty else do
name <- mkThriftName ident
i <- lookupEnumInt name loc
literal $ fromIntegral i
Expand Down Expand Up @@ -1407,7 +1393,7 @@ getAliasedType ty = ty
mkFieldMap
:: [ListElem MapPair Loc]
-> TC l (Map.Map Text (UntypedConst Loc))
mkFieldMap = fmap Map.fromList . mapT getName
mkFieldMap = fmap Map.fromList . traverse getName
where
getName ListElem{ leElem = MapPair{..} } = case mpKey of
UntypedConst _ (StringConst s _) -> pure (s, mpVal)
Expand All @@ -1416,7 +1402,7 @@ mkFieldMap = fmap Map.fromList . mapT getName
mkStructFieldMap
:: [ListElem StructPair Loc]
-> TC l (Map.Map Text (UntypedConst Loc))
mkStructFieldMap = fmap Map.fromList . mapT getName
mkStructFieldMap = fmap Map.fromList . traverse getName
where
getName ListElem{ leElem = StructPair{..} } =
pure (spKey, spVal)
Expand Down Expand Up @@ -1562,7 +1548,7 @@ mkTypemap (thriftName, opts@Options{..}) imap =
-- Topologically sort the Decls based on dependencies
-- This function will fail if there is a cycle
sortDecls :: [Parsed Decl] -> Either [TypeError l] [Parsed Decl]
sortDecls decls = mapE getVertex sccs
sortDecls decls = traverse getVertex sccs
where
sccs = stronglyConnComp graph
graph = mapMaybe mkVertex decls
Expand Down Expand Up @@ -1691,7 +1677,7 @@ mkSchema Struct{..} = buildSchema structMembers
buildSchema (field@Field{..} : fields) = do
(rty, tschema) <- (,)
<$> resolveAnnotatedType fieldType
`collect` buildSchema fields
<*> buildSchema fields
opts@Options{..} <- options <$> ask
let renamed =
Text.unpack $ renameField opts (getAnns structAnns) structName field
Expand Down Expand Up @@ -1813,7 +1799,7 @@ mkServiceMap (thriftName, opts@Options{..}) imap =

sortServices
:: [Parsed Service] -> Either [TypeError l] [Parsed Service]
sortServices services = mapE getVertex sccs
sortServices services = traverse getVertex sccs
where
sccs = stronglyConnComp graph
graph = map mkVertex services
Expand Down
Loading

0 comments on commit c39a45c

Please sign in to comment.