Skip to content

Commit

Permalink
FIX Allow sparse input data for OutputCodeClassifier (scikit-learn#17233
Browse files Browse the repository at this point in the history
)

Co-authored-by: Guillaume Lemaitre <[email protected]>
  • Loading branch information
zoj613 and glemaitre authored May 26, 2020
1 parent 76df39f commit 6b68144
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
9 changes: 9 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ Changelog
:pr:`16530` by :user:`Shuhua Fan <jim0421>`.


:mod:`sklearn.multiclass`
.........................

- |Fix| A fix to allow :class:`multiclass.OutputCodeClassifier` to accept
sparse input data in its `fit` and `predict` methods. The check for
validity of the input is now delegated to the base estimator.
:pr:`17233` by :user:`Zolisa Bleki <zoj613>`.


Code and Documentation Contributors
-----------------------------------

Expand Down
4 changes: 2 additions & 2 deletions sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def fit(self, X, y):
-------
self
"""
X, y = self._validate_data(X, y)
X, y = self._validate_data(X, y, accept_sparse=True)
if self.code_size <= 0:
raise ValueError("code_size should be greater than 0, got {0}"
"".format(self.code_size))
Expand Down Expand Up @@ -850,7 +850,7 @@ def predict(self, X):
Predicted multi-class targets.
"""
check_is_fitted(self)
X = check_array(X)
X = check_array(X, accept_sparse=True)
Y = np.array([_predict_binary(e, X) for e in self.estimators_]).T
pred = euclidean_distances(Y, self.code_book_).argmin(axis=1)
return self.classes_[pred]
29 changes: 29 additions & 0 deletions sklearn/tests/test_multiclass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import scipy.sparse as sp
import pytest

from re import escape

Expand All @@ -9,11 +10,13 @@
from sklearn.utils._testing import assert_warns
from sklearn.utils._testing import assert_raise_message
from sklearn.utils._testing import assert_raises_regexp
from sklearn.utils._mocking import CheckingClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multiclass import OneVsOneClassifier
from sklearn.multiclass import OutputCodeClassifier
from sklearn.utils.multiclass import (check_classification_targets,
type_of_target)
from sklearn.utils import check_array
from sklearn.utils import shuffle

from sklearn.metrics import precision_score
Expand Down Expand Up @@ -705,6 +708,32 @@ def test_ecoc_float_y():
" got -1", ovo.fit, X, y)


def test_ecoc_delegate_sparse_base_estimator():
# Non-regression test for
# https://github.com/scikit-learn/scikit-learn/issues/17218
X, y = iris.data, iris.target
X_sp = sp.csc_matrix(X)

# create an estimator that does not support sparse input
base_estimator = CheckingClassifier(
check_X=check_array,
check_X_params={"ensure_2d": True, "accept_sparse": False},
)
ecoc = OutputCodeClassifier(base_estimator, random_state=0)

with pytest.raises(TypeError, match="A sparse matrix was passed"):
ecoc.fit(X_sp, y)

ecoc.fit(X, y)
with pytest.raises(TypeError, match="A sparse matrix was passed"):
ecoc.predict(X_sp)

# smoke test to check when sparse input should be supported
ecoc = OutputCodeClassifier(LinearSVC(random_state=0))
ecoc.fit(X_sp, y).predict(X_sp)
assert len(ecoc.estimators_) == 4


def test_pairwise_indices():
clf_precomputed = svm.SVC(kernel='precomputed')
X, y = iris.data, iris.target
Expand Down

0 comments on commit 6b68144

Please sign in to comment.