diff --git a/clash-protocols/src/Protocols/Internal.hs b/clash-protocols/src/Protocols/Internal.hs index 22558b66..f42e726b 100644 --- a/clash-protocols/src/Protocols/Internal.hs +++ b/clash-protocols/src/Protocols/Internal.hs @@ -5,6 +5,7 @@ {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fconstraint-solver-iterations=20 #-} #if !MIN_VERSION_clash_prelude(1, 8, 2) {-# OPTIONS_GHC -fno-warn-orphans #-} #endif @@ -33,8 +34,10 @@ import qualified Clash.Explicit.Prelude as CE import Clash.Prelude (type (*), type (+)) import qualified Clash.Prelude as C +import Protocols.Internal.TH (simulateTupleInstances) import Protocols.Internal.Types import Protocols.Plugin +import Protocols.Plugin.Cpp (maxTupleSize) import Protocols.Plugin.TaggedBundle import Protocols.Plugin.Units @@ -234,6 +237,8 @@ instance (Simulate a, Simulate b) => Simulate (a, b) where in ((fwdL1, fwdR1), (bwdL1, bwdR1)) +simulateTupleInstances 3 maxTupleSize + instance (Drivable a, Drivable b) => Drivable (a, b) where type ExpectType (a, b) = (ExpectType a, ExpectType b) diff --git a/clash-protocols/src/Protocols/Internal/TH.hs b/clash-protocols/src/Protocols/Internal/TH.hs index 72743204..02b7d290 100644 --- a/clash-protocols/src/Protocols/Internal/TH.hs +++ b/clash-protocols/src/Protocols/Internal/TH.hs @@ -2,9 +2,12 @@ module Protocols.Internal.TH where +import qualified Clash.Prelude as C import Control.Monad.Extra (concatMapM) +import GHC.TypeNats import Language.Haskell.TH import Protocols.Internal.Types +import Protocols.Plugin {- | Template haskell function to generate IdleCircuit instances for the tuples n through m inclusive. To see a 2-tuple version of the pattern we generate, @@ -31,3 +34,77 @@ idleCircuitTupleInstance n = mkFwdExpr ty = [e|idleFwd $ Proxy @($ty)|] bwdExpr = tupE $ map mkBwdExpr circTys mkBwdExpr ty = [e|idleBwd $ Proxy @($ty)|] + +simulateTupleInstances :: Int -> Int -> DecsQ +simulateTupleInstances n m = concatMapM simulateTupleInstance [n .. m] + +simulateTupleInstance :: Int -> DecsQ +simulateTupleInstance n = + [d| + instance ($instCtx) => Simulate $instTy where + type SimulateFwdType $instTy = $fwdType + type SimulateBwdType $instTy = $bwdType + type SimulateChannels $instTy = $channelSum + + simToSigFwd _ $fwdPat0 = $(tupE $ zipWith (\ty expr -> [e|simToSigFwd (Proxy @($ty)) $expr|]) circTys fwdExpr) + simToSigBwd _ $bwdPat0 = $(tupE $ zipWith (\ty expr -> [e|simToSigBwd (Proxy @($ty)) $expr|]) circTys bwdExpr) + sigToSimFwd _ $fwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimFwd (Proxy @($ty)) $expr|]) circTys fwdExpr) + sigToSimBwd _ $bwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimBwd (Proxy @($ty)) $expr|]) circTys bwdExpr) + + stallC $(varP $ mkName "conf") $(varP $ mkName "rem0") = $(letE (stallVecs ++ stallCircuits) stallCExpr) + |] + where + -- Generate the types for the instance + circTys = map (\i -> varT $ mkName $ "c" <> show i) [1 .. n] + instTy = foldl appT (tupleT n) circTys + instCtx = foldl appT (tupleT n) $ map (\ty -> [t|Simulate $ty|]) circTys + fwdType = foldl appT (tupleT n) $ map (\ty -> [t|SimulateFwdType $ty|]) circTys + bwdType = foldl appT (tupleT n) $ map (\ty -> [t|SimulateBwdType $ty|]) circTys + channelSum = foldl1 (\a b -> [t|$a + $b|]) $ map (\ty -> [t|SimulateChannels $ty|]) circTys + + -- Relevant expressions and patterns + fwdPat0 = tupP $ map (\i -> varP $ mkName $ "fwd" <> show i) [1 .. n] + bwdPat0 = tupP $ map (\i -> varP $ mkName $ "bwd" <> show i) [1 .. n] + fwdExpr = map (\i -> varE $ mkName $ "fwd" <> show i) [1 .. n] + bwdExpr = map (\i -> varE $ mkName $ "bwd" <> show i) [1 .. n] + fwdExpr1 = map (\i -> varE $ mkName $ "fwdStalled" <> show i) [1 .. n] + bwdExpr1 = map (\i -> varE $ mkName $ "bwdStalled" <> show i) [1 .. n] + + -- stallC Declaration: Split off the stall vectors from the large input vector + stallVecs = zipWith mkStallVec [1 .. n] circTys + mkStallVec i ty = + valD + mkStallPat + ( normalB [e|(C.splitAtI @(SimulateChannels $ty) $(varE (mkName $ "rem" <> show (i - 1))))|] + ) + [] + where + mkStallPat = + tupP + [ varP (mkName $ "stalls" <> show i) + , varP (mkName $ if i == n then "_" else "rem" <> show i) + ] + + -- stallC Declaration: Generate stalling circuits + stallCircuits = zipWith mkStallCircuit [1 .. n] circTys + mkStallCircuit i ty = + valD + [p|Circuit $(varP $ mkName $ "stalled" <> show i)|] + (normalB [e|stallC @($ty) conf $(varE $ mkName $ "stalls" <> show i)|]) + [] + + -- Generate the stallC expression + stallCExpr = + [e| + Circuit $ \($fwdPat0, $bwdPat0) -> $(letE stallCResultDecs [e|($(tupE fwdExpr1), $(tupE bwdExpr1))|]) + |] + + stallCResultDecs = map mkStallCResultDec [1 .. n] + mkStallCResultDec i = + valD + (tupP [varP $ mkName $ "fwdStalled" <> show i, varP $ mkName $ "bwdStalled" <> show i]) + ( normalB $ + appE (varE $ mkName $ "stalled" <> show i) $ + tupE [varE $ mkName $ "fwd" <> show i, varE $ mkName $ "bwd" <> show i] + ) + []