diff --git a/examples/trace_regularization.ipynb b/examples/trace_regularization.ipynb new file mode 100644 index 0000000..374888d --- /dev/null +++ b/examples/trace_regularization.ipynb @@ -0,0 +1,486 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a85c316a-b950-448e-a264-4d0e6983790e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-09-13 14:27:14.105479: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2023-09-13 14:27:16.852688: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2023-09-13 14:27:16.932821: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2023-09-13 14:27:26.407757: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "/tikhome/knikolaou/miniconda3/envs/jax/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2023-09-13 14:27:50.219477: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n", + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "data": { + "text/html": [ + "
Using backend: cpu\n",
+       "
\n" + ], + "text/plain": [ + "Using backend: cpu\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Available hardware:\n",
+       "
\n" + ], + "text/plain": [ + "Available hardware:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
TFRT_CPU_0\n",
+       "
\n" + ], + "text/plain": [ + "TFRT_CPU_0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", + "\n", + "import jax\n", + "import znnl\n", + "from neural_tangents import stax\n", + "import copy\n", + "import optax\n", + "\n", + "from flax import linen as nn\n", + "import flax\n", + "import jax.nn.initializers as inits\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import numpy as onp\n", + "import jax.numpy as np\n", + "# import time" + ] + }, + { + "cell_type": "markdown", + "id": "4d5500f3", + "metadata": {}, + "source": [ + "# Including the trace regularization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afc55b14", + "metadata": {}, + "outputs": [], + "source": [ + "from flax.training.train_state import TrainState\n", + "from znnl.training_strategies import SimpleTraining\n", + "from typing import Callable, List, Optional, Tuple, Union\n", + "from znnl.accuracy_functions.accuracy_function import AccuracyFunction\n", + "from znnl.models.jax_model import JaxModel\n", + "from znnl.training_strategies.recursive_mode import RecursiveMode" + ] + }, + { + "cell_type": "code", + "execution_count": 333, + "id": "a8eea84c", + "metadata": {}, + "outputs": [], + "source": [ + " class RegularizedTraining(SimpleTraining):\n", + "\n", + " def __init__(\n", + " self,\n", + " model: Union[JaxModel, None],\n", + " loss_fn: Callable,\n", + " accuracy_fn: AccuracyFunction = None,\n", + " seed: int = None,\n", + " recursive_mode: RecursiveMode = None,\n", + " disable_loading_bar: bool = False,\n", + " recorders: List[\"JaxRecorder\"] = None,\n", + " regularization: float = 0.0, \n", + " ):\n", + " \"\"\"\n", + " Construct a simple training strategy for a model.\n", + "\n", + " Parameters\n", + " ----------\n", + " model : Union[JaxModel, None]\n", + " Model class for a Jax model.\n", + " \"None\" is only used if the training strategy is passed as an input\n", + " to a bigger framework. The strategy then is applied to the framework\n", + " and the model instantiation is handled by that framework.\n", + " loss_fn : Callable\n", + " A function to use in the loss computation.\n", + " accuracy_fn : AccuracyFunction (default = None)\n", + " Funktion class for computing the accuracy of model and given data.\n", + " seed : int (default = None)\n", + " Random seed for the RNG. Uses a random int if not specified.\n", + " recursive_mode : RecursiveMode\n", + " Defining the recursive mode that can be used in training.\n", + " If the recursive mode is used, the training will be performed until a\n", + " condition is fulfilled.\n", + " disable_loading_bar : bool\n", + " Disable the output visualization of the loading bar.\n", + " recorders : List[JaxRecorder]\n", + " A list of recorders to monitor model training.\n", + " \"\"\"\n", + " super().__init__(\n", + " model=model,\n", + " loss_fn=loss_fn,\n", + " accuracy_fn=accuracy_fn,\n", + " seed=seed,\n", + " recursive_mode=recursive_mode,\n", + " disable_loading_bar=disable_loading_bar,\n", + " recorders=recorders,\n", + " )\n", + " self.regularization = regularization\n", + "\n", + " \n", + " def _train_step(self, state: TrainState, batch: dict):\n", + " \"\"\"\n", + " Train a single step.\n", + "\n", + " Parameters\n", + " ----------\n", + " state : TrainState\n", + " Current state of the neural network.\n", + " batch : dict\n", + " Batch of data to train on.\n", + "\n", + " Returns\n", + " -------\n", + " state : dict\n", + " Updated state of the neural network.\n", + " metrics : dict\n", + " Metrics for the current model.\n", + " \"\"\"\n", + "\n", + " def network_grad_fn(params):\n", + " \"\"\"\n", + " helper grad computation\n", + " \"\"\"\n", + " traced_predictions = self.model.apply(params, batch[\"inputs\"]).sum(axis=1)\n", + " ntk_trace_values = np.mean(traced_predictions)\n", + " return ntk_trace_values \n", + "\n", + " def loss_fn(params):\n", + " \"\"\"\n", + " helper loss computation\n", + " \"\"\"\n", + " inner_predictions = self.model.apply(params, batch[\"inputs\"])\n", + " loss = self.loss_fn(inner_predictions, batch[\"targets\"])\n", + "\n", + " # Add gradient regularization\n", + " if self.regularization > 0.0:\n", + " grad = jax.grad(network_grad_fn)(params)\n", + " grad_square = jax.tree_map(lambda x: x ** 2, grad)\n", + " loss += self.regularization * jax.flatten_util.ravel_pytree(grad_square)[0].mean()\n", + "\n", + " return loss, inner_predictions\n", + "\n", + " grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n", + "\n", + " (_, predictions), grads = grad_fn(state.params)\n", + "\n", + " state = state.apply_gradients(grads=grads) # in place state update.\n", + " metrics = self._compute_metrics(\n", + " predictions=predictions, targets=batch[\"targets\"]\n", + " )\n", + "\n", + " return state, metrics\n", + " " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3547d5aa", + "metadata": {}, + "source": [ + "# Set up Model" + ] + }, + { + "cell_type": "code", + "execution_count": 334, + "id": "af442d14-0791-48cc-a9e8-aa0c5ee9f9c4", + "metadata": {}, + "outputs": [], + "source": [ + "data_generator = znnl.data.MNISTGenerator(ds_size=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 335, + "id": "11123b2a-b981-4218-98bf-47b0a2bfc271", + "metadata": {}, + "outputs": [], + "source": [ + "class Network(nn.Module):\n", + " \"\"\"\n", + " Simple CNN module.\n", + " \"\"\"\n", + " @nn.compact\n", + " def __call__(self, x): \n", + " x = x.reshape((x.shape[0], -1)) # flatten\n", + " \n", + " x = nn.Dense(features=128)(x)\n", + " x = nn.relu(x)\n", + " \n", + " x = nn.Dense(features=64)(x)\n", + " x = nn.relu(x)\n", + " x = nn.Dense(10)(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "634177cd", + "metadata": {}, + "source": [ + "# Execute training" + ] + }, + { + "cell_type": "code", + "execution_count": 343, + "id": "7936f03f-ee9b-46cb-a399-ba916cad09c2", + "metadata": {}, + "outputs": [], + "source": [ + "model = znnl.models.FlaxModel(\n", + " flax_module=Network(),\n", + " optimizer=optax.adam(learning_rate=0.01),\n", + " input_shape=(1, 28, 28, 1),\n", + " seed=0, \n", + " )\n", + "\n", + "train_recorder = znnl.training_recording.JaxRecorder(\n", + " name=\"train_recorder\",\n", + " loss=True,\n", + " ntk=True,\n", + " entropy= True, \n", + " trace=True,\n", + " loss_derivative=True,\n", + " update_rate=1, \n", + " chunk_size=1000\n", + ")\n", + "train_recorder.instantiate_recorder(\n", + " data_set=data_generator.train_ds\n", + ")\n", + "\n", + "\n", + "test_recorder = znnl.training_recording.JaxRecorder(\n", + " name=\"test_recorder\",\n", + " loss=True,\n", + " update_rate=1,\n", + " chunk_size=1000\n", + ")\n", + "test_recorder.instantiate_recorder(\n", + " data_set=data_generator.test_ds\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 370, + "id": "05e60cd9", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = RegularizedTraining(\n", + " model=model, \n", + " loss_fn=znnl.loss_functions.CrossEntropyLoss(),\n", + " accuracy_fn=znnl.accuracy_functions.LabelAccuracy(), \n", + " recorders=[train_recorder, test_recorder], \n", + " regularization=1e-2, \n", + " # regularization=0.0,\n", + " seed=0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 381, + "id": "da9ecc3f-dab4-4bc6-bd3a-35a3e5b6f855", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch: 1: 0%| | 0/100 [00:00" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(train_report.loss, 'o', mfc='None', label=\"Train\")\n", + "plt.plot(test_report.loss, 'o', mfc='None', label=\"Test\")\n", + "\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "# plt.yscale(\"log\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 384, + "id": "a6fd3a3c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(train_report.entropy, 'o', mfc='None', label=\"Entropy\")\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Entropy\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 385, + "id": "f8eea5f3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(train_report.trace, 'o', mfc='None', label=\"Trace\")\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Trace\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87c195e0", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "244cf1a9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}