Skip to content

Commit

Permalink
New :table command (#376)
Browse files Browse the repository at this point in the history
New `:table` REPL command which can print formatted tables for lists, tuples, or functions.
  • Loading branch information
byorgey authored Mar 8, 2024
1 parent 474384f commit 5a694e3
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 59 deletions.
13 changes: 10 additions & 3 deletions disco.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ extra-source-files: stack.yaml, repl/*.hs
test/syntax-tuples/input
test/syntax-types/expected
test/syntax-types/input
test/table-error/expected
test/table-error/input
test/table-function/expected
test/table-function/input
test/table-list/expected
test/table-list/input
test/types-192/expected
test/types-192/input
test/types-306/expected
Expand Down Expand Up @@ -462,7 +468,7 @@ library
filepath,
directory,
mtl >=2.2 && <2.4,
megaparsec >= 6.1.1 && < 9.6,
megaparsec >= 6.1.1 && < 9.7,
parser-combinators >= 1.0.0 && < 1.4,
prettyprinter >=1.7 && < 1.8,
split >= 0.2 && < 0.3,
Expand All @@ -488,7 +494,8 @@ library
optparse-applicative >= 0.12 && < 0.19,
-- oeis2 < 1.1,
algebraic-graphs >= 0.5 && < 0.8,
pretty-show >= 1.10 && < 1.11
pretty-show >= 1.10 && < 1.11,
boxes >= 0.1.5 && < 0.2

hs-source-dirs: src
default-language: Haskell2010
Expand All @@ -505,7 +512,7 @@ executable disco
haskeline >=0.8 && <0.9,
mtl >=2.2 && <2.4,
transformers >= 0.4 && < 0.7,
megaparsec >= 6.1.1 && < 9.6,
megaparsec >= 6.1.1 && < 9.7,
containers >= 0.5 && < 0.7,
unbound-generics >= 0.3 && < 0.4.3,
lens >= 4.14 && < 5.3,
Expand Down
213 changes: 179 additions & 34 deletions src/Disco/Interactive/Commands.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE StandaloneDeriving #-}

-- |
Expand Down Expand Up @@ -29,38 +28,24 @@ import Control.Lens (
(^.),
)
import Control.Monad.Except
import Data.Bifunctor (second)
import Data.Char (isSpace)
import Data.Coerce
import Data.List (find, isPrefixOf, sortBy)
import Data.List (find, isPrefixOf, sortBy, transpose)
import Data.Map ((!))
import qualified Data.Map as M
import Data.Typeable
import System.FilePath (splitFileName)
import Prelude as P

import Text.Megaparsec hiding (State, runParser)
import qualified Text.Megaparsec.Char as C
import Unbound.Generics.LocallyNameless (
Name,
name2String,
string2Name,
)

import Disco.Effects.Input
import Disco.Effects.LFresh
import Disco.Effects.State
import Polysemy
import Polysemy.Error hiding (try)
import Polysemy.Output
import Polysemy.Reader

import Data.Maybe (mapMaybe, maybeToList)
import Data.Typeable
import Disco.AST.Surface
import Disco.AST.Typed
import Disco.Compile
import Disco.Context as Ctx
import Disco.Desugar
import Disco.Doc
import Disco.Effects.Input
import Disco.Effects.LFresh
import Disco.Effects.State
import Disco.Enumerate (enumerateType)
import Disco.Error
import Disco.Eval
import Disco.Extensions
Expand Down Expand Up @@ -89,8 +74,23 @@ import Disco.Syntax.Prims (
)
import Disco.Typecheck
import Disco.Typecheck.Erase
import Disco.Types (toPolyType, pattern TyString)
import Disco.Types
import Disco.Util (maximum0)
import Disco.Value
import Polysemy
import Polysemy.Error hiding (try)
import Polysemy.Output
import Polysemy.Reader
import System.FilePath (splitFileName)
import Text.Megaparsec hiding (State, runParser)
import qualified Text.Megaparsec.Char as C
import qualified Text.PrettyPrint.Boxes as B
import Unbound.Generics.LocallyNameless (
Name,
name2String,
string2Name,
)
import Prelude as P

------------------------------------------------------------
-- REPL expression type
Expand All @@ -106,6 +106,7 @@ data REPLExpr :: CmdTag -> * where
Parse :: Term -> REPLExpr 'CParse -- Show the parsed AST
Pretty :: Term -> REPLExpr 'CPretty -- Pretty-print a term
Print :: Term -> REPLExpr 'CPrint -- Print a string
Table :: Term -> REPLExpr 'CTable -- Print a table
Ann :: Term -> REPLExpr 'CAnn -- Show type-annotated term
Desugar :: Term -> REPLExpr 'CDesugar -- Show a desugared term
Compile :: Term -> REPLExpr 'CCompile -- Show a compiled term
Expand Down Expand Up @@ -150,6 +151,7 @@ data CmdTag
| CParse
| CPretty
| CPrint
| CTable
| CAnn
| CDesugar
| CCompile
Expand Down Expand Up @@ -229,6 +231,7 @@ discoCommands =
, SomeCmd parseCmd
, SomeCmd prettyCmd
, SomeCmd printCmd
, SomeCmd tableCmd
, SomeCmd reloadCmd
, SomeCmd showDefnCmd
, SomeCmd typeCheckCmd
Expand Down Expand Up @@ -538,7 +541,7 @@ handleEval ::
REPLExpr 'CEval ->
Sem r ()
handleEval (Eval m) = do
mi <- inputToState @TopInfo $ loadParsedDiscoModule False FromCwdOrStdlib REPLModule m
mi <- loadParsedDiscoModule False FromCwdOrStdlib REPLModule m
addToREPLModule mi
forM_ (mi ^. miTerms) (mapError EvalErr . evalTerm True . fst)

Expand Down Expand Up @@ -589,7 +592,7 @@ handleHelp Help =
sortedList cmds =
sortBy (\(SomeCmd x) (SomeCmd y) -> compare (name x) (name y)) $ filteredCommands cmds
showCmd c = text (padRight (helpcmd c) maxlen ++ " " ++ shortHelp c)
longestCmd cmds = maximum $ map (\(SomeCmd c) -> length $ helpcmd c) cmds
longestCmd cmds = maximum0 $ map (\(SomeCmd c) -> length $ helpcmd c) cmds
padRight s maxsize = take maxsize (s ++ repeat ' ')
-- don't show dev-only commands by default
filteredCommands = P.filter (\(SomeCmd c) -> category c == User)
Expand Down Expand Up @@ -781,6 +784,148 @@ handlePrint (Print t) = do
v <- mapError EvalErr . evalTerm False $ at
info $ text (vlist vchar v)

------------------------------------------------------------
-- :table

tableCmd :: REPLCommand 'CTable
tableCmd =
REPLCommand
{ name = "table"
, helpcmd = ":table <expr>"
, shortHelp = "Print a formatted table for a list or function"
, category = User
, cmdtype = ColonCmd
, action = handleTable
, parser = Table <$> parseTermOrOp
}

handleTable :: Members (Error DiscoError ': State TopInfo ': Output (Message ()) ': EvalEffects) r => REPLExpr 'CTable -> Sem r ()
handleTable (Table t) = do
(at, ty) <- inputToState . typecheckTop $ inferTop t
v <- mapError EvalErr . evalTerm False $ at

tydefs <- use @TopInfo (replModInfo . to allTydefs)
info $ runInputConst tydefs $ formatTableFor ty v >>= text

-- | The max number of rows to show in the output of :table.
maxFunTableRows :: Int
maxFunTableRows = 25

-- | Uncurry a type, turning a type of the form A -> B -> ... -> Y ->
-- Z into the pair of types (A * B * ... * Y * Unit, Z). Note we do
-- not optimize away the Unit at the end of the chain, since this
-- needs to be an isomorphism. Otherwise we would not be able to
-- distinguish between e.g. Z and Unit -> Z.
uncurryTy :: Type -> (Type, Type)
uncurryTy (tyA :->: tyB) = (tyA :*: tyAs, tyRes)
where
(tyAs, tyRes) = uncurryTy tyB
uncurryTy ty = (TyUnit, ty)

-- | Evaluate the application of a curried function to an uncurried
-- input.
evalCurried :: Members EvalEffects r => Type -> Value -> Type -> Value -> Sem r Value
evalCurried (_ :->: tyB) f (_ :*: tyY) v = do
let (v1, v2) = vpair id id v
f' <- evalApp f [v1]
evalCurried tyB f' tyY v2
evalCurried _ v _ _ = return v

formatTableFor ::
Members (LFresh ': Input TyDefCtx ': EvalEffects) r =>
PolyType ->
Value ->
Sem r String
formatTableFor (Forall bnd) v = lunbind bnd $ \(vars, ty) ->
case vars of
[] -> case ty of
TyList ety -> do
byRows <- mapM (formatCols TopLevel ety) . vlist id $ v
return $ renderTable byRows
(_ :->: _) -> do
let (tyInputs, tyRes) = uncurryTy ty
vs = take (maxFunTableRows + 1) $ enumerateType tyInputs
(tyInputs', stripV) = stripFinalUnit tyInputs
results <- mapM (evalCurried ty v tyInputs) vs
byRows <-
mapM
(formatCols TopLevel (tyInputs' :*: tyRes))
(zipWith (curry (pairv id id)) (take maxFunTableRows (map stripV vs)) results)
return $ renderTable (byRows ++ [[(B.left, "...")] | length vs == maxFunTableRows + 1])
_otherTy -> do
tyStr <- prettyStr ty
return $ "Don't know how to make a table for type " ++ tyStr
_vars -> return "Can't make a table for a polymorphic type"

-- | Strip the unit type from the end of a chain like (tA :*: (tB :*: (tC :*: Unit))),
-- which is an output of 'uncurryTy', and return a function to make the corresponding
-- change to a value of that type.
stripFinalUnit :: Type -> (Type, Value -> Value)
stripFinalUnit (tA :*: TyUnit) = (tA, fst . vpair id id)
stripFinalUnit (tA :*: tB) = (tA :*: tB', pairv id id . second v' . vpair id id)
where
(tB', v') = stripFinalUnit tB
stripFinalUnit ty = (ty, id)

data Level = TopLevel | NestedPair | InnerLevel
deriving (Eq, Ord, Show)

-- | Turn a value into a list of formatted columns in a type-directed
-- way. Lists and tuples are only split out into columns if they
-- occur at the top level; lists or tuples nested inside of other
-- data structures are simply pretty-printed. However, note we have
-- to make a special case for nested tuples: if a pair type occurs
-- at the top level we keep recursively splitting out its children
-- into columns as long as they are also pair types.
--
-- Any value of a type other than a list or tuple is simply
-- pretty-printed.
formatCols ::
(Member LFresh r, Member (Input TyDefCtx) r) =>
Level ->
Type ->
Value ->
Sem r [(B.Alignment, String)]
formatCols l (t1 :*: t2) (vpair id id -> (v1, v2))
| l `elem` [TopLevel, NestedPair] =
(++) <$> formatCols NestedPair t1 v1 <*> formatCols NestedPair t2 v2
-- Special case for String (= List Char), just print as string value
formatCols TopLevel TyString v = formatColDefault TyString v
-- For any other lists @ top level, print each element in a separate column
formatCols TopLevel (TyList ety) (vlist id -> vs) =
concat <$> mapM (formatCols InnerLevel ety) vs
formatCols _ ty v = formatColDefault ty v

-- | Default formatting of a typed column value by simply
-- pretty-printing it, and using the alignment appropriate for its
-- type.
formatColDefault ::
(Member (Input TyDefCtx) r, Member LFresh r) =>
Type ->
Value ->
Sem r [(B.Alignment, String)]
formatColDefault ty v = (: []) . (alignmentForType ty,) <$> renderDoc (prettyValue ty v)

alignmentForType :: Type -> B.Alignment
alignmentForType ty | ty `elem` [TyN, TyZ, TyF, TyQ] = B.right
alignmentForType _ = B.left

-- | Render a table, given as a list of rows, formatting it so that
-- each column is aligned.
renderTable :: [[(B.Alignment, String)]] -> String
renderTable = stripTrailingWS . B.render . B.hsep 2 B.top . map renderCol . transpose . pad
where
pad :: [[(B.Alignment, String)]] -> [[(B.Alignment, String)]]
pad rows = map (padTo (maximum0 . map length $ rows)) rows
padTo n = take n . (++ repeat (B.left, ""))

renderCol :: [(B.Alignment, String)] -> B.Box
renderCol [] = B.nullBox
renderCol ((align, x) : xs) = B.vcat align . map B.text $ x : map snd xs

stripTrailingWS = unlines . map stripEnd . lines
stripEnd = reverse . dropWhile isSpace . reverse

------------------------------------------------------------
-- :reload

Expand Down Expand Up @@ -837,7 +982,7 @@ handleShowDefn (ShowDefn x) = do
let ds = map (pretty' . snd) xdefs ++ maybe [] (pure . pretty' . (name2s,)) mtydef
case ds of
[] -> text "No definition for" <+> pretty' x
_ -> vcat ds
_nonEmptyList -> vcat ds

------------------------------------------------------------
-- :test
Expand Down Expand Up @@ -877,7 +1022,7 @@ typeCheckCmd =
, category = Dev
, cmdtype = ColonCmd
, action = inputToState @TopInfo . handleTypeCheck
, parser = parseTypeCheck
, parser = TypeCheck <$> parseTermOrOp
}

handleTypeCheck ::
Expand All @@ -888,17 +1033,17 @@ handleTypeCheck (TypeCheck t) = do
(_, sig) <- typecheckTop $ inferTop t
info $ pretty' t <+> text ":" <+> pretty' sig

parseTypeCheck :: Parser (REPLExpr 'CTypeCheck)
parseTypeCheck =
TypeCheck
<$> ( (try term <?> "expression")
<|> (parseNakedOp <?> "operator")
)
------------------------------------------------------------

-- In a :type or :doc command, allow naked operators, as in :type + ,
-- In :type, :doc, or :table commands, allow naked operators, as in :type + ,
-- even though + by itself is not a syntactically valid term.
-- However, this seems like it may be a common thing for a student to
-- ask and there is no reason we can't have this as a special case.
parseTermOrOp :: Parser Term
parseTermOrOp =
(try term <?> "expression")
<|> (parseNakedOp <?> "operator")

parseNakedOp :: Parser Term
parseNakedOp = TPrim <$> parseNakedOpPrim

Expand Down
Loading

0 comments on commit 5a694e3

Please sign in to comment.