Skip to content

Commit

Permalink
Working on calculating loss derivatives to calculate loss ntk comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
ma-sauter committed Jan 30, 2024
1 parent e141dd8 commit 22cfc90
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 13 deletions.
65 changes: 54 additions & 11 deletions CI/unit_tests/analysis/test_loss_ntk_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,41 @@
import pytest

from znnl.analysis import loss_ntk_calculation
from znnl.distance_metrics import LPNorm
from znnl.training_recording import JaxRecorder
from znnl.models import NTModel
from znnl.loss_functions import LPNormLoss
from znnl.analysis import LossDerivative
from znnl.models import FlaxModel
from znnl.data import MNISTGenerator
from flax import linen as nn

from neural_tangents import stax

import optax


# Defines a simple CNN module
class ProductionModule(nn.Module):
"""
Simple CNN module.
"""

@nn.compact
def __call__(self, x):
x = nn.Conv(features=128, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
x = nn.Conv(features=128, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=300)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)

return x


class TestLossNTKCalculation:
"""
Test Suite for the loss NTK calculation module.
Expand All @@ -59,14 +86,12 @@ def test_loss_ntk_calculation(self):
)

# Define a test model
fuel_model = NTModel(
nt_module=dense_network,
optimizer=optax.adam(learning_rate=0.005),
input_shape=(9,),
production_model = FlaxModel(
flax_module=ProductionModule(),
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 28, 28, 1),
trace_axes=(),
batch_size=314,
)

# Initialize model parameters

data_generator = MNISTGenerator(ds_size=10)
Expand All @@ -77,14 +102,32 @@ def test_loss_ntk_calculation(self):

# Initialize the loss NTK calculation
loss_ntk_calculator = loss_ntk_calculation(
metric_fn=lambda x, y: (x - y) ** 2,
model=fuel_model,
metric_fn=LPNorm(order=2),
model=production_model,
dataset=data_set,
)

# Compute the loss NTK
ntk = loss_ntk_calculator.compute_loss_ntk(x_i=data_set, model=fuel_model)
print(ntk.shape)
loss_ntk = loss_ntk_calculator.compute_loss_ntk(
x_i=data_set, model=production_model
)["empirical"]

# Now for comparison calculate regular ntk
ntk = production_model.compute_ntk(data_set["inputs"], infinite=False)[
"empirical"
]
# Calculate Loss derivative fn
loss_derivative_calculator = LossDerivative(LPNormLoss(order=2))
# predictions calculation analogous to the one in jax recording
predictions = production_model(data_set["inputs"])
if type(predictions) is tuple:
predictions = predictions[0]
# calculation of loss derivatives
loss_derivatives = loss_derivative_calculator.calculate(
predictions=predictions,
targets=data_set["targets"],
)
print(loss_derivatives.shape)


TestLossNTKCalculation().test_loss_ntk_calculation()
10 changes: 8 additions & 2 deletions znnl/analysis/loss_ntk_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import jax.numpy as np
import numpy


class loss_ntk_calculation:
def __init__(
self,
Expand Down Expand Up @@ -75,8 +76,13 @@ def _function_for_loss_ntk(self, params, datapoint) -> float:
Seems like during the NTK calculation, this function needs to handle
the whole dataset at once instead of just one datapoint.
"""
_input = datapoint[:, : self.input_dimension]
_target = datapoint[:, self.input_dimension :]
batch_length = datapoint.shape[0]
_input = datapoint[:, : self.input_dimension].reshape(
batch_length, *self.input_shape[1:]
)
_target = datapoint[:, self.input_dimension :].reshape(
batch_length, *self.target_shape[1:]
)
return self.metric_fn(
self.apply_fn(params, _input),
_target,
Expand Down

0 comments on commit 22cfc90

Please sign in to comment.