Skip to content

Commit

Permalink
Add magnitude variance of the NTK to the recorder
Browse files Browse the repository at this point in the history
Also add a test
  • Loading branch information
knikolaou committed Sep 8, 2023
1 parent 244d707 commit 5be3fa1
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
28 changes: 28 additions & 0 deletions CI/unit_tests/training_recording/test_training_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_instantiation(self):
magnitude_ntk=True,
entropy=True,
magnitude_entropy=True,
magnitude_variance=True,
covariance_entropy=True,
eigenvalues=True,
trace=True,
Expand Down Expand Up @@ -99,6 +100,7 @@ def test_data_dump(self):
magnitude_ntk=True,
entropy=False,
magnitude_entropy=False,
magnitude_variance=False,
covariance_entropy=False,
eigenvalues=False,
)
Expand Down Expand Up @@ -135,3 +137,29 @@ def test_overwriting(self):
# Test overwriting.
recorder.instantiate_recorder(data_set=self.dummy_data_set, overwrite=True)
assert recorder._ntk_array == []

def test_magnitude_variance(self):
"""
Test the magnitude variance function.
"""
recorder = JaxRecorder(
loss=False,
accuracy=False,
ntk=False,
entropy=False,
magnitude_variance=True,
eigenvalues=False,
)
recorder.instantiate_recorder(data_set=self.dummy_data_set)

# Create some test data.
data = onp.random.uniform(1.0, 2.0, size=(100))
ntk = onp.eye(100) * data
# calculate the magnitude variance
recorder._update_magnitude_variance(parsed_data={"ntk": ntk})
# calculate the expected variance
expected_variance = onp.var(onp.sqrt(data) / onp.sqrt(data).mean())
# check that the variance is correct
testing.assert_almost_equal(
recorder._magnitude_variance_array, expected_variance
)
36 changes: 36 additions & 0 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class JaxRecorder:
covariance_entropy : bool (default=False)
If true, the entropy of the covariance ntk will be recorded.
Warning, large overhead.
magnitude_variance : bool (default=False)
If true, the variance of the gradient magnitudes of the ntk will be
recorded.
magnitude_entropy : bool (default=False)
If true, the entropy of the gradient magnitudes of the ntk will be recorded.
Warning, large overhead.
Expand Down Expand Up @@ -132,6 +135,10 @@ class JaxRecorder:
covariance_entropy: bool = False
_covariance_entropy_array: list = None

# Magnitude Variance of the model
magnitude_variance: bool = False
_magnitude_variance_array: list = None

# Magnitude Entropy of the model
magnitude_entropy: bool = False
_magnitude_entropy_array: list = None
Expand Down Expand Up @@ -252,6 +259,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
Expand Down Expand Up @@ -308,6 +316,7 @@ def update_recorder(self, epoch: int, model: JaxModel):
self.magnitude_ntk = False
self.entropy = False
self.magnitude_entropy = False
self.magnitude_variance = False
self.covariance_entropy = False
self.eigenvalues = False
self._read_selected_attributes()
Expand Down Expand Up @@ -509,6 +518,33 @@ def _update_magnitude_entropy(self, parsed_data: dict):
entropy = EntropyAnalysis.compute_shannon_entropy(magnitude_dist)
self._magnitude_entropy_array.append(entropy)

def _update_magnitude_variance(self, parsed_data: dict):
"""
Update the magnitude variance of the NTK.
The magnitude variance is defined as the variance of the normalized gradient
magnitudes.
As the normalization to obtain the magnitude distribution is done by dividing
by the sum of the magnitudes, the variance is calculated as:
magnitude_variance = var(magnitudes * onp.shape(magnitudes)[0])
This ensures that the variance is not dependent on the number entries in the
magnitude distribution.
It is equivalent to the following:
ntk_diag = sqrt( diag(ntk) )
magnitude_variance = var( diag / mean(ntk_diag) )
Parameters
----------
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
magnitude_dist = compute_magnitude_density(gram_matrix=parsed_data["ntk"])
magvar = onp.var(magnitude_dist * onp.shape(magnitude_dist)[0])
self._magnitude_variance_array.append(magvar)

def _update_eigenvalues(self, parsed_data: dict):
"""
Update the eigenvalue array.
Expand Down

0 comments on commit 5be3fa1

Please sign in to comment.