Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utils for Circuit (Vec n c) (Vec n c) #116

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clash-protocols/clash-protocols.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ library
Protocols.Idle
Protocols.Internal
Protocols.Internal.TH
Protocols.Vec
Protocols.Wishbone
Protocols.Wishbone.Standard
Protocols.Wishbone.Standard.Hedgehog
Expand Down Expand Up @@ -184,6 +185,7 @@ test-suite unittests
Tests.Protocols.Avalon
Tests.Protocols.Axi4
Tests.Protocols.Plugin
Tests.Protocols.Vec
Tests.Protocols.Wishbone
Util

Expand Down
5 changes: 3 additions & 2 deletions clash-protocols/src/Protocols/DfConv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ import Protocols.Axi4.WriteResponse
import Protocols.Df (Data (..), Df)
import qualified Protocols.Df as Df
import Protocols.Internal
import qualified Protocols.Vec as Vec

{- | Class for protocols that are "similar" to 'Df', i.e. they can be converted
to and from a pair of 'Df' ports (one going 'Fwd', one going 'Bwd'), using
Expand Down Expand Up @@ -599,7 +600,7 @@ vecToDfConv ::
(Vec n df)
vecToDfConv proxy =
mapCircuit (uncurry C.zip) unzip id id
$ vecCircuits
$ Vec.vecCircuits
$ repeat
$ toDfCircuit proxy

Expand All @@ -616,7 +617,7 @@ vecFromDfConv ::
)
vecFromDfConv proxy =
mapCircuit id id unzip (uncurry C.zip)
$ vecCircuits
$ Vec.vecCircuits
$ repeat
$ fromDfCircuit proxy

Expand Down
8 changes: 0 additions & 8 deletions clash-protocols/src/Protocols/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

-- | Protocol-agnostic acknowledgement
newtype Ack = Ack Bool
deriving (Generic, C.NFDataX, Show, C.Bundle, Eq, Ord)

Check warning on line 62 in clash-protocols/src/Protocols/Internal.hs

View workflow job for this annotation

GitHub Actions / Cabal tests - ghc 9.2.8 / clash 1.8.1

• Both DeriveAnyClass and GeneralizedNewtypeDeriving are enabled

Check warning on line 62 in clash-protocols/src/Protocols/Internal.hs

View workflow job for this annotation

GitHub Actions / Cabal tests - ghc 9.2.8 / clash 1.8.1

• Both DeriveAnyClass and GeneralizedNewtypeDeriving are enabled

Check warning on line 62 in clash-protocols/src/Protocols/Internal.hs

View workflow job for this annotation

GitHub Actions / Cabal tests - ghc 9.4.8 / clash 1.8.1

• Both DeriveAnyClass and GeneralizedNewtypeDeriving are enabled

Check warning on line 62 in clash-protocols/src/Protocols/Internal.hs

View workflow job for this annotation

GitHub Actions / Cabal tests - ghc 9.4.8 / clash 1.8.1

• Both DeriveAnyClass and GeneralizedNewtypeDeriving are enabled

Check warning on line 62 in clash-protocols/src/Protocols/Internal.hs

View workflow job for this annotation

GitHub Actions / Cabal tests - ghc 9.6.4 / clash 1.8.1

• Both DeriveAnyClass and GeneralizedNewtypeDeriving are enabled

Check warning on line 62 in clash-protocols/src/Protocols/Internal.hs

View workflow job for this annotation

GitHub Actions / Cabal tests - ghc 9.6.4 / clash 1.8.1

• Both DeriveAnyClass and GeneralizedNewtypeDeriving are enabled

-- | Acknowledge. Used in circuit-notation plugin to drive ignore components.
instance Default Ack where
Expand Down Expand Up @@ -539,14 +539,6 @@
Circuit a' b'
mapCircuit ia oa ob ib (Circuit f) = Circuit ((oa *** ob) . f . (ia *** ib))

