diff --git a/README.md b/README.md index 82f7663..b4a151f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # keras_lr_finder Plots the change of the loss function of a Keras model when the learning rate is exponentially increasing. +Will also calculate the best learning rate. ## Purpose See details in ["Estimating an Optimal Learning Rate For a Deep Neural Network"](https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0). @@ -32,6 +33,42 @@ lr_finder.plot_loss_change(sma=20, n_skip_beginning=20, n_skip_end=5, y_lim=(-0. ![Rate of change of the loss function](https://cdn-images-1.medium.com/max/1600/1*87mKq_XomYyJE29l91K0dw.png) +Once the finder has picked your best learning rate, update your model to use it: +```python +# Set the learning rate of your model to the newly found one +import keras.backend as K +new_lr = lr_finder.get_best_lr(sma=4) +K.set_value(model.optimizer.lr, new_lr) +``` +You can wrap this up nicely in a `LambdaCallback`, so that you periodically update your learning rate: + +```python +from keras.callbacks import LambdaCallback +def find_lr(epoch, logs): + # You may also make it more effective by only + # running this if the loss has stopped improving a la ReduceLROnPlateau + if epoch % 30 == 0: + lrf = LRFinder(model) + lrf.find(x_train,y_train, start_lr=0.0001, end_lr=1,batch_size=512,epochs=5) + new_lr = lrf.get_best_lr(4) + K.set_value(model.optimizer.lr, new_lr) + +lcb = LambdaCallback(on_epoch_end=find_lr) +model.train(callbacks=[lcb],...) +``` +### Use With Generator + +This library call also be used with generators (where `num_samples` is the total number of training samples in your training set): + +```python +lrf = LRFinder(model) +lrf.find_generator(train_gen, + start_lr=0.0001, + end_lr=1, + epochs=5, + steps_per_epoch=num_samples // batch_size) +``` + ## Contributions Contributions are welcome. Please, file issues and submit pull requests on GitHub, or contact me directly. diff --git a/keras_lr_finder/lr_finder.py b/keras_lr_finder/lr_finder.py index 558e82e..e13f396 100644 --- a/keras_lr_finder/lr_finder.py +++ b/keras_lr_finder/lr_finder.py @@ -2,6 +2,7 @@ import math from keras.callbacks import LambdaCallback import keras.backend as K +import numpy as np class LRFinder: @@ -26,7 +27,7 @@ def on_batch_end(self, batch, logs): self.losses.append(loss) # Check whether the loss got too large or NaN - if math.isnan(loss) or loss > self.best_loss * 4: + if batch > 5 and (math.isnan(loss) or loss > self.best_loss * 4): self.model.stop_training = True return @@ -39,7 +40,7 @@ def on_batch_end(self, batch, logs): def find(self, x_train, y_train, start_lr, end_lr, batch_size=64, epochs=1): num_batches = epochs * x_train.shape[0] / batch_size - self.lr_mult = (float(end_lr) / float(start_lr)) ** (float(1) / float(num_batches)) + self.lr_mult = (float(end_lr) / float(start_lr)) ** (1.0 / float(num_batches)) # Save weights into a file self.model.save_weights('tmp.h5') @@ -72,7 +73,7 @@ def find_generator(self, generator, start_lr, end_lr, epochs=1, steps_per_epoch= '`keras.utils.Sequence`' ' class. Please specify `steps_per_epoch` ' 'or use the `keras.utils.Sequence` class.') - self.lr_mult = (float(end_lr) / float(start_lr)) ** (float(1) / float(steps_per_epoch)) + self.lr_mult = (float(end_lr) / float(start_lr)) ** (1.0 / float(steps_per_epoch) * float(epochs)) # Save weights into a file self.model.save_weights('tmp.h5') @@ -119,14 +120,22 @@ def plot_loss_change(self, sma=1, n_skip_beginning=10, n_skip_end=5, y_lim=(-0.0 n_skip_end - number of batches to skip on the right. y_lim - limits for the y axis. """ - assert sma >= 1 - derivatives = [0] * sma - for i in range(sma, len(self.lrs)): - derivative = (self.losses[i] - self.losses[i - sma]) / sma - derivatives.append(derivative) - + derivatives = self.get_derivatives(sma)[n_skip_beginning:-n_skip_end] + lrs = self.lrs[n_skip_beginning:-n_skip_end] plt.ylabel("rate of loss change") plt.xlabel("learning rate (log scale)") - plt.plot(self.lrs[n_skip_beginning:-n_skip_end], derivatives[n_skip_beginning:-n_skip_end]) + plt.plot(lrs, derivatives) plt.xscale('log') plt.ylim(y_lim) + + def get_derivatives(self, sma): + assert sma >= 1 + derivatives = [0] * sma + for i in range(sma, len(self.lrs)): + derivatives.append((self.losses[i] - self.losses[i - sma]) / sma) + return derivatives + + def get_best_lr(self, sma, n_skip_beginning=10, n_skip_end=5): + derivatives = self.get_derivatives(sma) + best_der_idx = np.argmax(derivatives[n_skip_beginning:-n_skip_end])[0] + return self.lrs[n_skip_beginning:-n_skip_end][best_der_idx]