Skip to content

Commit

Permalink
ENH add support for sample weights in MAE (scikit-learn#17225)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucyleeow authored May 27, 2020
1 parent 6f33c5c commit f93f560
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 58 deletions.
8 changes: 7 additions & 1 deletion doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ Changelog
attribute name/path or a `callable` for extracting feature importance from
the estimator. :pr:`15361` by :user:`Venkatachalam N <venkyyuvy>`


:mod:`sklearn.metrics`
......................

- |Enhancement| Add `sample_weight` parameter to
:class:`metrics.median_absolute_error`.
:pr:`17225` by :user:`Lucy Liu <lucyleeow>`.

:mod:`sklearn.tree`
...................

Expand Down
31 changes: 0 additions & 31 deletions sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest

from sklearn.utils import check_random_state
from sklearn.utils.stats import _weighted_percentile
from sklearn.ensemble._gb_losses import RegressionLossFunction
from sklearn.ensemble._gb_losses import LeastSquaresError
from sklearn.ensemble._gb_losses import LeastAbsoluteError
Expand Down Expand Up @@ -103,36 +102,6 @@ def test_sample_weight_init_estimators():
assert_allclose(out, sw_out, rtol=1e-2)


def test_weighted_percentile():
y = np.empty(102, dtype=np.float64)
y[:50] = 0
y[-51:] = 2
y[-1] = 100000
y[50] = 1
sw = np.ones(102, dtype=np.float64)
sw[-1] = 0.0
score = _weighted_percentile(y, sw, 50)
assert score == 1


def test_weighted_percentile_equal():
y = np.empty(102, dtype=np.float64)
y.fill(0.0)
sw = np.ones(102, dtype=np.float64)
sw[-1] = 0.0
score = _weighted_percentile(y, sw, 50)
assert score == 0


def test_weighted_percentile_zero_weight():
y = np.empty(102, dtype=np.float64)
y.fill(1.0)
sw = np.ones(102, dtype=np.float64)
sw.fill(0.0)
score = _weighted_percentile(y, sw, 50)
assert score == 1.0


def test_quantile_loss_function():
# Non regression test for the QuantileLossFunction object
# There was a sign problem when evaluating the function
Expand Down
17 changes: 15 additions & 2 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
_num_samples)
from ..utils.validation import column_or_1d
from ..utils.validation import _deprecate_positional_args
from ..utils.validation import _check_sample_weight
from ..utils.stats import _weighted_percentile
from ..exceptions import UndefinedMetricWarning