{- | "Bundle" together a 'C.Vec' of 'Circuit's into a 'Circuit' with 'C.Vec' input and output.
The 'Circuit's all run in parallel.
-}
vecCircuits :: (C.KnownNat n) => C.Vec n (Circuit a b) -> Circuit (C.Vec n a) (C.Vec n b)
vecCircuits fs = Circuit (\inps -> C.unzip $ f <$> fs <*> uncurry C.zip inps)
where
f (Circuit ff) x = ff x

{- | "Bundle" together a pair of 'Circuit's into a 'Circuit' with two inputs and outputs.
The 'Circuit's run in parallel.
-}
Expand Down
133 changes: 133 additions & 0 deletions clash-protocols/src/Protocols/Vec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
-- | Utility functions for working with `Vec`s of `Circuit`s.
module Protocols.Vec (
vecCircuits,
append,
append3,
split,
split3,
zip,
zip3,
unzip,
unzip3,
concat,
unconcat,
) where

-- base
import Data.Tuple
import Prelude ()

-- clash-prelude
import Clash.Prelude hiding (concat, split, unconcat, unzip, unzip3, zip, zip3)
import qualified Clash.Prelude as C

-- clash-protocols-base
import Protocols.Plugin

import Data.Bifunctor

{- | "Bundle" together a 'Vec' of 'Circuit's into a 'Circuit' with 'Vec' input and output.
The 'Circuit's all run in parallel.

The inverse of 'vecCircuits' can not exist, as we can not guarantee that that the @n@th
manager interface only depends on the @n@th subordinate interface.
-}
vecCircuits :: (C.KnownNat n) => C.Vec n (Circuit a b) -> Circuit (C.Vec n a) (C.Vec n b)
vecCircuits fs = Circuit (\inps -> C.unzip $ f <$> fs <*> uncurry C.zip inps)
where
f (Circuit ff) = ff

-- | Append two separate vectors of the same circuits into one vector of circuits
append ::
(C.KnownNat n0) =>
Circuit (C.Vec n0 circuit, C.Vec n1 circuit) (C.Vec (n0 + n1) circuit)
append = Circuit (swap . bimap (uncurry (++)) splitAtI)

-- | Append three separate vectors of the same circuits into one vector of circuits
append3 ::
(C.KnownNat n0, C.KnownNat n1, KnownNat n2) =>
Circuit
(C.Vec n0 circuit, C.Vec n1 circuit, C.Vec n2 circuit)
(C.Vec (n0 + n1 + n2) circuit)
append3 = Circuit (swap . bimap (uncurry3 append3Vec) split3Vec)

-- | Split a vector of circuits into two vectors of circuits.
split ::
(C.KnownNat n0) =>
Circuit (C.Vec (n0 + n1) circuit) (C.Vec n0 circuit, C.Vec n1 circuit)
split = Circuit go
where
go ~(splitAtI -> (fwd0, fwd1), (bwd0, bwd1)) = (bwd0 ++ bwd1, (fwd0, fwd1))

-- | Split a vector of circuits into three vectors of circuits.
split3 ::
(C.KnownNat n0, C.KnownNat n1, C.KnownNat n2) =>
Circuit
(C.Vec (n0 + n1 + n2) circuit)
(C.Vec n0 circuit, C.Vec n1 circuit, C.Vec n2 circuit)
split3 = Circuit (swap . bimap split3Vec (uncurry3 append3Vec))

{- | Transforms two vectors of circuits into a vector of tuples of circuits.
Only works if the two vectors have the same length.
-}
zip ::
(C.KnownNat n) =>
Circuit (C.Vec n a, C.Vec n b) (C.Vec n (a, b))
zip = Circuit (swap . bimap (uncurry C.zip) C.unzip)

{- | Transforms three vectors of circuits into a vector of tuples of circuits.
Only works if the three vectors have the same length.
-}
zip3 ::
(C.KnownNat n) =>
Circuit (C.Vec n a, C.Vec n b, C.Vec n c) (C.Vec n (a, b, c))
zip3 = Circuit (swap . bimap (uncurry3 C.zip3) C.unzip3)

