diff --git a/clash-protocols/src/Protocols/Internal.hs b/clash-protocols/src/Protocols/Internal.hs index 5bdb83b0..33ce4f58 100644 --- a/clash-protocols/src/Protocols/Internal.hs +++ b/clash-protocols/src/Protocols/Internal.hs @@ -265,6 +265,7 @@ instance (Drivable a, Drivable b) => Drivable (a, b) where ) drivableTupleInstances 3 maxTupleSize + instance (CE.KnownNat n, Simulate a) => Simulate (C.Vec n a) where type SimulateFwdType (C.Vec n a) = C.Vec n (SimulateFwdType a) type SimulateBwdType (C.Vec n a) = C.Vec n (SimulateBwdType a) diff --git a/clash-protocols/src/Protocols/Internal/TH.hs b/clash-protocols/src/Protocols/Internal/TH.hs index 4e682f03..03353dce 100644 --- a/clash-protocols/src/Protocols/Internal/TH.hs +++ b/clash-protocols/src/Protocols/Internal/TH.hs @@ -3,6 +3,7 @@ module Protocols.Internal.TH where import qualified Clash.Prelude as C +import Control.Monad (zipWithM) import Control.Monad.Extra (concatMapM) import Data.Proxy import GHC.TypeNats @@ -53,7 +54,7 @@ simulateTupleInstance n = sigToSimFwd _ $fwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimFwd (Proxy @($ty)) $expr|]) circTys fwdExpr) sigToSimBwd _ $bwdPat0 = $(tupE $ zipWith (\ty expr -> [e|sigToSimBwd (Proxy @($ty)) $expr|]) circTys bwdExpr) - stallC $(varP $ mkName "conf") $(varP $ mkName "rem0") = $(letE (stallVecs ++ stallCircuits) stallCExpr) + stallC $(varP $ mkName "conf") $(varP $ mkName "rem0") = $stallCExpr |] where -- Generate the types for the instance @@ -73,43 +74,49 @@ simulateTupleInstance n = bwdExpr1 = map (\i -> varE $ mkName $ "bwdStalled" <> show i) [1 .. n] -- stallC Declaration: Split off the stall vectors from the large input vector - stallVecs = zipWith mkStallVec [1 .. n] circTys mkStallVec i ty = - valD - mkStallPat - ( normalB [e|(C.splitAtI @(SimulateChannels $ty) $(varE (mkName $ "rem" <> show (i - 1))))|] - ) - [] - where - mkStallPat = - tupP - [ varP (mkName $ "stalls" <> show i) - , varP (mkName $ if i == n then "_" else "rem" <> show i) - ] + [d| + $[p| + ( $(varP (mkName $ "stalls" <> show i)) + , $(varP (mkName $ if i == n then "_" else "rem" <> show i)) + ) + |] = + C.splitAtI @(SimulateChannels $ty) + $(varE $ mkName $ "rem" <> show (i - 1)) + |] -- stallC Declaration: Generate stalling circuits - stallCircuits = zipWith mkStallCircuit [1 .. n] circTys mkStallCircuit i ty = - valD - [p|Circuit $(varP $ mkName $ "stalled" <> show i)|] - (normalB [e|stallC @($ty) conf $(varE $ mkName $ "stalls" <> show i)|]) - [] + [d| + $[p|Circuit $(varP $ mkName $ "stalled" <> show i)|] = + stallC @($ty) conf $(varE $ mkName $ "stalls" <> show i) + |] -- Generate the stallC expression - stallCExpr = - [e| - Circuit $ \($fwdPat0, $bwdPat0) -> $(letE stallCResultDecs [e|($(tupE fwdExpr1), $(tupE bwdExpr1))|]) - |] + stallCExpr = do + stallVecs <- + concat <$> zipWithM mkStallVec [1 .. n] circTys + stallCircuits <- + concat <$> zipWithM mkStallCircuit [1 .. n] circTys + LetE (stallVecs <> stallCircuits) + <$> [e|Circuit $ \($fwdPat0, $bwdPat0) -> $circuitResExpr|] + + circuitResExpr = do + stallCResultDecs <- concatMapM mkStallCResultDec [1 .. n] + LetE stallCResultDecs <$> [e|($(tupE fwdExpr1), $(tupE bwdExpr1))|] - stallCResultDecs = map mkStallCResultDec [1 .. n] mkStallCResultDec i = - valD - (tupP [varP $ mkName $ "fwdStalled" <> show i, varP $ mkName $ "bwdStalled" <> show i]) - ( normalB $ - appE (varE $ mkName $ "stalled" <> show i) $ - tupE [varE $ mkName $ "fwd" <> show i, varE $ mkName $ "bwd" <> show i] - ) - [] + [d| + $[p| + ( $(varP $ mkName $ "fwdStalled" <> show i) + , $(varP $ mkName $ "bwdStalled" <> show i) + ) + |] = + $(varE $ mkName $ "stalled" <> show i) + ( $(varE $ mkName $ "fwd" <> show i) + , $(varE $ mkName $ "bwd" <> show i) + ) + |] drivableTupleInstances :: Int -> Int -> DecsQ drivableTupleInstances n m = concatMapM drivableTupleInstance [n .. m]