Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Scheduler Task #232

Draft
wants to merge 6 commits into
base: task_state
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cloud/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions cloud/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"@types/react-router-dom": "^5.1.5",
"@types/react-syntax-highlighter": "^11.0.4",
"@types/redux-logger": "^3.0.7",
"dagre-d3-react": "^0.2.4",
"eslint-webpack-plugin": "^2.4.3",
"lodash": "^4.17.19",
"polished": "^3.6.3",
Expand Down
27 changes: 25 additions & 2 deletions cloud/src/components/task/TaskState.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,41 @@
import React from 'react'
import { ContentBlock, Code } from '../ui'
import TaskGraph from './task-ui/TaskGraph'

type Props = {
state?: object
state: State
}

type State = {
ui?: TaskComponent[]
[key: string]: any
}

type TaskComponent = {
component: string
path: string
}

function renderComponent(def: TaskComponent, state: State) {
switch(def.component) {
case 'ui.cowait.io/task-graph':
return <TaskGraph graph={state[def.path]} />
default:
throw new Error(`Unknown Task Component ${def.component}`)
}
}

export const TaskState: React.FC<Props> = ({ state }) => {
if (!state) {
return null
}
const components = state.ui || []

return <ContentBlock>
<h4>State</h4>
<Code language="json">{JSON.stringify(state, null, 4)}</Code>
{components.map(c => renderComponent(c, state))}
</ContentBlock>
// <Code language="json">{JSON.stringify(state, null, 4)}</Code>
}

export default TaskState
1 change: 1 addition & 0 deletions cloud/src/components/task/styled/Log.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export const LogOutput = styled.pre<Props>`
font-family: ${p => p.theme.fonts.monospace};
color: ${p => p.theme.colors.text.secondary};
line-height: 1.25em;
max-width: 100vh;
`

export const LogContainer = styled.div<Props>`
Expand Down
99 changes: 99 additions & 0 deletions cloud/src/components/task/task-ui/TaskGraph.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import _ from 'lodash'
import React from 'react'
import { useSelector } from 'react-redux'
import { RootState } from '../../../store'
import DagreGraph from 'dagre-d3-react'
import styled from 'styled-components'

const DagStyle = styled.div`
.nodes {
fill: darkgray;
cursor: pointer;
}

