Skip to content

Commit

Permalink
Calculation and test should work now
Browse files Browse the repository at this point in the history
  • Loading branch information
ma-sauter committed Feb 4, 2024
1 parent 22cfc90 commit 5f9b6bf
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions CI/unit_tests/analysis/test_loss_ntk_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import jax.numpy as np
import numpy as onp
from numpy.testing import assert_array_almost_equal

import pytest

from znnl.analysis import loss_ntk_calculation
Expand Down Expand Up @@ -75,7 +78,10 @@ class TestLossNTKCalculation:

def test_loss_ntk_calculation(self):
"""
Test the loss NTK calculation.
Test the Loss NTK calculation.
Here we test if the Loss NTK calculated through the neural tangents module is
the same as the Loss NTK calculated with the already implemented NTK and loss
derivatives.
"""

# Define a test Network
Expand Down Expand Up @@ -118,16 +124,28 @@ def test_loss_ntk_calculation(self):
]
# Calculate Loss derivative fn
loss_derivative_calculator = LossDerivative(LPNormLoss(order=2))

# predictions calculation analogous to the one in jax recording
predictions = production_model(data_set["inputs"])
if type(predictions) is tuple:
predictions = predictions[0]
# calculation of loss derivatives
loss_derivatives = loss_derivative_calculator.calculate(
predictions=predictions,
targets=data_set["targets"],
)
print(loss_derivatives.shape)


TestLossNTKCalculation().test_loss_ntk_calculation()
# calculation of loss derivatives
# note: here we need the derivatives of the subloss, not the regular loss fn
loss_derivatives = onp.empty(shape=(len(predictions), len(predictions[0])))
for i in range(len(loss_derivatives)):
# The weird indexing here is because of axis constraints in the LPNormLoss module
loss_derivatives[i] = loss_derivative_calculator.calculate(
predictions[i : i + 1], data_set["targets"][i : i + 1]
)[0]

# Calculate the loss NTK from the loss derivatives and the ntk
loss_ntk_2 = onp.zeros_like(loss_ntk)
for i in range(len(loss_ntk_2)):
for j in range(len(loss_ntk_2[0])):
loss_ntk_2[i, j] = np.einsum(
"i, j, ij", loss_derivatives[i], loss_derivatives[j], ntk[i, j]
)

# Assert that the loss NTKs are the same
assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=2)

0 comments on commit 5f9b6bf

Please sign in to comment.