From d22bb682bdc8de314bfc0933aabac7e455af45f4 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 22 Jan 2024 17:51:52 +0100 Subject: [PATCH] writing test --- .../analysis/test_loss_ntk_calculation.py | 82 +++++++++++++++++++ znnl/analysis/loss_ntk_calculation.py | 10 ++- 2 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 CI/unit_tests/analysis/test_loss_ntk_calculation.py diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py new file mode 100644 index 0000000..90a4632 --- /dev/null +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -0,0 +1,82 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +import pytest + +from znnl.analysis import loss_ntk_calculation +from znnl.training_recording import JaxRecorder +from znnl.models import NTModel +from znnl.data import MNISTGenerator +from neural_tangents import stax + +import optax +import tensorflow_datasets as tfds + + +class TestLossNTKCalculation: + """ + Test Suite for the loss NTK calculation module. + """ + + def test_loss_ntk_calculation(self): + """ + Test the loss NTK calculation. + """ + + # Define a test Network + dense_network = stax.serial( + stax.Dense(32), + stax.Relu(), + stax.Dense(32), + ) + + # Define a test model + fuel_model = NTModel( + nt_module=dense_network, + optimizer=optax.adam(learning_rate=0.005), + input_shape=(9,), + trace_axes=(), + batch_size=314, + ) + + # Initialize model parameters + + data_generator = MNISTGenerator(ds_size=10) + data_set = { + "inputs": data_generator.train_ds["inputs"], + "targets": data_generator.train_ds["targets"], + } + + print(fuel_model.model_state.params) + + +TestLossNTKCalculation().test_loss_ntk_calculation() diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 07deab0..952e389 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -61,6 +61,10 @@ def __init__( batch_size=self.ntk_batch_size, store_on_device=self.store_on_device, ) + empirical_ntk = nt.empirical_ntk_fn( + f=self._function_for_loss_ntk, + trace_axes=self.trace_axes, + ) self.empirical_ntk_jit = jax.jit(empirical_ntk) def _function_for_loss_ntk(self, params, datapoint) -> float: @@ -82,7 +86,11 @@ def _function_for_loss_ntk(self, params, datapoint) -> float: ) def compute_loss_ntk( - self, x_i: np.ndarray, x_j: np.ndarray, model: JaxModel, infinite: bool = False + self, + x_i: np.ndarray, + model: JaxModel, + x_j: np.ndarray = None, + infinite: bool = False, ): """ Compute the loss NTK matrix for the model.