Skip to content

Commit

Permalink
make NTK computation return ntk in list
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 15, 2024
1 parent 81f06c8 commit 226cea1
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions znnl/ntk_computation/jax_ntk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
-------
"""

from typing import Callable, Optional
from typing import Callable, List, Optional

import jax.numpy as np
import neural_tangents as nt
Expand Down Expand Up @@ -105,7 +105,7 @@ def apply_fn(params, x):

def compute_ntk(
self, params: dict, x_i: np.ndarray, x_j: Optional[np.ndarray] = None
) -> np.ndarray:
) -> List[np.ndarray]:
"""
Compute the Neural Tangent Kernel (NTK) for the neural network.
Expand All @@ -118,7 +118,7 @@ def compute_ntk(
Returns
-------
np.ndarray
List[np.ndarray]
The NTK matrix.
"""
return self.empirical_ntk(x_i, x_j, params)
return [self.empirical_ntk(x_i, x_j, params)]

0 comments on commit 226cea1

Please sign in to comment.