diff --git a/keras_lr_finder/lr_finder.py b/keras_lr_finder/lr_finder.py index 8c1cf25..6840fe0 100644 --- a/keras_lr_finder/lr_finder.py +++ b/keras_lr_finder/lr_finder.py @@ -26,7 +26,7 @@ def on_batch_end(self, batch, logs): self.losses.append(loss) # Check whether the loss got too large or NaN - if math.isnan(loss) or loss > self.best_loss * 4: + if lr >= self.end_lr or math.isnan(loss) or loss > self.best_loss * 4: self.model.stop_training = True return @@ -40,6 +40,7 @@ def on_batch_end(self, batch, logs): def find(self, x_train, y_train, start_lr, end_lr, batch_size=64, epochs=1): num_batches = epochs * x_train.shape[0] / batch_size self.lr_mult = (end_lr / start_lr) ** (1 / num_batches) + self.end_lr = end_lr # Save weights into a file self.model.save_weights('tmp.h5')