Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ArmoniKGraph class #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ classifiers = [
]
dependencies = [
'armonik>=3.16.1',
'rustworkx',
'numpy',
]

Expand Down Expand Up @@ -45,6 +46,7 @@ tests = [
'pytest',
'pytest-cov',
'pytest-benchmark[histogram]',
'pytest-mock',
]
samples = [
'matplotlib'
Expand Down
127 changes: 127 additions & 0 deletions src/armonik_analytics/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from collections import defaultdict
from dataclasses import dataclass

import rustworkx as rx

from armonik.client import ArmoniKTasks
from armonik.common import Filter
from grpc import Channel


@dataclass
class ArmoniKGraphAttr:
"""A class for ArmoniKGraph attributes.

Parameters
----------
name : str
Name of the graph.
description : str
Description of the graph.
task_filter : armonik.common.Filter
Task filter defining the graph.
"""

name: str
description: str
task_filter: Filter


class ArmoniKGraph:
"""A class to represent workloads on ArmoniK. The execution of a program corresponds to a graph,
each node of which is a task. This class provides this representation and makes it possible to
analyse an execution.

The flexibility of ArmoniK means that a workload can share its session with other workloads, or
be dispersed between several sessions or partitions. To identify a workload, and therefore the
corresponding graph, the user must provide a filter on the tasks within a cluster.

In this graph, only the tasks are represented. Results are not included in the graph.

Each node contains an 'armonik.common.Task' object which contains the task metadata up to date
at the time the graph is loaded.

Parameters
----------
task_filter : armonik.common.Filter
A filter identifying the tasks belonging to the graph within the ArmoniK cluster.
name : str | None
An optional name for the graph. Default is None.
description: str | None
An optional description for the graph. Default is None.


Example
-------
>>> from armonik.client import TaskFieldFilter
>>> from armonik_analytics import ArmoniKGraph
>>> g = ArmoniKGraph(task_filter=(TaskFieldFilter.SESSION_ID == "session_id"))

"""

def __init__(
self, task_filter: Filter, name: str | None = None, description: str | None = None
) -> None:
self.graph = rx.PyDiGraph(check_cycle=True, multigraph=True)
self.graph.attrs = ArmoniKGraphAttr(
name=name if name else "",
description=description if description else "",
task_filter=task_filter,
)

def update(self, channel: Channel) -> None:
"""Updates in-place the contents of the graph from the state database of a running cluster.

Note that this operation deletes and then re-downloads the content, even if it has not changed.
This can take a significant amount of time.

Parameters
----------
channel : grpc.Channel
An open gRPC channel to the running cluster.
"""
# Clear current content. Should be improve in future versions.
self.graph.clear()

client = ArmoniKTasks(channel)

# Tasks depend on each other through their input/output data. The following dictionary is
# used to build dependencies between tasks. It stores for each input/output data item which
# unique task produces it and which task(s) consume(s) it.
edges = defaultdict(lambda: [None, []])

# Iterates over all tasks corresponding to the filter defining the graph
page = 0
total, tasks = client.list_tasks(task_filter=self.graph.attrs.task_filter, with_errors=True)
while tasks:
for task in tasks:
# Add task to graph
node_id = self.graph.add_node(task)
# Add task inputs/outputs to 'edges' dictionnary
for in_data_dep in task.data_dependencies:
edges[in_data_dep][1].append(node_id)
for out_data_dep in task.expected_output_ids:
# An output data can only be produced by a single task. However, an ArmoniK
# graph is dynamic and a task can transfer its responsibility for generating
# an output to another task. This operation is not reflected in the task's
# metadata. So care must be taken to select only the task that actually produces
# the result. This is the task that has all the other tasks as parents among the
# tasks claiming to produce this output data.
old_tail_id = edges[out_data_dep][0]
if old_tail_id:
if task.id in self.graph.get_node_data(old_tail_id).parent_task_ids:
node_id = old_tail_id
edges[out_data_dep][0] = node_id

page += 1
_, tasks = client.list_tasks(task_filter=self.graph.attrs.task_filter)

# Once built, the 'edges' dictionary is used to construct dependencies between tasks in
# the graph.
for tail, heads in edges.values():
# Root input data have no tails (not produced by any task) and don't correspond to any
# dependency between two tasks. Such data are ignored.
if tail is not None:
self.graph.add_edges_from_no_data([(tail, head) for head in heads])

assert self.graph.num_nodes() == total
94 changes: 94 additions & 0 deletions tests/unit/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import grpc
import rustworkx as rx
import pytest

from armonik.client import ArmoniKTasks, TaskFieldFilter
from armonik.common import Task

from armonik_analytics.graph import ArmoniKGraph


class ListTasks:
def __init__(self, tasks):
self.tasks = tasks
self.call_count = 0

def __call__(self, *args, **kwds):
self.call_count += 1
if self.call_count == 1:
return len(self.tasks), self.tasks
else:
return len(self.tasks), []


def single_node():
tasks = [
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]),
]
graph = rx.PyDiGraph(check_cycle=True)
graph.add_nodes_from(tasks)

return tasks, graph


def three_parallel_nodes():
tasks = [
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]),
Task(id="t1", data_dependencies=["i1"], expected_output_ids=["o1"]),
Task(id="t2", data_dependencies=["i2"], expected_output_ids=["o2"]),
]
graph = rx.PyDiGraph(check_cycle=True)
graph.add_nodes_from(tasks)

return tasks, graph


def three_dependant_nodes():
tasks = [
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]),
Task(id="t1", data_dependencies=["o0"], expected_output_ids=["o1"]),
Task(id="t2", data_dependencies=["o0"], expected_output_ids=["o2"]),
]
graph = rx.PyDiGraph(check_cycle=True)
graph.add_nodes_from(tasks)
graph.add_edges_from_no_data([(0, 1), (0, 2)])

return tasks, graph


def seven_dependant_nodes():
tasks = [
Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]),
Task(id="t1", data_dependencies=["o0"], expected_output_ids=["o1"]),
Task(id="t2", data_dependencies=["o0"], expected_output_ids=["o2"]),
Task(id="t3", data_dependencies=["o1", "o2"], expected_output_ids=["o3"]),
Task(id="t4", data_dependencies=["o1"], expected_output_ids=["o4"]),
Task(id="t5", data_dependencies=["o4"], expected_output_ids=["o5"]),
Task(id="t6", data_dependencies=["o3"], expected_output_ids=["o6", "o7"]),
]
graph = rx.PyDiGraph(check_cycle=True)
graph.add_nodes_from(tasks)
graph.add_edges_from_no_data([(0, 1), (0, 2), (1, 3), (1, 4), (2, 3), (4, 5), (3, 6)])

return tasks, graph


@pytest.mark.parametrize(
("tasks", "expected_graph"),
[
generator.__call__()
for generator in [
single_node,
three_parallel_nodes,
three_dependant_nodes,
seven_dependant_nodes,
]
],
)
def test_graph_update(mocker, tasks, expected_graph):
mocker.patch.object(ArmoniKTasks, "list_tasks", new=ListTasks(tasks))
g = ArmoniKGraph(task_filter=(TaskFieldFilter.SESSION_ID == "session_id"))
with grpc.insecure_channel("host") as channel:
g.update(channel)
assert g.graph.nodes() == expected_graph.nodes()
assert g.graph.edges() == expected_graph.edges()
Loading