Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mfkans #201

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions examples/KANs/README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# Kolmogorov-Arnold Networks in Neuromancer

This directory contains interactive examples that can serve as a step-by-step tutorial
showcasing the capabilities of Kolmogorov-Arnold Networks (KANs) and finite basis KANs (FBKANs) in Neuromancer.
This directory contains interactive examples that can serve as a step-by-step tutorial showcasing the capabilities of Kolmogorov-Arnold Networks (KANs), finite basis KANs (FBKANs) and multi-fidelity KANs (MFKANs) in Neuromancer.

Examples of learning from multiscale, noisy data with KANs and FBKANs:
+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/feature/fbkans/examples/KANs/p1_fbkan_vs_kan_noise_data_1d.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 1: A comparison of KANs and FBKANs in learning a 1D multiscale function with noise
+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/feature/fbkans/examples/KANs/p2_fbkan_vs_kan_noise_data_2d.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 2: A comparison of KANs and FBKANs in learning a 2D multiscale function with noise

Examples of learning multi-fidelity data with MFKANs:
+ <a target="_blank" href="https://colab.research.google.com/github/pnnl/neuromancer/blob/feature/mfkans/examples/KANs/p3_mfkan_example_1d.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> Part 3: A comparison of KANs and MFKANs in learning a 1D jump function with abundant, low-fidelity data and sparse, high-fidelity data

## Kolmogorov-Arnold Networks (KANs)
Based on the Kolmogorov-Arnold representation theorem, KANs offer an alternative architecture: where traditional neural networks utilize fixed activation functions, KANs employ learnable activation functions on the edges of the network, replacing linear weight parameters with parametrized spline functions. This fundamental shift sometimes enhances model interpretability and improves computational efficiency and accuracy [1]. KANs are available on Neuromancer via `blocks.KANBlock`, which leverages the efficient-kan implementation of [2]. Moreover, users can leverage the finite basis KANs (FBKANs), a domain decomposition method for KANs proposed by Howard et al. (2024)[3] by simply setting the `num_domains` argument in `blocks.KANBlock`.
Based on the Kolmogorov-Arnold representation theorem, KANs offer an alternative architecture: where traditional neural networks utilize fixed activation functions, KANs employ learnable activation functions on the edges of the network, replacing linear weight parameters with parametrized spline functions. This fundamental shift sometimes enhances model interpretability and improves computational efficiency and accuracy [1]. KANs are available on Neuromancer via `blocks.KANBlock`, which leverages the efficient-kan implementation of [2]. Moreover, users can leverage the finite basis KANs (FBKANs), a domain decomposition method for KANs proposed by Howard et al. (2024)[3] by simply setting the `num_domains` argument in `blocks.KANBlock`. Users can also leverage multi-fidelity KANs (MFKANs) via `blocks.MultiFidelityKAN`.

### References

