Skip to content

Commit

Permalink
fix bug in stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Jul 26, 2024
1 parent 6498f0f commit 5721909
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions n3fit/src/n3fit/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
which will tell `Validation` that no validation set was found and that the training is to
be used instead.
"""

import logging

import numpy as np
Expand Down Expand Up @@ -345,6 +346,8 @@ def __init__(
self._threshold_chi2 = threshold_chi2
self._stopping_degrees = np.zeros(self._n_replicas, dtype=int)
self._counts = np.zeros(self._n_replicas, dtype=int)
# Keep track of the replicas that should not be stopped yet
self._dont_stop_me_now = np.ones(self._n_replicas, dtype=bool)

self._dont_stop = dont_stop
self._stop_now = False
Expand Down Expand Up @@ -451,6 +454,8 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
passes &= fitstate.vl_loss < self._best_val_chi2s
# And the ones that pass positivity
passes &= self._positivity(fitstate)
# Stop replicas that are ok being stopped (because they are finished or otherwise)
passes &= self._dont_stop_me_now

self._stopping_degrees += self._counts

Expand All @@ -470,6 +475,7 @@ def monitor_chi2(self, training_info, epoch, print_stats=False):
for i_replica in np.where(stop_replicas)[0]:
self._stop_epochs[i_replica] = epoch
self._counts[i_replica] = 0
self._dont_stop_me_now[i_replica] = False

# By using the stopping degree we only stop when none of the replicas are improving anymore
if min(self._stopping_degrees) > self.stopping_patience:
Expand Down

0 comments on commit 5721909

Please sign in to comment.