diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 90b4b81..a6c35ad 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -60,6 +60,7 @@ def test_instantiation(self): magnitude_ntk=True, entropy=True, magnitude_entropy=True, + magnitude_variance=True, covariance_entropy=True, eigenvalues=True, trace=True, @@ -99,6 +100,7 @@ def test_data_dump(self): magnitude_ntk=True, entropy=False, magnitude_entropy=False, + magnitude_variance=False, covariance_entropy=False, eigenvalues=False, ) @@ -135,3 +137,29 @@ def test_overwriting(self): # Test overwriting. recorder.instantiate_recorder(data_set=self.dummy_data_set, overwrite=True) assert recorder._ntk_array == [] + + def test_magnitude_variance(self): + """ + Test the magnitude variance function. + """ + recorder = JaxRecorder( + loss=False, + accuracy=False, + ntk=False, + entropy=False, + magnitude_variance=True, + eigenvalues=False, + ) + recorder.instantiate_recorder(data_set=self.dummy_data_set) + + # Create some test data. + data = onp.random.uniform(1.0, 2.0, size=(100)) + ntk = onp.eye(100) * data + # calculate the magnitude variance + recorder._update_magnitude_variance(parsed_data={"ntk": ntk}) + # calculate the expected variance + expected_variance = onp.var(onp.sqrt(data) / onp.sqrt(data).mean()) + # check that the variance is correct + testing.assert_almost_equal( + recorder._magnitude_variance_array, expected_variance + ) diff --git a/examples/Computing-Collective-Variables.ipynb b/examples/Computing-Collective-Variables.ipynb index e5c3b4a..36a0781 100644 --- a/examples/Computing-Collective-Variables.ipynb +++ b/examples/Computing-Collective-Variables.ipynb @@ -186,7 +186,8 @@ " name=\"train_recorder\",\n", " loss=True,\n", " ntk=True,\n", - " entropy= True, \n", + " covariance_entropy=True,\n", + " magnitude_variance=True, \n", " trace=True,\n", " loss_derivative=True,\n", " update_rate=1\n", @@ -303,13 +304,27 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(train_report.entropy, 'o', mfc='None', label=\"Entropy\")\n", + "plt.plot(train_report.covariance_entropy, 'o', mfc='None', label=\"Entropy\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Entropy\")\n", "plt.legend()\n", "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8772e69", + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(train_report.magnitude_variance, 'o', mfc='None', label=\"Magnitude Variance\")\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Magnitude Variance\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -342,6 +357,14 @@ "plt.legend()\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d28033be", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -360,7 +383,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.5" } }, "nbformat": 4, diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 90b01d6..2a702bd 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -29,6 +29,7 @@ from os import path from pathlib import Path +import jax.numpy as np import numpy as onp from znnl.accuracy_functions.accuracy_function import AccuracyFunction @@ -82,6 +83,9 @@ class JaxRecorder: covariance_entropy : bool (default=False) If true, the entropy of the covariance ntk will be recorded. Warning, large overhead. + magnitude_variance : bool (default=False) + If true, the variance of the gradient magnitudes of the ntk will be + recorded. magnitude_entropy : bool (default=False) If true, the entropy of the gradient magnitudes of the ntk will be recorded. Warning, large overhead. @@ -132,6 +136,10 @@ class JaxRecorder: covariance_entropy: bool = False _covariance_entropy_array: list = None + # Magnitude Variance of the model + magnitude_variance: bool = False + _magnitude_variance_array: list = None + # Magnitude Entropy of the model magnitude_entropy: bool = False _magnitude_entropy_array: list = None @@ -252,6 +260,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): "magnitude_ntk" in self._selected_properties, "entropy" in self._selected_properties, "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, "covariance_entropy" in self._selected_properties, "eigenvalues" in self._selected_properties, "trace" in self._selected_properties, @@ -308,6 +317,7 @@ def update_recorder(self, epoch: int, model: JaxModel): self.magnitude_ntk = False self.entropy = False self.magnitude_entropy = False + self.magnitude_variance = False self.covariance_entropy = False self.eigenvalues = False self._read_selected_attributes() @@ -509,6 +519,33 @@ def _update_magnitude_entropy(self, parsed_data: dict): entropy = EntropyAnalysis.compute_shannon_entropy(magnitude_dist) self._magnitude_entropy_array.append(entropy) + def _update_magnitude_variance(self, parsed_data: dict): + """ + Update the magnitude variance of the NTK. + + The magnitude variance is defined as the variance of the normalized gradient + magnitudes. + As the normalization to obtain the magnitude distribution is done by dividing + by the sum of the magnitudes, the variance is calculated as: + + magnitude_variance = var(magnitudes * magnitudes.shape[0]) + + This ensures that the variance is not dependent on the number entries in the + magnitude distribution. + It is equivalent to the following: + + ntk_diag = sqrt( diag(ntk) ) + magnitude_variance = var( diag / mean(ntk_diag) ) + + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + magnitude_dist = compute_magnitude_density(gram_matrix=parsed_data["ntk"]) + magvar = np.var(magnitude_dist * magnitude_dist.shape[0]) + self._magnitude_variance_array.append(magvar) + def _update_eigenvalues(self, parsed_data: dict): """ Update the eigenvalue array. @@ -531,7 +568,7 @@ def _update_trace(self, parsed_data: dict): parsed_data : dict Data computed before the update to prevent repeated calculations. """ - trace = onp.trace(parsed_data["ntk"]) + trace = np.trace(parsed_data["ntk"]) self._trace_array.append(trace) def _update_loss_derivative(self, parsed_data):