From c51fdd33381f60b839533f65002257e374ae723f Mon Sep 17 00:00:00 2001 From: Mikael Figueroa Date: Fri, 31 Jan 2020 11:24:13 -0800 Subject: [PATCH 1/5] Add Adamod implementation to optimizers. --- keras_contrib/optimizers/__init__.py | 2 + keras_contrib/optimizers/adamod.py | 93 ++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 keras_contrib/optimizers/adamod.py diff --git a/keras_contrib/optimizers/__init__.py b/keras_contrib/optimizers/__init__.py index bdf6a7bf4..f0b4dce91 100644 --- a/keras_contrib/optimizers/__init__.py +++ b/keras_contrib/optimizers/__init__.py @@ -2,7 +2,9 @@ from .padam import Padam from .yogi import Yogi from .lars import LARS +from .adamod import Adamod # aliases ftml = FTML lars = LARS +adamod = Adamod diff --git a/keras_contrib/optimizers/adamod.py b/keras_contrib/optimizers/adamod.py new file mode 100644 index 000000000..832ff798e --- /dev/null +++ b/keras_contrib/optimizers/adamod.py @@ -0,0 +1,93 @@ +from keras import backend as K +from keras.optimizers import Optimizer + + +class Adamod(Optimizer): + """Adamod optimizer. + + Default parameters follow those provided in the original paper. + + # Arguments + learning_rate: float >= 0. Learning rate. + beta_1: float, 0 < beta < 1. Generally close to 1. + beta_2: float, 0 < beta < 1. Generally close to 1. + beta_3: float, 0 < beta < 1. Generally close to 1. + + # References + - [An Adaptive and Momental Bound Method for Stochastic Learning]( + https://arxiv.org/abs/1910.12249) + """ + + def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, beta_3=0.999, + **kwargs): + self.initial_decay = kwargs.pop('decay', 0.0) + self.epsilon = kwargs.pop('epsilon', K.epsilon()) + learning_rate = kwargs.pop('lr', learning_rate) + super(Adamod, self).__init__(**kwargs) + with K.name_scope(self.__class__.__name__): + self.iterations = K.variable(0, dtype='int64', name='iterations') + self.learning_rate = K.variable(learning_rate, name='learning_rate') + self.beta_1 = K.variable(beta_1, name='beta_1') + self.beta_2 = K.variable(beta_2, name='beta_2') + self.beta_3 = K.variable(beta_3, name='beta_3') + self.decay = K.variable(self.initial_decay, name='decay') + + @interfaces.legacy_get_updates_support + @K.symbolic + def get_updates(self, loss, params): + grads = self.get_gradients(loss, params) + self.updates = [K.update_add(self.iterations, 1)] + + lr = self.learning_rate + if self.initial_decay > 0: + lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) + + t = K.cast(self.iterations, K.floatx()) + 1 + lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) + m_bias_correction = 1.0 / (1. - K.pow(self.beta_1, t)) + + ms = [K.zeros(K.int_shape(p), + dtype=K.dtype(p), + name='m_' + str(i)) + for (i, p) in enumerate(params)] + vs = [K.zeros(K.int_shape(p), + dtype=K.dtype(p), + name='v_' + str(i)) + for (i, p) in enumerate(params)] + ss = [K.zeros(K.int_shape(p), + dtype=K.dtype(p), + name='s_' + str(i)) + for (i, p) in enumerate(params)] + + self.weights = [self.iterations] + ms + vs + ss + + for p, g, m, v, s in zip(params, grads, ms, vs, ss): + m_t = (self.beta_1 * m) + (1. - self.beta_1) * g + v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) + n_t = lr_t / (K.sqrt(v_t) + self.epsilon) + s_t = (self.beta_3 * s) + (1. - self.beta_3) * n_t + nhat_t = K.min(n_t, s_t) + p_t = p - nhat_t * m_t * m_bias_correction + + self.updates.append(K.update(m, m_t)) + self.updates.append(K.update(v, v_t)) + self.updates.append(K.update(s, s_t)) + new_p = p_t + + # Apply constraints. + if getattr(p, 'constraint', None) is not None: + new_p = p.constraint(new_p) + + self.updates.append(K.update(p, new_p)) + return self.updates + + def get_config(self): + config = {'learning_rate': float(K.get_value(self.learning_rate)), + 'beta_1': float(K.get_value(self.beta_1)), + 'beta_2': float(K.get_value(self.beta_2)), + 'beta_3': float(K.get_value(self.beta_3)), + 'decay': float(K.get_value(self.decay)), + 'epsilon': self.epsilon} + base_config = super(Adamod, self).get_config() + return dict(list(base_config.items()) + list(config.items())) From 2a167875425939c0aa3101b1f3dcf74408be50c3 Mon Sep 17 00:00:00 2001 From: Mikael Figueroa Date: Fri, 31 Jan 2020 13:41:52 -0800 Subject: [PATCH 2/5] Add optimizer test for Adamod. --- tests/keras_contrib/optimizers/adamod_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/keras_contrib/optimizers/adamod_test.py diff --git a/tests/keras_contrib/optimizers/adamod_test.py b/tests/keras_contrib/optimizers/adamod_test.py new file mode 100644 index 000000000..f09592edb --- /dev/null +++ b/tests/keras_contrib/optimizers/adamod_test.py @@ -0,0 +1,8 @@ +from __future__ import print_function +from keras_contrib.tests import optimizers +from keras_contrib.optimizers import Adamod + + +def test_adamod(): + optimiers._test_optimizer(Adamod()) + optimizers._test_optimizer(Adamod(beta_3=0.9999)) From 30fa49899fd63f9a281d0a5fe9c3cc94d2e9aed8 Mon Sep 17 00:00:00 2001 From: Mikael Figueroa Date: Fri, 31 Jan 2020 14:25:15 -0800 Subject: [PATCH 3/5] Fix Adamod tests and bug with K.min vs K.minimum. --- keras_contrib/optimizers/adamod.py | 4 +--- tests/keras_contrib/optimizers/adamod_test.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/keras_contrib/optimizers/adamod.py b/keras_contrib/optimizers/adamod.py index 832ff798e..38b3e85aa 100644 --- a/keras_contrib/optimizers/adamod.py +++ b/keras_contrib/optimizers/adamod.py @@ -32,8 +32,6 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, beta_3=0.999, self.beta_3 = K.variable(beta_3, name='beta_3') self.decay = K.variable(self.initial_decay, name='decay') - @interfaces.legacy_get_updates_support - @K.symbolic def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] @@ -67,7 +65,7 @@ def get_updates(self, loss, params): v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) n_t = lr_t / (K.sqrt(v_t) + self.epsilon) s_t = (self.beta_3 * s) + (1. - self.beta_3) * n_t - nhat_t = K.min(n_t, s_t) + nhat_t = K.minimum(n_t, s_t) p_t = p - nhat_t * m_t * m_bias_correction self.updates.append(K.update(m, m_t)) diff --git a/tests/keras_contrib/optimizers/adamod_test.py b/tests/keras_contrib/optimizers/adamod_test.py index f09592edb..2288c4484 100644 --- a/tests/keras_contrib/optimizers/adamod_test.py +++ b/tests/keras_contrib/optimizers/adamod_test.py @@ -4,5 +4,5 @@ def test_adamod(): - optimiers._test_optimizer(Adamod()) + optimizers._test_optimizer(Adamod()) optimizers._test_optimizer(Adamod(beta_3=0.9999)) From 418cef0800d5e112a2589936d8be0dc3f97ad148 Mon Sep 17 00:00:00 2001 From: Mikael Figueroa Date: Fri, 31 Jan 2020 14:44:05 -0800 Subject: [PATCH 4/5] Fix optimizers test key error. --- keras_contrib/tests/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_contrib/tests/optimizers.py b/keras_contrib/tests/optimizers.py index 2ef7d88b8..3c637d155 100644 --- a/keras_contrib/tests/optimizers.py +++ b/keras_contrib/tests/optimizers.py @@ -35,7 +35,7 @@ def _test_optimizer(optimizer, target=0.75): optimizer=optimizer, metrics=['accuracy']) history = model.fit(x_train, y_train, epochs=2, batch_size=16, verbose=0) - assert history.history['acc'][-1] >= target + assert history.history['accuracy'][-1] >= target config = optimizers.serialize(optimizer) custom_objects = {optimizer.__class__.__name__: optimizer.__class__} optim = optimizers.deserialize(config, custom_objects) From c030b52e1b7c4bfe57e01b3d1519d02752e58cf8 Mon Sep 17 00:00:00 2001 From: Mikael Figueroa Date: Fri, 31 Jan 2020 15:40:26 -0800 Subject: [PATCH 5/5] Try fixing optimizers test again. --- keras_contrib/tests/optimizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_contrib/tests/optimizers.py b/keras_contrib/tests/optimizers.py index 3c637d155..7449bb2f5 100644 --- a/keras_contrib/tests/optimizers.py +++ b/keras_contrib/tests/optimizers.py @@ -33,9 +33,9 @@ def _test_optimizer(optimizer, target=0.75): model = get_model(x_train.shape[1], 10, y_train.shape[1]) model.compile(loss='categorical_crossentropy', optimizer=optimizer, - metrics=['accuracy']) + metrics=['acc']) history = model.fit(x_train, y_train, epochs=2, batch_size=16, verbose=0) - assert history.history['accuracy'][-1] >= target + assert history.history['acc'][-1] >= target config = optimizers.serialize(optimizer) custom_objects = {optimizer.__class__.__name__: optimizer.__class__} optim = optimizers.deserialize(config, custom_objects)