Skip to content

Commit

Permalink
Included loss ntk, eigval and entropy in jax recorder
Browse files Browse the repository at this point in the history
  • Loading branch information
ma-sauter committed Jan 19, 2024
1 parent 768fc13 commit c802010
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 17 deletions.
2 changes: 2 additions & 0 deletions znnl/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from znnl.analysis.eigensystem import EigenSpaceAnalysis
from znnl.analysis.entropy import EntropyAnalysis
from znnl.analysis.loss_fn_derivative import LossDerivative
from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation

__all__ = [
EntropyAnalysis.__name__,
EigenSpaceAnalysis.__name__,
LossDerivative.__name__,
loss_ntk_calculation.__name__,
]
9 changes: 3 additions & 6 deletions znnl/analysis/loss_ntk_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@ def __init__(
self,
metric_fn: Callable,
model: JaxModel,
ntk_batch_size: int = 10,
store_on_device: bool = True,
trace_axes: Union[int, Sequence[int]] = (-1,),
):
"""Constructor for the loss ntk calculation class."""

# Set the attributes
self.metric_fn = metric_fn
self.ntk_batch_size = ntk_batch_size
self.store_on_device = store_on_device
self.trace_axes = trace_axes
self.ntk_batch_size = model.ntk_batch_size
self.store_on_device = model.store_on_device
self.trace_axes = model.trace_axes

# Set the loss ntk function
_function_for_loss_ntk = lambda x, y: self._function_for_loss_ntk_helper(
Expand Down
101 changes: 90 additions & 11 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from znnl.analysis.eigensystem import EigenSpaceAnalysis
from znnl.analysis.entropy import EntropyAnalysis
from znnl.analysis.loss_fn_derivative import LossDerivative
from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation
from znnl.loss_functions import SimpleLoss
from znnl.models.jax_model import JaxModel
from znnl.training_recording.data_storage import DataStorage
Expand Down Expand Up @@ -157,6 +158,18 @@ class JaxRecorder:
loss_derivative: bool = False
_loss_derivative_array: list = None

# Loss NTK
loss_ntk: bool = False
_loss_ntk_array: list = None

# Loss NTK eigenvalues
loss_ntk_eigenvalues: bool = False
_loss_ntk_eigenvalues_array: list = None

# Loss NTK entropy
loss_ntk_entropy: bool = False
_loss_ntk_entropy_array: list = None

# Class helpers
update_rate: int = 1
_loss_fn: SimpleLoss = None
Expand All @@ -165,6 +178,9 @@ class JaxRecorder:
_model: JaxModel = None
_data_set: dict = None
_compute_ntk: bool = False # Helps to know if we can compute it once and share.
_compute_loss_ntk: bool = (
False # Helps to know if we can compute it once and share.
)
_compute_loss_derivative: bool = False
_loss_derivative_fn: LossDerivative = False
_index_count: int = 0 # Helps to avoid problems with non-1 update rates.
Expand Down Expand Up @@ -254,19 +270,35 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
self._index_count = 0

# Check if we need an NTK computation and update the class accordingly
if any([
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]):
if any(
[
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]
):
self._compute_ntk = True

# Check if we need a loss NTK computation and update the class accordingly

if any(
[
"loss_ntk" in self._selected_properties,
"loss_ntk_eigenvalues" in self._selected_properties,
"loss_ntk_entropy" in self._selected_properties,
]
):
self._compute_loss_ntk = True
self._loss_ntk_calculator = loss_ntk_calculation(
metric_fn=self._loss_fn.metric, model=self._model
)

if "loss_derivative" in self._selected_properties:
self._loss_derivative_fn = LossDerivative(self._loss_fn)

Expand Down Expand Up @@ -321,6 +353,14 @@ def update_recorder(self, epoch: int, model: JaxModel):
self.eigenvalues = False
self._read_selected_attributes()

# Compute loss ntk here to avoid repeated computation.
if self._compute_loss_ntk:
parsed_data["loss_ntk"] = self._loss_ntk_calculator.compute_loss_ntk(
x_i=self._data_set,
model=self._model,
infinite=False, # Set true to compute infinite width limit of loss ntk
)

for item in self._selected_properties:
call_fn = getattr(self, f"_update_{item}") # get the callable function

Expand Down Expand Up @@ -587,6 +627,45 @@ def _update_loss_derivative(self, parsed_data):
loss_derivative = calculate_l_pq_norm(vector_loss_derivative)
self._loss_derivative_array.append(loss_derivative)

def _update_loss_ntk(self, parsed_data):
"""
Update the loss ntk array.
Parameters
----------
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
self._loss_ntk_array.append(parsed_data["loss_ntk"])

def _update_loss_ntk_eigenvalues(self, parsed_data):
"""
Update the loss ntk eigenvalue array.
Parameters
----------
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
calculator = EigenSpaceAnalysis(matrix=parsed_data["loss_ntk"])
eigenvalues = calculator.compute_eigenvalues(normalize=False)
self._loss_ntk_eigenvalues_array.append(eigenvalues)

def _update_loss_ntk_entropy(self, parsed_data):
"""
Update the loss ntk entropy array.
Parameters
----------
parsed_data : dict
Data computed before the update to prevent repeated calculations.
"""
calculator = EntropyAnalysis(matrix=parsed_data["loss_ntk"])
entropy = calculator.compute_von_neumann_entropy(
effective=False, normalize_eig=True
)
self._loss_ntk_entropy_array.append(entropy)

def gather_recording(self, selected_properties: list = None) -> dataclass:
"""
Export a dataclass of used properties.
Expand Down

0 comments on commit c802010

Please sign in to comment.