// status colors
.nodes .work { fill: #82b332; }
.nodes .done { fill: green; }
.nodes .fail { fill: red; }
.nodes .stop { fill: orange; }

.nodes text {
fill: white;
}

path {
stroke: white;
fill: white;
stroke-width: 3px;
}
`

type Props = {
graph: TaskGraph
}

type TaskGraph = {
[id: string]: TaskNode,
}

type TaskNode = {
id: string
task: string
depends_on: string[]
task_id?: string
}

type d3Link = {
source: string
target: string
class?: string
label?: string
config?: object
}

export const TaskGraph: React.FC<Props> = ({ graph }) => {
const tasks = useSelector((state: RootState) => state.tasks.items)
let nodes = _.map(graph, node => {
if (node.task_id && tasks[node.task_id]) {
let task = tasks[node.task_id]
return {
id: node.id,
label: task.id,
class: task.status,
}
}
return {
id: node.id,
label: node.task,
class: 'pending',
}
})
let links: d3Link[] = []
_.each(graph, node => {
_.each(node.depends_on, edge => {
links.push({
source: edge,
target: node.id,
})
})
})
return <DagStyle>
<DagreGraph
nodes={nodes}
links={links}
config={{
rankdir: 'LR',
align: 'UL',
ranker: 'tight-tree'
}}
width='100%'
height='500'
animate={100}
shape='rect'
zoomable
onNodeClick={(e: any) => console.log(e)}
/>
</DagStyle>
}

export default TaskGraph
5 changes: 5 additions & 0 deletions cowait/tasks/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# flake8: noqa: 401
from .graph import Graph
from .node import Node
from .result import Result
from .graph_task import GraphTask
100 changes: 100 additions & 0 deletions cowait/tasks/graph/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from .node import Node
from .result import Result


def wrap_input(input) -> Result:
if isinstance(input, Node):
return Result(input, lambda output: output)
return input


class Graph(object):
def __init__(self):
self.nodes = []
self.todo = []
self.results = {}
self.errors = {}

@property
def completed(self):
return len(self.todo) == 0

def node(self, task: str, inputs: dict = {}):
node = Node(task, {
key: wrap_input(input)
for key, input in inputs.items()
})
self.nodes.append(node)
self.todo.append(node)
return node

def has_missing_input(self, node: Node) -> bool:
for _, input in node.inputs.items():
if not isinstance(input, Result):
continue
if input.node not in self.results:
return True
return False

def has_upstream_failure(self, node: Node) -> bool:
for _, input in node.inputs.items():
if not isinstance(input, Result):
continue
if input.node in self.errors:
return True
return False

def next(self) -> Node:
idx = 0
while idx < len(self.todo):
node = self.todo[idx]
# check for upstream failures
if self.has_upstream_failure(node):
self.fail(node, Exception('Upstream dependency failure'))
continue

# increment index after the failure check, since calling self.fail()
# will remove a node from the list.
idx += 1

# check if the node is ready for execution
if self.has_missing_input(node):
continue

# collect input values
args = {}
for key, input in node.inputs.items():
if isinstance(input, Result):
outputs = self.results[input.node]
args[key] = input.get(outputs)
else:
args[key] = input

self.todo.remove(node)
return node.with_inputs(args)

return None

def reset(self):
self.todo = self.nodes.copy()
self.errors = {}
self.results = {}

def complete(self, node, result):
if node not in self.nodes:
raise Exception('Unknown node', node)
if node in self.errors:
raise Exception('Node already failed')
if node in self.todo:
self.todo.remove(node)
self.results[node] = result

def fail(self, node, exception):
if node not in self.nodes:
raise Exception('Unknown node', node)
if node in self.errors:
raise Exception('Node already completed')
if node in self.todo:
self.todo.remove(node)
self.errors[node] = exception

89 changes: 89 additions & 0 deletions cowait/tasks/graph/graph_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
from cowait.tasks import Task
from .graph import Graph


class GraphTask(Task):
def init(self) -> dict:
return {
'ui': [
{
'component': 'ui.cowait.io/task-graph',
'path': 'graph',
},
]
}

async def define(self, graph, **inputs):
# this is where you would define your graph nodes
# to create a dag, override this function in a subclass
pass

async def run(self, **inputs):
graph = Graph()
await self.define(graph, **inputs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before we run this we should make sure it has no loops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a validation step sounds like a pretty good idea 😬

pending = []
task_nodes = {}
node_tasks = {}

async def send_state():
state = {}
for node in graph.nodes:
task = node_tasks.get(node, None)
state[node.id] = {
'id': str(node.id),
'task': node.task if not issubclass(node.task, Task) else node.task.__module__,
'depends_on': [str(edge.id) for edge in node.edges],
'task_id': None if not task else task.id,
}
await self.set_state({'graph': state})

await send_state()

# run until all nodes complete
while not graph.completed:
# launch tasks for each node that is ready for execution
while True:
node = graph.next()
if node is None:
break

task = self.spawn(node.task, inputs=node.inputs)
node_tasks[node] = task

# wrap the task in a future and store it in a mapping from futures -> node
# so we can find the node once the task completes
task = asyncio.ensure_future(task)
task_nodes[task] = node

pending.append(task)

await send_state()

# if everything is completed, exit
if len(pending) == 0:
break

# wait until any task finishes
done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)

# mark finished nodes as completed
for task in done:
node = task_nodes[task]

try:
# unpacking the result will throw an exception if the task failed
graph.complete(node, task.result())
except Exception as e:
graph.fail(node, e)

pending.remove(task)

if not graph.completed:
raise Exception('Some tasks failed to finish')

await send_state()

# return what?
return True
Loading