[1] [Liu, Ziming, et al. (2024). KAN: Kolmogorov-Arnold Networks.](https://arxiv.org/abs/2404.19756)

[2] https://github.com/Blealtan/efficient-kan

[3] Howard, Amanda A., et al. (2024) Finite basis Kolmogorov-Arnold networks: domain decomposition for data-diven and physics-informed problems.
[3] [Howard, Amanda A., et al. (2024) Finite basis Kolmogorov-Arnold networks: domain decomposition for data-diven and physics-informed problems.](https://arxiv.org/abs/2406.19662)

[4] [Howard, Amanda A., et al. (2024) Multifidelity Kolmogorov-Arnold networks.](https://arxiv.org/abs/2410.14764)
200 changes: 103 additions & 97 deletions examples/KANs/p1_fbkan_vs_kan_noise_data_1d.ipynb

Large diffs are not rendered by default.

98 changes: 51 additions & 47 deletions examples/KANs/p2_fbkan_vs_kan_noise_data_2d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,17 @@
"source": [
"import torch\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from neuromancer.dataset import DictDataset\n",
"from neuromancer.modules import blocks\n",
"from neuromancer.system import Node\n",
"from neuromancer.system import Node, System\n",
"from neuromancer.constraint import variable\n",
"from neuromancer.loss import PenaltyLoss\n",
"from neuromancer.problem import Problem\n",
"from neuromancer.trainer import Trainer\n",
"from neuromancer.callbacks import Callback\n",
"from neuromancer.loggers import LossLogger\n"
]
},
Expand Down Expand Up @@ -344,14 +346,14 @@
"\n",
"- `x`, `y`: Input variables, where $x, y \\in [0, 1]^2$.\n",
"- `z`: True target values from the function $f(x, y)$.\n",
"- `z_hat`: Predicted values produced by either the KAN or FBKAN model.\n",
"- `z_hat`: Predicted values produced by either the KAN or FBKAN model, $\\hat{z}$.\n",
"\n",
"**Data Loss for FBKAN:**\n",
"\n",
"The data loss for FBKAN, denoted as `loss_data_fbkan`, measures the mean squared error (MSE) between the FBKAN predictions, `z_hat`, and the true values, `z`:\n",
"\n",
"$$\n",
"\\ell_{\\text{data, FBKAN}} = \\text{scaling} \\cdot \\frac{1}{N_{\\text{data}}} \\sum_{i=1}^{N_{\\text{data}}} \\left( z_i - z_{\\text{hat, FBKAN}} \\right)^2\n",
"\\ell_{\\text{data, FBKAN}} = \\text{scaling} \\cdot \\frac{1}{N_{\\text{data}}} \\sum_{i=1}^{N_{\\text{data}}} \\left( z_i - \\hat{z}_i \\right)^2\n",
"$$\n",
"\n",
"This loss guides the FBKAN model to approximate the target function values accurately.\n",
Expand All @@ -361,7 +363,7 @@
"Similarly, the data loss for KAN, denoted as `loss_data_kan`, is the mean squared error between the KAN predictions, `z_hat`, and the true target values, `z`:\n",
"\n",
"$$\n",
"\\ell_{\\text{data, KAN}} = \\text{scaling} \\cdot \\frac{1}{N_{\\text{data}}} \\sum_{i=1}^{N_{\\text{data}}} \\left( z_i - z_{\\text{hat, KAN}} \\right)^2\n",
"\\ell_{\\text{data, KAN}} = \\text{scaling} \\cdot \\frac{1}{N_{\\text{data}}} \\sum_{i=1}^{N_{\\text{data}}} \\left( z_i - \\hat{z}_i \\right)^2\n",
"$$\n",
"\n",
"This loss term helps the KAN model learn to approximate the target function.\n",
Expand Down Expand Up @@ -409,7 +411,7 @@
"id": "9af5ad88-a719-4da0-b75c-b5fe2ac6b41b",
"metadata": {},
"source": [
"### Construct the Neuromancer Problem objects and train"
"### Construct the Neuromancer `Problem` objects and train"
]
},
{
Expand Down Expand Up @@ -441,6 +443,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"None\n",
"None\n",
"Number of parameters: 600\n",
"Number of parameters: 150\n"
]
Expand Down Expand Up @@ -502,27 +506,27 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0\ttrain_loss: 0.19858\tdev_loss: 0.18781\teltime: 0.06342\n",
"epoch: 50\ttrain_loss: 0.12017\tdev_loss: 0.11974\teltime: 1.78908\n",
"epoch: 100\ttrain_loss: 0.03547\tdev_loss: 0.04448\teltime: 3.25373\n",
"epoch: 150\ttrain_loss: 0.01443\tdev_loss: 0.01751\teltime: 4.10649\n",
"epoch: 200\ttrain_loss: 0.01119\tdev_loss: 0.01346\teltime: 5.55846\n",
"epoch: 250\ttrain_loss: 0.00829\tdev_loss: 0.00942\teltime: 6.65685\n",
"epoch: 300\ttrain_loss: 0.00721\tdev_loss: 0.00813\teltime: 8.02346\n",
"epoch: 350\ttrain_loss: 0.00956\tdev_loss: 0.00810\teltime: 9.76293\n",
"epoch: 400\ttrain_loss: 0.00555\tdev_loss: 0.00580\teltime: 11.26922\n",
"epoch: 450\ttrain_loss: 0.00501\tdev_loss: 0.00493\teltime: 12.25867\n",
"epoch: 500\ttrain_loss: 0.00445\tdev_loss: 0.00428\teltime: 13.93477\n",
"epoch: 550\ttrain_loss: 0.00413\tdev_loss: 0.00444\teltime: 15.11188\n",
"epoch: 600\ttrain_loss: 0.00534\tdev_loss: 0.00503\teltime: 16.35734\n",
"epoch: 650\ttrain_loss: 0.00339\tdev_loss: 0.00344\teltime: 18.05322\n",
"epoch: 700\ttrain_loss: 0.00299\tdev_loss: 0.00280\teltime: 19.55740\n",
"epoch: 750\ttrain_loss: 0.00307\tdev_loss: 0.00257\teltime: 20.53209\n",
"epoch: 800\ttrain_loss: 0.00279\tdev_loss: 0.00238\teltime: 22.24692\n",
"epoch: 850\ttrain_loss: 0.00280\tdev_loss: 0.00239\teltime: 23.44692\n",
"epoch: 900\ttrain_loss: 0.00482\tdev_loss: 0.00311\teltime: 24.60944\n",
"epoch: 950\ttrain_loss: 0.00264\tdev_loss: 0.00213\teltime: 26.23807\n",
"epoch: 1000\ttrain_loss: 0.00259\tdev_loss: 0.00205\teltime: 27.82801\n"
"epoch: 0\ttrain_loss: 0.19858\tdev_loss: 0.18781\teltime: 0.04408\n",
"epoch: 50\ttrain_loss: 0.12017\tdev_loss: 0.11974\teltime: 0.98760\n",
"epoch: 100\ttrain_loss: 0.03547\tdev_loss: 0.04448\teltime: 1.88288\n",
"epoch: 150\ttrain_loss: 0.01443\tdev_loss: 0.01751\teltime: 2.69070\n",
"epoch: 200\ttrain_loss: 0.01119\tdev_loss: 0.01346\teltime: 3.48212\n",
"epoch: 250\ttrain_loss: 0.00829\tdev_loss: 0.00942\teltime: 4.28909\n",
"epoch: 300\ttrain_loss: 0.00721\tdev_loss: 0.00813\teltime: 5.13658\n",
"epoch: 350\ttrain_loss: 0.00956\tdev_loss: 0.00810\teltime: 5.89438\n",
"epoch: 400\ttrain_loss: 0.00555\tdev_loss: 0.00580\teltime: 6.76285\n",
"epoch: 450\ttrain_loss: 0.00501\tdev_loss: 0.00493\teltime: 7.56369\n",
"epoch: 500\ttrain_loss: 0.00445\tdev_loss: 0.00428\teltime: 8.55607\n",
"epoch: 550\ttrain_loss: 0.00413\tdev_loss: 0.00444\teltime: 9.78196\n",
"epoch: 600\ttrain_loss: 0.00534\tdev_loss: 0.00503\teltime: 10.75619\n",
"epoch: 650\ttrain_loss: 0.00339\tdev_loss: 0.00344\teltime: 11.84819\n",
"epoch: 700\ttrain_loss: 0.00299\tdev_loss: 0.00280\teltime: 12.94066\n",
"epoch: 750\ttrain_loss: 0.00307\tdev_loss: 0.00257\teltime: 14.01460\n",
"epoch: 800\ttrain_loss: 0.00279\tdev_loss: 0.00238\teltime: 15.07571\n",
"epoch: 850\ttrain_loss: 0.00280\tdev_loss: 0.00239\teltime: 16.25827\n",
"epoch: 900\ttrain_loss: 0.00482\tdev_loss: 0.00311\teltime: 17.41800\n",
"epoch: 950\ttrain_loss: 0.00264\tdev_loss: 0.00213\teltime: 18.54224\n",
"epoch: 1000\ttrain_loss: 0.00259\tdev_loss: 0.00205\teltime: 19.57284\n"
]
}
],
Expand All @@ -546,27 +550,27 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0\ttrain_loss: 0.20121\tdev_loss: 0.19113\teltime: 27.86444\n",
"epoch: 50\ttrain_loss: 0.16842\tdev_loss: 0.16688\teltime: 28.25043\n",
"epoch: 100\ttrain_loss: 0.14238\tdev_loss: 0.14534\teltime: 28.72706\n",
"epoch: 150\ttrain_loss: 0.11018\tdev_loss: 0.11803\teltime: 29.45312\n",
"epoch: 200\ttrain_loss: 0.07499\tdev_loss: 0.08277\teltime: 30.31180\n",
"epoch: 250\ttrain_loss: 0.06059\tdev_loss: 0.06680\teltime: 31.02909\n",
"epoch: 300\ttrain_loss: 0.05511\tdev_loss: 0.06288\teltime: 31.39583\n",
"epoch: 350\ttrain_loss: 0.05344\tdev_loss: 0.06120\teltime: 31.81557\n",
"epoch: 400\ttrain_loss: 0.05226\tdev_loss: 0.05941\teltime: 32.31351\n",
"epoch: 450\ttrain_loss: 0.05113\tdev_loss: 0.05787\teltime: 32.85939\n",
"epoch: 500\ttrain_loss: 0.04937\tdev_loss: 0.05685\teltime: 33.51572\n",
"epoch: 550\ttrain_loss: 0.04648\tdev_loss: 0.05583\teltime: 34.15521\n",
"epoch: 600\ttrain_loss: 0.04318\tdev_loss: 0.04911\teltime: 34.89664\n",
"epoch: 650\ttrain_loss: 0.04042\tdev_loss: 0.04821\teltime: 35.47894\n",
"epoch: 700\ttrain_loss: 0.03868\tdev_loss: 0.04386\teltime: 36.02937\n",
"epoch: 750\ttrain_loss: 0.03666\tdev_loss: 0.04242\teltime: 36.55968\n",
"epoch: 800\ttrain_loss: 0.03454\tdev_loss: 0.03878\teltime: 37.04607\n",
"epoch: 850\ttrain_loss: 0.03277\tdev_loss: 0.03831\teltime: 37.70012\n",
"epoch: 900\ttrain_loss: 0.03190\tdev_loss: 0.03761\teltime: 38.33910\n",
"epoch: 950\ttrain_loss: 0.03140\tdev_loss: 0.03736\teltime: 38.94758\n",
"epoch: 1000\ttrain_loss: 0.03165\tdev_loss: 0.03805\teltime: 39.31244\n"
"epoch: 0\ttrain_loss: 0.20121\tdev_loss: 0.19113\teltime: 19.64045\n",
"epoch: 50\ttrain_loss: 0.16842\tdev_loss: 0.16688\teltime: 20.06003\n",
"epoch: 100\ttrain_loss: 0.14238\tdev_loss: 0.14534\teltime: 20.53543\n",
"epoch: 150\ttrain_loss: 0.11018\tdev_loss: 0.11803\teltime: 20.97161\n",
"epoch: 200\ttrain_loss: 0.07499\tdev_loss: 0.08277\teltime: 21.44121\n",
"epoch: 250\ttrain_loss: 0.06059\tdev_loss: 0.06680\teltime: 21.91846\n",
"epoch: 300\ttrain_loss: 0.05511\tdev_loss: 0.06288\teltime: 22.31885\n",
"epoch: 350\ttrain_loss: 0.05344\tdev_loss: 0.06120\teltime: 22.70428\n",
"epoch: 400\ttrain_loss: 0.05226\tdev_loss: 0.05941\teltime: 23.09720\n",
"epoch: 450\ttrain_loss: 0.05113\tdev_loss: 0.05787\teltime: 23.51434\n",
"epoch: 500\ttrain_loss: 0.04937\tdev_loss: 0.05685\teltime: 24.02329\n",
"epoch: 550\ttrain_loss: 0.04648\tdev_loss: 0.05583\teltime: 24.45336\n",
"epoch: 600\ttrain_loss: 0.04318\tdev_loss: 0.04911\teltime: 24.96189\n",
"epoch: 650\ttrain_loss: 0.04042\tdev_loss: 0.04821\teltime: 25.38162\n",
"epoch: 700\ttrain_loss: 0.03868\tdev_loss: 0.04386\teltime: 25.88007\n",
"epoch: 750\ttrain_loss: 0.03666\tdev_loss: 0.04242\teltime: 26.37139\n",
"epoch: 800\ttrain_loss: 0.03454\tdev_loss: 0.03878\teltime: 26.78114\n",
"epoch: 850\ttrain_loss: 0.03277\tdev_loss: 0.03831\teltime: 27.16932\n",
"epoch: 900\ttrain_loss: 0.03190\tdev_loss: 0.03761\teltime: 27.55376\n",
"epoch: 950\ttrain_loss: 0.03140\tdev_loss: 0.03736\teltime: 28.00943\n",
"epoch: 1000\ttrain_loss: 0.03165\tdev_loss: 0.03805\teltime: 28.40723\n"
]
}
],
Expand Down
Loading