diff --git a/examples/ODEs/Part_7_DeepKoopman.ipynb b/examples/ODEs/Part_7_DeepKoopman.ipynb index 92643551..f49c2363 100644 --- a/examples/ODEs/Part_7_DeepKoopman.ipynb +++ b/examples/ODEs/Part_7_DeepKoopman.ipynb @@ -28,12 +28,14 @@ "[3] [E. Yeung, S. Kundu and N. Hodas, \"Learning Deep Neural Network Representations for Koopman Operators of Nonlinear Dynamical Systems,\" 2019 American Control Conference (ACC), Philadelphia, PA, USA, 2019, pp. 4832-4839](https://ieeexplore.ieee.org/document/8815339) \n", "[4] [Shaowu Pan, Karthik Duraisamy, Physics-Informed Probabilistic Learning of Linear Embeddings of Non-linear Dynamics With Guaranteed Stability, SIAM Journal on Applied Dynamical Systems, 2020](https://epubs.siam.org/doi/abs/10.1137/19M1267246?journalCode=sjaday) \n", "[5] [F. Fan, B. Yi, D. Rye, G. Shi and I. R. Manchester, \"Learning Stable Koopman Embeddings,\" 2022 American Control Conference (ACC), Atlanta, GA, USA, 2022, pp. 2742-2747](https://ieeexplore.ieee.org/document/9867865) \n", + "[6] https://pubs.aip.org/aip/cha/article-abstract/22/4/047510/341880/Applied-Koopmanisma \n", + "[7] https://arxiv.org/abs/1312.0041 \n", "\n", "\n", "### Generic Stable Layers References\n", - "[6] [E. Skomski, S. Vasisht, C. Wight, A. Tuor, J. Drgoňa and D. Vrabie, \"Constrained Block Nonlinear Neural Dynamical Models,\" 2021 American Control Conference (ACC), New Orleans, LA, USA, 2021, pp. 3993-4000](https://ieeexplore.ieee.org/document/9482930) \n", - "[7] [J. Drgoňa, A. Tuor, S. Vasisht and D. Vrabie, \"Dissipative Deep Neural Dynamical Systems,\" in IEEE Open Journal of Control Systems, 2022](https://ieeexplore.ieee.org/abstract/document/9809789) \n", - "[8] [Jiong Zhang, Qi Lei, Inderjit S. Dhillon, Stabilizing Gradients for Deep Neural Networks via Efficient SVD Parameterization, InternationalConferenceonMachine Learning, 2018](https://arxiv.org/abs/1803.09327)\n", + "[8] [E. Skomski, S. Vasisht, C. Wight, A. Tuor, J. Drgoňa and D. Vrabie, \"Constrained Block Nonlinear Neural Dynamical Models,\" 2021 American Control Conference (ACC), New Orleans, LA, USA, 2021, pp. 3993-4000](https://ieeexplore.ieee.org/document/9482930) \n", + "[9] [J. Drgoňa, A. Tuor, S. Vasisht and D. Vrabie, \"Dissipative Deep Neural Dynamical Systems,\" in IEEE Open Journal of Control Systems, 2022](https://ieeexplore.ieee.org/abstract/document/9809789) \n", + "[10] [Jiong Zhang, Qi Lei, Inderjit S. Dhillon, Stabilizing Gradients for Deep Neural Networks via Efficient SVD Parameterization, InternationalConferenceonMachine Learning, 2018](https://arxiv.org/abs/1803.09327)\n", "\n" ] }, @@ -321,7 +323,7 @@ "$$\\ell_{V} = || I - VV^T||_2 + || I - V^TV||_2 $$\n", "$$\\ell_{\\text{stable}} = \\ell_{U} + \\ell_{V} $$\n", "\n", - "For more details on the SVD and other linear algebra factorizations of trainable linear layers see the references [[6]](https://ieeexplore.ieee.org/document/9482930) and [[7]](https://ieeexplore.ieee.org/abstract/document/9809789), with Pytorch implementations in the [slim submodule](https://github.com/pnnl/neuromancer/tree/master/src/neuromancer/slim) of the Neuromancer library. " + "For more details on the SVD and other linear algebra factorizations of trainable linear layers see the references [[8]](https://ieeexplore.ieee.org/document/9482930) and [[9]](https://ieeexplore.ieee.org/abstract/document/9809789), with Pytorch implementations in the [slim submodule](https://github.com/pnnl/neuromancer/tree/master/src/neuromancer/slim) of the Neuromancer library. " ] }, { diff --git a/examples/ODEs/Part_7_DeepKoopman.py b/examples/ODEs/Part_7_DeepKoopman.py index 9d5fef70..4d0b7c41 100644 --- a/examples/ODEs/Part_7_DeepKoopman.py +++ b/examples/ODEs/Part_7_DeepKoopman.py @@ -6,9 +6,11 @@ [2] https://ieeexplore.ieee.org/document/8815339 [3] https://arxiv.org/abs/1710.04340 [4] https://nicholasgeneva.com/deep-learning/koopman/dynamics/2020/05/30/intro-to-koopman.html +[5] https://pubs.aip.org/aip/cha/article-abstract/22/4/047510/341880/Applied-Koopmanisma +[6] https://arxiv.org/abs/1312.0041 references stability: -[5] https://ieeexplore.ieee.org/document/9482930 +[7] https://ieeexplore.ieee.org/document/9482930 """ diff --git a/examples/ODEs/Part_8_nonauto_DeepKoopman.py b/examples/ODEs/Part_8_nonauto_DeepKoopman.py index 27c26ebe..a407f896 100644 --- a/examples/ODEs/Part_8_nonauto_DeepKoopman.py +++ b/examples/ODEs/Part_8_nonauto_DeepKoopman.py @@ -7,9 +7,11 @@ [3] https://ieeexplore.ieee.org/document/9799788 [4] https://ieeexplore.ieee.org/document/9022864 [5] https://github.com/HaojieSHI98/DeepKoopmanWithControl +[6] https://pubs.aip.org/aip/cha/article-abstract/22/4/047510/341880/Applied-Koopmanisma +[7] https://arxiv.org/abs/1312.0041 references stability: -[6] https://ieeexplore.ieee.org/document/9482930 +[8] https://ieeexplore.ieee.org/document/9482930 """ diff --git a/examples/control/Part_4_NODE_control.py b/examples/control/Part_4_NODE_control.py index 0875ceca..a012f842 100644 --- a/examples/control/Part_4_NODE_control.py +++ b/examples/control/Part_4_NODE_control.py @@ -136,6 +136,9 @@ def forward(self, x, R): features = torch.cat([x, R], dim=-1) return self.net(features) +# fix model parameters +system_nodel.requires_grad_(False) + insize = 2*nx policy = Policy(insize, nu) policy_node = Node(policy, ['xn', 'R'], ['U'], name='policy') diff --git a/examples/control/Part_6_DeepKoopman_DPC.py b/examples/control/Part_6_DeepKoopman_DPC.py index 9e24d1d2..d43e9f26 100644 --- a/examples/control/Part_6_DeepKoopman_DPC.py +++ b/examples/control/Part_6_DeepKoopman_DPC.py @@ -177,7 +177,7 @@ def forward(self, x, u): nonlin=torch.nn.ELU, hsizes=n_layers*[n_hidden]) # predicted trajectory decoder - decode_y = Node(f_y_inv, ['x'], ['yhat'], name='decoder_y') + decode_y = Node(f_y_inv, ['x'], ['y'], name='decoder_y') # instantiate SVD factorized Koopman operator with bounded eigenvalues K = slim.linear.SVDLinear(nx_koopman, nx_koopman, @@ -195,10 +195,11 @@ def forward(self, x, u): # latent Koopmann rollout dynamics_model = System([Koopman], name='Koopman', nsteps=nsteps) + dynamics_model.show() # variables Y = variable("Y") # observed - yhat = variable('yhat') # predicted output + yhat = variable('y') # predicted output x_latent = variable('x_latent') # encoded output trajectory in the latent space u_latent = variable('u_latent') # encoded input trajectory in the latent space x = variable('x') # Koopman latent space trajectory @@ -254,7 +255,7 @@ def forward(self, x, u): problem.nodes[3].nsteps = test_data['Y'].shape[1] test_outputs = problem.step(test_data) - pred_traj = test_outputs['yhat'][:, 1:-1, :].detach().numpy().reshape(-1, nx).T + pred_traj = test_outputs['y'][:, 1:-1, :].detach().numpy().reshape(-1, nx).T true_traj = test_data['Y'][:, 1:, :].detach().numpy().reshape(-1, nx).T input_traj = test_data['U'].detach().numpy().reshape(-1, nu).T @@ -302,7 +303,7 @@ def forward(self, x, u): ref = torch.cat(list_refs) batched_ref = ref.reshape([n_samples, nsteps+1, nref]) # Training dataset - train_data = DictDataset({'x': torch.rand(n_samples, 1, nx), + train_data = DictDataset({'y': torch.rand(n_samples, 1, ny), 'r': batched_ref}, name='train') # references for dev set @@ -310,7 +311,7 @@ def forward(self, x, u): ref = torch.cat(list_refs) batched_ref = ref.reshape([n_samples, nsteps+1, nref]) # Development dataset - dev_data = DictDataset({'x': torch.rand(n_samples, 1, nx), + dev_data = DictDataset({'y': torch.rand(n_samples, 1, ny), 'r': batched_ref}, name='dev') # torch dataloaders @@ -322,18 +323,25 @@ def forward(self, x, u): collate_fn=dev_data.collate_fn, shuffle=False) + """ + # # # Deep Koopman DPC architecture + """ + # state encoder + encode_y = Node(f_y, ['y'], ['x'], name='encoder_y') - - # initial condition encoder - encode_y = Node(f_y, ['y'], ['x'], name='encoder_Y0') + # fix parameters of the Koopman model + encode_y.requires_grad_(False) + encode_U.requires_grad_(False) + dynamics_model.requires_grad_(False) + decode_y.requires_grad_(False) # neural net control policy net = blocks.MLP(insize=nx + nref, outsize=nu, hsizes=4*[32], nonlin=torch.nn.ELU) - policy_node = Node(net, ['y', 'R'], ['U'], name='policy') + policy = Node(net, ['y', 'r'], ['U'], name='policy') - nodes = [encode_y, policy_node, encode_U, dynamics_model, decode_y] - cl_system = System(nodes, name='cl_system') + nodes = [encode_y, policy, encode_U, dynamics_model, decode_y] + cl_system = System(nodes, name='cl_system', nsteps=nsteps) cl_system.show() """ diff --git a/src/neuromancer/psl/nonautonomous.py b/src/neuromancer/psl/nonautonomous.py index b59fa9f4..8032bb15 100644 --- a/src/neuromancer/psl/nonautonomous.py +++ b/src/neuromancer/psl/nonautonomous.py @@ -255,6 +255,7 @@ def equations(self, t, x, u): * Caf = Feed Concentration (mol/m^3) States (2): * Concentration of A in CSTR (mol/m^3) + * Temperature in CSTR (K) """ Tc = u # Temperature of cooling jacket (K) Ca = x[0] # Concentration of A in CSTR (mol/m^3) diff --git a/src/neuromancer/system.py b/src/neuromancer/system.py index 13eed3c6..4ea053e9 100644 --- a/src/neuromancer/system.py +++ b/src/neuromancer/system.py @@ -150,25 +150,22 @@ def graph(self): graph.add_edge(pydot.Edge(src.name, dst.name, label=key)) unique_common_keys.add(key) - # get keys of recurrent nodes + # build I/O and node loop connections loop_keys = [] - for node in self.nodes: + init_keys = [] + previous_output_keys = [] + for idx_node, node in enumerate(self.nodes): node_loop_keys = set(node.input_keys) & set(node.output_keys) loop_keys += node_loop_keys - # get keys required as input and to initialize some nodes - init_keys = set(input_keys) - (set(output_keys) - set(loop_keys)) + init_keys += set(node.input_keys) - set(previous_output_keys) + previous_output_keys += node.output_keys - # build I/O and node loop connections - previous_output_keys = [] - for idx_node, node in enumerate(self.nodes): # build single node recurrent connections - node_loop_keys = list(set(node.input_keys) & set(node.output_keys)) for key in node_loop_keys: graph.add_edge(pydot.Edge(node.name, node.name, label=key)) # build connections to the dataset - for key in set(node.input_keys) & set(init_keys-set(previous_output_keys)): + for key in set(node.input_keys) & set(init_keys): graph.add_edge(pydot.Edge("in", node.name, label=key)) - previous_output_keys += node.output_keys # build feedback connections for init nodes feedback_src_nodes = reverse_order_nodes[:-1-idx_node] if len(set(node.input_keys) & set(loop_keys) & set(init_keys)) > 0: @@ -177,6 +174,7 @@ def graph(self): if key in src.output_keys and key not in previous_output_keys: graph.add_edge(pydot.Edge(src.name, node.name, label=key)) break + # build connections to the output of the system in a reversed order previous_output_keys = [] for node in self.nodes[::-1]: @@ -189,7 +187,7 @@ def graph(self): return graph def show(self, figname=None): - graph = self.system_graph + graph = self.graph() if figname is not None: plot_func = {'svg': graph.write_svg, 'png': graph.write_png,