Skip to content

Commit

Permalink
writing test
Browse files Browse the repository at this point in the history
  • Loading branch information
Marc Sauter committed Jan 22, 2024
1 parent fb1da2e commit d22bb68
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
82 changes: 82 additions & 0 deletions CI/unit_tests/analysis/test_loss_ntk_calculation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
"""

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import jax.numpy as np
import pytest

from znnl.analysis import loss_ntk_calculation
from znnl.training_recording import JaxRecorder
from znnl.models import NTModel
from znnl.data import MNISTGenerator
from neural_tangents import stax

import optax
import tensorflow_datasets as tfds


class TestLossNTKCalculation:
"""
Test Suite for the loss NTK calculation module.
"""

def test_loss_ntk_calculation(self):
"""
Test the loss NTK calculation.
"""

# Define a test Network
dense_network = stax.serial(
stax.Dense(32),
stax.Relu(),
stax.Dense(32),
)

# Define a test model
fuel_model = NTModel(
nt_module=dense_network,
optimizer=optax.adam(learning_rate=0.005),
input_shape=(9,),
trace_axes=(),
batch_size=314,
)

# Initialize model parameters

data_generator = MNISTGenerator(ds_size=10)
data_set = {
"inputs": data_generator.train_ds["inputs"],
"targets": data_generator.train_ds["targets"],
}

print(fuel_model.model_state.params)


TestLossNTKCalculation().test_loss_ntk_calculation()
10 changes: 9 additions & 1 deletion znnl/analysis/loss_ntk_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def __init__(
batch_size=self.ntk_batch_size,
store_on_device=self.store_on_device,
)
empirical_ntk = nt.empirical_ntk_fn(
f=self._function_for_loss_ntk,
trace_axes=self.trace_axes,
)
self.empirical_ntk_jit = jax.jit(empirical_ntk)

def _function_for_loss_ntk(self, params, datapoint) -> float:
Expand All @@ -82,7 +86,11 @@ def _function_for_loss_ntk(self, params, datapoint) -> float:
)

def compute_loss_ntk(
self, x_i: np.ndarray, x_j: np.ndarray, model: JaxModel, infinite: bool = False
self,
x_i: np.ndarray,
model: JaxModel,
x_j: np.ndarray = None,
infinite: bool = False,
):
"""
Compute the loss NTK matrix for the model.
Expand Down

0 comments on commit d22bb68

Please sign in to comment.