From 0aa62bf04103c460fdfe820031e71cab12c037d7 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Tue, 16 Apr 2024 19:04:18 +0200 Subject: [PATCH] feat: add ArmoniKGraph class --- pyproject.toml | 2 + src/armonik_analytics/graph.py | 127 +++++++++++++++++++++++++++++++++ tests/unit/test_graph.py | 94 ++++++++++++++++++++++++ 3 files changed, 223 insertions(+) create mode 100644 src/armonik_analytics/graph.py create mode 100644 tests/unit/test_graph.py diff --git a/pyproject.toml b/pyproject.toml index e78b3a2..de550ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ ] dependencies = [ 'armonik>=3.16.1', + 'rustworkx', 'numpy', ] @@ -45,6 +46,7 @@ tests = [ 'pytest', 'pytest-cov', 'pytest-benchmark[histogram]', + 'pytest-mock', ] samples = [ 'matplotlib' diff --git a/src/armonik_analytics/graph.py b/src/armonik_analytics/graph.py new file mode 100644 index 0000000..25dcbfd --- /dev/null +++ b/src/armonik_analytics/graph.py @@ -0,0 +1,127 @@ +from collections import defaultdict +from dataclasses import dataclass + +import rustworkx as rx + +from armonik.client import ArmoniKTasks +from armonik.common import Filter +from grpc import Channel + + +@dataclass +class ArmoniKGraphAttr: + """A class for ArmoniKGraph attributes. + + Parameters + ---------- + name : str + Name of the graph. + description : str + Description of the graph. + task_filter : armonik.common.Filter + Task filter defining the graph. + """ + + name: str + description: str + task_filter: Filter + + +class ArmoniKGraph: + """A class to represent workloads on ArmoniK. The execution of a program corresponds to a graph, + each node of which is a task. This class provides this representation and makes it possible to + analyse an execution. + + The flexibility of ArmoniK means that a workload can share its session with other workloads, or + be dispersed between several sessions or partitions. To identify a workload, and therefore the + corresponding graph, the user must provide a filter on the tasks within a cluster. + + In this graph, only the tasks are represented. Results are not included in the graph. + + Each node contains an 'armonik.common.Task' object which contains the task metadata up to date + at the time the graph is loaded. + + Parameters + ---------- + task_filter : armonik.common.Filter + A filter identifying the tasks belonging to the graph within the ArmoniK cluster. + name : str | None + An optional name for the graph. Default is None. + description: str | None + An optional description for the graph. Default is None. + + + Example + ------- + >>> from armonik.client import TaskFieldFilter + >>> from armonik_analytics import ArmoniKGraph + >>> g = ArmoniKGraph(task_filter=(TaskFieldFilter.SESSION_ID == "session_id")) + + """ + + def __init__( + self, task_filter: Filter, name: str | None = None, description: str | None = None + ) -> None: + self.graph = rx.PyDiGraph(check_cycle=True, multigraph=True) + self.graph.attrs = ArmoniKGraphAttr( + name=name if name else "", + description=description if description else "", + task_filter=task_filter, + ) + + def update(self, channel: Channel) -> None: + """Updates in-place the contents of the graph from the state database of a running cluster. + + Note that this operation deletes and then re-downloads the content, even if it has not changed. + This can take a significant amount of time. + + Parameters + ---------- + channel : grpc.Channel + An open gRPC channel to the running cluster. + """ + # Clear current content. Should be improve in future versions. + self.graph.clear() + + client = ArmoniKTasks(channel) + + # Tasks depend on each other through their input/output data. The following dictionary is + # used to build dependencies between tasks. It stores for each input/output data item which + # unique task produces it and which task(s) consume(s) it. + edges = defaultdict(lambda: [None, []]) + + # Iterates over all tasks corresponding to the filter defining the graph + page = 0 + total, tasks = client.list_tasks(task_filter=self.graph.attrs.task_filter, with_errors=True) + while tasks: + for task in tasks: + # Add task to graph + node_id = self.graph.add_node(task) + # Add task inputs/outputs to 'edges' dictionnary + for in_data_dep in task.data_dependencies: + edges[in_data_dep][1].append(node_id) + for out_data_dep in task.expected_output_ids: + # An output data can only be produced by a single task. However, an ArmoniK + # graph is dynamic and a task can transfer its responsibility for generating + # an output to another task. This operation is not reflected in the task's + # metadata. So care must be taken to select only the task that actually produces + # the result. This is the task that has all the other tasks as parents among the + # tasks claiming to produce this output data. + old_tail_id = edges[out_data_dep][0] + if old_tail_id: + if task.id in self.graph.get_node_data(old_tail_id).parent_task_ids: + node_id = old_tail_id + edges[out_data_dep][0] = node_id + + page += 1 + _, tasks = client.list_tasks(task_filter=self.graph.attrs.task_filter) + + # Once built, the 'edges' dictionary is used to construct dependencies between tasks in + # the graph. + for tail, heads in edges.values(): + # Root input data have no tails (not produced by any task) and don't correspond to any + # dependency between two tasks. Such data are ignored. + if tail is not None: + self.graph.add_edges_from_no_data([(tail, head) for head in heads]) + + assert self.graph.num_nodes() == total diff --git a/tests/unit/test_graph.py b/tests/unit/test_graph.py new file mode 100644 index 0000000..9724245 --- /dev/null +++ b/tests/unit/test_graph.py @@ -0,0 +1,94 @@ +import grpc +import rustworkx as rx +import pytest + +from armonik.client import ArmoniKTasks, TaskFieldFilter +from armonik.common import Task + +from armonik_analytics.graph import ArmoniKGraph + + +class ListTasks: + def __init__(self, tasks): + self.tasks = tasks + self.call_count = 0 + + def __call__(self, *args, **kwds): + self.call_count += 1 + if self.call_count == 1: + return len(self.tasks), self.tasks + else: + return len(self.tasks), [] + + +def single_node(): + tasks = [ + Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), + ] + graph = rx.PyDiGraph(check_cycle=True) + graph.add_nodes_from(tasks) + + return tasks, graph + + +def three_parallel_nodes(): + tasks = [ + Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), + Task(id="t1", data_dependencies=["i1"], expected_output_ids=["o1"]), + Task(id="t2", data_dependencies=["i2"], expected_output_ids=["o2"]), + ] + graph = rx.PyDiGraph(check_cycle=True) + graph.add_nodes_from(tasks) + + return tasks, graph + + +def three_dependant_nodes(): + tasks = [ + Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), + Task(id="t1", data_dependencies=["o0"], expected_output_ids=["o1"]), + Task(id="t2", data_dependencies=["o0"], expected_output_ids=["o2"]), + ] + graph = rx.PyDiGraph(check_cycle=True) + graph.add_nodes_from(tasks) + graph.add_edges_from_no_data([(0, 1), (0, 2)]) + + return tasks, graph + + +def seven_dependant_nodes(): + tasks = [ + Task(id="t0", data_dependencies=["i0"], expected_output_ids=["o0"]), + Task(id="t1", data_dependencies=["o0"], expected_output_ids=["o1"]), + Task(id="t2", data_dependencies=["o0"], expected_output_ids=["o2"]), + Task(id="t3", data_dependencies=["o1", "o2"], expected_output_ids=["o3"]), + Task(id="t4", data_dependencies=["o1"], expected_output_ids=["o4"]), + Task(id="t5", data_dependencies=["o4"], expected_output_ids=["o5"]), + Task(id="t6", data_dependencies=["o3"], expected_output_ids=["o6", "o7"]), + ] + graph = rx.PyDiGraph(check_cycle=True) + graph.add_nodes_from(tasks) + graph.add_edges_from_no_data([(0, 1), (0, 2), (1, 3), (1, 4), (2, 3), (4, 5), (3, 6)]) + + return tasks, graph + + +@pytest.mark.parametrize( + ("tasks", "expected_graph"), + [ + generator.__call__() + for generator in [ + single_node, + three_parallel_nodes, + three_dependant_nodes, + seven_dependant_nodes, + ] + ], +) +def test_graph_update(mocker, tasks, expected_graph): + mocker.patch.object(ArmoniKTasks, "list_tasks", new=ListTasks(tasks)) + g = ArmoniKGraph(task_filter=(TaskFieldFilter.SESSION_ID == "session_id")) + with grpc.insecure_channel("host") as channel: + g.update(channel) + assert g.graph.nodes() == expected_graph.nodes() + assert g.graph.edges() == expected_graph.edges()