diff --git a/examples/trace_regularization.ipynb b/examples/trace_regularization.ipynb
new file mode 100644
index 0000000..374888d
--- /dev/null
+++ b/examples/trace_regularization.ipynb
@@ -0,0 +1,486 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "a85c316a-b950-448e-a264-4d0e6983790e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-09-13 14:27:14.105479: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
+ "2023-09-13 14:27:16.852688: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
+ "2023-09-13 14:27:16.932821: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
+ "2023-09-13 14:27:26.407757: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
+ "/tikhome/knikolaou/miniconda3/envs/jax/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "2023-09-13 14:27:50.219477: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
+ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
Using backend: cpu\n",
+ " \n"
+ ],
+ "text/plain": [
+ "Using backend: cpu\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Available hardware:\n",
+ " \n"
+ ],
+ "text/plain": [
+ "Available hardware:\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "TFRT_CPU_0\n",
+ " \n"
+ ],
+ "text/plain": [
+ "TFRT_CPU_0\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import os\n",
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n",
+ "\n",
+ "import jax\n",
+ "import znnl\n",
+ "from neural_tangents import stax\n",
+ "import copy\n",
+ "import optax\n",
+ "\n",
+ "from flax import linen as nn\n",
+ "import flax\n",
+ "import jax.nn.initializers as inits\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import numpy as onp\n",
+ "import jax.numpy as np\n",
+ "# import time"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4d5500f3",
+ "metadata": {},
+ "source": [
+ "# Including the trace regularization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "afc55b14",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from flax.training.train_state import TrainState\n",
+ "from znnl.training_strategies import SimpleTraining\n",
+ "from typing import Callable, List, Optional, Tuple, Union\n",
+ "from znnl.accuracy_functions.accuracy_function import AccuracyFunction\n",
+ "from znnl.models.jax_model import JaxModel\n",
+ "from znnl.training_strategies.recursive_mode import RecursiveMode"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 333,
+ "id": "a8eea84c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " class RegularizedTraining(SimpleTraining):\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model: Union[JaxModel, None],\n",
+ " loss_fn: Callable,\n",
+ " accuracy_fn: AccuracyFunction = None,\n",
+ " seed: int = None,\n",
+ " recursive_mode: RecursiveMode = None,\n",
+ " disable_loading_bar: bool = False,\n",
+ " recorders: List[\"JaxRecorder\"] = None,\n",
+ " regularization: float = 0.0, \n",
+ " ):\n",
+ " \"\"\"\n",
+ " Construct a simple training strategy for a model.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " model : Union[JaxModel, None]\n",
+ " Model class for a Jax model.\n",
+ " \"None\" is only used if the training strategy is passed as an input\n",
+ " to a bigger framework. The strategy then is applied to the framework\n",
+ " and the model instantiation is handled by that framework.\n",
+ " loss_fn : Callable\n",
+ " A function to use in the loss computation.\n",
+ " accuracy_fn : AccuracyFunction (default = None)\n",
+ " Funktion class for computing the accuracy of model and given data.\n",
+ " seed : int (default = None)\n",
+ " Random seed for the RNG. Uses a random int if not specified.\n",
+ " recursive_mode : RecursiveMode\n",
+ " Defining the recursive mode that can be used in training.\n",
+ " If the recursive mode is used, the training will be performed until a\n",
+ " condition is fulfilled.\n",
+ " disable_loading_bar : bool\n",
+ " Disable the output visualization of the loading bar.\n",
+ " recorders : List[JaxRecorder]\n",
+ " A list of recorders to monitor model training.\n",
+ " \"\"\"\n",
+ " super().__init__(\n",
+ " model=model,\n",
+ " loss_fn=loss_fn,\n",
+ " accuracy_fn=accuracy_fn,\n",
+ " seed=seed,\n",
+ " recursive_mode=recursive_mode,\n",
+ " disable_loading_bar=disable_loading_bar,\n",
+ " recorders=recorders,\n",
+ " )\n",
+ " self.regularization = regularization\n",
+ "\n",
+ " \n",
+ " def _train_step(self, state: TrainState, batch: dict):\n",
+ " \"\"\"\n",
+ " Train a single step.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " state : TrainState\n",
+ " Current state of the neural network.\n",
+ " batch : dict\n",
+ " Batch of data to train on.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " state : dict\n",
+ " Updated state of the neural network.\n",
+ " metrics : dict\n",
+ " Metrics for the current model.\n",
+ " \"\"\"\n",
+ "\n",
+ " def network_grad_fn(params):\n",
+ " \"\"\"\n",
+ " helper grad computation\n",
+ " \"\"\"\n",
+ " traced_predictions = self.model.apply(params, batch[\"inputs\"]).sum(axis=1)\n",
+ " ntk_trace_values = np.mean(traced_predictions)\n",
+ " return ntk_trace_values \n",
+ "\n",
+ " def loss_fn(params):\n",
+ " \"\"\"\n",
+ " helper loss computation\n",
+ " \"\"\"\n",
+ " inner_predictions = self.model.apply(params, batch[\"inputs\"])\n",
+ " loss = self.loss_fn(inner_predictions, batch[\"targets\"])\n",
+ "\n",
+ " # Add gradient regularization\n",
+ " if self.regularization > 0.0:\n",
+ " grad = jax.grad(network_grad_fn)(params)\n",
+ " grad_square = jax.tree_map(lambda x: x ** 2, grad)\n",
+ " loss += self.regularization * jax.flatten_util.ravel_pytree(grad_square)[0].mean()\n",
+ "\n",
+ " return loss, inner_predictions\n",
+ "\n",
+ " grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
+ "\n",
+ " (_, predictions), grads = grad_fn(state.params)\n",
+ "\n",
+ " state = state.apply_gradients(grads=grads) # in place state update.\n",
+ " metrics = self._compute_metrics(\n",
+ " predictions=predictions, targets=batch[\"targets\"]\n",
+ " )\n",
+ "\n",
+ " return state, metrics\n",
+ " "
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "3547d5aa",
+ "metadata": {},
+ "source": [
+ "# Set up Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 334,
+ "id": "af442d14-0791-48cc-a9e8-aa0c5ee9f9c4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data_generator = znnl.data.MNISTGenerator(ds_size=50)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 335,
+ "id": "11123b2a-b981-4218-98bf-47b0a2bfc271",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Network(nn.Module):\n",
+ " \"\"\"\n",
+ " Simple CNN module.\n",
+ " \"\"\"\n",
+ " @nn.compact\n",
+ " def __call__(self, x): \n",
+ " x = x.reshape((x.shape[0], -1)) # flatten\n",
+ " \n",
+ " x = nn.Dense(features=128)(x)\n",
+ " x = nn.relu(x)\n",
+ " \n",
+ " x = nn.Dense(features=64)(x)\n",
+ " x = nn.relu(x)\n",
+ " x = nn.Dense(10)(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "634177cd",
+ "metadata": {},
+ "source": [
+ "# Execute training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 343,
+ "id": "7936f03f-ee9b-46cb-a399-ba916cad09c2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = znnl.models.FlaxModel(\n",
+ " flax_module=Network(),\n",
+ " optimizer=optax.adam(learning_rate=0.01),\n",
+ " input_shape=(1, 28, 28, 1),\n",
+ " seed=0, \n",
+ " )\n",
+ "\n",
+ "train_recorder = znnl.training_recording.JaxRecorder(\n",
+ " name=\"train_recorder\",\n",
+ " loss=True,\n",
+ " ntk=True,\n",
+ " entropy= True, \n",
+ " trace=True,\n",
+ " loss_derivative=True,\n",
+ " update_rate=1, \n",
+ " chunk_size=1000\n",
+ ")\n",
+ "train_recorder.instantiate_recorder(\n",
+ " data_set=data_generator.train_ds\n",
+ ")\n",
+ "\n",
+ "\n",
+ "test_recorder = znnl.training_recording.JaxRecorder(\n",
+ " name=\"test_recorder\",\n",
+ " loss=True,\n",
+ " update_rate=1,\n",
+ " chunk_size=1000\n",
+ ")\n",
+ "test_recorder.instantiate_recorder(\n",
+ " data_set=data_generator.test_ds\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 370,
+ "id": "05e60cd9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = RegularizedTraining(\n",
+ " model=model, \n",
+ " loss_fn=znnl.loss_functions.CrossEntropyLoss(),\n",
+ " accuracy_fn=znnl.accuracy_functions.LabelAccuracy(), \n",
+ " recorders=[train_recorder, test_recorder], \n",
+ " regularization=1e-2, \n",
+ " # regularization=0.0,\n",
+ " seed=0\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 381,
+ "id": "da9ecc3f-dab4-4bc6-bd3a-35a3e5b6f855",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Epoch: 1: 0%| | 0/100 [00:00, ?batch/s]/tmp/ipykernel_503965/4271076503.py:88: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.\n",
+ " grad_square = jax.tree_map(lambda x: x ** 2, grad)\n",
+ "Epoch: 100: 100%|███████████████████████████████| 100/100 [00:17<00:00, 5.77batch/s, accuracy=0.54]\n"
+ ]
+ }
+ ],
+ "source": [
+ "batched_training_metrics = trainer.train_model(\n",
+ " train_ds=data_generator.train_ds, \n",
+ " test_ds=data_generator.test_ds,\n",
+ " batch_size=32,\n",
+ " epochs=100,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 382,
+ "id": "4af69dd9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_report = train_recorder.gather_recording()\n",
+ "test_report = test_recorder.gather_recording()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 386,
+ "id": "e0d7d0fc",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.plot(train_report.loss, 'o', mfc='None', label=\"Train\")\n",
+ "plt.plot(test_report.loss, 'o', mfc='None', label=\"Test\")\n",
+ "\n",
+ "plt.xlabel(\"Epoch\")\n",
+ "plt.ylabel(\"Loss\")\n",
+ "# plt.yscale(\"log\")\n",
+ "plt.legend()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 384,
+ "id": "a6fd3a3c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.plot(train_report.entropy, 'o', mfc='None', label=\"Entropy\")\n",
+ "plt.xlabel(\"Epoch\")\n",
+ "plt.ylabel(\"Entropy\")\n",
+ "plt.legend()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 385,
+ "id": "f8eea5f3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.plot(train_report.trace, 'o', mfc='None', label=\"Trace\")\n",
+ "plt.xlabel(\"Epoch\")\n",
+ "plt.ylabel(\"Trace\")\n",
+ "plt.legend()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "87c195e0",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "244cf1a9",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "jax",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}