From e64ffb2140cb42f1f2373e36092ec700dff8ff0b Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Tue, 28 May 2024 14:24:14 +0200 Subject: [PATCH] Write class-wise ntk computation --- CI/unit_tests/ntk_computation/test_jax_ntk.py | 28 ++- .../ntk_computation/test_jax_ntk_classwise.py | 195 ++++++++++++++++++ .../test_jax_ntk_subsampling.py | 16 +- znnl/ntk_computation/jax_ntk.py | 2 + znnl/ntk_computation/jax_ntk_classwise.py | 148 +++++++++++++ znnl/ntk_computation/jax_ntk_subsampling.py | 10 +- 6 files changed, 370 insertions(+), 29 deletions(-) create mode 100644 CI/unit_tests/ntk_computation/test_jax_ntk_classwise.py diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk.py b/CI/unit_tests/ntk_computation/test_jax_ntk.py index d4f58d2..43403e9 100644 --- a/CI/unit_tests/ntk_computation/test_jax_ntk.py +++ b/CI/unit_tests/ntk_computation/test_jax_ntk.py @@ -77,7 +77,7 @@ def test_constructor(self): apply_fn = lambda x: x batch_size = 10 ntk_implementation = None - trace_axes = () + trace_axes = (-1,) store_on_device = False flatten = True data_keys = ["image", "label"] @@ -99,22 +99,28 @@ def test_constructor(self): assert jax_ntk_computation.flatten == flatten assert jax_ntk_computation.data_keys == data_keys - # Default ntk_implementation should be NTK_VECTOR_PRODUCTS - assert ( - jax_ntk_computation.ntk_implementation - == nt.NtkImplementation.NTK_VECTOR_PRODUCTS - ) + def test_constructor_default(self): + """ + Test the default setting of the constructor of the JAX NTK computation class. + """ + apply_fn = lambda x: x - # Test the default trace_axes jax_ntk_computation = JAXNTKComputation( apply_fn=apply_fn, - batch_size=batch_size, - ntk_implementation=ntk_implementation, - store_on_device=store_on_device, - flatten=flatten, ) + assert jax_ntk_computation.apply_fn == apply_fn + assert jax_ntk_computation.batch_size == 10 assert jax_ntk_computation.trace_axes == () + assert jax_ntk_computation.store_on_device == False + assert jax_ntk_computation.flatten == True + assert jax_ntk_computation.data_keys == ["inputs", "targets"] + + # Default ntk_implementation should be NTK_VECTOR_PRODUCTS + assert ( + jax_ntk_computation.ntk_implementation + == nt.NtkImplementation.NTK_VECTOR_PRODUCTS + ) def test_check_shape(self): """ diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk_classwise.py b/CI/unit_tests/ntk_computation/test_jax_ntk_classwise.py new file mode 100644 index 0000000..d619bdd --- /dev/null +++ b/CI/unit_tests/ntk_computation/test_jax_ntk_classwise.py @@ -0,0 +1,195 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import jax.numpy as np +import neural_tangents as nt +import optax +from flax import linen as nn +from jax import random + +from znnl.models import FlaxModel +from znnl.ntk_computation import JAXNTKClassWise + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=2, use_bias=True)(x) + return x + + +class TestJAXNTKClassWise: + """ + Test class for the class-wise JAX NTK computation. + """ + + @classmethod + def setup_class(cls): + """ + Setup the test class. + """ + cls.flax_model = FlaxModel( + flax_module=FlaxTestModule(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(8,), + seed=17, + ) + + # Create random labels between zero and two + targets = np.array([0, 1, 2, 0, 1, 2, 0, 0]) + one_hot_targets = np.eye(3)[targets] + + cls.dataset_int = { + "inputs": random.normal(random.PRNGKey(0), (10, 8)), + "targets": np.expand_dims(targets, axis=1), + } + cls.dataset_onehot = { + "inputs": random.normal(random.PRNGKey(0), (10, 8)), + "targets": one_hot_targets, + } + + def test_constructor(self): + """ + Test the constructor of the JAX NTK computation class. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.apply, + ) + + assert jax_ntk.batch_size == 10 + assert jax_ntk._sample_indices == None + + def test_get_sample_indices(self): + """ + Test the _get_sample_indices method. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.apply, + ) + + # Test the one-hot targets + sample_idx_one_hot = jax_ntk._get_sample_indices(self.dataset_onehot) + assert len(sample_idx_one_hot) == 3 + assert len(sample_idx_one_hot[0]) == 4 + assert len(sample_idx_one_hot[1]) == 2 + assert len(sample_idx_one_hot[2]) == 2 + + # Test the integer targets + sample_idx_int = jax_ntk._get_sample_indices(self.dataset_int) + assert len(sample_idx_int) == 3 + assert len(sample_idx_int[0]) == 4 + assert len(sample_idx_int[1]) == 2 + assert len(sample_idx_int[2]) == 2 + + # Test upper bound of ntk_size + jax_ntk.ntk_size = 3 + sample_idx_one_hot = jax_ntk._get_sample_indices(self.dataset_onehot) + assert len(sample_idx_one_hot) == 3 + assert len(sample_idx_one_hot[0]) == 3 + assert len(sample_idx_one_hot[1]) == 2 + assert len(sample_idx_one_hot[2]) == 2 + + def test_subsample_data(self): + """ + Test the _subsample_data method. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.apply, + ) + + # Test the one-hot targets + subsampled_data_one_hot = jax_ntk._subsample_data( + self.dataset_onehot["inputs"], + jax_ntk._get_sample_indices(self.dataset_onehot), + ) + assert len(subsampled_data_one_hot) == 3 + assert subsampled_data_one_hot[0].shape == (4, 8) + assert subsampled_data_one_hot[1].shape == (2, 8) + assert subsampled_data_one_hot[2].shape == (2, 8) + + # Test the integer targets + subsampled_data_int = jax_ntk._subsample_data( + self.dataset_int["inputs"], jax_ntk._get_sample_indices(self.dataset_int) + ) + assert len(subsampled_data_int) == 3 + assert subsampled_data_int[0].shape == (4, 8) + assert subsampled_data_int[1].shape == (2, 8) + assert subsampled_data_int[2].shape == (2, 8) + + def test_compute_ntk(self): + """ + Test the compute_ntk method. + """ + jax_ntk = JAXNTKClassWise( + apply_fn=self.flax_model.ntk_apply_fn, + batch_size=10, + ) + + params = {"params": self.flax_model.model_state.params} + + # Test the one-hot targets + ntks = jax_ntk.compute_ntk(params, self.dataset_onehot) + assert len(ntks) == 3 + assert ntks[0].shape == (8, 8) + assert ntks[1].shape == (4, 4) + assert ntks[2].shape == (4, 4) + + # Test the integer targets + ntks = jax_ntk.compute_ntk(params, self.dataset_int) + print(ntks) + assert len(ntks) == 3 + assert ntks[0].shape == (8, 8) + assert ntks[1].shape == (4, 4) + assert ntks[2].shape == (4, 4) + + # Test if not all classes are present + dataset = { + "inputs": self.dataset_int["inputs"], + "targets": np.array([0, 0, 0, 0, 0, 0, 0, 0]), + } + ntks = jax_ntk.compute_ntk(params, dataset) + assert len(ntks) == 1 + assert ntks[0].shape == (16, 16) + + dataset = { + "inputs": self.dataset_int["inputs"], + "targets": np.array([0, 0, 0, 0, 0, 0, 0, 5]), + } + ntks = jax_ntk.compute_ntk(params, dataset) + assert len(ntks) == 6 + assert ntks[0].shape == (14, 14) + assert ntks[1].shape == (0, 0) + assert ntks[2].shape == (0, 0) + assert ntks[3].shape == (0, 0) + assert ntks[4].shape == (0, 0) + assert ntks[5].shape == (2, 2) diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk_subsampling.py b/CI/unit_tests/ntk_computation/test_jax_ntk_subsampling.py index 0cd2394..14e7dfa 100644 --- a/CI/unit_tests/ntk_computation/test_jax_ntk_subsampling.py +++ b/CI/unit_tests/ntk_computation/test_jax_ntk_subsampling.py @@ -74,24 +74,10 @@ def test_constructor(self): Test the constructor of the JAX NTK computation class. """ jax_ntk = JAXNTKSubsampling( - apply_fn=self.flax_model.ntk_apply_fn, - ntk_size=3, - seed=0, - batch_size=10, - trace_axes=(), - store_on_device=False, - flatten=True, - data_keys=["inputs", "targets"], + apply_fn=self.flax_model.ntk_apply_fn, ntk_size=3, seed=0 ) - assert jax_ntk.apply_fn == self.flax_model.ntk_apply_fn assert jax_ntk.ntk_size == 3 - assert jax_ntk.seed == 0 - assert jax_ntk.batch_size == 10 - assert jax_ntk.trace_axes == () - assert jax_ntk.store_on_device is False - assert jax_ntk.flatten is True - assert jax_ntk.data_keys == ["inputs", "targets"] def test_get_sample_indices(self): """ diff --git a/znnl/ntk_computation/jax_ntk.py b/znnl/ntk_computation/jax_ntk.py index 19fd93a..2b795e2 100644 --- a/znnl/ntk_computation/jax_ntk.py +++ b/znnl/ntk_computation/jax_ntk.py @@ -91,6 +91,8 @@ def apply_fn(params, x): The keys used to define inputs and targets in the dataset. These keys are used to extract values from the dataset dictionary in the `compute_ntk` method. + Note that the first key has to refer the input data and the second key + to the targets / labels of the dataset. """ self.apply_fn = apply_fn self.batch_size = batch_size diff --git a/znnl/ntk_computation/jax_ntk_classwise.py b/znnl/ntk_computation/jax_ntk_classwise.py index b6f1bc5..f5e64a1 100644 --- a/znnl/ntk_computation/jax_ntk_classwise.py +++ b/znnl/ntk_computation/jax_ntk_classwise.py @@ -28,6 +28,7 @@ from typing import Callable, List, Optional import jax.numpy as np +import jax.tree as jt import neural_tangents as nt from jax import random, vmap @@ -35,6 +36,21 @@ class JAXNTKClassWise(JAXNTKComputation): + """ + Class for computing the empirical Neural Tangent Kernel (NTK) using the + neural-tangents library (implemented in JAX) with class-wise subsampling. + + This class is a subclass of JAXNTKComputation and adds the functionality of + subsampling the data according to the classes before computing the NTK. + In this way, the NTK is computed for each class separately. + + Note + ---- + This class is only implemented for the computing the NTK of a single dataset. + This menas that axis 0 and 1 of the NTK matrix correspond to the same dataset. + More information can be found in the `compute_ntk` method. + """ + def __init__( self, apply_fn: Callable, @@ -43,6 +59,8 @@ def __init__( trace_axes: tuple = (), store_on_device: bool = False, flatten: bool = True, + data_keys: Optional[List[str]] = None, + ntk_size: int = None, ): """ Constructor the JAX NTK computation class. @@ -82,6 +100,14 @@ def apply_fn(params, x): flatten : bool, default True If True, the NTK shape is checked and flattened into a 2D matrix, if required. + data_keys : List[str], default ["inputs", "targets"] + The keys used to define inputs and targets in the dataset. + These keys are used to extract values from the dataset dictionary in + the `compute_ntk` method. + Note that the first key has to refer the input data and the second key + to the targets / labels of the dataset. + ntk_size : int (default = None) + Upper limit for the number of samples used for the NTK computation. """ super().__init__( apply_fn=apply_fn, @@ -90,4 +116,126 @@ def apply_fn(params, x): trace_axes=trace_axes, store_on_device=store_on_device, flatten=flatten, + data_keys=data_keys, ) + + self._sample_indices = None + self.ntk_size = ntk_size + + def _get_sample_indices(self, dataset: dict) -> List[np.ndarray]: + """ + Group the data by class and return the indices of the samples to use for the + NTK computation. + + Parameters + ---------- + dataset : dict + The dataset containing the inputs and targets. + + Returns + ------- + sample_indices : dict + A dictionary containing the indices of the samples for each class, with + the class label as the key. + """ + targets = dataset[self.data_keys[1]] + + if len(targets.shape) > 1: + # If one-hot encoding is used, convert it to class labels + if targets.shape[1] > 1: + targets = np.argmax(targets, axis=1) + # If the targets are already class labels, squeeze the array + elif targets.shape[1] == 1: + targets = np.squeeze(targets, axis=1) + + unique_classes = np.unique(targets) + _indices = np.arange(targets.shape[0]) + sample_indices = {} + + for class_label in unique_classes: + # Create mask for samples of the current class + mask = targets == class_label + indices = np.compress(mask, _indices, axis=0) + if self.ntk_size is not None: + indices = indices[: self.ntk_size] + sample_indices[int(class_label)] = indices + + return sample_indices + + def _subsample_data(self, x: np.ndarray, sample_indices: dict) -> np.ndarray: + """ + Subsample the data based on indices. + + Parameters + ---------- + x : np.ndarray + The input data. + sample_indices : dict + The indices of the samples to use for the NTK computation. + + Returns + ------- + np.ndarray + The subsampled data. + """ + return jt.map(lambda indices: np.take(x, indices, axis=0), sample_indices) + + def _compute_ntk(self, params: dict, x_i: np.ndarray) -> np.ndarray: + """ + Compute the NTK for the neural network. + + Parameters + ---------- + params : dict + The parameters of the neural network. + x_i : np.ndarray + The input to the neural network. + + Returns + ------- + np.ndarray + The NTK matrix. + """ + ntk = self.empirical_ntk(x_i, None, params) + ntk = self._check_shape(ntk) + return ntk + + def compute_ntk(self, params: dict, dataset: dict) -> List[np.ndarray]: + """ + Compute the Neural Tangent Kernel (NTK) for the neural network. + + Note + ---- + This method only accepts a single dataset for the NTK computation. This means + both axes of the NTK matrix correspond to the same dataset. + + Parameters + ---------- + params : dict + The parameters of the neural network. + dataset_i : dict + The input dataset for the NTK computation. + + Returns + ------- + List[np.ndarray] + The NTK matrix. + """ + + self._sample_indices = self._get_sample_indices(dataset) + + x_i = self._subsample_data(dataset[self.data_keys[0]], self._sample_indices) + + ntks = jt.map(lambda x_i: self._compute_ntk(params, x_i), x_i) + + ntks = list(ntks.values()) + + # Get the maximum key in the sample indices i + max_key = max(self._sample_indices.keys()) + + # Fill in the missing classes with empty NTKs + for i in range(max_key): + if i not in self._sample_indices.keys(): + ntks.insert(i, np.zeros((0, 0))) + + return ntks diff --git a/znnl/ntk_computation/jax_ntk_subsampling.py b/znnl/ntk_computation/jax_ntk_subsampling.py index 5ab8f39..4586ee3 100644 --- a/znnl/ntk_computation/jax_ntk_subsampling.py +++ b/znnl/ntk_computation/jax_ntk_subsampling.py @@ -28,6 +28,7 @@ from typing import Callable, List, Optional import jax.numpy as np +import jax.tree as jt import neural_tangents as nt from jax import random @@ -110,6 +111,8 @@ def apply_fn(params, x): The keys used to define inputs and targets in the dataset. These keys are used to extract values from the dataset dictionary in the `compute_ntk` method. + Note that the first key has to refer the input data and the second key + to the targets / labels of the dataset. """ super().__init__( apply_fn=apply_fn, @@ -121,7 +124,7 @@ def apply_fn(params, x): data_keys=data_keys, ) self.ntk_size = ntk_size - self.seed = seed + self.key = random.PRNGKey(seed) self._sample_indices: List[np.ndarray] = [] self.n_parts = None @@ -144,7 +147,8 @@ def _get_sample_indices(self, x: np.ndarray) -> List[np.ndarray]: data_len = x.shape[0] self.n_parts = data_len // self.ntk_size - key = random.PRNGKey(self.seed) + key, self.key = random.split(self.key) + indices = random.permutation(key, np.arange(data_len)) return [ @@ -220,6 +224,6 @@ def compute_ntk( x_j = self._subsample_data(x_j) if x_j is not None else [None] * self.n_parts - ntks = [self._compute_ntk(params, x_i[i], x_j[i]) for i in range(self.n_parts)] + ntks = jt.map(lambda x_i, x_j: self._compute_ntk(params, x_i, x_j), x_i, x_j) return ntks