From 226cea10cc715b5949491eea039dd206f0de6b73 Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Wed, 15 May 2024 21:41:47 +0200 Subject: [PATCH] make NTK computation return ntk in list --- znnl/ntk_computation/jax_ntk.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/znnl/ntk_computation/jax_ntk.py b/znnl/ntk_computation/jax_ntk.py index 96b09ab..80e9b7d 100644 --- a/znnl/ntk_computation/jax_ntk.py +++ b/znnl/ntk_computation/jax_ntk.py @@ -25,7 +25,7 @@ ------- """ -from typing import Callable, Optional +from typing import Callable, List, Optional import jax.numpy as np import neural_tangents as nt @@ -105,7 +105,7 @@ def apply_fn(params, x): def compute_ntk( self, params: dict, x_i: np.ndarray, x_j: Optional[np.ndarray] = None - ) -> np.ndarray: + ) -> List[np.ndarray]: """ Compute the Neural Tangent Kernel (NTK) for the neural network. @@ -118,7 +118,7 @@ def compute_ntk( Returns ------- - np.ndarray + List[np.ndarray] The NTK matrix. """ - return self.empirical_ntk(x_i, x_j, params) + return [self.empirical_ntk(x_i, x_j, params)]