From 9d0adc965d19bf50150cb1e94e1bf94fab75582b Mon Sep 17 00:00:00 2001 From: "Guillaume \"Vermeille\" Sanchez" Date: Sat, 9 Jun 2018 17:34:17 +0200 Subject: [PATCH] Stop when we reach end_lr even if the loss did not diverge Stop when we reach end_lr even if the loss did not diverge. In some cases, the loss will not increase and diverge, so make sure we stop at end_lr (as the user asked anyway) --- keras_lr_finder/lr_finder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_lr_finder/lr_finder.py b/keras_lr_finder/lr_finder.py index 8c1cf25..6840fe0 100644 --- a/keras_lr_finder/lr_finder.py +++ b/keras_lr_finder/lr_finder.py @@ -26,7 +26,7 @@ def on_batch_end(self, batch, logs): self.losses.append(loss) # Check whether the loss got too large or NaN - if math.isnan(loss) or loss > self.best_loss * 4: + if lr >= self.end_lr or math.isnan(loss) or loss > self.best_loss * 4: self.model.stop_training = True return @@ -40,6 +40,7 @@ def on_batch_end(self, batch, logs): def find(self, x_train, y_train, start_lr, end_lr, batch_size=64, epochs=1): num_batches = epochs * x_train.shape[0] / batch_size self.lr_mult = (end_lr / start_lr) ** (1 / num_batches) + self.end_lr = end_lr # Save weights into a file self.model.save_weights('tmp.h5')