Skip to content

Commit

Permalink
Konsti add magnitude variance (#101)
Browse files Browse the repository at this point in the history
* Add magnitude variance of the NTK to the recorder
Also add a test

* Add Magnitude Variance to example notebook on Collective Variables

* clear outputs of example notebook on CVs

* Update README.rst (#100)

Fix spelling mistake in readme

* Sam tov new data generators (#102)

* Add new data generators

* Update pre-commit hooks and reformat.

* remove protobuf from requirements.txt

* Specify NT version

* Add tests for generators and add MPG + Abalone

* Update comment

* training recording.
make CVs be uniformly stored as jax.numpy arrays.

---------

Co-authored-by: knikolaou <>
Co-authored-by: Samuel Tovey <[email protected]>
  • Loading branch information
KonstiNik and SamTov authored Sep 28, 2023
1 parent cb5dc33 commit 9e68f62
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 4 deletions.
28 changes: 28 additions & 0 deletions CI/unit_tests/training_recording/test_training_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_instantiation(self):
magnitude_ntk=True,
entropy=True,
magnitude_entropy=True,
magnitude_variance=True,
covariance_entropy=True,
eigenvalues=True,
trace=True,
Expand Down Expand Up @@ -99,6 +100,7 @@ def test_data_dump(self):
magnitude_ntk=True,
entropy=False,
magnitude_entropy=False,
magnitude_variance=False,
covariance_entropy=False,
eigenvalues=False,
)
Expand Down Expand Up @@ -135,3 +137,29 @@ def test_overwriting(self):
# Test overwriting.
recorder.instantiate_recorder(data_set=self.dummy_data_set, overwrite=True)
assert recorder._ntk_array == []

def test_magnitude_variance(self):
"""
Test the magnitude variance function.
"""
recorder = JaxRecorder(
loss=False,
accuracy=False,
ntk=False,
entropy=False,
magnitude_variance=True,
eigenvalues=False,
)
recorder.instantiate_recorder(data_set=self.dummy_data_set)

# Create some test data.
data = onp.random.uniform(1.0, 2.0, size=(100))
ntk = onp.eye(100) * data
# calculate the magnitude variance
recorder._update_magnitude_variance(parsed_data={"ntk": ntk})
# calculate the expected variance
expected_variance = onp.var(onp.sqrt(data) / onp.sqrt(data).mean())
# check that the variance is correct
testing.assert_almost_equal(
recorder._magnitude_variance_array, expected_variance
)
29 changes: 26 additions & 3 deletions examples/Computing-Collective-Variables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@
" name=\"train_recorder\",\n",
" loss=True,\n",
" ntk=True,\n",
" entropy= True, \n",
" covariance_entropy=True,\n",
" magnitude_variance=True, \n",
" trace=True,\n",
" loss_derivative=True,\n",
" update_rate=1\n",
Expand Down Expand Up @@ -303,13 +304,27 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(train_report.entropy, 'o', mfc='None', label=\"Entropy\")\n",
"plt.plot(train_report.covariance_entropy, 'o', mfc='None', label=\"Entropy\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Entropy\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8772e69",
"metadata": {},
"outputs": [],
"source": [
"plt.plot(train_report.magnitude_variance, 'o', mfc='None', label=\"Magnitude Variance\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Magnitude Variance\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -342,6 +357,14 @@
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d28033be",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -360,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.5"
}
},
"nbformat": 4,
Expand Down
39 changes: 38 additions & 1 deletion znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from os import path
from pathlib import Path

import jax.numpy as np
import numpy as onp

from znnl.accuracy_functions.accuracy_function import AccuracyFunction
Expand Down Expand Up @@ -82,6 +83,9 @@ class JaxRecorder:
covariance_entropy : bool (default=False)
If true, the entropy of the covariance ntk will be recorded.
Warning, large overhead.
magnitude_variance : bool (default=False)
If true, the variance of the gradient magnitudes of the ntk will be
recorded.
magnitude_entropy : bool (default=False)
If true, the entropy of the gradient magnitudes of the ntk will be recorded.
Warning, large overhead.
Expand Down Expand Up @@ -132,6 +136,10 @@ class JaxRecorder:
covariance_entropy: bool = False
_covariance_entropy_array: list = None

# Magnitude Variance of the model
magnitude_variance: bool = False
_magnitude_variance_array: list = None

# Magnitude Entropy of the model
magnitude_entropy: bool = False
_magnitude_entropy_array: list = None
Expand Down Expand Up @@ -252,6 +260,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
Expand Down Expand Up @@ -308,6 +317,7 @@ def update_recorder(self, epoch: int, model: JaxModel):
self.magnitude_ntk = False
self.entropy = False
self.magnitude_entropy = False
self.magnitude_variance = False
self.covariance_entropy = False
self.eigenvalues = False
self._read_selected_attributes()
Expand Down Expand Up @@ -509,6 +519,33 @@ def _update_magnitude_entropy(self, parsed_data: dict):
entropy = EntropyAnalysis.compute_shannon_entropy(magnitude_dist)
self._magnitude_entropy_array.append(entropy)

def _update_magnitude_variance(self, parsed_data: dict):
"""
Update the magnitude variance of the NTK.
The magnitude variance is defined as the variance of the normalized gradient
magnitudes.
As the normalization to obtain the magnitude distribution is done by dividing
by the sum of the magnitudes, the variance is calculated as:
magnitude_variance = var(magnitudes * magnitudes.shape[0])
This ensures that the variance is not dependent on the number entries in the
magnitude distribution.
It is equivalent to the following:
ntk_diag = sqrt( diag(ntk) )
magnitude_variance = var( diag / mean(ntk_diag) )
Parameters
----------
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
magnitude_dist = compute_magnitude_density(gram_matrix=parsed_data["ntk"])
magvar = np.var(magnitude_dist * magnitude_dist.shape[0])
self._magnitude_variance_array.append(magvar)

def _update_eigenvalues(self, parsed_data: dict):
"""
Update the eigenvalue array.
Expand All @@ -531,7 +568,7 @@ def _update_trace(self, parsed_data: dict):
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
trace = onp.trace(parsed_data["ntk"])
trace = np.trace(parsed_data["ntk"])
self._trace_array.append(trace)

def _update_loss_derivative(self, parsed_data):
Expand Down

0 comments on commit 9e68f62

Please sign in to comment.