Expand Down Expand Up @@ -340,7 +342,8 @@ def mean_squared_log_error(y_true, y_pred, *,


@_deprecate_positional_args
def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average'):
def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average',
sample_weight=None):
"""Median absolute error regression loss
Median absolute error output is non-negative floating point. The best value
Expand All @@ -365,6 +368,11 @@ def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average'):
'uniform_average' :
Errors of all outputs are averaged with uniform weight.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
.. versionadded:: 0.24
Returns
-------
loss : float or ndarray of floats
Expand Down Expand Up @@ -392,7 +400,12 @@ def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average'):
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
if sample_weight is None:
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
else:
sample_weight = _check_sample_weight(sample_weight, y_pred)
output_errors = _weighted_percentile(np.abs(y_pred - y_true),
sample_weight=sample_weight)
if isinstance(multioutput, str):
if multioutput == 'raw_values':
return output_errors
Expand Down
70 changes: 55 additions & 15 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from sklearn.linear_model import Ridge, LogisticRegression, Perceptron
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.datasets import make_blobs
from sklearn.datasets import make_classification
from sklearn.datasets import make_classification, make_regression
from sklearn.datasets import make_multilabel_classification
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split, cross_val_score
Expand Down Expand Up @@ -89,7 +89,7 @@ def _make_estimators(X_train, y_train, y_ml_train):
# Make estimators that make sense to test various scoring methods
sensible_regr = DecisionTreeRegressor(random_state=0)
# some of the regressions scorers require strictly positive input.
sensible_regr.fit(X_train, y_train + 1)
sensible_regr.fit(X_train, _require_positive_y(y_train))
sensible_clf = DecisionTreeClassifier(random_state=0)
sensible_clf.fit(X_train, y_train)
sensible_ml_clf = DecisionTreeClassifier(random_state=0)
Expand Down Expand Up @@ -474,8 +474,9 @@ def test_raises_on_score_list():


@ignore_warnings
def test_scorer_sample_weight():
# Test that scorers support sample_weight or raise sensible errors
def test_classification_scorer_sample_weight():
# Test that classification scorers support sample_weight or raise sensible
# errors

# Unlike the metrics invariance test, in the scorer case it's harder
# to ensure that, on the classifier output, weighted and unweighted
Expand All @@ -493,31 +494,70 @@ def test_scorer_sample_weight():
estimator = _make_estimators(X_train, y_train, y_ml_train)

for name, scorer in SCORERS.items():
if name in REGRESSION_SCORERS:
# skip the regression scores
continue
if name in MULTILABEL_ONLY_SCORERS:
target = y_ml_test
else:
target = y_test
if name in REQUIRE_POSITIVE_Y_SCORERS:
target = _require_positive_y(target)
try:
weighted = scorer(estimator[name], X_test, target,
sample_weight=sample_weight)
ignored = scorer(estimator[name], X_test[10:], target[10:])
unweighted = scorer(estimator[name], X_test, target)
assert weighted != unweighted, (
"scorer {0} behaves identically when "
"called with sample weights: {1} vs "
"{2}".format(name, weighted, unweighted))
f"scorer {name} behaves identically when called with "
f"sample weights: {weighted} vs {unweighted}")
assert_almost_equal(weighted, ignored,
err_msg="scorer {0} behaves differently when "
"ignoring samples and setting sample_weight to"
" 0: {1} vs {2}".format(name, weighted,
ignored))
err_msg=f"scorer {name} behaves differently "
f"when ignoring samples and setting "
f"sample_weight to 0: {weighted} vs {ignored}")

except TypeError as e:
assert "sample_weight" in str(e), (
"scorer {0} raises unhelpful exception when called "
"with sample weights: {1}".format(name, str(e)))
f"scorer {name} raises unhelpful exception when called "
f"with sample weights: {str(e)}")


@ignore_warnings
def test_regression_scorer_sample_weight():
# Test that regression scorers support sample_weight or raise sensible
# errors

# Odd number of test samples req for neg_median_absolute_error
X, y = make_regression(n_samples=101, n_features=20, random_state=0)
y = _require_positive_y(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

sample_weight = np.ones_like(y_test)
# Odd number req for neg_median_absolute_error
sample_weight[:11] = 0

reg = DecisionTreeRegressor(random_state=0)
reg.fit(X_train, y_train)

for name, scorer in SCORERS.items():
if name not in REGRESSION_SCORERS:
# skip classification scorers
continue
try:
weighted = scorer(reg, X_test, y_test,
sample_weight=sample_weight)
ignored = scorer(reg, X_test[11:], y_test[11:])
unweighted = scorer(reg, X_test, y_test)
assert weighted != unweighted, (
f"scorer {name} behaves identically when called with "
f"sample weights: {weighted} vs {unweighted}")
assert_almost_equal(weighted, ignored,
err_msg=f"scorer {name} behaves differently "
f"when ignoring samples and setting "
f"sample_weight to 0: {weighted} vs {ignored}")

except TypeError as e:
assert "sample_weight" in str(e), (
f"scorer {name} raises unhelpful exception when called "
f"with sample weights: {str(e)}")


@pytest.mark.parametrize('name', SCORERS)
Expand Down
36 changes: 36 additions & 0 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,39 @@ class loguniform(scipy.stats.reciprocal):
)
class MaskedArray(_MaskedArray):
pass # TODO: remove in 0.25


def _take_along_axis(arr, indices, axis):
"""Implements a simplified version of np.take_along_axis if numpy
version < 1.15"""
if np_version > (1, 14):
return np.take_along_axis(arr=arr, indices=indices, axis=axis)
else:
if axis is None:
arr = arr.flatten()

if not np.issubdtype(indices.dtype, np.intp):
raise IndexError('`indices` must be an integer array')
if arr.ndim != indices.ndim:
raise ValueError(
"`indices` and `arr` must have the same number of dimensions")

shape_ones = (1,) * indices.ndim
dest_dims = (
list(range(axis)) +
[None] +
list(range(axis+1, indices.ndim))
)

# build a fancy index, consisting of orthogonal aranges, with the
# requested index inserted at the right location
fancy_index = []
for dim, n in zip(dest_dims, arr.shape):
if dim is None:
fancy_index.append(indices)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
fancy_index.append(np.arange(n).reshape(ind_shape))

fancy_index = tuple(fancy_index)
return arr[fancy_index]
61 changes: 52 additions & 9 deletions sklearn/utils/stats.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,61 @@
import numpy as np

from .extmath import stable_cumsum
from .fixes import _take_along_axis


def _weighted_percentile(array, sample_weight, percentile=50):
"""Compute weighted percentile
Computes lower weighted percentile. If `array` is a 2D array, the
`percentile` is computed along the axis 0.
.. versionchanged:: 0.24
Accepts 2D `array`.
Parameters
----------
array : 1D or 2D array
Values to take the weighted percentile of.
sample_weight: 1D or 2D array
Weights for each value in `array`. Must be same shape as `array` or
of shape `(array.shape[0],)`.
percentile: int, default=50
Percentile to compute. Must be value between 0 and 100.
Returns
-------
percentile : int if `array` 1D, ndarray if `array` 2D
Weighted percentile.
"""
Compute the weighted ``percentile`` of ``array`` with ``sample_weight``.
"""
sorted_idx = np.argsort(array)
n_dim = array.ndim
if n_dim == 0:
return array[()]
if array.ndim == 1:
array = array.reshape((-1, 1))
# When sample_weight 1D, repeat for each array.shape[1]
if (array.shape != sample_weight.shape and
array.shape[0] == sample_weight.shape[0]):
sample_weight = np.tile(sample_weight, (array.shape[1], 1)).T
sorted_idx = np.argsort(array, axis=0)
sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0)

# Find index of median prediction for each sample
weight_cdf = stable_cumsum(sample_weight[sorted_idx])
percentile_idx = np.searchsorted(
weight_cdf, (percentile / 100.) * weight_cdf[-1])
# in rare cases, percentile_idx equals to len(sorted_idx)
percentile_idx = np.clip(percentile_idx, 0, len(sorted_idx)-1)
return array[sorted_idx[percentile_idx]]
weight_cdf = stable_cumsum(sorted_weights, axis=0)
adjusted_percentile = percentile / 100 * weight_cdf[-1]
percentile_idx = np.array([
np.searchsorted(weight_cdf[:, i], adjusted_percentile[i])
for i in range(weight_cdf.shape[1])
])
percentile_idx = np.array(percentile_idx)
# In rare cases, percentile_idx equals to sorted_idx.shape[0]
max_idx = sorted_idx.shape[0] - 1
percentile_idx = np.apply_along_axis(lambda x: np.clip(x, 0, max_idx),
axis=0, arr=percentile_idx)

col_index = np.arange(array.shape[1])
percentile_in_sorted = sorted_idx[percentile_idx, col_index]
percentile = array[percentile_in_sorted, col_index]
return percentile[0] if n_dim == 1 else percentile
Loading

0 comments on commit f93f560

Please sign in to comment.