-- | Unzip a vector of tuples of circuits into a tuple of vectors of circuits.
unzip ::
(C.KnownNat n) =>
Circuit (C.Vec n (a, b)) (C.Vec n a, C.Vec n b)
unzip = Circuit (swap . bimap C.unzip (uncurry C.zip))

-- | Unzip a vector of 3-tuples of circuits into a 3-tuple of vectors of circuits.
unzip3 ::
(C.KnownNat n) =>
Circuit (C.Vec n (a, b, c)) (C.Vec n a, C.Vec n b, C.Vec n c)
unzip3 = Circuit (swap . bimap C.unzip3 (uncurry3 C.zip3))

-- | transform a vector of vectors of circuits into a vector of circuits.
concat ::
(C.KnownNat n0, C.KnownNat n1) =>
Circuit (C.Vec n0 (C.Vec n1 circuit)) (C.Vec (n0 * n1) circuit)
concat = Circuit (swap . bimap C.concat (C.unconcat SNat))

-- | transform a vector of circuits into a vector of vectors of circuits.
unconcat ::
(C.KnownNat n, C.KnownNat m) =>
SNat m ->
Circuit (C.Vec (n * m) circuit) (C.Vec n (C.Vec m circuit))
unconcat SNat = Circuit (swap . bimap (C.unconcat SNat) C.concat)

-- Internal utilities

-- | Uncurry a function with three arguments into a function that takes a 3-tuple as argument.
uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 f (a, b, c) = f a b c

-- Append three vectors of `a` into one vector of `a`.
append3Vec ::
(KnownNat n0, KnownNat n1, KnownNat n2) =>
C.Vec n0 a ->
C.Vec n1 a ->
C.Vec n2 a ->
C.Vec (n0 + n1 + n2) a
append3Vec v0 v1 v2 = v0 ++ v1 ++ v2

-- Split a C.Vector of 3-tuples into three vectors of the same length.
split3Vec ::
(KnownNat n0, KnownNat n1, KnownNat n2) =>
C.Vec (n0 + n1 + n2) a ->
(C.Vec n0 a, C.Vec n1 a, C.Vec n2 a)
split3Vec v = (v0, v1, v2)
where
(v0, splitAtI -> (v1, v2)) = splitAtI v
2 changes: 2 additions & 0 deletions clash-protocols/tests/Tests/Protocols.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import qualified Tests.Protocols.Avalon
import qualified Tests.Protocols.Axi4
import qualified Tests.Protocols.Df
import qualified Tests.Protocols.DfConv
import qualified Tests.Protocols.Vec
import qualified Tests.Protocols.Wishbone

tests :: TestTree
Expand All @@ -16,6 +17,7 @@ tests =
, Tests.Protocols.Avalon.tests
, Tests.Protocols.Axi4.tests
, Tests.Protocols.Wishbone.tests
, Tests.Protocols.Vec.tests
]

