Skip to content

Commit

Permalink
Implement a flax ResNet from HuggingFace
Browse files Browse the repository at this point in the history
- Make small changes in the JaxModel Class to allow to resnet implementation
- write huggingface Flax implementation
- test the NTK calculation

Todo:
- test for models beyond resnets
- update example script
  • Loading branch information
knikolaou committed Nov 2, 2023
1 parent 49b1395 commit 1b5efc8
Show file tree
Hide file tree
Showing 6 changed files with 459 additions and 974 deletions.
93 changes: 93 additions & 0 deletions CI/unit_tests/models/test_huggingface_flax_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
"""
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import jax.numpy as np
import optax
import pytest
from flax import linen as nn
from jax import random
from transformers import FlaxResNetForImageClassification, ResNetConfig

from znnl.models import HuggingFaceFlaxModel


class TestFlaxHFModule:
"""
Test suite for the flax Hugging Face (HF) module.
"""

@classmethod
def setup_class(cls):
"""
Create a model and data for the tests.
The resnet config has a 1 dimensional input and a 2 dimensional output.
"""

resnet_config = ResNetConfig(
num_channels=2,
embedding_size=64,
hidden_sizes=[256, 512, 1024, 2048],
depths=[3, 4, 6, 3],
layer_type="bottleneck",
hidden_act="relu",
downsample_in_first_stage=False,
out_features=None,
out_indices=None,
id2label=dict(zip([1, 2], [1, 2])),
return_dict=True,
)
hf_model = FlaxResNetForImageClassification(
config=resnet_config,
input_shape=(1, 8, 8, 2),
seed=0,
_do_init=True,
)
cls.model = HuggingFaceFlaxModel(
hf_model,
optax.adam(learning_rate=0.001),
batch_size=3,
)

key = random.PRNGKey(0)
cls.x = random.normal(key, (3, 2, 8, 8))

def test_ntk_shape(self):
"""
Test whether the NTK shape is correct.
"""
ntk = self.model.compute_ntk(self.x)["empirical"]
assert ntk.shape == (3, 3)

def test_infinite_failure(self):
"""
Test that the call to the infinite NTK fails.
"""
with pytest.raises(NotImplementedError):
self.model.compute_ntk(self.x, infinite=True)
200 changes: 200 additions & 0 deletions examples/HuggingFace_ResNet_Implementation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fbd304b1",
"metadata": {},
"source": [
"# Using Transformers from Huggingface\n",
"This is an example notebook of how to use Huggingface models with ZnNL"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b42c9519",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# import os\n",
"# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n",
"\n",
"import znnl as nl\n",
"\n",
"import numpy as np\n",
"import optax\n",
"\n",
"from znnl.models import HuggingFaceFlaxModel\n",
"\n",
"import jax\n",
"print(jax.default_backend())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dba15f7c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"data_generator = nl.data.CIFAR10Generator(2)\n",
"\n",
"# Input data needs to have shape (num_points, channels, height, width)\n",
"train_ds={\"inputs\": np.swapaxes(data_generator.train_ds[\"inputs\"], 1, 3), \"targets\": data_generator.train_ds[\"targets\"]}\n",
"test_ds={\"inputs\": np.swapaxes(data_generator.test_ds[\"inputs\"], 1, 3), \"targets\": data_generator.test_ds[\"targets\"]}\n",
"\n",
"data_generator.train_ds = train_ds\n",
"data_generator.test_ds = test_ds"
]
},
{
"cell_type": "markdown",
"id": "d4580ffd",
"metadata": {},
"source": [
"# Execute"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9392cd92",
"metadata": {},
"outputs": [],
"source": [
"# From scratch\n",
"\n",
"resnet_config = ResNetConfig(\n",
" num_channels = 3,\n",
" embedding_size = 24, \n",
" hidden_sizes = [12, 12, 12], \n",
" depths = [3, 4, 6], \n",
" layer_type = 'bottleneck', \n",
" hidden_act = 'relu', \n",
" downsample_in_first_stage = False, \n",
" out_features = None, \n",
" out_indices = None, \n",
" id2label = dict(zip(np.arange(10), np.arange(10))),\n",
" return_dict = True,\n",
")\n",
"\n",
"\n",
"model = FlaxResNetForImageClassification(\n",
" config=resnet_config,\n",
" input_shape=(1, 32, 32, 3),\n",
" seed=0,\n",
" _do_init = True,\n",
")\n",
"\n",
"znnl_model = HuggingFaceFlaxModel(\n",
" model, \n",
" optax.adamw(learning_rate=0.001),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5868f984",
"metadata": {},
"outputs": [],
"source": [
"train_recorder = nl.training_recording.JaxRecorder(\n",
" name=\"train_recorder\",\n",
" loss=True,\n",
" ntk=True,\n",
" covariance_entropy=True,\n",
" magnitude_variance=True, \n",
" trace=True,\n",
" loss_derivative=True,\n",
" update_rate=1\n",
")\n",
"train_recorder.instantiate_recorder(\n",
" data_set=data_generator.train_ds\n",
")\n",
"\n",
"trainer = nl.training_strategies.SimpleTraining(\n",
" model=znnl_model, \n",
" loss_fn=nl.loss_functions.CrossEntropyLoss(),\n",
" accuracy_fn=nl.accuracy_functions.LabelAccuracy(),\n",
" recorders=[train_recorder],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3215d048",
"metadata": {},
"outputs": [],
"source": [
"batch_wise_training_metrics = trainer.train_model(\n",
" train_ds=data_generator.train_ds,\n",
" test_ds=data_generator.test_ds,\n",
" batch_size=100,\n",
" epochs=50,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57f9421f",
"metadata": {},
"outputs": [],
"source": [
"train_report = train_recorder.gather_recording()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "355cd5d7",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93fa752a",
"metadata": {},
"outputs": [],
"source": [
"plt.plot(train_report.loss, label=\"loss\")\n",
"plt.plot(train_report.covariance_entropy, label=\"covariance_entropy\")\n",
"plt.plot(train_report.trace/5000, label=\"trace\")\n",
"plt.yscale(\"log\")\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 1b5efc8

Please sign in to comment.