-
Notifications
You must be signed in to change notification settings - Fork 0
/
Flow.hs
217 lines (182 loc) · 9.05 KB
/
Flow.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
{-# LANGUAGE MultiParamTypeClasses, Rank2Types, ScopedTypeVariables #-}
module Flow where
import Control.Applicative
import Control.Monad.State.Strict
import Data.Map.Strict (Map, (!))
import Data.Set (Set)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
-- A recursive dataflow graph consists of a set of nodes, each with an initial
-- value, and a step function that computes the "next" value of a node from the
-- "previous" values of all nodes.
--
-- When every node's next value is the same as its current value, the dataflow
-- graph has reached a fixed point - the computation has terminated.
--
-- If we let nodes be evaluated in no particular order, this forms a concurrent,
-- nondeterministic model of computation.
--
-- However, I believe that if the step function is monotone (wrt some ordering
-- of the value type) in the "previous values" map, this suffices to show that
-- the dataflow graph is essentially deterministic: if it can reach a fixed
-- point, there is only one fixed point, and it will eventually reach it
-- assuming a fair scheduler.
data Graph node value = Graph { graphInit :: Init node value
, graphStep :: Step node value }
type Init node value = Map node value
-- The step function is restricted to be Applicative in the nodes' values. This
-- prevents dynamically choosing which nodes to evaluate based on a node's
-- value, forcing the dependency graph to be static. This prevents some useful
-- dataflow programs, but enables some useful implementation strategies; a
-- tradeoff.
type Step node value =
forall exp. Applicative exp =>
(node -> exp value) -> node -> exp value
-- Some utility functions.
tabulate :: Ord a => [a] -> (a -> b) -> Map a b
tabulate keys func = Map.fromList [(k, func k) | k <- keys]
invert :: (Ord a, Ord b) => Map a (Set b) -> Map b (Set a)
invert m = Map.fromListWith Set.union l
where l = [ (v, Set.singleton k)
| (k,vs) <- Map.toList m, v <- Set.toList vs]
--------------------------------------------------------
---------- PUSH-BASED dataflow implementation ----------
--------------------------------------------------------
-- The applicative expression functor we use.
data PushExp node value a = PushExp { pushDeps :: Set node
, pushThunk :: Map node value -> a }
instance Ord node => Functor (PushExp node value) where fmap = liftA
instance Ord node => Applicative (PushExp node value) where
pure x = PushExp Set.empty (const x)
PushExp adeps a <*> PushExp bdeps b = PushExp (Set.union adeps bdeps) ab
where ab map = a map (b map)
readNode :: Ord node => node -> PushExp node value value
readNode node = PushExp (Set.singleton node) (! node)
-- The monad in which we iterate the state to completion
type Push node value = State (Set node, Map node value)
pushFix :: forall node value. (Ord node, Eq value) =>
Graph node value -> Map node value
pushFix (Graph init step) = evalState loop (Set.fromList nodes, init)
where
nodes = Map.keys init
loop = do next <- popDirty
case next of Just node -> do run node; loop
Nothing -> gets snd
run :: node -> Push node value ()
run node = do cache <- gets snd
let oldValue = cache ! node
let newValue = pushThunk (exprs ! node) cache
unless (oldValue == newValue) $ do
markDirty (clientsOf node)
writeNode node newValue
exprs :: Map node (PushExp node value value)
exprs = tabulate nodes (step readNode)
-- needed in case a node has no clients and doesn't show up in the
-- inverted dependency graph.
clientsOf node = Map.findWithDefault Set.empty node clients
clients :: Map node (Set node)
clients = invert (Map.map pushDeps exprs)
writeNode :: (Ord node, Eq value) => node -> value -> Push node value ()
writeNode node value = modify f
where f (dirty, cache) = (dirty, Map.insert node value cache)
markDirty :: (Ord node, Eq value) => Set node -> Push node value ()
markDirty nodes = modify (\(dirty, cache) -> (Set.union dirty nodes, cache))
popDirty :: (Ord node, Eq value) => Push node value (Maybe node)
popDirty = do (dirty, cache) <- get
case choose dirty of
Nothing -> return Nothing
Just (node, dirty') -> do put (dirty', cache)
return (Just node)
-- A scheduling strategy for dirty nodes. Currently: the one with the smallest
-- index. I'm not sure there's anything smarter that we could do.
choose :: Ord a => Set a -> Maybe (a, Set a)
choose x | Set.null x = Nothing
| otherwise = Just (Set.deleteFindMin x)
--------------------------------------------------------
---------- PULL-BASED dataflow implementation ----------
--------------------------------------------------------
-- The "expression" type we use.
newtype PullExp node value a =
PullExp { runPullExp :: Set node -> Map node value
-> (a, Bool, Set node, Map node value) }
-- Instances for PullExp.
instance Ord node => Functor (PullExp node value) where fmap = liftM
instance Ord node => Applicative (PullExp node value) where
pure = return; (<*>) = ap
instance Ord node => Monad (PullExp node value) where
return x = PullExp (\_ cache -> (x, False, Set.empty, cache))
PullExp a >>= f = PullExp g
where g frozen cache =
let (av, achange, avisit, cache1) = a frozen cache
frozen1 = Set.union frozen avisit
(fv, fchange, fvisit, cache2) =
runPullExp (f av) frozen1 cache1
in (fv, achange || fchange, Set.union avisit fvisit, cache2)
instance Ord node => MonadState (Map node value) (PullExp node value) where
state f = PullExp (\_ cache -> let (v,cache') = f cache
in (v, False, Set.empty, cache'))
markChanged :: PullExp node value ()
markChanged = PullExp (\frozen cache -> ((), True, Set.empty, cache))
markVisited :: Set node -> PullExp node value ()
markVisited nodes = PullExp (\frozen cache -> ((), False, nodes, cache))
getFrozen :: PullExp node value (Set node)
getFrozen = PullExp (\frozen cache -> (frozen, False, Set.empty, cache))
listen :: PullExp node value a -> PullExp node value (a, Bool, Set node)
listen (PullExp f) = PullExp g
where g frozen cache = ((x, changed, visited), changed, visited, cache')
where (x, changed, visited, cache') = f frozen cache
-- Our top-level state is: a set of finished nodes & a map from nodes to their
-- current values.
type PullState node value = (Set node, Map node value)
pullInit :: Ord node => Graph node value -> PullState node value
pullInit g = (Set.empty, graphInit g)
-- The guts of the pull-based implementation are here.
pullGet :: forall node value. (Ord node, Eq value) =>
Graph node value -> node -> State (PullState node value) value
pullGet graph node = do (finished, cache) <- get
let (value, _changed, visited, cache') =
runPullExp (visit node) finished cache
put (Set.union finished visited, cache')
return value
where
visit :: node -> PullExp node value value
visit node = do cachedValue <- gets (! node)
-- we have to get the frozen set before we markVisited,
-- b/c markVisited adds to the frozen set.
frozen <- getFrozen
markVisited (Set.singleton node)
if Set.member node frozen
then return cachedValue
else iterate node cachedValue
iterate :: node -> value -> PullExp node value value
iterate node oldValue = do
(newValue, changed, visited) <- listen (step node oldValue)
if not changed || not (Set.member node visited)
-- if we didn't change or we didn't depend on ourselves, no need to
-- iterate further.
then return newValue
else iterate node newValue
step :: node -> value -> PullExp node value value
step node oldValue = do newValue <- graphStep graph visit node
when (oldValue /= newValue) $ do
modify (Map.insert node newValue)
markChanged
return newValue
------------------------------
---------- Examples ----------
------------------------------
testPull :: (Ord n, Eq v) => Graph n v -> n -> v
testPull g n = evalState (pullGet g n) state
where state = pullInit g
ex1 :: Graph Int Int
ex1 = Graph
(Map.fromList $ zip [0..10] [0,0..])
(\self n -> case n of 0 -> self 0
n -> (1 +) <$> self (n-1))
ex2 :: Graph Int Int
ex2 = Graph
(Map.fromList $ zip [0..10] [0,0..])
(\self n -> case n of
0 -> let f x = if x < 2 then x+1 else x
in f <$> self n
n -> (1+) <$> self (n-1))