forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
151 lines (120 loc) · 4.86 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
import logging
from .plugin import Plugin
class Agent():
"""
Agents create/manage a pipeline of plugins
"""
def __init__(self, pipeline=[], **kwargs):
"""
pipeline should be a list of source plugins from the graph
"""
if isinstance(pipeline, Plugin):
self.pipeline = [pipeline]
elif isinstance(pipeline, list):
self.pipeline = pipeline
else:
raise TypeError(f"expected Plugin or list[Plugin] for 'pipeline' argument (was {type(pipeline)})")
self.save_mermaid = kwargs.get('save_mermaid')
def process(self, input, channel=0, **kwargs):
"""
Add data to the pipeline's input queue.
channel is the index of the source plugin from the constructor.
"""
if len(self.pipeline) == 0:
raise NotImplementedError(f"{type(self)} has not implemented a pipeline")
self.pipeline[channel].input(input, **kwargs)
def __call__(self, input, channel=0, **kwargs):
"""
Operator overload for process()
"""
return self.process(input, channel, **kwargs)
def start(self):
"""
Start threads for all plugins in the graph that have threading enabled.
"""
for channel in self.pipeline:
channel.start()
return self
def run(self, timeout=None):
"""
Run the agent forever or return after the specified timeout (in seconds)
"""
self.start()
if self.save_mermaid:
self.to_mermaid(save=self.save_mermaid)
logging.success(f"{type(self).__name__} - system ready")
self.pipeline[0].join(timeout)
return self
def to_mermaid(self, save=None):
"""
Return or save mermaid diagram of the pipeline
"""
from .utils import get_class_that_defined_method
from .plugins import Callback
nodes = []
def get_node_name(plugin):
if isinstance(plugin, Callback):
return get_class_that_defined_method(plugin.function).__name__ \
+ '.' + plugin.function.__name__
return type(plugin).__name__
def get_nodes(plugin):
for node in nodes:
if node['plugin'] == plugin:
return
type_name = get_node_name(plugin)
inst_name = type_name
while inst_name in [node['inst_name'] for node in nodes]:
if inst_name[-1].isdigit():
inst_name = inst_name[:-1] + f"{int(inst_name[-1]) + 1}"
else:
inst_name = inst_name + '_1'
if plugin.threaded:
node_shape = ('[',']')
else:
node_shape = ('[[',']]')
nodes.append({
'plugin': plugin,
'type_name': type_name,
'inst_name': inst_name,
'shape': node_shape,
})
for output_channel in plugin.outputs:
for output in output_channel:
get_nodes(output)
def find_node(plugin):
for node in nodes:
if node['plugin'] == plugin:
return node
return None
for plugin in self.pipeline:
get_nodes(plugin)
text = "---\n"
text += f"title: {type(self).__name__}\n"
text += "---\n"
text += "graph\n"
for node in nodes:
text += f'{node["inst_name"]}{node["shape"][0]}"{node["type_name"]}"{node["shape"][1]}\n'
for node in nodes:
for c, output_channel in enumerate(node['plugin'].outputs):
for output in output_channel:
if c == 0:
text += f'{node["inst_name"]} ---> {find_node(output)["inst_name"]}\n'
else:
text += f'{node["inst_name"]} -- channel {c} ---> {find_node(output)["inst_name"]}\n'
if save:
with open(save, 'w') as file:
file.write(text)
logging.info(f"saved pipeline mermaid to {save}")
return text
def Pipeline(plugins):
"""
Connect the `plugins` list feed-forward style where each is an input to the next.
This uses plugin.add(), but specifying pipelines in list notation can be cleaner.
Returns the first plugin in the pipeline, from which other plugins can be found.
"""
if len(plugins) == 0:
return None
for i in range(len(plugins)-1):
plugins[i].add(plugins[i+1])
return [plugins[0]]