diff --git a/example/Example.hs b/example/Example.hs index f52b0f9..4835b27 100644 --- a/example/Example.hs +++ b/example/Example.hs @@ -36,7 +36,7 @@ module Example where import Circuit -import Clash.Prelude (Signal, Vec (..)) +import Clash.Prelude idCircuit :: Circuit a a idCircuit = idC @@ -88,7 +88,7 @@ sigExpr sig = circuit do -- sigPat :: (( Signal Int -> Signal Int )) sigPat :: Circuit (Signal domain Int) (Signal domain Int) sigPat = circuit $ \(Signal a) -> do - i <- (idC :: Circuit (Signal domain Int) (Signal domain Int)) -< Signal a + i <- idC -< Signal a idC -< i sigPat2 :: Circuit (Signal dom Int) (Signal dom Int) @@ -96,6 +96,17 @@ sigPat2 = circuit $ \(Signal a) -> do i <- (idC :: Circuit (Signal dom Int) (Signal dom Int)) -< Signal a idC -< i +fwdCircuit :: Circuit (Vec 3 (Signal dom Int)) (Vec 3 (Signal dom Int)) +fwdCircuit = circuit $ \(Fwd x) -> do + i <- idC -< Fwd (fmap (+1) x) + idC -< i + +fwdWithLetCircuit :: KnownNat n => Circuit (Vec n (Signal dom Int)) (Vec n (Signal dom Int)) +fwdWithLetCircuit = circuit $ \(Fwd x) -> do + let y = fmap (+1) x + i <- idC -< Fwd y + idC -< i + fstC :: Circuit (Signal domain a, Signal domain b) (Signal domain a) fstC = circuit $ \(a, _b) -> do idC -< a @@ -122,6 +133,21 @@ unfstC3 = circuit $ \a -> do ab' <- idC -< ab idC -< ab' +-- a version of `idC` on `Signal domain Int` which has bad type inference. +idCHard + :: (Fwd a ~ Signal domain Int, Bwd a ~ (), Fwd b ~ Signal domain Int, Bwd b ~ ()) + => Circuit a b +idCHard = Circuit $ \ (aFwd :-> ()) -> () :-> aFwd + +typedBus1 :: forall domain . Circuit (Signal domain Int) (Signal domain Int) +typedBus1 = circuit $ \a -> do + (b :: Signal domain Int) <- idCHard -< a + idCHard -< b + +typedBus2 :: forall domain . Circuit (Signal domain Int) (Signal domain Int) +typedBus2 = circuit $ \a -> do + b <- idCHard -< a + idCHard -< (b :: Signal domain Int) swapTest :: forall a b. Circuit (a,b) (b,a) -- swapTest = circuit $ \(a,b) -> (idCircuit :: Circuit (b, a) (b, a)) -< (b, a) @@ -152,16 +178,16 @@ dupSignalC1 = circuit $ \x -> do -- -- myDesire = Circuit (\(aM2S,bS2M) -> let -- -- (aM2S', bS2M') = runCircuit myCircuit (aM2S, bS2M) -- -- in (aM2S', bS2M')) - +-- -- -- var :: (Int, Int) -- -- var = (3, 5) - +-- -- -- myLet :: Int -- -- myLet = let (yo, yo') = var in yo - +-- -- -- ah :: (Int,Int) -- -- ah = (7,11) - +-- -- -- tupCir1 :: Circuit (Int, Char) (Char, Int) -- -- tupCir1 = circuit \ input -> do -- -- (c,i) <- swapC @Int -< input @@ -170,7 +196,7 @@ dupSignalC1 = circuit $ \x -> do -- -- c' <- myCircuitRev -< c -- -- c'' <- myIdCircuit -< c' -- -- idC -< (i', c'') - +-- -- tupleCircuit :: Circuit Int Char -- tupleCircuit = id $ circuit \a -> do -- let b = 3 @@ -179,7 +205,7 @@ dupSignalC1 = circuit $ \x -> do -- b' <- myCircuit -< a' -- b'' <- (circuit \aa -> do idC -< aa) -< b' -- idC -< b'' - +-- -- -- simpleCircuit :: Circuit Int Char -- -- simpleCircuit = id $ circuit \a -> do -- -- b <- (circuit \a -> do b <- myCircuit -< a;idC -< b) -< a @@ -187,7 +213,7 @@ dupSignalC1 = circuit $ \x -> do -- -- b' <- myCircuit -< a' -- -- b'' <- (circuit \aa -> do idC -< aa) -< b' -- -- idC -< b'' - +-- -- myCircuit :: Int -- myCircuit = circuit \(v1 :: DF d a) (v3 :: blah) -> do -- v1' <- total -< (v3 :: DF domain Int) -< (v4 :: DF domain Int) @@ -196,39 +222,39 @@ dupSignalC1 = circuit $ \x -> do -- -- v2' <- total2 -< v2 -- -- v3 <- zipC -< (v1', v2') -- v1 <- idC -< v3 - +-- -- -- type RunCircuit a b = (Circuit a b -> (M2S a, S2M b) -> (M2S b, S2M a)) -- -- type CircuitId a b = Circuit a b -> Circuit a b - +-- -- -- myCircuit = let -- -- _circuits :: (RunCircuit a b, RunCircuit c d, RunCircuit (b,d) e, CircuitId (a,c) e) -- -- _circuits@(runC1, runC2, runC2, cId) = (runCircuit, runCircuit, runCircuit, id) - +-- -- -- in cId $ Circuit $ \((v1M2S, v2M2S),outputS2M) -> let - +-- -- -- (v1'M2S, v1S2M) = runC1 total (v1M2s, v1'S2M) -- -- (v2'M2S, v2S2M) = runC2 total2 (v2M2s, v2'S2M) -- -- (v3M2S, (v1'S2M, v2'S2M)) = runC3 zipC ((v1'M2S, v2'M2S), v3S2M) - +-- -- -- in (v3M2S, (v1S2M, v2S2M)) - - - - +-- +-- +-- +-- -- -- circuitHelper -- -- :: Circuit a b -- -- -> Circuit c d -- -- -> Circuit (b,d) e - - +-- +-- -- -- myCircuit :: Int -- -- myCircuit = circuit (\(v1,v2) -> (v2,v1)) - +-- -- -- myCircuit :: Int -- -- myCircuit = circuit do -- -- (v2,v1) <- yeah -- -- idC -< (v1, v2) - +-- -- -- myCircuit = proc v1 -> do -- -- x <- total -< value -- -- fin -< a diff --git a/shell.nix b/shell.nix index 97f2f83..ed4af77 100644 --- a/shell.nix +++ b/shell.nix @@ -8,8 +8,8 @@ stdenv.mkDerivation { buildInputs = [ ghc - # cabal-install - # haskellPackages.ghcid + cabal-install + haskellPackages.ghcid haskellPackages.stylish-haskell ]; diff --git a/src/Circuit.hs b/src/Circuit.hs index 1eba6ce..12c9cee 100644 --- a/src/Circuit.hs +++ b/src/Circuit.hs @@ -10,15 +10,18 @@ This file contains the 'Circuit' type, that the notation describes. -} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE NoImplicitPrelude #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeFamilies #-} - -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} module Circuit where @@ -70,6 +73,26 @@ type instance Bwd (Signal dom a) = () newtype Circuit a b = Circuit { runCircuit :: CircuitT a b } type CircuitT a b = (Fwd a :-> Bwd b) -> (Bwd a :-> Fwd b) + +type TagCircuitT a b = (BusTag a (Fwd a) :-> BusTag b (Bwd b)) -> (BusTag a (Bwd a) :-> BusTag b (Fwd b)) + +newtype BusTag t b = BusTag {unBusTag :: b} + +mkTagCircuit :: TagCircuitT a b -> Circuit a b +mkTagCircuit f = Circuit $ \ (aFwd :-> bBwd) -> let + (BusTag aBwd :-> BusTag bFwd) = f (BusTag aFwd :-> BusTag bBwd) + in (aBwd :-> bFwd) + +runTagCircuit :: Circuit a b -> TagCircuitT a b +runTagCircuit (Circuit c) (aFwd :-> bBwd) = let + (aBwd :-> bFwd) = c (unBusTag aFwd :-> unBusTag bBwd) + in (BusTag aBwd :-> BusTag bFwd) + +pattern TagCircuit :: TagCircuitT a b -> Circuit a b +pattern TagCircuit f <- (runTagCircuit -> f) where + TagCircuit f = mkTagCircuit f + + class TrivialBwd a where unitBwd :: a @@ -96,3 +119,83 @@ instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e) instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e, TrivialBwd f) => TrivialBwd (a,b,c,d,e,f) where unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e, TrivialBwd f, TrivialBwd g) => TrivialBwd (a,b,c,d,e,f,g) where + unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e, TrivialBwd f, TrivialBwd g, TrivialBwd h) => TrivialBwd (a,b,c,d,e,f,g,h) where + unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e, TrivialBwd f, TrivialBwd g, TrivialBwd h, TrivialBwd i) => TrivialBwd (a,b,c,d,e,f,g,h,i) where + unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e, TrivialBwd f, TrivialBwd g, TrivialBwd h, TrivialBwd i, TrivialBwd j) => TrivialBwd (a,b,c,d,e,f,g,h,i,j) where + unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd) + +instance TrivialBwd a => TrivialBwd (BusTag t a) where + unitBwd = BusTag unitBwd + +class BusTagBundle t a where + type BusTagUnbundled t a = res | res -> t a + taggedBundle :: BusTagUnbundled t a -> BusTag t a + taggedUnbundle :: BusTag t a -> BusTagUnbundled t a + +instance BusTagBundle () () where + type BusTagUnbundled () () = () + taggedBundle = BusTag + taggedUnbundle = unBusTag + +instance BusTagBundle (ta, tb) (a, b) where + type BusTagUnbundled (ta, tb) (a, b) = (BusTag ta a, BusTag tb b) + taggedBundle (BusTag a, BusTag b) = BusTag (a, b) + taggedUnbundle (BusTag (a, b)) = (BusTag a, BusTag b) + +instance BusTagBundle (ta, tb, tc) (a, b, c) where + type BusTagUnbundled (ta, tb, tc) (a, b, c) = (BusTag ta a, BusTag tb b, BusTag tc c) + taggedBundle (BusTag a, BusTag b, BusTag c) = BusTag (a, b, c) + taggedUnbundle (BusTag (a, b, c)) = (BusTag a, BusTag b, BusTag c) + +instance BusTagBundle (ta, tb, tc, td) (a, b, c, d) where + type BusTagUnbundled (ta, tb, tc, td) (a, b, c, d) = (BusTag ta a, BusTag tb b, BusTag tc c, BusTag td d) + taggedBundle (BusTag a, BusTag b, BusTag c, BusTag d) = BusTag (a, b, c, d) + taggedUnbundle (BusTag (a, b, c, d)) = (BusTag a, BusTag b, BusTag c, BusTag d) + +instance BusTagBundle (ta, tb, tc, td, te) (a, b, c, d, e) where + type BusTagUnbundled (ta, tb, tc, td, te) (a, b, c, d, e) = (BusTag ta a, BusTag tb b, BusTag tc c, BusTag td d, BusTag te e) + taggedBundle (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e) = BusTag (a, b, c, d, e) + taggedUnbundle (BusTag (a, b, c, d, e)) = (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e) + +instance BusTagBundle (ta, tb, tc, td, te, tf) (a, b, c, d, e, f) where + type BusTagUnbundled (ta, tb, tc, td, te, tf) (a, b, c, d, e, f) = (BusTag ta a, BusTag tb b, BusTag tc c, BusTag td d, BusTag te e, BusTag tf f) + taggedBundle (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f) = BusTag (a, b, c, d, e, f) + taggedUnbundle (BusTag (a, b, c, d, e, f)) = (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f) + +instance BusTagBundle (ta, tb, tc, td, te, tf, tg) (a, b, c, d, e, f, g) where + type BusTagUnbundled (ta, tb, tc, td, te, tf, tg) (a, b, c, d, e, f, g) = (BusTag ta a, BusTag tb b, BusTag tc c, BusTag td d, BusTag te e, BusTag tf f, BusTag tg g) + taggedBundle (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g) = BusTag (a, b, c, d, e, f, g) + taggedUnbundle (BusTag (a, b, c, d, e, f, g)) = (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g) + +instance BusTagBundle (ta, tb, tc, td, te, tf, tg, th) (a, b, c, d, e, f, g, h) where + type BusTagUnbundled (ta, tb, tc, td, te, tf, tg, th) (a, b, c, d, e, f, g, h) = (BusTag ta a, BusTag tb b, BusTag tc c, BusTag td d, BusTag te e, BusTag tf f, BusTag tg g, BusTag th h) + taggedBundle (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g, BusTag h) = BusTag (a, b, c, d, e, f, g, h) + taggedUnbundle (BusTag (a, b, c, d, e, f, g, h)) = (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g, BusTag h) + +instance BusTagBundle (ta, tb, tc, td, te, tf, tg, th, ti) (a, b, c, d, e, f, g, h, i) where + type BusTagUnbundled (ta, tb, tc, td, te, tf, tg, th, ti) (a, b, c, d, e, f, g, h, i) = (BusTag ta a, BusTag tb b, BusTag tc c, BusTag td d, BusTag te e, BusTag tf f, BusTag tg g, BusTag th h, BusTag ti i) + taggedBundle (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g, BusTag h, BusTag i) = BusTag (a, b, c, d, e, f, g, h, i) + taggedUnbundle (BusTag (a, b, c, d, e, f, g, h, i)) = (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g, BusTag h, BusTag i) + +instance BusTagBundle (ta, tb, tc, td, te, tf, tg, th, ti, tj) (a, b, c, d, e, f, g, h, i, j) where + type BusTagUnbundled (ta, tb, tc, td, te, tf, tg, th, ti, tj) (a, b, c, d, e, f, g, h, i, j) = (BusTag ta a, BusTag tb b, BusTag tc c, BusTag td d, BusTag te e, BusTag tf f, BusTag tg g, BusTag th h, BusTag ti i, BusTag tj j) + taggedBundle (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g, BusTag h, BusTag i, BusTag j) = BusTag (a, b, c, d, e, f, g, h, i, j) + taggedUnbundle (BusTag (a, b, c, d, e, f, g, h, i, j)) = (BusTag a, BusTag b, BusTag c, BusTag d, BusTag e, BusTag f, BusTag g, BusTag h, BusTag i, BusTag j) + +instance BusTagBundle (Vec n t) (Vec n a) where + type BusTagUnbundled (Vec n t) (Vec n a) = Vec n (BusTag t a) + taggedBundle = BusTag . fmap unBusTag + taggedUnbundle = fmap BusTag . unBusTag + +pattern BusTagBundle :: BusTagBundle t a => BusTagUnbundled t a -> BusTag t a +pattern BusTagBundle a <- (taggedUnbundle -> a) where + BusTagBundle a = taggedBundle a +{-# COMPLETE BusTagBundle #-} diff --git a/src/CircuitNotation.hs b/src/CircuitNotation.hs index 66c91d7..b426e68 100644 --- a/src/CircuitNotation.hs +++ b/src/CircuitNotation.hs @@ -122,7 +122,7 @@ import GHC.Types.Unique.Map import GHC.Types.Unique.Map.Extra -- clash-prelude -import Clash.Prelude (Signal, Vec((:>), Nil)) +import Clash.Prelude (Vec((:>), Nil)) -- lens import qualified Control.Lens as L @@ -242,8 +242,8 @@ data PortDescription a | Ref a | RefMulticast a | Lazy SrcSpanAnnA (PortDescription a) - | SignalExpr (LHsExpr GhcPs) - | SignalPat (LPat GhcPs) + | FwdExpr (LHsExpr GhcPs) + | FwdPat (LPat GhcPs) | PortType (LHsType GhcPs) (PortDescription a) | PortErr SrcSpanAnnA MsgDoc deriving (Foldable, Functor, Traversable) @@ -453,42 +453,6 @@ thName nm = [name] -> name _ -> error "thName called on a non NameG Name" --- | Make a type signature from a port description. Things without a concrete type (e.g. Signal a), --- are given a type name based on the location of the port. -portTypeSigM :: (p ~ GhcPs, ?nms :: ExternalNames) => PortDescription PortName -> CircuitM (LHsType p) -portTypeSigM = \case - Tuple ps -> tupT <$> mapM portTypeSigM ps - Vec s ps -> vecT s <$> mapM portTypeSigM ps - Ref (PortName loc fs) -> do - L.use portVarTypes >>= \pvt -> - case lookupUniqMap pvt fs of - Nothing -> - let - -- GHC >= 9.2 interprets any type variable name starting with a "_" as - -- a wildcard and throws an error suggesting a concrete type. To prevent - -- this error from cropping up, we prefix it with "dflt" if we detect an - -- underscore. Note that we see "_" in cases where the user wants to ignore - -- a certain protocol, hence then name "dflt". - s0 = GHC.unpackFS fs - s1 | '_':_ <- s0 = "dflt" <> s0 - | otherwise = s0 - in - pure $ varT loc (s1 <> "Ty") - Just (_sigLoc, sig) -> pure sig - RefMulticast p -> portTypeSigM (Ref p) - PortErr loc msgdoc -> do - dflags <- GHC.getDynFlags - unsafePerformIO . throwOneError $ - mkLongErrMsg dflags (locA loc) Outputable.alwaysQualify (Outputable.text "portTypeSig") msgdoc - Lazy _ p -> portTypeSigM p - SignalExpr (L l _) -> do - n <- uniqueCounter <<+= 1 - pure $ (conT l (thName ''Signal)) `appTy` (varT l (genLocName l "dom")) `appTy` (varT l (genLocName l ("sig_" <> show n))) - SignalPat (L l _) -> do - n <- uniqueCounter <<+= 1 - pure $ (conT l (thName ''Signal)) `appTy` (varT l (genLocName l "dom")) `appTy` (varT l (genLocName l ("sig_" <> show n))) - PortType _ p -> portTypeSigM p - -- | Generate a "unique" name by appending the location as a string. genLocName :: SrcSpanAnnA -> String -> String #if __GLASGOW_HASKELL__ >= 902 @@ -671,7 +635,7 @@ bindSlave (L loc expr) = case expr of #else ConPatIn (L _ (GHC.Unqual occ)) (PrefixCon [lpat]) #endif - | OccName.occNameString occ == "Signal" -> SignalPat lpat + | OccName.occNameString occ `elem` fwdNames -> FwdPat lpat -- empty list is done as the constructor #if __GLASGOW_HASKELL__ >= 900 ConPat _ (L _ rdr) _ @@ -706,7 +670,7 @@ bindMaster (L loc expr) = case expr of | rdrName == thName '[] -> Vec loc [] -- XXX: vloc? | otherwise -> Ref (PortName loc (fromRdrName rdrName)) -- XXX: vloc? HsApp _xapp (L _ (HsVar _ (L _ (GHC.Unqual occ)))) sig - | OccName.occNameString occ == "Signal" -> SignalExpr sig + | OccName.occNameString occ `elem` fwdNames -> FwdExpr sig ExplicitTuple _ tups _ -> let #if __GLASGOW_HASKELL__ >= 902 vals = fmap (\(Present _ e) -> e) tups @@ -722,16 +686,18 @@ bindMaster (L loc expr) = case expr of Vec loc $ fmap bindMaster exprs #if __GLASGOW_HASKELL__ < 810 HsArrApp _xapp (L _ (HsVar _ (L _ (GHC.Unqual occ)))) sig _ _ - | OccName.occNameString occ == "Signal" -> SignalExpr sig + | OccName.occNameString occ `elem` fwdNames -> FwdExpr sig ExprWithTySig ty expr' -> PortType (hsSigWcType ty) (bindMaster expr') ELazyPat _ expr' -> Lazy loc (bindMaster expr') #else -- XXX: Untested? HsProc _ _ (L _ (HsCmdTop _ (L _ (HsCmdArrApp _xapp (L _ (HsVar _ (L _ (GHC.Unqual occ)))) sig _ _)))) - | OccName.occNameString occ == "Signal" -> SignalExpr sig + | OccName.occNameString occ `elem` fwdNames -> FwdExpr sig ExprWithTySig _ expr' ty -> PortType (hsSigWcType ty) (bindMaster expr') #endif + HsPar _ expr' -> bindMaster expr' + -- OpApp _xapp (L _ circuitVar) (L _ infixVar) appR -> k _ -> PortErr loc @@ -825,12 +791,12 @@ checkCircuit = do -- Creating ------------------------------------------------------------ -data Direc = Fwd | Bwd deriving Show +data Direction = Fwd | Bwd deriving Show -bindWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Direc -> PortDescription PortName -> LPat p +bindWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Direction -> PortDescription PortName -> LPat p bindWithSuffix dflags dir = \case - Tuple ps -> tildeP noSrcSpanA $ tupP $ fmap (bindWithSuffix dflags dir) ps - Vec s ps -> vecP s $ fmap (bindWithSuffix dflags dir) ps + Tuple ps -> tildeP noSrcSpanA $ taggedBundleP $ tupP $ fmap (bindWithSuffix dflags dir) ps + Vec s ps -> taggedBundleP $ vecP s $ fmap (bindWithSuffix dflags dir) ps Ref (PortName loc fs) -> varP loc (GHC.unpackFS fs <> "_" <> show dir) RefMulticast (PortName loc fs) -> case dir of Bwd -> L loc (WildPat noExtField) @@ -840,14 +806,14 @@ bindWithSuffix dflags dir = \case Lazy loc p -> tildeP loc $ bindWithSuffix dflags dir p #if __GLASGOW_HASKELL__ >= 902 -- XXX: propagate location - SignalExpr (L _ _) -> nlWildPat + FwdExpr (L _ _) -> nlWildPat #else - SignalExpr (L l _) -> L l (WildPat noExt) + FwdExpr (L l _) -> L l (WildPat noExt) #endif - SignalPat lpat -> lpat - PortType _ p -> bindWithSuffix dflags dir p + FwdPat lpat -> tagP lpat + PortType _ty p -> bindWithSuffix dflags dir p -revDirec :: Direc -> Direc +revDirec :: Direction -> Direction revDirec = \case Fwd -> Bwd Bwd -> Fwd @@ -855,7 +821,7 @@ revDirec = \case bindOutputs :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags - -> Direc + -> Direction -> PortDescription PortName -- ^ slave ports -> PortDescription PortName @@ -866,10 +832,10 @@ bindOutputs dflags direc slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nm m2s = bindWithSuffix dflags direc masters s2m = bindWithSuffix dflags (revDirec direc) slaves -expWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => Direc -> PortDescription PortName -> LHsExpr p +expWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => Direction -> PortDescription PortName -> LHsExpr p expWithSuffix dir = \case - Tuple ps -> tupE noSrcSpanA $ fmap (expWithSuffix dir) ps - Vec s ps -> vecE s $ fmap (expWithSuffix dir) ps + Tuple ps -> taggedBundleE $ tupE noSrcSpanA $ fmap (expWithSuffix dir) ps + Vec s ps -> taggedBundleE $ vecE s $ fmap (expWithSuffix dir) ps Ref (PortName loc fs) -> varE loc (var $ GHC.unpackFS fs <> "_" <> show dir) RefMulticast (PortName loc fs) -> case dir of Bwd -> varE noSrcSpanA (trivialBwd ?nms) @@ -877,13 +843,13 @@ expWithSuffix dir = \case -- laziness only affects the pattern side Lazy _ p -> expWithSuffix dir p PortErr _ _ -> error "expWithSuffix PortErr!" - SignalExpr lexpr -> lexpr - SignalPat (L l _) -> tupE l [] - PortType _ p -> expWithSuffix dir p + FwdExpr lexpr -> tagE lexpr + FwdPat (L l _) -> tagE $ varE l (trivialBwd ?nms) + PortType ty p -> tagTypeE dir ty (expWithSuffix dir p) createInputs :: (p ~ GhcPs, ?nms :: ExternalNames) - => Direc + => Direction -> PortDescription PortName -- ^ slave ports -> PortDescription PortName @@ -894,18 +860,18 @@ createInputs dir slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpanA (fwdB m2s = expWithSuffix (revDirec dir) masters s2m = expWithSuffix dir slaves -decFromBinding :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Int -> Binding (LHsExpr p) PortName -> HsBind p -decFromBinding dflags i Binding {..} = do +decFromBinding :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Binding (LHsExpr p) PortName -> HsBind p +decFromBinding dflags Binding {..} = do let bindPat = bindOutputs dflags Bwd bIn bOut inputExp = createInputs Fwd bOut bIn - bod = varE noSrcSpanA (var $ "run" <> show i) `appE` bCircuit `appE` inputExp + bod = runCircuitFun noSrcSpanA `appE` bCircuit `appE` inputExp in patBind bindPat bod patBind :: LPat GhcPs -> LHsExpr GhcPs -> HsBind GhcPs patBind lhs expr = PatBind noExt lhs rhs ([], []) where rhs :: GRHSs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)) - rhs = GRHSs emptyComments [gr] $ + rhs = GRHSs emptyComments [gr] $ #if __GLASGOW_HASKELL__ >= 902 EmptyLocalBinds noExtField #else @@ -921,6 +887,62 @@ circuitConstructor loc = varE loc (circuitCon ?nms) runCircuitFun :: (?nms :: ExternalNames) => SrcSpanAnnA -> LHsExpr GhcPs runCircuitFun loc = varE loc (runCircuitName ?nms) + +#if __GLASGOW_HASKELL__ < 902 +prefixCon :: [arg] -> HsConDetails arg rec +prefixCon a = PrefixCon a +#else +prefixCon :: [arg] -> HsConDetails tyarg arg rec +prefixCon a = PrefixCon [] a +#endif + +taggedBundleP :: (p ~ GhcPs, ?nms :: ExternalNames) => LPat p -> LPat p +taggedBundleP a = noLoc (conPatIn (noLoc (tagBundlePat ?nms)) (prefixCon [a])) + +taggedBundleE :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsExpr p -> LHsExpr p +taggedBundleE a = varE noSrcSpanA (tagBundlePat ?nms) `appE` a + +tagP :: (p ~ GhcPs, ?nms :: ExternalNames) => LPat p -> LPat p +tagP a = noLoc (conPatIn (noLoc (tagName ?nms)) (prefixCon [a])) + +tagE :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsExpr p -> LHsExpr p +tagE a = varE noSrcSpanA (tagName ?nms) `appE` a + +tagTypeCon :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType GhcPs +tagTypeCon = noLoc (HsTyVar noExt NotPromoted (noLoc (tagTName ?nms))) + +sigPat :: (p ~ GhcPs) => SrcSpanAnnA -> LHsType GhcPs -> LPat p -> LPat p +sigPat loc ty a = L loc $ +#if __GLASGOW_HASKELL__ < 810 + SigPat (HsWC noExt (HsIB noExt ty)) a +#elif __GLASGOW_HASKELL__ < 900 + SigPat noExt a (HsWC noExt (HsIB noExt ty)) +#else + SigPat noExt a (HsPS noExt ty) +#endif + +sigE :: (?nms :: ExternalNames) => SrcSpanAnnA -> LHsType GhcPs -> LHsExpr GhcPs -> LHsExpr GhcPs +sigE loc ty a = L loc $ +#if __GLASGOW_HASKELL__ < 810 + ExprWithTySig (HsWC noExt (HsIB noExt ty)) a +#elif __GLASGOW_HASKELL__ < 902 + ExprWithTySig noExt a (HsWC noExt (HsIB noExt ty)) +#else + ExprWithTySig noExt a (HsWC noExtField (L loc $ HsSig noExtField (HsOuterImplicit noExtField) ty)) +#endif + +tagTypeP :: (p ~ GhcPs, ?nms :: ExternalNames) => Direction -> LHsType GhcPs -> LPat p -> LPat p +tagTypeP dir ty + = sigPat noSrcSpanA (tagTypeCon `appTy` ty `appTy` busType) + where + busType = conT noSrcSpanA (fwdAndBwdTypes ?nms dir) `appTy` ty + +tagTypeE :: (p ~ GhcPs, ?nms :: ExternalNames) => Direction -> LHsType GhcPs -> LHsExpr p -> LHsExpr p +tagTypeE dir ty a + = sigE noSrcSpanA (tagTypeCon `appTy` ty `appTy` busType) a + where + busType = conT noSrcSpanA (fwdAndBwdTypes ?nms dir) `appTy` ty + constVar :: SrcSpanAnnA -> LHsExpr GhcPs constVar loc = varE loc (thName 'const) @@ -951,20 +973,6 @@ varT loc nm = L loc (HsTyVar noExt NotPromoted (noLoc (tyVar nm))) conT :: SrcSpanAnnA -> GHC.RdrName -> LHsType GhcPs conT loc nm = L loc (HsTyVar noExt NotPromoted (noLoc nm)) -circuitTy :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType p -> LHsType p -> LHsType p -circuitTy a b = conT noSrcSpanA (circuitTyCon ?nms) `appTy` a `appTy` b - -circuitTTy :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType p -> LHsType p -> LHsType p -circuitTTy a b = conT noSrcSpanA (circuitTTyCon ?nms) `appTy` a `appTy` b - --- a b -> (Circuit a b -> CircuitT a b) -mkRunCircuitTy :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType p -> LHsType p -> LHsType p -mkRunCircuitTy a b = noLoc $ hsFunTy (circuitTy a b) (circuitTTy a b) - --- a b -> (CircuitT a b -> Circuit a b) -mkCircuitTy :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType p -> LHsType p -> LHsType p -mkCircuitTy a b = noLoc $ hsFunTy (circuitTTy a b) (circuitTy a b) - -- perhaps this should happen on construction gatherTypes :: p ~ GhcPs @@ -973,7 +981,7 @@ gatherTypes gatherTypes = L.traverseOf_ L.cosmos addTypes where addTypes = \case - PortType ty (Ref (PortName loc fs)) -> + PortType ty (Ref (PortName loc fs)) -> portVarTypes %= \pvt -> alterUniqMap (const (Just (loc, ty))) pvt fs PortType ty p -> portTypes <>= [(ty, p)] _ -> pure () @@ -999,10 +1007,8 @@ circuitQQExpM = do masters <- L.use circuitMasters -- Construction of the circuit expression - let decs = concat - [ lets - , imap (\i -> noLoc . decFromBinding dflags i) binds - ] + let decs = lets <> map (noLoc . decFromBinding dflags) binds + let pats = bindOutputs dflags Fwd masters slaves res = createInputs Bwd slaves masters @@ -1015,77 +1021,7 @@ circuitQQExpM = do binds mapM_ gatherTypes [masters, slaves] - slavesTy <- portTypeSigM slaves - mastersTy <- portTypeSigM masters - let mkRunTy bind = - mkRunCircuitTy <$> - (portTypeSigM (bOut bind)) <*> - (portTypeSigM (bIn bind)) - bindTypes <- mapM mkRunTy binds - let runCircuitsType = - noLoc (HsParTy noExt (tupT bindTypes `arrTy` circuitTTy slavesTy mastersTy)) - `arrTy` circuitTy slavesTy mastersTy - - allTypes <- L.use portTypes - - context <- mapM (\(ty, p) -> tyEq <$> portTypeSigM p <*> pure ty) allTypes - - -- the full signature - loc <- L.use circuitLoc - let inferenceHelperName = genLocName loc "inferenceHelper" - inferenceSig :: LHsSigType GhcPs -#if __GLASGOW_HASKELL__ >= 902 - inferenceSig = noLoc $ - HsSig - noExtField - (HsOuterImplicit noExtField) - (noLoc $ HsQualTy noExtField (Just (noLoc context)) runCircuitsType) -#else - inferenceSig = HsIB noExt (noLoc $ HsQualTy noExt (noLoc context) runCircuitsType) -#endif - - inferenceHelperTy = - TypeSig noExt - [noLoc (var inferenceHelperName)] - (HsWC noExtField inferenceSig) - - let numBinds = length binds - runCircuitExprs = lamE [varP noSrcSpanA "f"] $ - circuitConstructor noSrcSpanA `appE` - noLoc (HsPar noExt - (varE noSrcSpanA (var "f") `appE` tupE noSrcSpanA (replicate numBinds (runCircuitFun noSrcSpanA)))) - runCircuitBinds = tupP $ map (\i -> varP noSrcSpanA ("run" <> show i)) [0 .. numBinds-1] - - let c = letE noSrcSpanA - [noLoc inferenceHelperTy] - [noLoc $ patBind (varP noSrcSpanA inferenceHelperName) (runCircuitExprs)] - (varE noSrcSpanA (var inferenceHelperName) `appE` lamE [runCircuitBinds, pats] body) - -- ppr c - pure c - - -- pure $ varE noSrcSpan (var "undefined") - --- [inference-helper] --- The inference helper constructs the circuit and provides all the `runCircuit`s with the types --- matching the structure of the port expressions. This way we can enforce that ports 'keep the --- same type' which normally gets lost when deconstructing and reconstructing types. It also means --- that we can add type annotations of the ports as a context to this helper function. For example --- --- swapIC c = circuit $ \(a :: Int, b) -> do --- a' <- c -< a --- b' <- c -< b --- idC -< (b',a') --- --- will produce the helper --- --- inferenceHelper :: --- aTy ~ Int => --- -> ( (Circuit aTy a'Ty -> CircuitT aTy a'Ty) --- -> (Circuit bTy b'Ty -> CircuitT bTy b'Ty) --- -> CircuitT (aTy, bTy) (b'Ty, a'Ty) --- ) -> CircuitT (aTy, bTy) (b'Ty, a'Ty) --- inferenceHelper = \f -> Circuit (f runCircuit runCircuit) - + pure $ circuitConstructor noSrcSpanA `appE` lamE [pats] body grr :: MonadIO m => OccName.NameSpace -> m () grr nm @@ -1097,7 +1033,7 @@ grr nm | nm == OccName.tvName = liftIO $ putStrLn "tvName" | otherwise = liftIO $ putStrLn "I dunno" -completeUnderscores :: CircuitM () +completeUnderscores :: (?nms :: ExternalNames) => CircuitM () completeUnderscores = do binds <- L.use circuitBinds masters <- L.use circuitMasters @@ -1105,7 +1041,7 @@ completeUnderscores = do let addDef :: String -> PortDescription PortName -> CircuitM () addDef suffix = \case Ref (PortName loc (unpackFS -> name@('_':_))) -> do - let bind = patBind (varP loc (name <> suffix)) (varE loc (thName 'def)) + let bind = patBind (varP loc (name <> suffix)) (tagE $ varE loc (thName 'def)) circuitLets <>= [L loc bind] _ -> pure () @@ -1195,7 +1131,7 @@ pluginImpl cliOptions _modSummary m = do debug <- case cliOptions of [] -> pure False ["debug"] -> pure True - _ -> do + _ -> do warningMsg $ Outputable.text $ "CircuitNotation: unknown cli options " <> show cliOptions pure False hpm_module' <- do @@ -1213,28 +1149,33 @@ ppr a = do showC :: Data.Data a => a -> String showC a = show (typeOf a) <> " " <> show (Data.toConstr a) --- ppp :: MonadIO m => String -> m () --- ppp s = case SP.parseValue s of --- Just a -> valToStr a - -- Names --------------------------------------------------------------- +fwdNames :: [String] +fwdNames = ["Fwd", "Signal"] + -- | Collection of names external to circuit-notation. data ExternalNames = ExternalNames { circuitCon :: GHC.RdrName - , circuitTyCon :: GHC.RdrName - , circuitTTyCon :: GHC.RdrName , runCircuitName :: GHC.RdrName + , tagBundlePat :: GHC.RdrName + , tagName :: GHC.RdrName + , tagTName :: GHC.RdrName , fwdBwdCon :: GHC.RdrName + , fwdAndBwdTypes :: Direction -> GHC.RdrName , trivialBwd :: GHC.RdrName } defExternalNames :: ExternalNames defExternalNames = ExternalNames - { circuitCon = GHC.Unqual (OccName.mkDataOcc "Circuit") - , circuitTyCon = GHC.Unqual (OccName.mkTcOcc "Circuit") - , circuitTTyCon = GHC.Unqual (OccName.mkTcOcc "CircuitT") - , runCircuitName = GHC.Unqual (OccName.mkVarOcc "runCircuit") + { circuitCon = GHC.Unqual (OccName.mkDataOcc "TagCircuit") + , runCircuitName = GHC.Unqual (OccName.mkVarOcc "runTagCircuit") + , tagBundlePat = GHC.Unqual (OccName.mkDataOcc "BusTagBundle") + , tagName = GHC.Unqual (OccName.mkDataOcc "BusTag") + , tagTName = GHC.Unqual (OccName.mkTcOcc "BusTag") , fwdBwdCon = GHC.Unqual (OccName.mkDataOcc ":->") + , fwdAndBwdTypes = \case + Fwd -> GHC.Unqual (OccName.mkTcOcc "Fwd") + Bwd -> GHC.Unqual (OccName.mkTcOcc "Bwd") , trivialBwd = GHC.Unqual (OccName.mkVarOcc "unitBwd") }