Skip to content

Commit

Permalink
Implement regularization schedules.
Browse files Browse the repository at this point in the history
A schedule is a function that depends on the current epoch
and rescales the regularization factor.
This function can also be defined by the user.
  • Loading branch information
knikolaou committed Oct 3, 2023
1 parent 9e7e073 commit d975c1f
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 184 deletions.
144 changes: 69 additions & 75 deletions examples/trace_regularization.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions znnl/regularizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@
Summary
-------
"""
from znnl.regularizers.regularizer import Regularizer
from znnl.regularizers.norm_regularizer import NormRegularizer
from znnl.regularizers.regularizer import Regularizer
from znnl.regularizers.trace_regularizer import TraceRegularizer
from znnl.regularizers.grad_variance_regularizer import GradVarianceRegularizer

__all__ = [
Regularizer.__name__,
NormRegularizer.__name__,
TraceRegularizer.__name__,
GradVarianceRegularizer.__name__,
]
79 changes: 0 additions & 79 deletions znnl/regularizers/grad_variance_regularizer.py

This file was deleted.

29 changes: 16 additions & 13 deletions znnl/regularizers/norm_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,32 @@
Summary
-------
"""
from znnl.regularizers.regularizer import Regularizer
from functools import partial
from typing import Callable, Optional

import jax.flatten_util
import jax.tree_util
import jax.numpy as np
import jax.tree_util
from jax import jit
from functools import partial

from znnl.regularizers.regularizer import Regularizer


class NormRegularizer(Regularizer):
"""
Class to regularize on the norm of the parameters.
Regularizing training using the norm of the parameters.
Any function can be used as norm, as long as it takes the parameters as input
Any function can be used as norm, as long as it takes the parameters as input
and returns a scalar.
The function is applied to each parameter
The function is applied to each parameter
"""

def __init__(
self, reg_factor: float = 1e-2, norm_fn: Optional[Callable] = None
self,
reg_factor: float = 1e-2,
reg_schedule_fn: Optional[Callable] = None,
norm_fn: Optional[Callable] = None,
) -> None:
"""
Constructor of the regularizer class.
Expand All @@ -57,31 +62,29 @@ def __init__(
Function to compute the norm of the parameters.
If None, the default norm is the mean squared error.
"""
super().__init__(reg_factor)
super().__init__(reg_factor, reg_schedule_fn)

self.norm_fn = norm_fn
if self.norm_fn is None:
self.norm_fn = lambda x: np.mean(x**2)
def __call__(self, params: dict, **kwargs: dict) -> float:

def _calculate_regularization(self, params: dict, **kwargs: dict) -> float:
"""
Call function of the trace regularizer class.
Calculate the regularization contribution to the loss using the norm of the
Parameters
----------
params : dict
Parameters of the model.
kwargs : dict
Additional arguments.
Additional arguments.
Individual regularizers can define their own arguments.
Returns
-------
reg_loss : float
Loss contribution from the regularizer.
"""

param_vector = jax.flatten_util.ravel_pytree(params)[0]
reg_loss = self.reg_factor * self.norm_fn(param_vector)
return reg_loss

