Skip to content

Commit

Permalink
Move vec operations to correct AST
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoPeters1024 committed Dec 13, 2021
1 parent faa139b commit 0e250b8
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 113 deletions.
47 changes: 25 additions & 22 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Array.Accelerate.AST
Expand Down Expand Up @@ -149,7 +151,6 @@ import Data.Array.Accelerate.Sugar.Foreign
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Data.Primitive.Types
import Control.DeepSeq
import Data.Kind
import Data.Maybe
Expand Down Expand Up @@ -560,6 +561,21 @@ data OpenExp env aenv t where
-> OpenExp env aenv (Vec n s)
-> OpenExp env aenv tup

VecIndex :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> OpenExp env aenv (Vec n s)
-> OpenExp env aenv i
-> OpenExp env aenv s

VecWrite :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> OpenExp env aenv (Vec n s)
-> OpenExp env aenv i
-> OpenExp env aenv s
-> OpenExp env aenv (Vec n s)

-- Array indices & shapes
IndexSlice :: SliceIndex slix sl co sh
-> OpenExp env aenv slix
Expand Down Expand Up @@ -748,10 +764,6 @@ data PrimFun sig where
PrimLOr :: PrimFun ((PrimBool, PrimBool) -> PrimBool)
PrimLNot :: PrimFun (PrimBool -> PrimBool)

-- local array operators
PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a)
PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a)

-- general conversion between types
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
PrimToFloating :: NumType a -> FloatingType b -> PrimFun (a -> b)
Expand Down Expand Up @@ -818,6 +830,8 @@ expType = \case
Nil -> TupRunit
VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR
VecUnpack vecR _ -> vecRtuple vecR
VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s
VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT
IndexSlice si _ _ -> shapeType $ sliceShapeR si
IndexFull si _ _ -> shapeType $ sliceDomainR si
ToIndex{} -> TupRsingle scalarTypeInt
Expand Down Expand Up @@ -850,9 +864,6 @@ primConstType = \case
floating :: FloatingType t -> ScalarType t
floating = SingleScalarType . NumSingleType . FloatingNumType

vector :: forall n a. (KnownNat n) => VectorType (Vec n a) -> ScalarType (Vec n a)
vector = VectorScalarType

primFunType :: PrimFun (a -> b) -> (TypeR a, TypeR b)
primFunType = \case
-- Num
Expand Down Expand Up @@ -931,17 +942,6 @@ primFunType = \case
PrimLOr -> binary' tbool
PrimLNot -> unary' tbool

