diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 48dc491d32624..958b6210475bb 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -113,6 +113,15 @@ Changelog :pr:`16530` by :user:`Shuhua Fan `. +:mod:`sklearn.multiclass` +......................... + +- |Fix| A fix to allow :class:`multiclass.OutputCodeClassifier` to accept + sparse input data in its `fit` and `predict` methods. The check for + validity of the input is now delegated to the base estimator. + :pr:`17233` by :user:`Zolisa Bleki `. + + Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 377e875760f72..019cf9f4683f9 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -802,7 +802,7 @@ def fit(self, X, y): ------- self """ - X, y = self._validate_data(X, y) + X, y = self._validate_data(X, y, accept_sparse=True) if self.code_size <= 0: raise ValueError("code_size should be greater than 0, got {0}" "".format(self.code_size)) @@ -850,7 +850,7 @@ def predict(self, X): Predicted multi-class targets. """ check_is_fitted(self) - X = check_array(X) + X = check_array(X, accept_sparse=True) Y = np.array([_predict_binary(e, X) for e in self.estimators_]).T pred = euclidean_distances(Y, self.code_book_).argmin(axis=1) return self.classes_[pred] diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 03ada399d2af2..a4bdd6ef2688f 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -1,5 +1,6 @@ import numpy as np import scipy.sparse as sp +import pytest from re import escape @@ -9,11 +10,13 @@ from sklearn.utils._testing import assert_warns from sklearn.utils._testing import assert_raise_message from sklearn.utils._testing import assert_raises_regexp +from sklearn.utils._mocking import CheckingClassifier from sklearn.multiclass import OneVsRestClassifier from sklearn.multiclass import OneVsOneClassifier from sklearn.multiclass import OutputCodeClassifier from sklearn.utils.multiclass import (check_classification_targets, type_of_target) +from sklearn.utils import check_array from sklearn.utils import shuffle from sklearn.metrics import precision_score @@ -705,6 +708,32 @@ def test_ecoc_float_y(): " got -1", ovo.fit, X, y) +def test_ecoc_delegate_sparse_base_estimator(): + # Non-regression test for + # https://github.com/scikit-learn/scikit-learn/issues/17218 + X, y = iris.data, iris.target + X_sp = sp.csc_matrix(X) + + # create an estimator that does not support sparse input + base_estimator = CheckingClassifier( + check_X=check_array, + check_X_params={"ensure_2d": True, "accept_sparse": False}, + ) + ecoc = OutputCodeClassifier(base_estimator, random_state=0) + + with pytest.raises(TypeError, match="A sparse matrix was passed"): + ecoc.fit(X_sp, y) + + ecoc.fit(X, y) + with pytest.raises(TypeError, match="A sparse matrix was passed"): + ecoc.predict(X_sp) + + # smoke test to check when sparse input should be supported + ecoc = OutputCodeClassifier(LinearSVC(random_state=0)) + ecoc.fit(X_sp, y).predict(X_sp) + assert len(ecoc.estimators_) == 4 + + def test_pairwise_indices(): clf_precomputed = svm.SVC(kernel='precomputed') X, y = iris.data, iris.target