Skip to content

Commit

Permalink
updated docstring of MovingHorizon forward. The output should indeed …
Browse files Browse the repository at this point in the history
…by (ndelay, batch, dim) dimension
  • Loading branch information
Birmiwal, Rahul R committed Oct 30, 2023
1 parent 3419669 commit bf86d5b
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions src/neuromancer/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward(self, input):
repeated ndelay times to initialize the buffer.
:param input: (dict: str: 2-d tensor (batch, dim)) Dictionary of single step tensor inputs
:return: (dict: str: 3-d Tensor (ndelay, batch, dim)) Dictionary of tensor outputs for the last ndelay times
:return: (dict: str: 3-d Tensor (ndelay, batch, dim)) Dictionary of tensor outputs
"""
for k in self.input_keys:
self.history[k].append(input[k])
Expand Down Expand Up @@ -150,25 +150,22 @@ def graph(self):
graph.add_edge(pydot.Edge(src.name, dst.name, label=key))
unique_common_keys.add(key)

# get keys of recurrent nodes
# build I/O and node loop connections
loop_keys = []
for node in self.nodes:
init_keys = []
previous_output_keys = []
for idx_node, node in enumerate(self.nodes):
node_loop_keys = set(node.input_keys) & set(node.output_keys)
loop_keys += node_loop_keys
# get keys required as input and to initialize some nodes
init_keys = set(input_keys) - (set(output_keys) - set(loop_keys))
init_keys += set(node.input_keys) - set(previous_output_keys)
previous_output_keys += node.output_keys

# build I/O and node loop connections
previous_output_keys = []
for idx_node, node in enumerate(self.nodes):
# build single node recurrent connections
node_loop_keys = list(set(node.input_keys) & set(node.output_keys))
for key in node_loop_keys:
graph.add_edge(pydot.Edge(node.name, node.name, label=key))
# build connections to the dataset
for key in set(node.input_keys) & set(init_keys-set(previous_output_keys)):
for key in set(node.input_keys) & set(init_keys):
graph.add_edge(pydot.Edge("in", node.name, label=key))
previous_output_keys += node.output_keys
# build feedback connections for init nodes
feedback_src_nodes = reverse_order_nodes[:-1-idx_node]
if len(set(node.input_keys) & set(loop_keys) & set(init_keys)) > 0:
Expand All @@ -177,6 +174,7 @@ def graph(self):
if key in src.output_keys and key not in previous_output_keys:
graph.add_edge(pydot.Edge(src.name, node.name, label=key))
break

# build connections to the output of the system in a reversed order
previous_output_keys = []
for node in self.nodes[::-1]:
Expand All @@ -189,7 +187,7 @@ def graph(self):
return graph

def show(self, figname=None):
graph = self.system_graph
graph = self.graph()
if figname is not None:
plot_func = {'svg': graph.write_svg,
'png': graph.write_png,
Expand Down Expand Up @@ -256,10 +254,4 @@ def forward(self, input_dict):
indata = {k: data[k][:, i] for k in node.input_keys} # collect what the compute node needs from data nodes
outdata = node(indata) # compute
data = self.cat(data, outdata) # feed the data nodes
return data # return recorded system measurements






return data # return recorded system measurements

0 comments on commit bf86d5b

Please sign in to comment.