94 changes: 89 additions & 5 deletions znnl/regularizers/regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,124 @@
Summary
-------
"""
import logging
from abc import ABC
from typing import Callable, Optional

logger = logging.getLogger(__name__)


class Regularizer(ABC):
"""
Parent class for a regularizer. All regularizers should inherit from this class.
"""

def __init__(self, reg_factor) -> None:
def __init__(
self, reg_factor: float, reg_schedule_fn: Optional[Callable] = None
) -> None:
"""
Constructor of the regularizer class.
Parameters
----------
reg_factor : float
Regularization factor.
reg_schedule_fn : Optional[Callable]
Function to schedule the regularization factor.
The function takes the current epoch and the regularization factor
as input and returns the scheduled regularization factor (float).
An example function is:
def reg_schedule(epoch: int, reg_factor: float) -> float:
return reg_factor * 0.99 ** epoch
where the regularization factor is reduced by 1% each epoch.
The default is None, which means no scheduling is applied:
def reg_schedule(epoch: int, reg_factor: float) -> float:
return reg_factor
"""
self.reg_factor = reg_factor
self.reg_schedule_fn = reg_schedule_fn

def __call__(self, params: dict, **kwargs: dict) -> float:
if self.reg_schedule_fn:
logger.info(
"Setting a regularization schedule."
"The set regularization factor will be overwritten."
)
if not callable(self.reg_schedule_fn):
raise TypeError("Regularization schedule must be a Callable.")

if self.reg_schedule_fn is None:
self.reg_schedule_fn = self._schedule_fn_default

@staticmethod
def _schedule_fn_default(epoch: int, reg_factor: float) -> float:
"""
Call function of the regularizer class.
Default function for the regularization factor.
Parameters
----------
epoch : int
Current epoch.
reg_factor : float
Regularization factor.
Returns
-------
scheduled_reg_factor : float
Scheduled regularization factor.
"""
return reg_factor

def _calculate_regularization(self, params: dict, **kwargs: dict) -> float:
"""
Calculate the regularization contribution to the loss.
Parameters
----------
params : dict
Parameters of the model.
kwargs : dict
Additional arguments.
Individual regularizers can define their own arguments.
Additional arguments.
Individual regularizers can utilize arguments from the set:
apply_fn : Callable
Function to apply the model to inputs.
batch : dict
Batch of data.
epoch : int
Current epoch.
Returns
-------
reg_loss : float
Loss contribution from the regularizer.
"""
raise NotImplementedError

def __call__(
self, apply_fn: Callable, params: dict, batch: dict, epoch: int
) -> float:
"""
Call function of the regularizer class.
Parameters
----------
apply_fn : Callable
Function to apply the model to inputs.
params : dict
Parameters of the model.
batch : dict
Batch of data.
epoch : int
Current epoch.
Returns
-------
scaled_reg_loss : float
Scaled loss contribution from the regularizer.
"""
self.reg_factor = self.reg_schedule_fn(epoch, self.reg_factor)
return self.reg_factor * self._calculate_regularization(
apply_fn=apply_fn, params=params, batch=batch, epoch=epoch
)
26 changes: 17 additions & 9 deletions znnl/regularizers/trace_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,43 +25,51 @@
-------
Module containing the trace regularizer class.
"""
from znnl.regularizers.regularizer import Regularizer
from typing import Callable

import jax.flatten_util
import jax.tree_util

from znnl.regularizers.regularizer import Regularizer


class TraceRegularizer(Regularizer):
"""
Trace regularizer class.
Regularizing the loss of gradient based learning proportional to the trace of the
Regularizing the loss of gradient based learning proportional to the trace of the
NTK. As:
Trace(NTK) = sum_i (d f(x_i)/d theta)^2
the trace of the NTK is the sum of the squared gradients of the model, the trace
regularizer is equivalent to regularizing on the sum of the squared gradients of
the trace of the NTK is the sum of the squared gradients of the model, the trace
regularizer is equivalent to regularizing on the sum of the squared gradients of
the model.
"""

def __init__(self, reg_factor: float = 1e-1) -> None:
def __init__(
self, reg_factor: float = 1e-1, reg_schedule_fn: Callable = None
) -> None:
"""
Constructor of the trace regularizer class.
Parameters
----------
reg_factor : float
Regularization factor.
reg_schedule_fn : Callable
"""
super().__init__(reg_factor)

def __call__(self, apply_fn: Callable, params: dict, batch: dict) -> float:
super().__init__(reg_factor, reg_schedule_fn)

def _calculate_regularization(
self, apply_fn: Callable, params: dict, batch: dict, epoch: int
) -> float:
"""
Call function of the trace regularizer class.
Parameters
----------
apply_fn : Callable
Function to apply the model to inputs.
Function to apply the model to inputs.
params : dict
Parameters of the model.
batch : dict
Expand Down
2 changes: 2 additions & 0 deletions znnl/training_strategies/loss_aware_reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ def train_model(
train_losses = []
train_accuracy = []
for i in loading_bar:
self.epoch = i

# Update the recorder properties
if self.recorders is not None:
for item in self.recorders:
Expand Down
2 changes: 2 additions & 0 deletions znnl/training_strategies/partitioned_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ def train_model(
)

for i in loading_bar:
self.epoch = i

# Update the recorder properties
if self.recorders is not None:
for item in self.recorders:
Expand Down
Loading

0 comments on commit d975c1f

Please sign in to comment.