Skip to content

Commit

Permalink
make loss derivative computation work and include in example
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 29, 2024
1 parent e64ffb2 commit 2630779
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
13 changes: 8 additions & 5 deletions examples/Computing-Collective-Variables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@
},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n",
"# import os\n",
"# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n",
"\n",
"import znnl as nl\n",
"from neural_tangents import stax\n",
"import optax\n",
"\n",
"from papyrus.measurements import (\n",
" Loss, Accuracy, NTKTrace, NTKEntropy, NTK, NTKSelfEntropy, NTKEigenvalues\n",
" Loss, Accuracy, NTKTrace, NTKEntropy, NTK, NTKSelfEntropy, NTKEigenvalues, LossDerivative,\n",
")\n",
"\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -168,7 +168,9 @@
"ntk_computation = nl.ntk_computation.JAXNTKComputation(\n",
" apply_fn=fuel_model.ntk_apply_fn, \n",
" batch_size=314,\n",
")"
")\n",
"\n",
"loss_derivative_computation = nl.analysis.LossDerivative(loss_fn=nl.loss_functions.LPNormLoss(order=2))"
]
},
{
Expand Down Expand Up @@ -212,6 +214,7 @@
" NTK(name=\"ntk\"),\n",
" NTKSelfEntropy(name=\"ntk_self_entropy\"),\n",
" NTKEigenvalues(name=\"ntk_eigenvalues\"),\n",
" LossDerivative(name=\"loss_derivative\", apply_fn=loss_derivative_computation.calculate),\n",
" ],\n",
" storage_path=\".\",\n",
" update_rate=1, \n",
Expand Down Expand Up @@ -399,7 +402,7 @@
"calculate_l_pq_norm = nl.utils.matrix_utils.calculate_l_pq_norm\n",
"\n",
"l_pq_norms = np.array([\n",
" calculate_l_pq_norm(i) for i in train_report.loss_derivative\n",
" calculate_l_pq_norm(i) for i in train_report[\"loss_derivative\"]\n",
"])\n",
"\n",
"plt.plot(\n",
Expand Down
6 changes: 3 additions & 3 deletions znnl/ntk_computation/jax_ntk_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from typing import Callable, List, Optional

import jax.numpy as np
import jax.tree as jt
import neural_tangents as nt
from jax import random, vmap
from jax.tree_util import tree_map as jmap

from znnl.ntk_computation.jax_ntk import JAXNTKComputation

Expand Down Expand Up @@ -178,7 +178,7 @@ def _subsample_data(self, x: np.ndarray, sample_indices: dict) -> np.ndarray:
np.ndarray
The subsampled data.
"""
return jt.map(lambda indices: np.take(x, indices, axis=0), sample_indices)
return jmap(lambda indices: np.take(x, indices, axis=0), sample_indices)

def _compute_ntk(self, params: dict, x_i: np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -226,7 +226,7 @@ def compute_ntk(self, params: dict, dataset: dict) -> List[np.ndarray]:

x_i = self._subsample_data(dataset[self.data_keys[0]], self._sample_indices)

ntks = jt.map(lambda x_i: self._compute_ntk(params, x_i), x_i)
ntks = jmap(lambda x_i: self._compute_ntk(params, x_i), x_i)

ntks = list(ntks.values())

Expand Down
4 changes: 2 additions & 2 deletions znnl/ntk_computation/jax_ntk_subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from typing import Callable, List, Optional

import jax.numpy as np
import jax.tree as jt
import neural_tangents as nt
from jax import random
from jax.tree_util import tree_map as jmap

from znnl.ntk_computation.jax_ntk import JAXNTKComputation

Expand Down Expand Up @@ -224,6 +224,6 @@ def compute_ntk(

x_j = self._subsample_data(x_j) if x_j is not None else [None] * self.n_parts

ntks = jt.map(lambda x_i, x_j: self._compute_ntk(params, x_i, x_j), x_i, x_j)
ntks = jmap(lambda x_i, x_j: self._compute_ntk(params, x_i, x_j), x_i, x_j)

return ntks

0 comments on commit 2630779

Please sign in to comment.