Skip to content

Commit

Permalink
Merge pull request #359 from rg2410/issue-358
Browse files Browse the repository at this point in the history
Sample weight sliced to work with cross validation, issue #358
  • Loading branch information
lopuhin authored Jan 22, 2020
2 parents 4839d19 + 729e557 commit 017c738
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
6 changes: 5 additions & 1 deletion eli5/sklearn/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,12 @@ def _cv_scores_importances(self, X, y, groups=None, **fit_params):
cv = check_cv(self.cv, y, is_classifier(self.estimator))
feature_importances = [] # type: List
base_scores = [] # type: List[float]
weights = fit_params.pop('sample_weight', None)
fold_fit_params = fit_params.copy()
for train, test in cv.split(X, y, groups):
est = clone(self.estimator).fit(X[train], y[train], **fit_params)
if weights is not None:
fold_fit_params['sample_weight'] = weights[train]
est = clone(self.estimator).fit(X[train], y[train], **fold_fit_params)
score_func = partial(self.scorer_, est)
_base_score, _importances = self._get_score_importances(
score_func, X[test], y[test])
Expand Down
17 changes: 16 additions & 1 deletion tests/test_sklearn_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from sklearn.base import is_classifier, is_regressor
from sklearn.svm import SVR, SVC
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.feature_selection import SelectFromModel
Expand Down Expand Up @@ -165,6 +165,7 @@ def test_explain_weights(iris_train):
for _expl in res:
assert "petal width (cm)" in _expl


def test_pandas_xgboost_support(iris_train):
xgboost = pytest.importorskip('xgboost')
pd = pytest.importorskip('pandas')
Expand All @@ -175,3 +176,17 @@ def test_pandas_xgboost_support(iris_train):
est.fit(X, y)
# we expect no exception to be raised here when using xgboost with pd.DataFrame
perm = PermutationImportance(est).fit(X, y)


def test_cv_sample_weight(iris_train):
X, y, feature_names, target_names = iris_train
weights_ones = np.ones(len(y))
model = RandomForestClassifier(random_state=42)

# we expect no exception to be raised when passing weights with a CV
perm_weights = PermutationImportance(model, cv=5, random_state=42).\
fit(X, y, sample_weight=weights_ones)
perm = PermutationImportance(model, cv=5, random_state=42).fit(X, y)

# passing a vector of weights filled with one should be the same as passing no weights
assert (perm.feature_importances_ == perm_weights.feature_importances_).all()

0 comments on commit 017c738

Please sign in to comment.