main :: IO ()
Expand Down
191 changes: 191 additions & 0 deletions clash-protocols/tests/Tests/Protocols/Vec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
{-# LANGUAGE NumericUnderscores #-}

module Tests.Protocols.Vec where

-- base
import Prelude

-- clash-prelude
import Clash.Prelude (System)
import qualified Clash.Prelude as C

-- hedgehog
import Hedgehog

-- tasty
import Test.Tasty
import Test.Tasty.Hedgehog (HedgehogTestLimit (HedgehogTestLimit))
import Test.Tasty.Hedgehog.Extra (testProperty)
import Test.Tasty.TH (testGroupGenerator)

-- clash-protocols (me!)
import Protocols
import qualified Protocols.Vec as Vec

import Clash.Hedgehog.Sized.Vector (genVec)
import Protocols.Hedgehog

-- tests
import Tests.Protocols.Df (genData, genSmallInt, genVecData)

prop_append :: Property
prop_append =
idWithModel
@(C.Vec 2 (Df System Int), C.Vec 3 (Df System Int))
defExpectOptions
gen
model
dut
where
gen =
(,)
<$> genVecData genSmallInt
<*> genVecData genSmallInt
dut = Vec.append
model = uncurry (C.++)

prop_append3 :: Property
prop_append3 =
idWithModel
@(C.Vec 2 (Df System Int), C.Vec 3 (Df System Int), C.Vec 4 (Df System Int))
@(C.Vec 9 (Df System Int))
defExpectOptions
gen
model
dut
where
gen :: Gen (C.Vec 2 [Int], C.Vec 3 [Int], C.Vec 4 [Int])
gen =
(,,)
<$> genVecData genSmallInt
<*> genVecData genSmallInt
<*> genVecData genSmallInt
dut = Vec.append3
model (a, b, c) = (a C.++ b) C.++ c

prop_split :: Property
prop_split =
idWithModel
@(C.Vec 5 (Df System Int))
@(C.Vec 2 (Df System Int), C.Vec 3 (Df System Int))
defExpectOptions
gen
model
dut
where
gen = genVecData genSmallInt
dut = Vec.split
model = C.splitAtI

prop_split3 :: Property
prop_split3 =
idWithModel
@(C.Vec 9 (Df System Int))
@(C.Vec 2 (Df System Int), C.Vec 3 (Df System Int), C.Vec 4 (Df System Int))
defExpectOptions
gen
model
dut
where
gen = genVecData genSmallInt
dut = Vec.split3
model v = (v0, v1, v2)
where
(v0, C.splitAtI -> (v1, v2)) = C.splitAtI v

prop_zip :: Property
prop_zip =
idWithModel
@(C.Vec 2 (Df System Int), C.Vec 2 (Df System Int))
defExpectOptions
gen
model
dut
where
gen =
(,)
<$> genVecData genSmallInt
<*> genVecData genSmallInt
dut = Vec.zip
model (a, b) = C.zip a b

prop_zip3 :: Property
prop_zip3 =
idWithModel
@(C.Vec 2 (Df System Int), C.Vec 2 (Df System Int), C.Vec 2 (Df System Int))
defExpectOptions
gen
model
dut
where
gen =
(,,)
<$> genVecData genSmallInt
<*> genVecData genSmallInt
<*> genVecData genSmallInt
dut = Vec.zip3
model (a, b, c) = C.zip3 a b c

prop_unzip :: Property
prop_unzip =
idWithModel
@(C.Vec 2 (Df System Int, Df System Int))
defExpectOptions
gen
model
dut
where
gen = genVec ((,) <$> genData genSmallInt <*> genData genSmallInt)
dut = Vec.unzip
model = C.unzip

prop_unzip3 :: Property
prop_unzip3 =
idWithModel
@(C.Vec 2 (Df System Int, Df System Int, Df System Int))
defExpectOptions
gen
model
dut
where
gen = genVec ((,,) <$> genData genSmallInt <*> genData genSmallInt <*> genData genSmallInt)
dut = Vec.unzip3
model = C.unzip3

prop_concat :: Property
prop_concat =
idWithModel
@(C.Vec 2 (C.Vec 3 (Df System Int)))
defExpectOptions
gen
model
dut
where
gen = genVec (genVecData genSmallInt)
dut = Vec.concat
model = C.concat

prop_unconcat :: Property
prop_unconcat =
idWithModel
@(C.Vec 6 (Df System Int))
defExpectOptions
gen
model
dut
where
gen = genVecData genSmallInt
dut = Vec.unconcat C.d2
model = C.unconcat C.d2

tests :: TestTree
tests =
-- TODO: Move timeout option to hedgehog for better error messages.
-- TODO: Does not seem to work for combinatorial loops like @let x = x in x@??
localOption (mkTimeout 20_000_000 {- 20 seconds -}) $
localOption
(HedgehogTestLimit (Just 1000))
$(testGroupGenerator)

main :: IO ()
main = defaultMain tests
Loading