Skip to content

Commit

Permalink
dag tests
Browse files Browse the repository at this point in the history
  • Loading branch information
johanhenriksson committed Jan 29, 2021
1 parent e05a7be commit 42b7bb8
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
15 changes: 8 additions & 7 deletions cowait/tasks/graph/graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ async def define(self, graph, **inputs):
pass

async def run(self, **inputs):
g = Graph()
await self.define(g, **inputs)
graph = Graph()
await self.define(graph, **inputs)

# run until all nodes complete
pending = []
node_tasks = {}
while not g.completed:
while not graph.completed:
# launch tasks for each node that is ready for execution
while True:
node = g.next()
node = graph.next()
if node is None:
break

Expand All @@ -44,13 +44,14 @@ async def run(self, **inputs):
node = node_tasks[task]

try:
g.complete(node, task.result())
# unpacking the result will throw an exception if the task failed
graph.complete(node, task.result())
except Exception as e:
g.fail(node, e)
graph.fail(node, e)

pending.remove(task)

if not g.completed:
if not graph.completed:
raise Exception('Some tasks failed to finish')

# return what?
Expand Down
82 changes: 82 additions & 0 deletions test/tasks/graph/graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
from cowait.tasks.graph import Graph, Node, Result


@pytest.mark.task_graph
def test_node_dependencies():
g = Graph()
b = g.node('B')
g.node('A', {
'b': b,
'number': 123,
})

# expect B to be returned first
first = g.next()
assert first is not None
assert first.task == 'B'
assert first.inputs == {}

# at this point, no new node should be available until b is completed
assert g.next() is None

g.complete(b, 'yey')

# finally, A should be returned with the output of B as input
second = g.next()
assert second is not None
assert second.task == 'A'
assert second.inputs == {'b': 'yey', 'number': 123}


@pytest.mark.task_graph
def test_node_upstream_error():
g = Graph()
b = g.node('B')
a = g.node('A', {
'b': b,
})

g.fail(b, Exception('test'))
assert b in g.errors
assert a not in g.errors

# node of the nodes should be ready
# A should be marked as failed
assert g.next() is None

assert a in g.errors


@pytest.mark.task_graph
def test_node_output_accessor():
g = Graph()
b = g.node('B')
g.node('A', inputs={
'one': b.output('value'),
'two': b.output(lambda x: x['value'] * 2),
})
g.complete(b, {'value': 2})

result = g.next()
assert result is not None
assert result.inputs['one'] == 2
assert result.inputs['two'] == 4


@pytest.mark.task_graph
def test_unpack_result():
outputs = {'a': 123}

r1 = Result(None, 'a')
assert r1.get(outputs) == 123

r2 = Result(None, lambda x: x['a'] * 2)
assert r2.get(outputs) == 246


@pytest.mark.task_graph
def test_graph_node_ids():
a = Node.next_id()
b = Node.next_id()
assert b > a

0 comments on commit 42b7bb8

Please sign in to comment.