Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 1, 2023
1 parent f887760 commit b805c10
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 1,535 deletions.
106 changes: 68 additions & 38 deletions notebooks/L96_offline_NN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@
"val_size = 4000\n",
"\n",
"# Training Data\n",
"X_true_train = X_true[:-val_size, :] # Flatten because we first use single input as a sample\n",
"X_true_train = X_true[\n",
" :-val_size, :\n",
"] # Flatten because we first use single input as a sample\n",
"subgrid_tend_train = xy_true[:-val_size, :]\n",
"\n",
"# Test Data\n",
Expand Down Expand Up @@ -183,7 +185,7 @@
"outputs": [],
"source": [
"# Number of sample in each batch\n",
"BATCH_SIZE = 2000 \n",
"BATCH_SIZE = 2000\n",
"# this batch size would result in 2 test batches and 8 training batches."
]
},
Expand Down Expand Up @@ -262,7 +264,7 @@
"plt.figure(dpi=150)\n",
"plt.plot(X_iter, subgrid_tend_iter, \".\")\n",
"plt.xlabel(\"State - X\", fontsize=20)\n",
"plt.ylabel(\"Subgrid tendency - U\", fontsize=20);\n",
"plt.ylabel(\"Subgrid tendency - U\", fontsize=20)\n",
"plt.xlim([-12, 16])"
]
},
Expand Down Expand Up @@ -419,7 +421,9 @@
"net_input = torch.randn(1, 1)\n",
"out_linear = linear_network(net_input)\n",
"out_fcnn = fcnn_network(net_input)\n",
"print(f\"The output of the random input from the linear network is: {out_linear.item():.4f}\")\n",
"print(\n",
" f\"The output of the random input from the linear network is: {out_linear.item():.4f}\"\n",
")\n",
"print(f\"The output of the random input from the fcnn is: {out_fcnn.item():.4f}\")"
]
},
Expand Down Expand Up @@ -458,7 +462,7 @@
"\n",
"# Predict the output\n",
"y_tmp_linear = linear_network(torch.unsqueeze(X_tmp[0], 1))\n",
"y_tmp_fcnn = fcnn_network(torch.unsqueeze(X_tmp[0], 1))\n",
"y_tmp_fcnn = fcnn_network(torch.unsqueeze(X_tmp[0], 1))\n",
"\n",
"# Calculate the MSE loss\n",
"loss_linear = loss_fn(y_tmp_linear, torch.unsqueeze(X_tmp[1], 1))\n",
Expand Down Expand Up @@ -532,7 +536,7 @@
"outputs": [],
"source": [
"# switch the commenting out below to try a different optimizer.\n",
"#optimizer_linear = optim.SGD(linear_network.parameters(), lr=learning_rate, momentum=momentum)\n",
"# optimizer_linear = optim.SGD(linear_network.parameters(), lr=learning_rate, momentum=momentum)\n",
"optimizer_linear = optim.Adam(linear_network.parameters(), lr=learning_rate)\n",
"print(\"Before backward pass: \\n\", list(linear_network.parameters())[0].data.numpy())\n",
"\n",
Expand All @@ -549,7 +553,7 @@
"metadata": {},
"outputs": [],
"source": [
"#optimizer_fcnn = optim.SGD(fcnn_network.parameters(), lr=learning_rate, momentum=momentum)\n",
"# optimizer_fcnn = optim.SGD(fcnn_network.parameters(), lr=learning_rate, momentum=momentum)\n",
"optimizer_fcnn = optim.Adam(fcnn_network.parameters(), lr=learning_rate)"
]
},
Expand Down Expand Up @@ -695,7 +699,7 @@
"metadata": {},
"outputs": [],
"source": [
"#Epochs refer to the number of times we iterate over the entire training data during training.\n",
"# Epochs refer to the number of times we iterate over the entire training data during training.\n",
"n_epochs = 15"
]
},
Expand Down Expand Up @@ -732,16 +736,16 @@
"source": [
"plt.figure(dpi=150)\n",
"\n",
"plt.plot(train_loss_linear, label='Linear train loss')\n",
"plt.plot(test_loss_linear, linestyle='--', label='Linear test loss')\n",
"plt.plot(train_loss_linear, label=\"Linear train loss\")\n",
"plt.plot(test_loss_linear, linestyle=\"--\", label=\"Linear test loss\")\n",
"\n",
"plt.plot(train_loss_fcnn, label='FCNN train loss')\n",
"plt.plot(test_loss_fcnn, linestyle='--', label='FCNN test loss')\n",
"plt.plot(train_loss_fcnn, label=\"FCNN train loss\")\n",
"plt.plot(test_loss_fcnn, linestyle=\"--\", label=\"FCNN test loss\")\n",
"\n",
"plt.legend()\n",
"plt.xlabel(\"Iteration\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.yscale('log')\n",
"plt.yscale(\"log\")\n",
"plt.title(\"Loss vs Iteration\")\n",
"plt.show();"
]
Expand Down Expand Up @@ -770,9 +774,11 @@
"\n",
"\n",
"plt.figure(dpi=150)\n",
"plt.plot(predictions_linear.detach().numpy()[0:1000], label=\"Predicted from linear model\")\n",
"plt.plot(\n",
" predictions_linear.detach().numpy()[0:1000], label=\"Predicted from linear model\"\n",
")\n",
"plt.plot(predictions_fcnn.detach().numpy()[0:1000], label=\"Predicted from FCNN model\")\n",
"plt.plot(subgrid_tend_test[:1000, 1], label=\"True Values\", color='k', linestyle='--')\n",
"plt.plot(subgrid_tend_test[:1000, 1], label=\"True Values\", color=\"k\", linestyle=\"--\")\n",
"plt.legend(fontsize=7);"
]
},
Expand All @@ -791,10 +797,10 @@
"\n",
"\n",
"plt.figure(dpi=150)\n",
"plt.hist2d(np.reshape(X_true, -1), np.reshape(xy_true, -1), bins=91, cmap='Reds')\n",
"plt.hist2d(np.reshape(X_true, -1), np.reshape(xy_true, -1), bins=91, cmap=\"Reds\")\n",
"\n",
"plt.plot(X_points, linear_pred, \"-\", label='Linear predictions')\n",
"plt.plot(X_points, fcnn_pred, \"-\", label='FCNN predictions', color='g')\n",
"plt.plot(X_points, linear_pred, \"-\", label=\"Linear predictions\")\n",
"plt.plot(X_points, fcnn_pred, \"-\", label=\"FCNN predictions\", color=\"g\")\n",
"\n",
"plt.legend()\n",
"plt.xlim([-12, 16])\n",
Expand Down Expand Up @@ -837,22 +843,32 @@
"plt.figure(figsize=(12, 4), dpi=150)\n",
"\n",
"plt.subplot(131)\n",
"plt.hist2d(np.reshape(np.roll(X_true, -1, axis=1), -1), np.reshape(xy_true, -1), bins=91, cmap='Reds');\n",
"plt.hist2d(\n",
" np.reshape(np.roll(X_true, -1, axis=1), -1),\n",
" np.reshape(xy_true, -1),\n",
" bins=91,\n",
" cmap=\"Reds\",\n",
")\n",
"plt.xlim([-12, 16])\n",
"plt.xlabel(\"State - $X_{k-1}$\", fontsize=20)\n",
"plt.ylabel(\"Subgrid tendency - $U_{k}$\", fontsize=20);\n",
"plt.ylabel(\"Subgrid tendency - $U_{k}$\", fontsize=20)\n",
"\n",
"plt.subplot(132)\n",
"plt.hist2d(np.reshape(X_true, -1), np.reshape(xy_true, -1), bins=91, cmap='Reds');\n",
"plt.hist2d(np.reshape(X_true, -1), np.reshape(xy_true, -1), bins=91, cmap=\"Reds\")\n",
"plt.xlim([-12, 16])\n",
"plt.xlabel(\"State - $X_{k}$\", fontsize=20)\n",
"plt.ylabel(\"Subgrid tendency - $U_{k}$\", fontsize=20);\n",
"plt.ylabel(\"Subgrid tendency - $U_{k}$\", fontsize=20)\n",
"\n",
"plt.subplot(133)\n",
"plt.hist2d(np.reshape(np.roll(X_true, 1, axis=1), -1), np.reshape(xy_true, -1), bins=91, cmap='Reds');\n",
"plt.hist2d(\n",
" np.reshape(np.roll(X_true, 1, axis=1), -1),\n",
" np.reshape(xy_true, -1),\n",
" bins=91,\n",
" cmap=\"Reds\",\n",
")\n",
"plt.xlim([-12, 16])\n",
"plt.xlabel(\"State - $X_{k+1}$\", fontsize=20)\n",
"plt.ylabel(\"Subgrid tendency - $U_{k}$\", fontsize=20);\n",
"plt.ylabel(\"Subgrid tendency - $U_{k}$\", fontsize=20)\n",
"\n",
"plt.tight_layout()"
]
Expand Down Expand Up @@ -897,7 +913,8 @@
" torch.from_numpy(X_true_test), torch.from_numpy(subgrid_tend_test)\n",
")\n",
"nlocal_loader_test = Data.DataLoader(\n",
" dataset=nlocal_data_test, batch_size=BATCH_SIZE, shuffle=True)"
" dataset=nlocal_data_test, batch_size=BATCH_SIZE, shuffle=True\n",
")"
]
},
{
Expand Down Expand Up @@ -940,9 +957,11 @@
"metadata": {},
"outputs": [],
"source": [
"#optimizer_nonlocal_fcnn = optim.SGD(nonlocal_fcnn_network.parameters(),\n",
"# optimizer_nonlocal_fcnn = optim.SGD(nonlocal_fcnn_network.parameters(),\n",
"# lr=learning_rate, momentum=momentum)\n",
"optimizer_nonlocal_fcnn = optim.Adam(nonlocal_fcnn_network.parameters(), lr=learning_rate)"
"optimizer_nonlocal_fcnn = optim.Adam(\n",
" nonlocal_fcnn_network.parameters(), lr=learning_rate\n",
")"
]
},
{
Expand All @@ -954,8 +973,12 @@
"source": [
"n_epochs = 120\n",
"train_loss_nonlocal, test_loss_nonlocal = fit_model(\n",
" nonlocal_fcnn_network, loss_fn, optimizer_nonlocal_fcnn, \n",
" nlocal_loader_train, nlocal_loader_test, n_epochs\n",
" nonlocal_fcnn_network,\n",
" loss_fn,\n",
" optimizer_nonlocal_fcnn,\n",
" nlocal_loader_train,\n",
" nlocal_loader_test,\n",
" n_epochs,\n",
")"
]
},
Expand All @@ -968,16 +991,16 @@
"source": [
"plt.figure(dpi=150)\n",
"\n",
"plt.plot(train_loss_nonlocal, label='Non-local model train loss')\n",
"plt.plot(test_loss_nonlocal, linestyle='--', label='Non-local model test loss')\n",
"plt.plot(train_loss_nonlocal, label=\"Non-local model train loss\")\n",
"plt.plot(test_loss_nonlocal, linestyle=\"--\", label=\"Non-local model test loss\")\n",
"\n",
"plt.plot(train_loss_fcnn, label='local FCNN train loss')\n",
"plt.plot(test_loss_fcnn, linestyle='--', label='local FCNN test loss')\n",
"plt.plot(train_loss_fcnn, label=\"local FCNN train loss\")\n",
"plt.plot(test_loss_fcnn, linestyle=\"--\", label=\"local FCNN test loss\")\n",
"\n",
"plt.legend()\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.yscale('log')\n",
"plt.yscale(\"log\")\n",
"plt.title(\"Loss vs Epochs\")\n",
"plt.grid()\n",
"plt.show();"
Expand All @@ -1002,10 +1025,17 @@
"\n",
"\n",
"plt.figure(dpi=150)\n",
"plt.plot(predictions_linear.detach().numpy()[0:1000], label=\"Predicted from linear model\")\n",
"plt.plot(predictions_fcnn.detach().numpy()[0:1000], label=\"Predicted from local FCNN model\")\n",
"plt.plot(predictions_nonlocal_fcnn.detach().numpy()[0:1000,k_loc], label=\"Predicted from non-local FCNN model\")\n",
"plt.plot(subgrid_tend_test[:1000, 1], label=\"True Values\", color='k', linestyle='--')\n",
"plt.plot(\n",
" predictions_linear.detach().numpy()[0:1000], label=\"Predicted from linear model\"\n",
")\n",
"plt.plot(\n",
" predictions_fcnn.detach().numpy()[0:1000], label=\"Predicted from local FCNN model\"\n",
")\n",
"plt.plot(\n",
" predictions_nonlocal_fcnn.detach().numpy()[0:1000, k_loc],\n",
" label=\"Predicted from non-local FCNN model\",\n",
")\n",
"plt.plot(subgrid_tend_test[:1000, 1], label=\"True Values\", color=\"k\", linestyle=\"--\")\n",
"plt.legend(fontsize=7);"
]
},
Expand Down
46 changes: 28 additions & 18 deletions notebooks/L96_online_implement_NN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"# The model architectures\n",
"# ---------------------------\n",
"\n",
"\n",
"class LinearRegression(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
Expand All @@ -94,7 +95,8 @@
" # we call a object of this class\n",
" x = self.linear1(x)\n",
" return x\n",
" \n",
"\n",
"\n",
"class FCNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
Expand All @@ -108,7 +110,8 @@
" x = self.relu(self.linear1(x))\n",
" x = self.relu(self.linear2(x))\n",
" x = self.linear3(x)\n",
" return x \n",
" return x\n",
"\n",
"\n",
"class NonLocal_FCNN(nn.Module):\n",
" def __init__(self):\n",
Expand Down Expand Up @@ -192,8 +195,8 @@
"forcing = 18\n",
"dt = 0.01\n",
"\n",
"k=8\n",
"j=32 \n",
"k = 8\n",
"j = 32\n",
"\n",
"W = L96(k, j, F=forcing)\n",
"\n",
Expand Down Expand Up @@ -361,7 +364,9 @@
"\n",
"# Evaluate with nonlocal FCNN\n",
"gcm_nonlocal_net = GCM_network(forcing, nonlocal_fcnn_network)\n",
"Xnn_nonlocal, t = gcm_nonlocal_net(init_conditions, dt, int(T_test / dt), nonlocal_fcnn_network)"
"Xnn_nonlocal, t = gcm_nonlocal_net(\n",
" init_conditions, dt, int(T_test / dt), nonlocal_fcnn_network\n",
")"
]
},
{
Expand All @@ -384,11 +389,11 @@
"time_i = 200\n",
"plt.figure(dpi=150)\n",
"plt.plot(t[:time_i], X_full[:time_i, 4], label=\"Full L96\")\n",
"plt.plot(t[:time_i], X_no_param[:time_i, 4], '--', label=\"No parameterization\")\n",
"plt.plot(t[:time_i], X_no_param[:time_i, 4], \"--\", label=\"No parameterization\")\n",
"\n",
"plt.plot(t[:time_i], Xnn_linear[:time_i, 4], label=\"linear parameterization\")\n",
"\n",
"plt.plot(t[:time_i], Xnn_local[:time_i, 4], label=\"local NN\")\n",
"plt.plot(t[:time_i], Xnn_local[:time_i, 4], label=\"local NN\")\n",
"plt.plot(t[:time_i], Xnn_nonlocal[:time_i, 4], label=\"nonlocal NN\")\n",
"plt.legend(loc=\"upper left\", fontsize=7);"
]
Expand Down Expand Up @@ -417,37 +422,42 @@
"\n",
" # Evaluate with linear network\n",
" gcm_linear_net = GCM_network(forcing, linear_network)\n",
" Xnn_linear, t = gcm_linear_net(init_conditions_i, dt, int(T_test / dt), linear_network)\n",
" Xnn_linear, t = gcm_linear_net(\n",
" init_conditions_i, dt, int(T_test / dt), linear_network\n",
" )\n",
"\n",
" # Evaluate with local FCNN\n",
" gcm_local_net = GCM_network(forcing, local_fcnn_network)\n",
" Xnn_local, t = gcm_local_net(init_conditions_i, dt, int(T_test / dt), local_fcnn_network)\n",
" Xnn_local, t = gcm_local_net(\n",
" init_conditions_i, dt, int(T_test / dt), local_fcnn_network\n",
" )\n",
"\n",
" # Evaluate with nonlocal FCNN\n",
" gcm_nonlocal_net = GCM_network(forcing, nonlocal_fcnn_network)\n",
" Xnn_nonlocal, t = gcm_nonlocal_net(init_conditions_i, dt, int(T_test / dt), nonlocal_fcnn_network)\n",
" Xnn_nonlocal, t = gcm_nonlocal_net(\n",
" init_conditions_i, dt, int(T_test / dt), nonlocal_fcnn_network\n",
" )\n",
"\n",
" # GCM parameterized by the global 3-layer network\n",
" #gcm_net_3layers = GCM_network(forcing, nn_3l)\n",
" #Xnn_3layer_i, t = gcm_net_3layers(init_conditions_i, dt, int(T_test / dt), nn_3l)\n",
" # gcm_net_3layers = GCM_network(forcing, nn_3l)\n",
" # Xnn_3layer_i, t = gcm_net_3layers(init_conditions_i, dt, int(T_test / dt), nn_3l)\n",
"\n",
" # GCM parameterized by the linear network\n",
" #gcm_net_1layers = GCM_network(forcing, linear_network)\n",
" #Xnn_1layer_i, t = gcm_net_1layers(init_conditions_i, dt, int(T_test / dt), linear_network)\n",
" # gcm_net_1layers = GCM_network(forcing, linear_network)\n",
" # Xnn_1layer_i, t = gcm_net_1layers(init_conditions_i, dt, int(T_test / dt), linear_network)\n",
"\n",
" err_linear.append(\n",
" np.sum(np.abs(X_full[i * 10 : i * 10 + T_test * 100 + 1] - Xnn_linear))\n",
" )\n",
" \n",
"\n",
" err_local.append(\n",
" np.sum(np.abs(X_full[i * 10 : i * 10 + T_test * 100 + 1] - Xnn_local))\n",
" )\n",
" \n",
"\n",
" err_nonlocal.append(\n",
" np.sum(np.abs(X_full[i * 10 : i * 10 + T_test * 100 + 1] - Xnn_nonlocal))\n",
" )\n",
" \n",
" \n",
"\n",
"\n",
"print(f\"Sum of errors for linear: {sum(err_linear):.2f}\")\n",
"print(f\"Sum of errors for local neural network: {sum(err_local):.2f}\")\n",
Expand Down
Loading

0 comments on commit b805c10

Please sign in to comment.