Skip to content

Commit

Permalink
Adapt the trace opt to more flexible ntk calculation.
Browse files Browse the repository at this point in the history
This allows for subsampling the ntk.
  • Loading branch information
knikolaou committed May 17, 2024
1 parent 59b8bee commit fd80fc1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions CI/unit_tests/optimizers/test_trace_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_apply_operation(self):
ntk = ntk_computation.compute_ntk(
{"params": model.model_state.params}, data.train_ds["inputs"]
)
ntk = np.array(ntk).mean(axis=0)
expected_lr = scale_factor / np.trace(ntk)

# Compute actual values
Expand Down
1 change: 1 addition & 0 deletions znnl/optimizers/trace_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def apply_optimizer(
if epoch % self.rescale_interval == 0:
# Compute the ntk trace.
ntk = ntk_fn({"params": model_state.params}, data_set)
ntk = np.array(ntk).mean(axis=0)
trace = np.trace(ntk)

# Create the new optimizer.
Expand Down

0 comments on commit fd80fc1

Please sign in to comment.