diff --git a/znnl/analysis/__init__.py b/znnl/analysis/__init__.py index 94701c5..09105e9 100644 --- a/znnl/analysis/__init__.py +++ b/znnl/analysis/__init__.py @@ -28,9 +28,11 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative +from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation __all__ = [ EntropyAnalysis.__name__, EigenSpaceAnalysis.__name__, LossDerivative.__name__, + loss_ntk_calculation.__name__, ] diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index f6c02c7..0fc9d7d 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -37,17 +37,14 @@ def __init__( self, metric_fn: Callable, model: JaxModel, - ntk_batch_size: int = 10, - store_on_device: bool = True, - trace_axes: Union[int, Sequence[int]] = (-1,), ): """Constructor for the loss ntk calculation class.""" # Set the attributes self.metric_fn = metric_fn - self.ntk_batch_size = ntk_batch_size - self.store_on_device = store_on_device - self.trace_axes = trace_axes + self.ntk_batch_size = model.ntk_batch_size + self.store_on_device = model.store_on_device + self.trace_axes = model.trace_axes # Set the loss ntk function _function_for_loss_ntk = lambda x, y: self._function_for_loss_ntk_helper( diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 862b7d4..08c5ea9 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -37,6 +37,7 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative +from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage @@ -157,6 +158,18 @@ class JaxRecorder: loss_derivative: bool = False _loss_derivative_array: list = None + # Loss NTK + loss_ntk: bool = False + _loss_ntk_array: list = None + + # Loss NTK eigenvalues + loss_ntk_eigenvalues: bool = False + _loss_ntk_eigenvalues_array: list = None + + # Loss NTK entropy + loss_ntk_entropy: bool = False + _loss_ntk_entropy_array: list = None + # Class helpers update_rate: int = 1 _loss_fn: SimpleLoss = None @@ -165,6 +178,9 @@ class JaxRecorder: _model: JaxModel = None _data_set: dict = None _compute_ntk: bool = False # Helps to know if we can compute it once and share. + _compute_loss_ntk: bool = ( + False # Helps to know if we can compute it once and share. + ) _compute_loss_derivative: bool = False _loss_derivative_fn: LossDerivative = False _index_count: int = 0 # Helps to avoid problems with non-1 update rates. @@ -254,19 +270,35 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True + # Check if we need a loss NTK computation and update the class accordingly + + if any( + [ + "loss_ntk" in self._selected_properties, + "loss_ntk_eigenvalues" in self._selected_properties, + "loss_ntk_entropy" in self._selected_properties, + ] + ): + self._compute_loss_ntk = True + self._loss_ntk_calculator = loss_ntk_calculation( + metric_fn=self._loss_fn.metric, model=self._model + ) + if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) @@ -321,6 +353,14 @@ def update_recorder(self, epoch: int, model: JaxModel): self.eigenvalues = False self._read_selected_attributes() + # Compute loss ntk here to avoid repeated computation. + if self._compute_loss_ntk: + parsed_data["loss_ntk"] = self._loss_ntk_calculator.compute_loss_ntk( + x_i=self._data_set, + model=self._model, + infinite=False, # Set true to compute infinite width limit of loss ntk + ) + for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function @@ -587,6 +627,45 @@ def _update_loss_derivative(self, parsed_data): loss_derivative = calculate_l_pq_norm(vector_loss_derivative) self._loss_derivative_array.append(loss_derivative) + def _update_loss_ntk(self, parsed_data): + """ + Update the loss ntk array. + + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + self._loss_ntk_array.append(parsed_data["loss_ntk"]) + + def _update_loss_ntk_eigenvalues(self, parsed_data): + """ + Update the loss ntk eigenvalue array. + + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + calculator = EigenSpaceAnalysis(matrix=parsed_data["loss_ntk"]) + eigenvalues = calculator.compute_eigenvalues(normalize=False) + self._loss_ntk_eigenvalues_array.append(eigenvalues) + + def _update_loss_ntk_entropy(self, parsed_data): + """ + Update the loss ntk entropy array. + + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + calculator = EntropyAnalysis(matrix=parsed_data["loss_ntk"]) + entropy = calculator.compute_von_neumann_entropy( + effective=False, normalize_eig=True + ) + self._loss_ntk_entropy_array.append(entropy) + def gather_recording(self, selected_properties: list = None) -> dataclass: """ Export a dataclass of used properties.