Skip to content

Commit

Permalink
API: SVMs: eps -> tol
Browse files Browse the repository at this point in the history
Change the eps argument to 'tol' and expose it in base classes.

The reasonning behind calling this argument tol rather than eps is that
it is a bound on the tolerance of the optimization, and it does not
relate to the machine precision (eps).
  • Loading branch information
GaelVaroquaux committed Mar 1, 2011
1 parent 4820ae7 commit 09940a9
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 55 deletions.
14 changes: 7 additions & 7 deletions doc/modules/svm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ training samples::
>>> Y = [0, 1]
>>> clf = svm.SVC()
>>> clf.fit(X, Y)
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, eps=0.001,
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, tol=0.001,
cache_size=100.0, shrinking=True, gamma=0.5)

After being fitted, the model can then be used to predict new values::
Expand Down Expand Up @@ -110,7 +110,7 @@ classifiers are constructed and each one trains data from two classes.
>>> Y = [0, 1, 2, 3]
>>> clf = svm.SVC()
>>> clf.fit(X, Y)
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, eps=0.001,
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, tol=0.001,
cache_size=100.0, shrinking=True, gamma=0.25)
>>> dec = clf.decision_function([[1]])
>>> dec.shape[1] # 4 classes: 4*3/2 = 6
Expand All @@ -124,8 +124,8 @@ two classes, only one model is trained.

>>> lin_clf = svm.LinearSVC()
>>> lin_clf.fit(X, Y)
LinearSVC(loss='l2', C=1.0, intercept_scaling=1, fit_intercept=True,
eps=0.0001, penalty='l2', multi_class=False, dual=True)
LinearSVC(loss='l2', C=1.0, dual=True, fit_intercept=True, penalty='l2',
multi_class=False, tol=0.0001, intercept_scaling=1)
>>> dec = lin_clf.decision_function([[1]])
>>> dec.shape[1]
4
Expand Down Expand Up @@ -169,8 +169,8 @@ floating point values instead of integer values.
>>> y = [0.5, 2.5]
>>> clf = svm.SVR()
>>> clf.fit(X, y)
SVR(kernel='rbf', C=1.0, probability=False, degree=3, shrinking=True,
eps=0.001, p=0.1, cache_size=100.0, coef0=0.0, nu=0.5, gamma=0.5)
SVR(kernel='rbf', C=1.0, probability=False, degree=3, shrinking=True, p=0.1,
tol=0.001, cache_size=100.0, coef0=0.0, nu=0.5, gamma=0.5)
>>> clf.predict([[1, 1]])
array([ 1.5])

Expand Down Expand Up @@ -270,7 +270,7 @@ Tips on Practical Use
* The underlying :class:`LinearSVC` implementation uses a random
number generator to select features when fitting the model. It is
thus not uncommon, to have slightly different results for the same
input data. If that happens, try with a smaller eps parameter.
input data. If that happens, try with a smaller tol parameter.