-- Local Vector operations
PrimVectorIndex v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` i, single a)

PrimVectorWrite v'@(VectorType _ a) i' ->
let v = singleVector v'
i = integral i'
in (v `TupRpair` (i `TupRpair` single a), v)

-- general conversion between types
PrimFromIntegral a b -> unary (integral a) (num b)
PrimToFloating a b -> unary (num a) (floating b)
Expand All @@ -954,7 +954,6 @@ primFunType = \case
compare' a = binary (single a) tbool

single = TupRsingle . SingleScalarType
singleVector = TupRsingle . VectorScalarType
num = TupRsingle . SingleScalarType . NumSingleType
integral = num . IntegralNumType
floating = num . FloatingNumType
Expand Down Expand Up @@ -1092,6 +1091,8 @@ rnfOpenExp topExp =
Nil -> ()
VecPack vecr e -> rnfVecR vecr `seq` rnfE e
VecUnpack vecr e -> rnfVecR vecr `seq` rnfE e
VecIndex vt it v i -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i
VecWrite vt it v i e -> rnfVectorType vt `seq` rnfIntegralType it `seq` rnfE v `seq` rnfE i `seq` rnfE e
IndexSlice slice slix sh -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sh
IndexFull slice slix sl -> rnfSliceIndex slice `seq` rnfE slix `seq` rnfE sl
ToIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix
Expand Down Expand Up @@ -1184,7 +1185,6 @@ rnfPrimFun (PrimMin t) = rnfSingleType t
rnfPrimFun PrimLAnd = ()
rnfPrimFun PrimLOr = ()
rnfPrimFun PrimLNot = ()
rnfPrimFun (PrimVectorIndex v i) = rnfVectorType v `seq` rnfIntegralType i
rnfPrimFun (PrimFromIntegral i n) = rnfIntegralType i `seq` rnfNumType n
rnfPrimFun (PrimToFloating n f) = rnfNumType n `seq` rnfFloatingType f

Expand Down Expand Up @@ -1313,6 +1313,8 @@ liftOpenExp pexp =
Nil -> [|| Nil ||]
VecPack vecr e -> [|| VecPack $$(liftVecR vecr) $$(liftE e) ||]
VecUnpack vecr e -> [|| VecUnpack $$(liftVecR vecr) $$(liftE e) ||]
VecIndex vt it v i -> [|| VecIndex $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) ||]
VecWrite vt it v i e -> [|| VecWrite $$(liftVectorType vt) $$(liftIntegralType it) $$(liftE v) $$(liftE i) $$(liftE e) ||]
IndexSlice slice slix sh -> [|| IndexSlice $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sh) ||]
IndexFull slice slix sl -> [|| IndexFull $$(liftSliceIndex slice) $$(liftE slix) $$(liftE sl) ||]
ToIndex shr sh ix -> [|| ToIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||]
Expand Down Expand Up @@ -1411,7 +1413,6 @@ liftPrimFun (PrimMin t) = [|| PrimMin $$(liftSingleType t) ||]
liftPrimFun PrimLAnd = [|| PrimLAnd ||]
liftPrimFun PrimLOr = [|| PrimLOr ||]
liftPrimFun PrimLNot = [|| PrimLNot ||]
liftPrimFun (PrimVectorIndex v i) = [|| PrimVectorIndex $$(liftVectorType v) $$(liftIntegralType i) ||]
liftPrimFun (PrimFromIntegral ta tb) = [|| PrimFromIntegral $$(liftIntegralType ta) $$(liftNumType tb) ||]
liftPrimFun (PrimToFloating ta tb) = [|| PrimToFloating $$(liftNumType ta) $$(liftFloatingType tb) ||]

Expand Down Expand Up @@ -1461,6 +1462,8 @@ formatExpOp = later $ \case
Nil{} -> "Nil"
VecPack{} -> "VecPack"
VecUnpack{} -> "VecUnpack"
VecIndex{} -> "VecIndex"
VecWrite{} -> "VecWrite"
IndexSlice{} -> "IndexSlice"
IndexFull{} -> "IndexFull"
ToIndex{} -> "ToIndex"
Expand Down
4 changes: 2 additions & 2 deletions src/Data/Array/Accelerate/Analysis/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ encodeOpenExp exp =
Pair e1 e2 -> intHost $(hashQ "Pair") <> travE e1 <> travE e2
VecPack _ e -> intHost $(hashQ "VecPack") <> travE e
VecUnpack _ e -> intHost $(hashQ "VecUnpack") <> travE e
VecIndex _ _ v i -> intHost $(hashQ "VecIndex") <> travE v <> travE i
VecWrite _ _ v i e -> intHost $(hashQ "VecWrite") <> travE v <> travE i <> travE e
Const tp c -> intHost $(hashQ "Const") <> encodeScalarConst tp c
Undef tp -> intHost $(hashQ "Undef") <> encodeScalarType tp
IndexSlice spec ix sh -> intHost $(hashQ "IndexSlice") <> travE ix <> travE sh <> encodeSliceIndex spec
Expand Down Expand Up @@ -448,8 +450,6 @@ encodePrimFun (PrimEq a) = intHost $(hashQ "PrimEq")
encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq") <> encodeSingleType a
encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a
encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a
encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b
encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b
encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd")
Expand Down
5 changes: 4 additions & 1 deletion src/Data/Array/Accelerate/Classes/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
Expand All @@ -18,12 +20,13 @@
--
module Data.Array.Accelerate.Classes.Vector where

import Data.Kind
import GHC.TypeLits
import Data.Array.Accelerate.Sugar.Vec
import Data.Array.Accelerate.Smart
import Data.Primitive.Vec



instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where
type IndexType (Exp (Vec n a)) = Exp Int
vecIndex = mkVectorIndex
Expand Down
2 changes: 0 additions & 2 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1145,8 +1145,6 @@ evalPrim (PrimMin ty) = evalMin ty
evalPrim PrimLAnd = evalLAnd
evalPrim PrimLOr = evalLOr
evalPrim PrimLNot = evalLNot
evalPrim (PrimVectorIndex v i) = evalVectorIndex v i
evalPrim (PrimVectorWrite v i) = evalVectorWrite v i
evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb
evalPrim (PrimToFloating ta tb) = evalToFloating ta tb

Expand Down
4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/Representation/Vec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ data VecR (n :: Nat) single tuple where
VecRnil :: SingleType s -> VecR 0 s ()
VecRsucc :: VecR n s t -> VecR (n + 1) s (t, s)


vecRvector :: KnownNat n => VecR n s tuple -> VectorType (Vec n s)
vecRvector = uncurry VectorType . go
where
go :: VecR n s tuple -> (Int, SingleType s)
go (VecRnil tp) = (0, tp)
go (VecRsucc vec) | (n, tp) <- go vec = (n + 1, tp)

vecRSingle :: KnownNat n => VecR n s tuple -> SingleType s
vecRSingle vecr = let (VectorType _ s) = vecRvector vecr in s

vecRtuple :: VecR n s tuple -> TypeR tuple
vecRtuple = snd . go
where
Expand Down
30 changes: 23 additions & 7 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,21 @@ data PreSmartExp acc exp t where
-> exp (Vec n s)
-> PreSmartExp acc exp tup

VecIndex :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> exp (Vec n s)
-> exp i
-> PreSmartExp acc exp s

VecWrite :: (KnownNat n, v ~ Vec n s)
=> VectorType v
-> IntegralType i
-> exp (Vec n s)
-> exp i
-> exp s
-> PreSmartExp acc exp (Vec n s)

ToIndex :: ShapeR sh
-> exp sh
-> exp sh
Expand Down Expand Up @@ -860,6 +875,8 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where
Prj _ _ -> error "I never joke about my work"
VecPack vecR _ -> TupRsingle $ VectorScalarType $ vecRvector vecR
VecUnpack vecR _ -> vecRtuple vecR
VecIndex vecT _ _ _ -> let (VectorType _ s) = vecT in TupRsingle $ SingleScalarType s
VecWrite vecT _ _ _ _ -> TupRsingle $ VectorScalarType vecT
ToIndex _ _ _ -> TupRsingle scalarTypeInt
FromIndex shr _ _ -> shapeType shr
Case _ ((_,c):_) -> typeR c
Expand Down Expand Up @@ -1179,16 +1196,15 @@ mkLNot (Exp a) = mkExp $ SmartExp (PrimApp PrimLNot x) `Pair` SmartExp Nil
where
x = SmartExp $ Prj PairIdxLeft a

-- Operators from Vec

inferNat :: forall n. KnownNat n => Int
inferNat = fromInteger $ natVal (Proxy @n)

mkVectorIndex :: forall n a. (KnownNat n, Elt a, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a
mkVectorIndex = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType
mkVectorIndex (Exp v) (Exp i) = mkExp $ VecIndex (VectorType (inferNat @n) singleType) integralType v i

mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a)
mkVectorWrite = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType
mkVectorWrite (Exp v) (Exp i) (Exp el) = mkExp $ VecWrite (VectorType (inferNat @n) singleType) integralType v i el

-- Numeric conversions

Expand Down
2 changes: 0 additions & 2 deletions src/Data/Array/Accelerate/Trafo/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ evalPrimApp env f x
PrimNEq ty -> evalNEq ty x env
PrimMax ty -> evalMax ty x env
PrimMin ty -> evalMin ty x env
PrimVectorIndex _ _ -> Nothing
PrimVectorWrite _ _ -> Nothing
PrimLAnd -> evalLAnd x env
PrimLOr -> evalLOr x env
PrimLNot -> evalLNot x env
Expand Down
68 changes: 37 additions & 31 deletions src/Data/Array/Accelerate/Trafo/Sharing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,8 @@ convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp
Pair e1 e2 -> AST.Pair (cvt e1) (cvt e2)
VecPack vec e -> AST.VecPack vec (cvt e)
VecUnpack vec e -> AST.VecUnpack vec (cvt e)
VecIndex vt it v i -> AST.VecIndex vt it (cvt v) (cvt i)
VecWrite vt it v i e -> AST.VecWrite vt it (cvt v) (cvt i) (cvt e)
ToIndex shr sh ix -> AST.ToIndex shr (cvt sh) (cvt ix)
FromIndex shr sh e -> AST.FromIndex shr (cvt sh) (cvt e)
Case e rhs -> cvtCase (cvt e) (over (mapped . _2) cvt rhs)
Expand Down Expand Up @@ -1841,37 +1843,39 @@ makeOccMapSharingExp config accOccMap expOccMap = travE
return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height)

reconstruct $ case pexp of
Tag tp i -> return (Tag tp i, 0) -- height is 0!
Const tp c -> return (Const tp c, 1)
Undef tp -> return (Undef tp, 1)
Nil -> return (Nil, 1)
Pair e1 e2 -> travE2 Pair e1 e2
Prj i e -> travE1 (Prj i) e
VecPack vec e -> travE1 (VecPack vec) e
VecUnpack vec e -> travE1 (VecUnpack vec) e
ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix
FromIndex shr sh e -> travE2 (FromIndex shr) sh e
Match t e -> travE1 (Match t) e
Case e rhs -> do
(e', h1) <- travE lvl e
(rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ]
return (Case e' rhs', h1 `max` maximum h2 + 1)
Cond e1 e2 e3 -> travE3 Cond e1 e2 e3
While t p iter init -> do
(p' , h1) <- traverseFun1 lvl t p
(iter', h2) <- traverseFun1 lvl t iter
(init', h3) <- travE lvl init
return (While t p' iter' init', h1 `max` h2 `max` h3 + 1)
PrimConst c -> return (PrimConst c, 1)
PrimApp p e -> travE1 (PrimApp p) e
Index tp a e -> travAE (Index tp) a e
LinearIndex tp a i -> travAE (LinearIndex tp) a i
Shape shr a -> travA (Shape shr) a
ShapeSize shr e -> travE1 (ShapeSize shr) e
Foreign tp ff f e -> do
(e', h) <- travE lvl e
return (Foreign tp ff f e', h+1)
Coerce t1 t2 e -> travE1 (Coerce t1 t2) e
Tag tp i -> return (Tag tp i, 0) -- height is 0!
Const tp c -> return (Const tp c, 1)
Undef tp -> return (Undef tp, 1)
Nil -> return (Nil, 1)
Pair e1 e2 -> travE2 Pair e1 e2
Prj i e -> travE1 (Prj i) e
VecPack vec e -> travE1 (VecPack vec) e
VecUnpack vec e -> travE1 (VecUnpack vec) e
VecIndex vt ti v i -> travE2 (VecIndex vt ti) v i
VecWrite vt ti v i e -> travE3 (VecWrite vt ti) v i e
ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix
FromIndex shr sh e -> travE2 (FromIndex shr) sh e
Match t e -> travE1 (Match t) e
Case e rhs -> do
(e', h1) <- travE lvl e
(rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ]
return (Case e' rhs', h1 `max` maximum h2 + 1)
Cond e1 e2 e3 -> travE3 Cond e1 e2 e3
While t p iter init -> do
(p' , h1) <- traverseFun1 lvl t p
(iter', h2) <- traverseFun1 lvl t iter
(init', h3) <- travE lvl init
return (While t p' iter' init', h1 `max` h2 `max` h3 + 1)
PrimConst c -> return (PrimConst c, 1)
PrimApp p e -> travE1 (PrimApp p) e
Index tp a e -> travAE (Index tp) a e
LinearIndex tp a i -> travAE (LinearIndex tp) a i
Shape shr a -> travA (Shape shr) a
ShapeSize shr e -> travE1 (ShapeSize shr) e
Foreign tp ff f e -> do
(e', h) <- travE lvl e
return (Foreign tp ff f e', h+1)
Coerce t1 t2 e -> travE1 (Coerce t1 t2) e

where
traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
Expand Down Expand Up @@ -2755,6 +2759,8 @@ determineScopesSharingExp config accOccMap expOccMap = scopesExp
Prj i e -> travE1 (Prj i) e
VecPack vec e -> travE1 (VecPack vec) e
VecUnpack vec e -> travE1 (VecUnpack vec) e
VecIndex vt it v i -> travE2 (VecIndex vt it) v i
VecWrite vt it v i e -> travE3 (VecWrite vt it) v i e
ToIndex shr sh ix -> travE2 (ToIndex shr) sh ix
FromIndex shr sh e -> travE2 (FromIndex shr) sh e
Match t e -> travE1 (Match t) e
Expand Down
Loading

0 comments on commit 0e250b8

Please sign in to comment.