diff --git a/examples/neural_operators/Part_1_antiderivative_aligned.ipynb b/examples/neural_operators/Part_1_antiderivative_aligned.ipynb
new file mode 100644
index 00000000..912c136d
--- /dev/null
+++ b/examples/neural_operators/Part_1_antiderivative_aligned.ipynb
@@ -0,0 +1,923 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Antiderivative Operator - Aligned Dataset\n",
+ "\n",
+ "This tutorial demonstrates the use of learning neural operators for a data driven use case (non-physics informed). \n",
+ "\n",
+ "### References\n",
+ "[1] [Antiderivative operator from an aligned dataset - DeepXDE](https://deepxde.readthedocs.io/en/latest/demos/operator/antiderivative_aligned.html)\n",
+ "\n",
+ "[2] [DeepONet Tutorial in JAX](https://github.com/Ceyron/machine-learning-and-simulation/blob/main/english/neural_operators/simple_deepOnet_in_JAX.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install (Colab only)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:18.434791Z",
+ "start_time": "2024-08-28T23:45:18.432335Z"
+ }
+ },
+ "source": [
+ "#%pip install \"neuromancer[examples] @ git+https://github.com/pnnl/neuromancer.git@master\"\n",
+ "#%pip install watermark"
+ ],
+ "outputs": [],
+ "execution_count": 8
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:18.468181Z",
+ "start_time": "2024-08-28T23:45:18.465571Z"
+ }
+ },
+ "source": [
+ "import os\n",
+ "\n",
+ "from IPython.display import clear_output\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from pathlib import Path\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torch.utils.data import DataLoader\n",
+ "import time\n",
+ "from scipy.integrate import simpson, cumulative_trapezoid\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "os.environ[\"DDE_BACKEND\"] = \"pytorch\"\n",
+ "import deepxde as dde\n",
+ "# FIXME only for development\n",
+ "import sys\n",
+ "sys.path.insert(0, '../../src')"
+ ],
+ "outputs": [],
+ "execution_count": 9
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:18.479529Z",
+ "start_time": "2024-08-28T23:45:18.477254Z"
+ }
+ },
+ "source": [
+ "from neuromancer.callbacks import Callback\n",
+ "from neuromancer.constraint import variable\n",
+ "from neuromancer.dataset import DictDataset\n",
+ "from neuromancer.loss import PenaltyLoss\n",
+ "from neuromancer.modules.blocks import MLP\n",
+ "from neuromancer.modules.activations import activations\n",
+ "from neuromancer.problem import Problem\n",
+ "from neuromancer.system import Node\n",
+ "from neuromancer.trainer import Trainer\n",
+ "from neuromancer.dynamics.operators import DeepONet"
+ ],
+ "outputs": [],
+ "execution_count": 10
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:18.489448Z",
+ "start_time": "2024-08-28T23:45:18.486697Z"
+ }
+ },
+ "source": [
+ "# PyTorch random seed\n",
+ "torch.manual_seed(1234)\n",
+ "\n",
+ "# NumPy random seed\n",
+ "np.random.seed(1234)\n",
+ "\n",
+ "# Device configuration\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
+ ],
+ "outputs": [],
+ "execution_count": 11
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Problem Setup\n",
+ "\n",
+ "original source: [https://deepxde.readthedocs.io/en/latest/demos/operator/antiderivative_aligned.html](https://deepxde.readthedocs.io/en/latest/demos/operator/antiderivative_aligned.html) \n",
+ "\n",
+ "We will learn the antiderivative operator \n",
+ "\n",
+ "$$G : v \\mapsto u$$\n",
+ "\n",
+ "defined by an ODE\n",
+ "\n",
+ "$$\\frac{du(x)}{dx} = v(x),\\;\\;x\\in [0,1]$$\n",
+ "\n",
+ "**Initial Condition:** \n",
+ "$$u(0) = (0)$$\n",
+ "\n",
+ "We learn *G* from a dataset. Each data point in the dataset is one pair of (v,u), generated as follows:\n",
+ "\n",
+ "1. A random function *v* is sampled from a Gaussian random field (GRF) with the resolution m = 100.\n",
+ "2. Solve *u* for *v* numerically. We assume that for each *u*, we have the values of *u(x)* in the same Nu = 100 locations. Because we have the values of *u(x)* in the same locations, we call this dataset as \"aligned data\".\n",
+ "\n",
+ "* Dataset information\n",
+ " * The training dataset has size 150.\n",
+ " * The testing dataset has size 1000. (We split this into a dev/test split of size 500 each)\n",
+ " * Input of the branch net: the functions *v*. It is a matrix of shape (dataset size, m), e.g., (150, 100) for the training dataset.\n",
+ " * Input of the trunk net: the locations *x* of *u(x)*. It is a matrix of shape (*Nu*, dimension)\n",
+ " * i.e., (100,1) for both training and testing datasets.\n",
+ " * Output: The values of *u(x)* in different locations for different *v*. It is a matrix of shape (dataset size, *Nu*).\n",
+ " * e.g., (150, 100) for the training dataset.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Dataset Prep"
+ ]
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:25.717053Z",
+ "start_time": "2024-08-28T23:45:18.498354Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "t = 100\n",
+ "space = dde.data.GRF(N=t, length_scale=1)\n",
+ "features = space.random(size=50000)\n",
+ "h = 1/t\n",
+ "sensors = np.linspace(0, 1, num=t)[:, None]\n",
+ "y = space.eval_batch(features, sensors)\n",
+ "anti_y = []\n",
+ "print()\n",
+ "for yi in y:\n",
+ " s0 = 0 # Initial Condition\n",
+ " # Explicit Euler Method\n",
+ " s = np.zeros(t)\n",
+ " s[0] = s0\n",
+ " for i in range(0, t - 1):\n",
+ " s[i + 1] = s[i] + h*yi[i]\n",
+ " #plt.figure()\n",
+ " #plt.plot(sensors, yi, 'g', label=\"yi\")\n",
+ " # integrate\n",
+ " anti_y.append(s)\n",
+ " #plt.plot(sensors, s, 'b', label=\"integral yi\")\n",
+ " #plt.legend(loc='lower right')\n",
+ "anti_y = np.array(anti_y)\n",
+ "#plt.show()"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "execution_count": 12
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:25.726261Z",
+ "start_time": "2024-08-28T23:45:25.723616Z"
+ }
+ },
+ "source": [
+ "def prepare_data(dataset, name):\n",
+ " ## Note: transposing branch input because DictDataset in Neuromancer needs all tensors in the dict to have the same shape at index 0\n",
+ " branch_inputs = dataset[\"X\"][0].T\n",
+ " trunk_inputs = dataset[\"X\"][1]\n",
+ " outputs = dataset[\"y\"].T\n",
+ "\n",
+ " Nu = outputs.shape[0]\n",
+ " Nsamples = outputs.shape[1]\n",
+ " print(f'{name} dataset: Nu = {Nu}, Nsamples = {Nsamples}')\n",
+ "\n",
+ " # convert to pytorch tensors of float type\n",
+ " t_branch_inputs = torch.from_numpy(branch_inputs).float()\n",
+ " t_trunk_inputs = torch.from_numpy(trunk_inputs).float()\n",
+ " t_outputs = torch.from_numpy(outputs).float()\n",
+ "\n",
+ " data = DictDataset({\n",
+ " \"branch_inputs\": t_branch_inputs,\n",
+ " \"trunk_inputs\": t_trunk_inputs,\n",
+ " \"outputs\": t_outputs\n",
+ " }, name=name)\n",
+ "\n",
+ " return data, Nu"
+ ],
+ "outputs": [],
+ "execution_count": 13
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Create named dictionary datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:25.771091Z",
+ "start_time": "2024-08-28T23:45:25.732999Z"
+ }
+ },
+ "source": [
+ "# getting the shape of the generated data the same as the sample data\n",
+ "#branch_inputs = dataset_train[\"X\"][0].T\n",
+ "#trunk_inputs = dataset_train[\"X\"][1]\n",
+ "#outputs = dataset_train[\"y\"].T\n",
+ "new_branch_inputs = y.T\n",
+ "new_trunk_inputs = sensors\n",
+ "new_outputs = anti_y.T\n",
+ "branch_inputs_train, branch_inputs_test, outputs_train, outputs_test = train_test_split(y, anti_y, test_size=0.8)\n",
+ "branch_inputs_dev, branch_inputs_test, outputs_dev, outputs_test = train_test_split(branch_inputs_test, outputs_test, test_size=0.5)\n",
+ "new_branch_inputs = branch_inputs_train.T\n",
+ "new_trunk_inputs = sensors\n",
+ "new_outputs = outputs_train.T\n",
+ "\n",
+ "train_data = DictDataset({\n",
+ " \"branch_inputs\": torch.from_numpy(branch_inputs_train.T).float(),\n",
+ " \"trunk_inputs\": torch.from_numpy(new_trunk_inputs).float(),\n",
+ " \"outputs\": torch.from_numpy(outputs_train.T).float()\n",
+ "}, name=\"train\")\n",
+ "dev_data = DictDataset({\n",
+ " \"branch_inputs\": torch.from_numpy(branch_inputs_dev.T).float(),\n",
+ " \"trunk_inputs\": torch.from_numpy(new_trunk_inputs).float(),\n",
+ " \"outputs\": torch.from_numpy(outputs_dev.T).float()\n",
+ "}, name=\"dev\")\n",
+ "test_data = DictDataset({\n",
+ " \"branch_inputs\": torch.from_numpy(branch_inputs_test.T).float(),\n",
+ " \"trunk_inputs\": torch.from_numpy(new_trunk_inputs).float(),\n",
+ " \"outputs\": torch.from_numpy(outputs_test.T).float()\n",
+ "}, name=\"test\")\n",
+ "Nu = t\n",
+ "print()"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "execution_count": 14
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Create torch DataLoaders for the Trainer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:25.839432Z",
+ "start_time": "2024-08-28T23:45:25.836501Z"
+ }
+ },
+ "source": [
+ "batch_size = 100\n",
+ "print(f\"batch_size: {batch_size}\")\n",
+ "train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=train_data.collate_fn, shuffle=False)\n",
+ "dev_loader = DataLoader(dev_data, batch_size=batch_size, collate_fn=dev_data.collate_fn, shuffle=False)\n",
+ "test_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=test_data.collate_fn, shuffle=False)"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "batch_size: 100\n"
+ ]
+ }
+ ],
+ "execution_count": 15
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define node"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:25.934539Z",
+ "start_time": "2024-08-28T23:45:25.930468Z"
+ }
+ },
+ "source": [
+ "in_size_branch = Nu\n",
+ "width_size = 40\n",
+ "depth_branch = 2\n",
+ "interact_size = 40\n",
+ "in_size_trunk = 1\n",
+ "depth_trunk = 2\n",
+ "branch_net = MLP(\n",
+ " insize=in_size_branch,\n",
+ " outsize=interact_size,\n",
+ " nonlin=nn.ReLU,\n",
+ " hsizes=[width_size] * depth_branch,\n",
+ " bias=True,\n",
+ ")\n",
+ "trunk_net = MLP(\n",
+ " insize=in_size_trunk,\n",
+ " outsize=interact_size,\n",
+ " nonlin=nn.ReLU,\n",
+ " hsizes=[width_size] * depth_trunk,\n",
+ " bias=True,\n",
+ ")\n",
+ "deeponet = DeepONet(\n",
+ " branch_net=branch_net,\n",
+ " trunk_net=trunk_net,\n",
+ " bias=True\n",
+ ")"
+ ],
+ "outputs": [],
+ "execution_count": 16
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:25.961139Z",
+ "start_time": "2024-08-28T23:45:25.959066Z"
+ }
+ },
+ "source": [
+ "node_deeponet = Node(deeponet, ['branch_inputs', 'trunk_inputs'], ['g'], name=\"deeponet\")\n",
+ "print(node_deeponet)"
+ ],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "deeponet(branch_inputs, trunk_inputs) -> g\n"
+ ]
+ }
+ ],
+ "execution_count": 17
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Objective and Constraints in NeuroMANCER\n",
+ "\n",
+ "We use Mean Squared Error(MSE) for our loss function\n",
+ "\n",
+ "$$\\sum_{i=1}^{D}(x_i-y_i)^2$$\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:25.975872Z",
+ "start_time": "2024-08-28T23:45:25.971519Z"
+ }
+ },
+ "source": [
+ "var_y_est = variable(\"g\")\n",
+ "var_y_true = variable(\"outputs\")\n",
+ "\n",
+ "nodes = [node_deeponet]\n",
+ "\n",
+ "var_loss = (var_y_est == var_y_true.T)^2\n",
+ "var_loss.name = \"residual_loss\"\n",
+ "objectives = [var_loss]\n",
+ "\n",
+ "loss = PenaltyLoss(objectives, constraints=[])\n",
+ "\n",
+ "problem = Problem(nodes, loss=loss, grad_inference=True)\n"
+ ],
+ "outputs": [],
+ "execution_count": 18
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-28T23:45:26.480299Z",
+ "start_time": "2024-08-28T23:45:25.986037Z"
+ }
+ },
+ "source": [
+ "problem.show()"
+ ],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "