diff --git a/metric_learn/scml.py b/metric_learn/scml.py index 2bdd0d57..412fdcec 100644 --- a/metric_learn/scml.py +++ b/metric_learn/scml.py @@ -23,7 +23,8 @@ class _BaseSCML(MahalanobisMixin): def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, - verbose=False, preprocessor=None, random_state=None): + verbose=False, preprocessor=None, random_state=None, + warm_start=False): self.beta = beta self.basis = basis self.n_basis = n_basis @@ -34,6 +35,7 @@ def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None, self.verbose = verbose self.preprocessor = preprocessor self.random_state = random_state + self.warm_start = warm_start super(_BaseSCML, self).__init__(preprocessor) def _fit(self, triplets, basis=None, n_basis=None): @@ -74,13 +76,14 @@ def _fit(self, triplets, basis=None, n_basis=None): n_triplets = triplets.shape[0] - # weight vector - w = np.zeros((1, n_basis)) - # avarage obj gradient wrt weights - avg_grad_w = np.zeros((1, n_basis)) + if not self.warm_start or not hasattr(self, "w_"): + # weight vector + self.w_ = np.zeros((1, n_basis)) + # average obj gradient wrt weights + self.avg_grad_w_ = np.zeros((1, n_basis)) + # l2 norm in time of all obj gradients wrt weights + self.ada_grad_w_ = np.zeros((1, n_basis)) - # l2 norm in time of all obj gradients wrt weights - ada_grad_w = np.zeros((1, n_basis)) # slack for not dividing by zero delta = 0.001 @@ -93,27 +96,29 @@ def _fit(self, triplets, basis=None, n_basis=None): idx = rand_int[iter] - slack_val = 1 + np.matmul(dist_diff[idx, :], w.T) + slack_val = 1 + np.matmul(dist_diff[idx, :], self.w_.T) slack_mask = np.squeeze(slack_val > 0, axis=1) grad_w = np.sum(dist_diff[idx[slack_mask], :], axis=0, keepdims=True)/self.batch_size - avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1) - ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w)) + self.avg_grad_w_ = (iter * self.avg_grad_w_ + grad_w) / (iter+1) - scale_f = -(iter+1) / (self.gamma * (delta + ada_grad_w)) + self.ada_grad_w_ = np.sqrt(np.square(self.ada_grad_w_) + + np.square(grad_w)) + + scale_f = -(iter+1) / (self.gamma * (delta + self.ada_grad_w_)) # proximal operator with negative trimming equivalent - w = scale_f * np.minimum(avg_grad_w + self.beta, 0) + self.w_ = scale_f * np.minimum(self.avg_grad_w_ + self.beta, 0) if (iter + 1) % self.output_iter == 0: # regularization part of obj function - obj1 = np.sum(w)*self.beta + obj1 = np.sum(self.w_)*self.beta # Every triplet distance difference in the space given by L # plus a slack of one - slack_val = 1 + np.matmul(dist_diff, w.T) + slack_val = 1 + np.matmul(dist_diff, self.w_.T) # Mask of places with positive slack slack_mask = slack_val > 0 @@ -129,7 +134,7 @@ def _fit(self, triplets, basis=None, n_basis=None): # update the best if obj < best_obj: best_obj = obj - best_w = w + best_w = self.w_ if self.verbose: print("max iteration reached.") @@ -355,6 +360,13 @@ class SCML(_BaseSCML, _TripletsClassifierMixin): random_state : int or numpy.RandomState or None, optional (default=None) A pseudo random number generator object or a seed for it if int. + warm_start : bool, default=False + When set to True, reuse the solution of the previous call to fit as + initialization, otherwise, just erase the previous solution. + Repeatedly calling fit when warm_start is True can result in a different + solution than when calling fit a single time because of the way the data + is shuffled. + Attributes ---------- components_ : `numpy.ndarray`, shape=(n_features, n_features) @@ -465,6 +477,13 @@ class SCML_Supervised(_BaseSCML, TransformerMixin): random_state : int or numpy.RandomState or None, optional (default=None) A pseudo random number generator object or a seed for it if int. + warm_start : bool, default=False + When set to True, reuse the solution of the previous call to fit as + initialization, otherwise, just erase the previous solution. + Repeatedly calling fit when warm_start is True can result in a different + solution than when calling fit a single time because of the way the data + is shuffled. + Attributes ---------- components_ : `numpy.ndarray`, shape=(n_features, n_features) @@ -506,13 +525,14 @@ class SCML_Supervised(_BaseSCML, TransformerMixin): def __init__(self, k_genuine=3, k_impostor=10, beta=1e-5, basis='lda', n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, verbose=False, preprocessor=None, - random_state=None): + random_state=None, warm_start=False): self.k_genuine = k_genuine self.k_impostor = k_impostor _BaseSCML.__init__(self, beta=beta, basis=basis, n_basis=n_basis, max_iter=max_iter, output_iter=output_iter, batch_size=batch_size, verbose=verbose, - preprocessor=preprocessor, random_state=random_state) + preprocessor=preprocessor, random_state=random_state, + warm_start=warm_start) def fit(self, X, y): """Create constraints from labels and learn the SCML model. diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index fe1560c2..8ab8e7b4 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -8,7 +8,7 @@ from sklearn.datasets import (load_iris, make_classification, make_regression, make_spd_matrix) from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose) + assert_allclose, assert_raises) from sklearn.exceptions import ConvergenceWarning from sklearn.utils.validation import check_X_y from sklearn.preprocessing import StandardScaler @@ -323,6 +323,37 @@ def test_large_output_iter(self): scml.fit(triplets) assert msg == raised_error.value.args[0] + @pytest.mark.parametrize("basis", ("lda", "triplet_diffs")) + def test_warm_start(self, basis): + X, y = load_iris(return_X_y=True) + # Test that warm_start=True leads to different weights in each fit call + scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5, + random_state=42, warm_start=True) + scml.fit(X, y) + w_1 = scml.w_ + avg_grad_w_1 = scml.avg_grad_w_ + ada_grad_w_1 = scml.ada_grad_w_ + scml.fit(X, y) + w_2 = scml.w_ + assert_raises(AssertionError, assert_array_almost_equal, w_1, w_2) + # And that default warm_start value is False and leads to same + # weights in each fit call + scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5, + random_state=42) + scml.fit(X, y) + w_3 = scml.w_ + scml.fit(X, y) + w_4 = scml.w_ + assert_array_almost_equal(w_3, w_4) + # But would lead to same results with warm_strat=True if same init params + # were used + scml.warm_start = True + scml.w_ = w_1 + scml.avg_grad_w_ = avg_grad_w_1 + scml.ada_grad_w_ = ada_grad_w_1 + scml.fit(X, y) + w_5 = scml.w_ + assert_array_almost_equal(w_2, w_5) class TestLSML(MetricTestCase): def test_iris(self):