diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index c2465c0..456d7a8 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -30,6 +30,9 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import jax.numpy as np +import numpy as onp +from numpy.testing import assert_array_almost_equal + import pytest from znnl.analysis import loss_ntk_calculation @@ -75,7 +78,10 @@ class TestLossNTKCalculation: def test_loss_ntk_calculation(self): """ - Test the loss NTK calculation. + Test the Loss NTK calculation. + Here we test if the Loss NTK calculated through the neural tangents module is + the same as the Loss NTK calculated with the already implemented NTK and loss + derivatives. """ # Define a test Network @@ -118,16 +124,28 @@ def test_loss_ntk_calculation(self): ] # Calculate Loss derivative fn loss_derivative_calculator = LossDerivative(LPNormLoss(order=2)) + # predictions calculation analogous to the one in jax recording predictions = production_model(data_set["inputs"]) if type(predictions) is tuple: predictions = predictions[0] - # calculation of loss derivatives - loss_derivatives = loss_derivative_calculator.calculate( - predictions=predictions, - targets=data_set["targets"], - ) - print(loss_derivatives.shape) - -TestLossNTKCalculation().test_loss_ntk_calculation() + # calculation of loss derivatives + # note: here we need the derivatives of the subloss, not the regular loss fn + loss_derivatives = onp.empty(shape=(len(predictions), len(predictions[0]))) + for i in range(len(loss_derivatives)): + # The weird indexing here is because of axis constraints in the LPNormLoss module + loss_derivatives[i] = loss_derivative_calculator.calculate( + predictions[i : i + 1], data_set["targets"][i : i + 1] + )[0] + + # Calculate the loss NTK from the loss derivatives and the ntk + loss_ntk_2 = onp.zeros_like(loss_ntk) + for i in range(len(loss_ntk_2)): + for j in range(len(loss_ntk_2[0])): + loss_ntk_2[i, j] = np.einsum( + "i, j, ij", loss_derivatives[i], loss_derivatives[j], ntk[i, j] + ) + + # Assert that the loss NTKs are the same + assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=2)