Skip to content

Commit

Permalink
[MRG+1] Fixes scikit-learn#7578 added check_decision_proba_consistenc…
Browse files Browse the repository at this point in the history
…y in estimator_checks (scikit-learn#8253)
  • Loading branch information
Shubham Bhardwaj authored and lesteve committed Mar 7, 2017
1 parent 5135c56 commit 02c705e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ API changes summary
selection classes to be used with tools such as
:func:`sklearn.model_selection.cross_val_predict`.
:issue:`2879` by :user:`Stephen Hoover <stephen-hoover>`.

- Estimators with both methods ``decision_function`` and ``predict_proba``
are now required to have a monotonic relation between them. The
method ``check_decision_proba_consistency`` has been added in
**sklearn.utils.estimator_checks** to check their consistency.
:issue:`7578` by :user:`Shubham Bhardwaj <shubham0704>`


.. _changes_0_18_1:

Expand Down
44 changes: 33 additions & 11 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import traceback
import pickle
from copy import deepcopy

import numpy as np
from scipy import sparse
from scipy.stats import rankdata
import struct

from sklearn.externals.six.moves import zip
Expand Down Expand Up @@ -113,10 +113,10 @@ def _yield_classifier_checks(name, Classifier):
# basic consistency testing
yield check_classifiers_train
yield check_classifiers_regression_target
if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"]
if (name not in
["MultinomialNB", "LabelPropagation", "LabelSpreading"] and
# TODO some complication with -1 label
and name not in ["DecisionTreeClassifier",
"ExtraTreeClassifier"]):
name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]):
# We don't raise a warning in these classifiers, as
# the column y interface is used by the forests.

Expand All @@ -127,6 +127,8 @@ def _yield_classifier_checks(name, Classifier):
yield check_class_weight_classifiers

yield check_non_transformer_estimators_n_iter
# test if predict_proba is a monotonic transformation of decision_function
yield check_decision_proba_consistency


@ignore_warnings(category=DeprecationWarning)
Expand Down Expand Up @@ -269,8 +271,7 @@ def set_testing_parameters(estimator):
# set parameters to speed up some estimators and
# avoid deprecated behaviour
params = estimator.get_params()
if ("n_iter" in params
and estimator.__class__.__name__ != "TSNE"):
if ("n_iter" in params and estimator.__class__.__name__ != "TSNE"):
estimator.set_params(n_iter=5)
if "max_iter" in params:
warnings.simplefilter("ignore", ConvergenceWarning)
Expand Down Expand Up @@ -1112,8 +1113,7 @@ def check_classifiers_train(name, Classifier):
assert_equal(decision.shape, (n_samples,))
dec_pred = (decision.ravel() > 0).astype(np.int)
assert_array_equal(dec_pred, y_pred)
if (n_classes is 3
and not isinstance(classifier, BaseLibSVM)):
if (n_classes is 3 and not isinstance(classifier, BaseLibSVM)):
# 1on1 of LibSVM works differently
assert_equal(decision.shape, (n_samples, n_classes))
assert_array_equal(np.argmax(decision, axis=1), y_pred)
Expand Down Expand Up @@ -1574,9 +1574,9 @@ def check_parameters_default_constructible(name, Estimator):
try:
def param_filter(p):
"""Identify hyper parameters of an estimator"""
return (p.name != 'self'
and p.kind != p.VAR_KEYWORD
and p.kind != p.VAR_POSITIONAL)
return (p.name != 'self' and
p.kind != p.VAR_KEYWORD and
p.kind != p.VAR_POSITIONAL)

init_params = [p for p in signature(init).parameters.values()
if param_filter(p)]
Expand Down Expand Up @@ -1721,3 +1721,25 @@ def check_classifiers_regression_target(name, Estimator):
e = Estimator()
msg = 'Unknown label type: '
assert_raises_regex(ValueError, msg, e.fit, X, y)


@ignore_warnings(category=DeprecationWarning)
def check_decision_proba_consistency(name, Estimator):
# Check whether an estimator having both decision_function and
# predict_proba methods has outputs with perfect rank correlation.

centers = [(2, 2), (4, 4)]
X, y = make_blobs(n_samples=100, random_state=0, n_features=4,
centers=centers, cluster_std=1.0, shuffle=True)
X_test = np.random.randn(20, 2) + 4
estimator = Estimator()

set_testing_parameters(estimator)

if (hasattr(estimator, "decision_function") and
hasattr(estimator, "predict_proba")):

estimator.fit(X, y)
a = estimator.predict_proba(X_test)[:, 1]
b = estimator.decision_function(X_test)
assert_array_equal(rankdata(a), rankdata(b))

0 comments on commit 02c705e

Please sign in to comment.