.. _svm_kernels:
Expand Down
11 changes: 7 additions & 4 deletions scikits/learn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class LogisticRegression(BaseLibLinear, ClassifierMixin,
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased
tol: float, optional
tolerance for stopping criteria
Attributes
----------
Expand All @@ -60,19 +63,19 @@ class LogisticRegression(BaseLibLinear, ClassifierMixin,
The underlying C implementation uses a random number generator to
select features when fitting the model. It is thus not uncommon,
to have slightly different results for the same input data. If
that happens, try with a smaller eps parameter.
that happens, try with a smaller tol parameter.
References
----------
LIBLINEAR -- A Library for Large Linear Classification
http://www.csie.ntu.edu.tw/~cjlin/liblinear/
"""

def __init__(self, penalty='l2', dual=False, eps=1e-4, C=1.0,
def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0,
fit_intercept=True, intercept_scaling=1):

super(LogisticRegression, self).__init__ (penalty=penalty,
dual=dual, loss='lr', eps=eps, C=C,
dual=dual, loss='lr', tol=tol, C=C,
fit_intercept=fit_intercept, intercept_scaling=intercept_scaling)

def predict_proba(self, X):
Expand All @@ -96,7 +99,7 @@ def predict_proba(self, X):
X = np.asanyarray(X, dtype=np.float64, order='C')
probas = _liblinear.predict_prob_wrap(X, self.raw_coef_,
self._get_solver_type(),
self.eps, self.C,
self.tol, self.C,
self.class_weight_label,
self.class_weight, self.label_,
self._get_bias())
Expand Down
11 changes: 7 additions & 4 deletions scikits/learn/linear_model/sparse/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LogisticRegression(SparseBaseLibLinear, ClassifierMixin,
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased
tol: float, optional
tolerance for stopping criteria
Attributes
----------
Expand All @@ -67,19 +70,19 @@ class LogisticRegression(SparseBaseLibLinear, ClassifierMixin,
The underlying C implementation uses a random number generator to
select features when fitting the model. It is thus not uncommon,
to have slightly different results for the same input data. If
that happens, try with a smaller eps parameter.
that happens, try with a smaller tol parameter.
References
----------
LIBLINEAR -- A Library for Large Linear Classification
http://www.csie.ntu.edu.tw/~cjlin/liblinear/
"""

def __init__(self, penalty='l2', dual=False, eps=1e-4, C=1.0,
def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0,
fit_intercept=True, intercept_scaling=1):

super(LogisticRegression, self).__init__ (penalty=penalty,
dual=dual, loss='lr', eps=eps, C=C,
dual=dual, loss='lr', tol=tol, C=C,
fit_intercept=fit_intercept, intercept_scaling=intercept_scaling)

def predict_proba(self, X):
Expand All @@ -95,7 +98,7 @@ def predict_proba(self, X):
probas = csr_predict_prob(X.shape[1], X.data, X.indices,
X.indptr, self.raw_coef_,
self._get_solver_type(),
self.eps, self.C,
self.tol, self.C,
self.class_weight_label,
self.class_weight, self.label_,
self._get_bias())
Expand Down
2 changes: 1 addition & 1 deletion scikits/learn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Pipeline(BaseEstimator):
>>> # For instance, fit using a k of 10 in the SelectKBest
>>> # and a parameter 'C' of the svn
>>> anova_svm.fit(X, y, anova__k=10, svc__C=.1) #doctest: +ELLIPSIS
Pipeline(steps=[('anova', SelectKBest(k=10, score_func=<function f_regression at ...>)), ('svc', SVC(kernel='linear', C=0.1, probability=False, degree=3, coef0=0.0, eps=0.001,
Pipeline(steps=[('anova', SelectKBest(k=10, score_func=<function f_regression at ...>)), ('svc', SVC(kernel='linear', C=0.1, probability=False, degree=3, coef0=0.0, tol=0.001,
cache_size=100.0, shrinking=True, gamma=0.0))])
>>> prediction = anova_svm.predict(X)
Expand Down
22 changes: 11 additions & 11 deletions scikits/learn/svm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BaseLibSVM(BaseEstimator):
_svm_types = ['c_svc', 'nu_svc', 'one_class', 'epsilon_svr', 'nu_svr']

def __init__(self, impl, kernel, degree, gamma, coef0, cache_size,
eps, C, nu, p, shrinking, probability):
tol, C, nu, p, shrinking, probability):

if not impl in self._svm_types:
raise ValueError("impl should be one of %s, %s was given" % (
Expand All @@ -53,7 +53,7 @@ def __init__(self, impl, kernel, degree, gamma, coef0, cache_size,
self.gamma = gamma
self.coef0 = coef0
self.cache_size = cache_size
self.eps = eps
self.tol = tol
self.C = C
self.nu = nu
self.p = p
Expand Down Expand Up @@ -143,7 +143,7 @@ def fit(self, X, y, class_weight={}, sample_weight=[], **params):
self.dual_coef_, self.intercept_, self.label_, self.probA_, \
self.probB_ = \
libsvm_train(_X, y, solver_type, kernel_type, self.degree,
self.gamma, self.coef0, self.eps, self.C,
self.gamma, self.coef0, self.tol, self.C,
self.nu, self.cache_size, self.p,
self.class_weight_label, self.class_weight,
sample_weight, int(self.shrinking),
Expand Down Expand Up @@ -185,7 +185,7 @@ def predict(self, X):
return libsvm_predict(X, self.support_vectors_,
self.dual_coef_, self.intercept_,
self._svm_types.index(self.impl), kernel_type,
self.degree, self.gamma, self.coef0, self.eps,
self.degree, self.gamma, self.coef0, self.tol,
self.C, self.class_weight_label,
self.class_weight, self.nu, self.cache_size,
self.p, int(self.shrinking),
Expand Down Expand Up @@ -225,7 +225,7 @@ def predict_proba(self, T):
pprob = libsvm_predict_proba(T, self.support_vectors_,
self.dual_coef_, self.intercept_,
self._svm_types.index(self.impl), kernel_type,
self.degree, self.gamma, self.coef0, self.eps,
self.degree, self.gamma, self.coef0, self.tol,
self.C, self.class_weight_label,
self.class_weight, self.nu, self.cache_size,
self.p, int(self.shrinking),
Expand Down Expand Up @@ -280,7 +280,7 @@ def decision_function(self, T):
dec_func = libsvm_decision_function(T, self.support_vectors_,
self.dual_coef_, self.intercept_,
self._svm_types.index(self.impl), kernel_type,
self.degree, self.gamma, self.coef0, self.eps,
self.degree, self.gamma, self.coef0, self.tol,
self.C, self.class_weight_label,
self.class_weight, self.nu, self.cache_size,
self.p, int(self.shrinking),
Expand Down Expand Up @@ -319,12 +319,12 @@ class BaseLibLinear(BaseEstimator):
'PL2_LLR_D1' : 7, # L2 penalty, logistic regression, dual form
}

def __init__(self, penalty='l2', loss='l2', dual=True, eps=1e-4, C=1.0,
def __init__(self, penalty='l2', loss='l2', dual=True, tol=1e-4, C=1.0,
multi_class=False, fit_intercept=True, intercept_scaling=1):
self.penalty = penalty
self.loss = loss
self.dual = dual
self.eps = eps
self.tol = tol
self.C = C
self.fit_intercept = fit_intercept
self.intercept_scaling = intercept_scaling
Expand Down Expand Up @@ -377,7 +377,7 @@ def fit(self, X, y, class_weight={}, **params):
y = np.asanyarray(y, dtype=np.int32, order='C')

self.raw_coef_, self.label_ = _liblinear.train_wrap(X, y,
self._get_solver_type(), self.eps,
self._get_solver_type(), self.tol,
self._get_bias(), self.C,
self.class_weight_label, self.class_weight)

Expand All @@ -402,7 +402,7 @@ def predict(self, X):

return _liblinear.predict_wrap(X, coef,
self._get_solver_type(),
self.eps, self.C,
self.tol, self.C,
self.class_weight_label,
self.class_weight, self.label_,
self._get_bias())
Expand All @@ -426,7 +426,7 @@ def decision_function(self, X):
self._check_n_features(X)

dec_func = _liblinear.decision_function_wrap(
X, self.raw_coef_, self._get_solver_type(), self.eps,
X, self.raw_coef_, self._get_solver_type(), self.tol,
self.C, self.class_weight_label, self.class_weight,
self.label_, self._get_bias())

Expand Down
6 changes: 3 additions & 3 deletions scikits/learn/svm/liblinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class LinearSVC(BaseLibLinear, ClassifierMixin, CoefSelectTransformerMixin):
Select the algorithm to either solve the dual or primal
optimization problem.
eps: float, optional
precision for stopping criteria
tol: float, optional
tolerance for stopping criteria
multi_class: boolean, optional
perform multi-class SVM by Cramer and Singer. If active,
Expand Down Expand Up @@ -60,7 +60,7 @@ class LinearSVC(BaseLibLinear, ClassifierMixin, CoefSelectTransformerMixin):
The underlying C implementation uses a random number generator to
select features when fitting the model. It is thus not uncommon,
to have slightly different results for the same input data. If
that happens, try with a smaller eps parameter.
that happens, try with a smaller tol parameter.
See also
--------
Expand Down
34 changes: 17 additions & 17 deletions scikits/learn/svm/libsvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SVC(BaseLibSVM, ClassifierMixin):
shrinking: boolean, optional
wether to use the shrinking heuristic.
eps: float, optional
tol: float, optional
precision for stopping criteria
cache_size: float, optional
Expand Down Expand Up @@ -72,7 +72,7 @@ class SVC(BaseLibSVM, ClassifierMixin):
>>> from scikits.learn.svm import SVC
>>> clf = SVC()
>>> clf.fit(X, y)
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, eps=0.001,
SVC(kernel='rbf', C=1.0, probability=False, degree=3, coef0=0.0, tol=0.001,
cache_size=100.0, shrinking=True, gamma=0.25)
>>> print clf.predict([[-0.8, -1]])
[ 1.]
Expand All @@ -84,10 +84,10 @@ class SVC(BaseLibSVM, ClassifierMixin):

def __init__(self, C=1.0, kernel='rbf', degree=3, gamma=0.0,
coef0=0.0, shrinking=True, probability=False,
eps=1e-3, cache_size=100.0):
tol=1e-3, cache_size=100.0):

BaseLibSVM.__init__(self, 'c_svc', kernel, degree, gamma, coef0,
cache_size, eps, C, 0., 0.,
cache_size, tol, C, 0., 0.,
shrinking, probability)


Expand Down Expand Up @@ -125,7 +125,7 @@ class NuSVC(BaseLibSVM, ClassifierMixin):
shrinking: boolean, optional
wether to use the shrinking heuristic.
eps: float, optional
tol: float, optional
precision for stopping criteria
cache_size: float, optional
Expand Down Expand Up @@ -179,7 +179,7 @@ class NuSVC(BaseLibSVM, ClassifierMixin):
>>> from scikits.learn.svm import NuSVC
>>> clf = NuSVC()
>>> clf.fit(X, y)
NuSVC(kernel='rbf', probability=False, degree=3, coef0=0.0, eps=0.001,
NuSVC(kernel='rbf', probability=False, degree=3, coef0=0.0, tol=0.001,
cache_size=100.0, shrinking=True, nu=0.5, gamma=0.25)
>>> print clf.predict([[-0.8, -1]])
[ 1.]
Expand All @@ -191,10 +191,10 @@ class NuSVC(BaseLibSVM, ClassifierMixin):

def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma=0.0,
coef0=0.0, shrinking=True, probability=False,
eps=1e-3, cache_size=100.0):
tol=1e-3, cache_size=100.0):

BaseLibSVM.__init__(self, 'nu_svc', kernel, degree, gamma, coef0,
cache_size, eps, 0., nu, 0.,
cache_size, tol, 0., nu, 0.,
shrinking, probability)


Expand Down Expand Up @@ -231,7 +231,7 @@ class SVR(BaseLibSVM, RegressorMixin):
enable probability estimates. This must be enabled prior
to calling prob_predict.
eps: float, optional
tol: float, optional
precision for stopping criteria
coef0 : float, optional
Expand Down Expand Up @@ -267,11 +267,11 @@ class SVR(BaseLibSVM, RegressorMixin):
NuSVR
"""
def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0,
cache_size=100.0, eps=1e-3, C=1.0, nu=0.5, p=0.1,
cache_size=100.0, tol=1e-3, C=1.0, nu=0.5, p=0.1,
shrinking=True, probability=False):

BaseLibSVM.__init__(self, 'epsilon_svr', kernel, degree, gamma, coef0,
cache_size, eps, C, nu, p,
cache_size, tol, C, nu, p,
shrinking, probability)

def fit(self, X, y, sample_weight=[]):
Expand Down Expand Up @@ -325,7 +325,7 @@ class NuSVR(BaseLibSVM, RegressorMixin):
kernel coefficient for rbf and poly, by default 1/n_features
will be taken.
eps: float, optional
tol: float, optional
precision for stopping criteria
probability: boolean, optional (False by default)
Expand Down Expand Up @@ -367,10 +367,10 @@ class NuSVR(BaseLibSVM, RegressorMixin):

def __init__(self, nu=0.5, C=1.0, kernel='rbf', degree=3,
gamma=0.0, coef0=0.0, shrinking=True,
probability=False, cache_size=100.0, eps=1e-3):
probability=False, cache_size=100.0, tol=1e-3):

BaseLibSVM.__init__(self, 'epsilon_svr', kernel, degree, gamma, coef0,
cache_size, eps, C, nu, 0.,
cache_size, tol, C, nu, 0.,
shrinking, probability)

def fit(self, X, y):
Expand Down Expand Up @@ -423,7 +423,7 @@ class OneClassSVM(BaseLibSVM):
Independent term in kernel function. It is only significant in
poly/sigmoid.
eps: float, optional
tol: float, optional
precision for stopping criteria
shrinking: boolean, optional
Expand Down Expand Up @@ -452,9 +452,9 @@ class OneClassSVM(BaseLibSVM):
"""
def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0,
cache_size=100.0, eps=1e-3, nu=0.5, shrinking=True):
cache_size=100.0, tol=1e-3, nu=0.5, shrinking=True):
BaseLibSVM.__init__(self, 'one_class', kernel, degree, gamma, coef0,
cache_size, eps, 0.0, nu, 0.0, shrinking, False)
cache_size, tol, 0.0, nu, 0.0, shrinking, False)

def fit(self, X, class_weight={}, sample_weight=[], **params):
"""
Expand Down
Loading

0 comments on commit 09940a9

Please sign in to comment.