diff --git a/CI/unit_tests/models/test_huggingface_flax_model.py b/CI/unit_tests/models/test_huggingface_flax_model.py new file mode 100644 index 0000000..13c2d42 --- /dev/null +++ b/CI/unit_tests/models/test_huggingface_flax_model.py @@ -0,0 +1,93 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +import optax +import pytest +from flax import linen as nn +from jax import random +from transformers import FlaxResNetForImageClassification, ResNetConfig + +from znnl.models import HuggingFaceFlaxModel + + +class TestFlaxHFModule: + """ + Test suite for the flax Hugging Face (HF) module. + """ + + @classmethod + def setup_class(cls): + """ + Create a model and data for the tests. + The resnet config has a 1 dimensional input and a 2 dimensional output. + """ + + resnet_config = ResNetConfig( + num_channels=2, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + out_features=None, + out_indices=None, + id2label=dict(zip([1, 2], [1, 2])), + return_dict=True, + ) + hf_model = FlaxResNetForImageClassification( + config=resnet_config, + input_shape=(1, 8, 8, 2), + seed=0, + _do_init=True, + ) + cls.model = HuggingFaceFlaxModel( + hf_model, + optax.adam(learning_rate=0.001), + batch_size=3, + ) + + key = random.PRNGKey(0) + cls.x = random.normal(key, (3, 2, 8, 8)) + + def test_ntk_shape(self): + """ + Test whether the NTK shape is correct. + """ + ntk = self.model.compute_ntk(self.x)["empirical"] + assert ntk.shape == (3, 3) + + def test_infinite_failure(self): + """ + Test that the call to the infinite NTK fails. + """ + with pytest.raises(NotImplementedError): + self.model.compute_ntk(self.x, infinite=True) diff --git a/examples/HuggingFace_ResNet_Implementation.ipynb b/examples/HuggingFace_ResNet_Implementation.ipynb new file mode 100644 index 0000000..c7bf52f --- /dev/null +++ b/examples/HuggingFace_ResNet_Implementation.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fbd304b1", + "metadata": {}, + "source": [ + "# Using Transformers from Huggingface\n", + "This is an example notebook of how to use Huggingface models with ZnNL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b42c9519", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# import os\n", + "# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", + "\n", + "import znnl as nl\n", + "\n", + "import numpy as np\n", + "import optax\n", + "\n", + "from znnl.models import HuggingFaceFlaxModel\n", + "\n", + "import jax\n", + "print(jax.default_backend())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dba15f7c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "data_generator = nl.data.CIFAR10Generator(2)\n", + "\n", + "# Input data needs to have shape (num_points, channels, height, width)\n", + "train_ds={\"inputs\": np.swapaxes(data_generator.train_ds[\"inputs\"], 1, 3), \"targets\": data_generator.train_ds[\"targets\"]}\n", + "test_ds={\"inputs\": np.swapaxes(data_generator.test_ds[\"inputs\"], 1, 3), \"targets\": data_generator.test_ds[\"targets\"]}\n", + "\n", + "data_generator.train_ds = train_ds\n", + "data_generator.test_ds = test_ds" + ] + }, + { + "cell_type": "markdown", + "id": "d4580ffd", + "metadata": {}, + "source": [ + "# Execute" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9392cd92", + "metadata": {}, + "outputs": [], + "source": [ + "# From scratch\n", + "\n", + "resnet_config = ResNetConfig(\n", + " num_channels = 3,\n", + " embedding_size = 24, \n", + " hidden_sizes = [12, 12, 12], \n", + " depths = [3, 4, 6], \n", + " layer_type = 'bottleneck', \n", + " hidden_act = 'relu', \n", + " downsample_in_first_stage = False, \n", + " out_features = None, \n", + " out_indices = None, \n", + " id2label = dict(zip(np.arange(10), np.arange(10))),\n", + " return_dict = True,\n", + ")\n", + "\n", + "\n", + "model = FlaxResNetForImageClassification(\n", + " config=resnet_config,\n", + " input_shape=(1, 32, 32, 3),\n", + " seed=0,\n", + " _do_init = True,\n", + ")\n", + "\n", + "znnl_model = HuggingFaceFlaxModel(\n", + " model, \n", + " optax.adamw(learning_rate=0.001),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5868f984", + "metadata": {}, + "outputs": [], + "source": [ + "train_recorder = nl.training_recording.JaxRecorder(\n", + " name=\"train_recorder\",\n", + " loss=True,\n", + " ntk=True,\n", + " covariance_entropy=True,\n", + " magnitude_variance=True, \n", + " trace=True,\n", + " loss_derivative=True,\n", + " update_rate=1\n", + ")\n", + "train_recorder.instantiate_recorder(\n", + " data_set=data_generator.train_ds\n", + ")\n", + "\n", + "trainer = nl.training_strategies.SimpleTraining(\n", + " model=znnl_model, \n", + " loss_fn=nl.loss_functions.CrossEntropyLoss(),\n", + " accuracy_fn=nl.accuracy_functions.LabelAccuracy(),\n", + " recorders=[train_recorder],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3215d048", + "metadata": {}, + "outputs": [], + "source": [ + "batch_wise_training_metrics = trainer.train_model(\n", + " train_ds=data_generator.train_ds,\n", + " test_ds=data_generator.test_ds,\n", + " batch_size=100,\n", + " epochs=50,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57f9421f", + "metadata": {}, + "outputs": [], + "source": [ + "train_report = train_recorder.gather_recording()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "355cd5d7", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93fa752a", + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(train_report.loss, label=\"loss\")\n", + "plt.plot(train_report.covariance_entropy, label=\"covariance_entropy\")\n", + "plt.plot(train_report.trace/5000, label=\"trace\")\n", + "plt.yscale(\"log\")\n", + "plt.legend()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/ResNet_Implementation.ipynb b/examples/ResNet_Implementation.ipynb deleted file mode 100644 index 7563a29..0000000 --- a/examples/ResNet_Implementation.ipynb +++ /dev/null @@ -1,959 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "fbd304b1", - "metadata": {}, - "source": [ - "# Using Transformers from Huggingface\n", - "This is an example notebook of how to use Huggingface models with ZnNL" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "b42c9519", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
Using backend: cpu\n",
-       "
\n" - ], - "text/plain": [ - "Using backend: cpu\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
Available hardware:\n",
-       "
\n" - ], - "text/plain": [ - "Available hardware:\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
TFRT_CPU_0\n",
-       "
\n" - ], - "text/plain": [ - "TFRT_CPU_0\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cpu\n" - ] - } - ], - "source": [ - "import os\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", - "\n", - "import znnl as nl\n", - "\n", - "import tensorflow_datasets as tfds\n", - "\n", - "import numpy as np\n", - "from flax import linen as nn\n", - "import optax\n", - "from transformers import ResNetConfig, FlaxResNetForImageClassification\n", - "\n", - "from flax.training import\n", - "\n", - "import jax\n", - "print(jax.default_backend())" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "dba15f7c", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "data_generator = nl.data.CIFAR10Generator(10)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "10f29532", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "znnl.data.cifar10.CIFAR10Generator" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_generator.__class__" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "e4cd2b25", - "metadata": {}, - "outputs": [], - "source": [ - "id2label = dict(zip(np.arange(10).tolist(), np.arange(20).tolist()))" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "aaf69e62", - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "ZnNL: A Zincwarecode package.\n", - "\n", - "License\n", - "-------\n", - "This program and the accompanying materials are made available under the terms\n", - "of the Eclipse Public License v2.0 which accompanies this distribution, and is\n", - "available at https://www.eclipse.org/legal/epl-v20.html\n", - "\n", - "SPDX-License-Identifier: EPL-2.0\n", - "\n", - "Copyright Contributors to the Zincwarecode Project.\n", - "\n", - "Contact Information\n", - "-------------------\n", - "email: zincwarecode@gmail.com\n", - "github: https://github.com/zincware\n", - "web: https://zincwarecode.com/\n", - "\n", - "Citation\n", - "--------\n", - "If you use this module please cite us with:\n", - "\n", - "Summary\n", - "-------\n", - "\"\"\"\n", - "from typing import Callable, Sequence, Union\n", - "\n", - "import jax\n", - "import jax.numpy as np\n", - "import jax.random\n", - "import neural_tangents as nt\n", - "import optax\n", - "from flax.training.train_state import TrainState\n", - "\n", - "from znnl.optimizers.trace_optimizer import TraceOptimizer\n", - "from znnl.utils.prng import PRNGKey\n", - "\n", - "\n", - "class HFBaseModel:\n", - " \"\"\"\n", - " Base class for huggingface models.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " model: Callable,\n", - " optimizer: Union[Callable, TraceOptimizer],\n", - " input_shape: tuple,\n", - " ntk_batch_size: int = 10,\n", - " trace_axes: Union[int, Sequence[int]] = (-1,),\n", - " ):\n", - " \"\"\"\n", - " Construct a znrnd model.\n", - " Parameters\n", - " ----------\n", - " optimizer : Callable\n", - " optimizer to use in the training. OpTax is used by default and\n", - " cross-compatibility is not assured.\n", - " input_shape : tuple\n", - " Shape of the NN input.\n", - " seed : int, default None\n", - " Random seed for the RNG. Uses a random int if not specified.\n", - " ntk_batch_size : int, default 10\n", - " Batch size to use in the NTK computation.\n", - " trace_axes : Union[int, Sequence[int]]\n", - " Tracing over axes of the NTK.\n", - " The default value is trace_axes(-1,), which reduces the NTK to a tensor\n", - " of rank 2.\n", - " For a full NTK set trace_axes=().\n", - " \"\"\"\n", - " self.apply_fn = model.__call__\n", - " self.params = model.params\n", - " \n", - " self.optimizer = optimizer\n", - " self.input_shape = input_shape\n", - "\n", - " # Initialized in self.init_model\n", - " self.rng = None\n", - "\n", - " # initialize the model state\n", - " self.model_state = self._create_train_state()\n", - "\n", - " # Prepare NTK calculation\n", - " self.empirical_ntk = nt.batch(\n", - " nt.empirical_ntk_fn(f=self._ntk_apply_fn, trace_axes=trace_axes),\n", - " batch_size=ntk_batch_size,\n", - " )\n", - " self.empirical_ntk_jit = jax.jit(self.empirical_ntk)\n", - "\n", - " def _create_train_state(self) -> TrainState:\n", - " \"\"\"\n", - " Create a training state of the model.\n", - " Returns\n", - " -------\n", - " initial state of model to then be trained.\n", - " Notes\n", - " -----\n", - " TODO: Make the TrainState class passable by the user as it can track custom\n", - " model properties.\n", - " \"\"\"\n", - " # Set dummy optimizer for case of trace optimizer.\n", - " if isinstance(self.optimizer, TraceOptimizer):\n", - " optimizer = optax.sgd(1.0)\n", - " else:\n", - " optimizer = self.optimizer\n", - "\n", - " return TrainState.create(apply_fn=self.apply_fn, params=self.params, tx=optimizer)\n", - "\n", - " def _ntk_apply_fn(self, params: dict, inputs: np.ndarray):\n", - " \"\"\"\n", - " Apply function used in the NTK computation.\n", - " Parameters\n", - " ----------\n", - " params: dict\n", - " Contains the model parameters to use for the model computation.\n", - " inputs : np.ndarray\n", - " Feature vector on which to apply the model.\n", - " Returns\n", - " -------\n", - " The apply function used in the NTK computation.\n", - " \"\"\"\n", - " raise NotImplementedError(\"Implemented in child class\")\n", - "\n", - " def compute_ntk(\n", - " self,\n", - " x_i: np.ndarray,\n", - " x_j: np.ndarray = None,\n", - " infinite: bool = False,\n", - " ):\n", - " \"\"\"\n", - " Compute the NTK matrix for the model.\n", - " Parameters\n", - " ----------\n", - " x_i : np.ndarray\n", - " Dataset for which to compute the NTK matrix.\n", - " x_j : np.ndarray (optional)\n", - " Dataset for which to compute the NTK matrix.\n", - " infinite : bool (default = False)\n", - " If true, compute the infinite width limit as well.\n", - " Returns\n", - " -------\n", - " NTK : dict\n", - " The NTK matrix for both the empirical and infinite width computation.\n", - " \"\"\"\n", - " if x_j is None:\n", - " x_j = x_i\n", - " empirical_ntk = self.empirical_ntk_jit(x_i, x_j, self.model_state.params)\n", - "\n", - " if infinite:\n", - " try:\n", - " infinite_ntk = self.kernel_fn(x_i, x_j, \"ntk\")\n", - " except AttributeError:\n", - " raise NotImplementedError(\"Infinite NTK not available for this model.\")\n", - " else:\n", - " infinite_ntk = None\n", - "\n", - " return {\"empirical\": empirical_ntk, \"infinite\": infinite_ntk}\n", - " \n", - " def _apply_fn(self, feature_vector: np.ndarray):\n", - " \"\"\"\n", - " Apply the model.\n", - " Parameters\n", - " ----------\n", - " feature_vector : np.ndarray\n", - " Feature vector on which to apply operation.\n", - " Returns\n", - " -------\n", - " output of the model.\n", - " \"\"\"\n", - " raise \n", - "\n", - " def __call__(self, feature_vector: np.ndarray):\n", - " \"\"\"\n", - " Call the network.\n", - " Parameters\n", - " ----------\n", - " feature_vector : np.ndarray\n", - " Feature vector on which to apply operation.\n", - " Returns\n", - " -------\n", - " output of the model.\n", - " \"\"\"\n", - " return self.apply(self.model_state.params, feature_vector)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "cfd1644a", - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "ZnNL: A Zincwarecode package.\n", - "\n", - "License\n", - "-------\n", - "This program and the accompanying materials are made available under the terms\n", - "of the Eclipse Public License v2.0 which accompanies this distribution, and is\n", - "available at https://www.eclipse.org/legal/epl-v20.html\n", - "\n", - "SPDX-License-Identifier: EPL-2.0\n", - "\n", - "Copyright Contributors to the Zincwarecode Project.\n", - "\n", - "Contact Information\n", - "-------------------\n", - "email: zincwarecode@gmail.com\n", - "github: https://github.com/zincware\n", - "web: https://zincwarecode.com/\n", - "\n", - "Citation\n", - "--------\n", - "If you use this module please cite us with:\n", - "\n", - "Summary\n", - "-------\n", - "\"\"\"\n", - "import logging\n", - "from typing import Callable, List, Sequence, Union\n", - "\n", - "import jax\n", - "import jax.numpy as np\n", - "from flax import linen as nn\n", - "\n", - "from znnl.models.jax_model import JaxModel\n", - "\n", - "logger = logging.getLogger(__name__)\n", - "\n", - "\n", - "class HFModel(HFBaseModel):\n", - " \"\"\"\n", - " Class for the Flax model in ZnRND.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " model: Callable,\n", - " optimizer: Callable,\n", - " input_shape: tuple,\n", - " batch_size: int = 10,\n", - " trace_axes: Union[int, Sequence[int]] = (-1,),\n", - " ):\n", - " \"\"\"\n", - " Construct a Flax model.\n", - "\n", - " Parameters\n", - " ----------\n", - " layer_stack : List[nn.Module]\n", - " A list of flax modules to be used in the call method.\n", - " optimizer : Callable\n", - " optimizer to use in the training. OpTax is used by default and\n", - " cross-compatibility is not assured.\n", - " input_shape : tuple\n", - " Shape of the NN input.\n", - " batch_size : int\n", - " Size of batch to use in the NTk calculation.\n", - " flax_module : nn.Module\n", - " Flax module to use instead of building one from scratch here.\n", - " trace_axes : Union[int, Sequence[int]]\n", - " Tracing over axes of the NTK.\n", - " The default value is trace_axes(-1,), which reduces the NTK to a tensor\n", - " of rank 2.\n", - " For a full NTK set trace_axes=().\n", - " seed : int, default None\n", - " Random seed for the RNG. Uses a random int if not specified.\n", - " \"\"\"\n", - " logger.info(\n", - " \"Flax models have occasionally experienced memory allocation issues on \"\n", - " \"GPU. This is an ongoing bug that we are striving to fix soon.\"\n", - " )\n", - "\n", - " self.apply_fn = jax.jit(model.__call__)\n", - "\n", - " # Save input parameters, call self.init_model\n", - " super().__init__(\n", - " model=model,\n", - " optimizer=optimizer,\n", - " input_shape=input_shape,\n", - " trace_axes=trace_axes,\n", - " ntk_batch_size=batch_size,\n", - " )\n", - "\n", - " def _ntk_apply_fn(self, params, inputs: np.ndarray):\n", - " \"\"\"\n", - " Return an NTK capable apply function.\n", - "\n", - " Parameters\n", - " ----------\n", - " params : dict\n", - " Network parameters to use in the calculation.\n", - " inputs : np.ndarray\n", - " Data on which to apply the network\n", - "\n", - " Returns\n", - " -------\n", - " Acts on the data with the model architecture and parameter set.\n", - " \"\"\"\n", - " return self.model_state.apply_fn({\"params\": params}, inputs, mutable=[\"batch_stats\"])[0]\n", - "\n", - "\n", - " def apply(self, params: dict, inputs: np.ndarray):\n", - " \"\"\"Apply the model to a feature vector.\n", - "\n", - " Parameters\n", - " ----------\n", - " params: dict\n", - " Contains the model parameters to use for the model computation.\n", - " inputs : np.ndarray\n", - " Feature vector on which to apply the model.\n", - "\n", - " Returns\n", - " -------\n", - " Output of the model.\n", - " \"\"\"\n", - " return self.model_state.apply_fn(inputs, params=params).logits\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d3ed757e", - "metadata": {}, - "outputs": [], - "source": [ - "# Write HF model as Jax model\n", - "\n", - "\"\"\"\n", - "ZnNL: A Zincwarecode package.\n", - "\n", - "License\n", - "-------\n", - "This program and the accompanying materials are made available under the terms\n", - "of the Eclipse Public License v2.0 which accompanies this distribution, and is\n", - "available at https://www.eclipse.org/legal/epl-v20.html\n", - "\n", - "SPDX-License-Identifier: EPL-2.0\n", - "\n", - "Copyright Contributors to the Zincwarecode Project.\n", - "\n", - "Contact Information\n", - "-------------------\n", - "email: zincwarecode@gmail.com\n", - "github: https://github.com/zincware\n", - "web: https://zincwarecode.com/\n", - "\n", - "Citation\n", - "--------\n", - "If you use this module please cite us with:\n", - "\n", - "Summary\n", - "-------\n", - "\"\"\"\n", - "import logging\n", - "from typing import Callable, List, Sequence, Union\n", - "\n", - "import jax\n", - "import jax.numpy as np\n", - "from flax import linen as nn\n", - "\n", - "from znnl.models.jax_model import JaxModel\n", - "\n", - "logger = logging.getLogger(__name__)\n", - "\n", - "\n", - "\n", - "class HFFlaxModel(JaxModel):\n", - " \"\"\"\n", - " Class for the Flax model in ZnRND.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " optimizer: Callable,\n", - " input_shape: tuple,\n", - " batch_size: int = 10,\n", - " layer_stack: List[nn.Module] = None,\n", - " flax_module: nn.Module = None,\n", - " trace_axes: Union[int, Sequence[int]] = (-1,),\n", - " seed: int = None,\n", - " ):\n", - " \"\"\"\n", - " Construct a Flax model.\n", - "\n", - " Parameters\n", - " ----------\n", - " layer_stack : List[nn.Module]\n", - " A list of flax modules to be used in the call method.\n", - " optimizer : Callable\n", - " optimizer to use in the training. OpTax is used by default and\n", - " cross-compatibility is not assured.\n", - " input_shape : tuple\n", - " Shape of the NN input.\n", - " batch_size : int\n", - " Size of batch to use in the NTk calculation.\n", - " flax_module : nn.Module\n", - " Flax module to use instead of building one from scratch here.\n", - " trace_axes : Union[int, Sequence[int]]\n", - " Tracing over axes of the NTK.\n", - " The default value is trace_axes(-1,), which reduces the NTK to a tensor\n", - " of rank 2.\n", - " For a full NTK set trace_axes=().\n", - " seed : int, default None\n", - " Random seed for the RNG. Uses a random int if not specified.\n", - " \"\"\"\n", - " logger.info(\n", - " \"Flax models have occasionally experienced memory allocation issues on \"\n", - " \"GPU. This is an ongoing bug that we are striving to fix soon.\"\n", - " )\n", - " if layer_stack is not None:\n", - " self.model = FundamentalModel(layer_stack)\n", - " if flax_module is not None:\n", - " self.model = flax_module\n", - " if layer_stack is None and flax_module is None:\n", - " raise TypeError(\"Provide either a Flax nn.Module or a layer stack.\")\n", - "\n", - " self.apply_fn = jax.jit(self.model.apply)\n", - "\n", - " # Save input parameters, call self.init_model\n", - " super().__init__(\n", - " optimizer=optimizer,\n", - " input_shape=input_shape,\n", - " seed=seed,\n", - " trace_axes=trace_axes,\n", - " ntk_batch_size=batch_size,\n", - " )\n", - "\n", - " def _ntk_apply_fn(self, params, inputs: np.ndarray):\n", - " \"\"\"\n", - " Return an NTK capable apply function.\n", - "\n", - " Parameters\n", - " ----------\n", - " params : dict\n", - " Network parameters to use in the calculation.\n", - " inputs : np.ndarray\n", - " Data on which to apply the network\n", - "\n", - " Returns\n", - " -------\n", - " Acts on the data with the model architecture and parameter set.\n", - " \"\"\"\n", - " return self.model.apply({\"params\": params}, inputs, mutable=[\"batch_stats\"])[0]\n", - "\n", - " def _init_params(self, kernel_init: Callable = None, bias_init: Callable = None):\n", - " \"\"\"Initialize a state for the model parameters.\n", - "\n", - " Parameters\n", - " ----------\n", - " kernel_init : Callable\n", - " Define the kernel initialization.\n", - " bias_init : Callable\n", - " Define the bias initialization.\n", - "\n", - " Returns\n", - " -------\n", - " Initial state for the model parameters.\n", - " \"\"\"\n", - " if kernel_init:\n", - " self.model.kernel_init = kernel_init\n", - " if bias_init:\n", - " self.model.bias_init = bias_init\n", - "\n", - " params = self.model.init(self.rng(), np.ones(list(self.input_shape)))[\"params\"]\n", - "\n", - " return params\n", - "\n", - " def apply(self, params: dict, inputs: np.ndarray):\n", - " \"\"\"Apply the model to a feature vector.\n", - "\n", - " Parameters\n", - " ----------\n", - " params: dict\n", - " Contains the model parameters to use for the model computation.\n", - " inputs : np.ndarray\n", - " Feature vector on which to apply the model.\n", - "\n", - " Returns\n", - " -------\n", - " Output of the model.\n", - " \"\"\"\n", - " return self.apply_fn({\"params\": params}, inputs)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "9392cd92", - "metadata": {}, - "outputs": [], - "source": [ - "# From scratch\n", - "\n", - "resnet_config = ResNetConfig(\n", - " num_channels = 3,\n", - " embedding_size = 64, \n", - " hidden_sizes = [256, 512, 1024, 2048], \n", - " depths = [3, 4, 6, 3], \n", - " layer_type = 'bottleneck', \n", - " hidden_act = 'relu', \n", - " downsample_in_first_stage = False, \n", - " out_features = None, \n", - " out_indices = None, \n", - " id2label = id2label,\n", - ")\n", - "\n", - "\n", - "model = FlaxResNetForImageClassification(\n", - " config=resnet_config,\n", - " input_shape=(1, 32, 32, 3),\n", - " seed=0,\n", - " _do_init = True,\n", - ")\n", - "\n", - "znnl_model = HFModel(\n", - " model, \n", - " optax.adam(learning_rate=0.01),\n", - " input_shape=(1, 32, 32, 3), \n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "bbfb7773", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[ 0.05988872, -0.20056242, -4.595742 , -1.6283834 , 4.054898 ,\n", - " -0.93490016, 0.42661703, 1.5616785 , 0.32605764, 1.8218877 ]], dtype=float32)" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_input = np.swapaxes(data_generator.train_ds['inputs'][:1], 1, 3)\n", - "\n", - "znnl_model(test_input)" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "00faca50", - "metadata": {}, - "outputs": [], - "source": [ - "from znnl.loss_functions.cross_entropy_loss import CrossEntropyDistance\n", - "import optax\n", - "\n", - "def loss_fn(prediction, target): return optax.softmax_cross_entropy(logits=prediction, labels=target)\n", - "\n", - "\n", - "def vmapped_cross_entropy_loss(inputs, targets):\n", - " mapped_loss = jax.vmap(loss_fn, in_axes=(0, 0))(inputs, targets)\n", - " return mapped_loss.mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "5868f984", - "metadata": {}, - "outputs": [], - "source": [ - "def cross_entropy_loss(logits, labels):\n", - " return -np.mean(np.sum(labels * logits, axis=-1))\n", - "\n", - "trainer = nl.training_strategies.SimpleTraining(\n", - " model=znnl_model, \n", - " # loss_fn=vmapped_cross_entropy_loss,\n", - " loss_fn=nl.loss_functions.CrossEntropyLoss(),\n", - " # loss_fn=nl.loss_functions.MeanPowerLoss(order=2),\n", - " accuracy_fn=nl.accuracy_functions.LabelAccuracy(),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "3215d048", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/2 [00:00 TrainState: + def _create_train_state(self, params: dict) -> TrainState: """ Create a training state of the model. Returns @@ -119,8 +132,6 @@ def _create_train_state( TODO: Make the TrainState class passable by the user as it can track custom model properties. """ - params = self._init_params(kernel_init, bias_init) - # Set dummy optimizer for case of trace optimizer. if isinstance(self.optimizer, TraceOptimizer): optimizer = optax.sgd(1.0)