Skip to content

Commit

Permalink
Merge pull request scikit-learn#5678 from betatim/no-warning-iforest
Browse files Browse the repository at this point in the history
[MRG+1] IsolationForest max_samples warning and calculation
  • Loading branch information
glouppe committed Nov 11, 2015
2 parents a5d6144 + b767117 commit 889e2d4
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
13 changes: 10 additions & 3 deletions sklearn/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def fit(self, X, y, sample_weight=None):
self : object
Returns self.
"""
return self._fit(X, y, self.max_samples, sample_weight)
return self._fit(X, y, self.max_samples, sample_weight=sample_weight)

def _fit(self, X, y, max_samples, sample_weight=None):
def _fit(self, X, y, max_samples, max_depth=None, sample_weight=None):
"""Build a Bagging ensemble of estimators from the training
set (X, y).
Expand All @@ -267,6 +267,10 @@ def _fit(self, X, y, max_samples, sample_weight=None):
max_samples : int or float, optional (default=None)
Argument to use instead of self.max_samples.
max_depth : int, optional (default=None)
Override value used when constructing base estimator. Only
supported if the base estimator has a max_depth parameter.
sample_weight : array-like, shape = [n_samples] or None
Sample weights. If None, then samples are equally weighted.
Note that this is supported only if the base estimator supports
Expand All @@ -289,9 +293,12 @@ def _fit(self, X, y, max_samples, sample_weight=None):
# Check parameters
self._validate_estimator()

if max_depth is not None:
self.base_estimator_.max_depth = max_depth

# if max_samples is float:
if not isinstance(max_samples, (numbers.Integral, np.integer)):
max_samples = int(self.max_samples * X.shape[0])
max_samples = int(max_samples * X.shape[0])

if not (0 < max_samples <= X.shape[0]):
raise ValueError("max_samples must be in (0, n_samples]")
Expand Down
48 changes: 31 additions & 17 deletions sklearn/ensemble/iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from scipy.sparse import issparse

from ..externals import six
from ..externals.joblib import Parallel, delayed
from ..tree import ExtraTreeRegressor
from ..utils import check_random_state, check_array
Expand Down Expand Up @@ -47,10 +48,11 @@ class IsolationForest(BaseBagging):
n_estimators : int, optional (default=100)
The number of base estimators in the ensemble.
max_samples : int or float, optional (default=256)
max_samples : int or float, optional (default="auto")
The number of samples to draw from X to train each base estimator.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
- If "auto", then `max_samples=min(256, n_samples)`.
If max_samples is larger than the number of samples provided,
all samples will be used for all trees (no sampling).
Expand Down Expand Up @@ -99,15 +101,14 @@ class IsolationForest(BaseBagging):

def __init__(self,
n_estimators=100,
max_samples=256,
max_samples="auto",
max_features=1.,
bootstrap=False,
n_jobs=1,
random_state=None,
verbose=0):
super(IsolationForest, self).__init__(
base_estimator=ExtraTreeRegressor(
max_depth=int(np.ceil(np.log2(max(max_samples, 2)))),
max_features=1,
splitter='random',
random_state=random_state),
Expand Down Expand Up @@ -151,16 +152,34 @@ def fit(self, X, y=None, sample_weight=None):
y = rnd.uniform(size=X.shape[0])

# ensure that max_sample is in [1, n_samples]:
max_samples = self.max_samples
n_samples = X.shape[0]
if max_samples > n_samples:
warn("max_samples (%s) is greater than the "
"total number of samples (%s). max_samples "
"will be set to n_samples for estimation."
% (self.max_samples, n_samples))
max_samples = n_samples

if isinstance(self.max_samples, six.string_types):
if self.max_samples == 'auto':
max_samples = min(256, n_samples)
else:
raise ValueError('max_samples (%s) is not supported.'
'Valid choices are: "auto", int or'
'float' % self.max_samples)

elif isinstance(self.max_samples, six.integer_types):
if self.max_samples > n_samples:
warn("max_samples (%s) is greater than the "
"total number of samples (%s). max_samples "
"will be set to n_samples for estimation."
% (self.max_samples, n_samples))
max_samples = n_samples
else:
max_samples = self.max_samples
else: # float
if not (0. < self.max_samples <= 1.):
raise ValueError("max_samples must be in (0, 1]")
max_samples = int(self.max_samples * X.shape[0])

self.max_samples_ = max_samples
max_depth = int(np.ceil(np.log2(max(max_samples, 2))))
super(IsolationForest, self)._fit(X, y, max_samples,
max_depth=max_depth,
sample_weight=sample_weight)
return self

Expand Down Expand Up @@ -206,12 +225,7 @@ def predict(self, X):

depths += _average_path_length(n_samples_leaf)

if not isinstance(self.max_samples, (numbers.Integral, np.integer)):
max_samples = int(self.max_samples * X.shape[0])
else:
max_samples = self.max_samples

scores = 2 ** (-depths.mean(axis=1) / _average_path_length(max_samples))
scores = 2 ** (-depths.mean(axis=1) / _average_path_length(self.max_samples_))

return scores

Expand Down Expand Up @@ -249,7 +263,7 @@ def _average_path_length(n_samples_leaf):
average_path_length : array, same shape as n_samples_leaf
"""
if isinstance(n_samples_leaf, int):
if isinstance(n_samples_leaf, six.integer_types):
if n_samples_leaf <= 1:
return 1.
else:
Expand Down
37 changes: 34 additions & 3 deletions sklearn/ensemble/tests/test_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import assert_warns_message
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_no_warnings
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import ignore_warnings

Expand Down Expand Up @@ -95,9 +98,37 @@ def test_iforest_error():
IsolationForest(max_samples=0.0).fit, X)
assert_raises(ValueError,
IsolationForest(max_samples=2.0).fit, X)
assert_warns(UserWarning,
IsolationForest(max_samples=1000).fit, X)
# cannot check for string values
# The dataset has less than 256 samples, explicitly setting max_samples > n_samples
# should result in a warning. If not set explicitly there should be no warning
assert_warns_message(UserWarning,
"max_samples will be set to n_samples for estimation",
IsolationForest(max_samples=1000).fit, X)
assert_no_warnings(IsolationForest(max_samples='auto').fit, X)
assert_raises(ValueError,
IsolationForest(max_samples='foobar').fit, X)


def test_recalculate_max_depth():
"""Check that max_depth is recalculated when max_samples is reset to n_samples"""
X = iris.data
clf = IsolationForest().fit(X)
for est in clf.estimators_:
assert_equal(est.max_depth, int(np.ceil(np.log2(X.shape[0]))))


def test_max_samples_attribute():
X = iris.data
clf = IsolationForest().fit(X)
assert_equal(clf.max_samples_, X.shape[0])

clf = IsolationForest(max_samples=500)
assert_warns_message(UserWarning,
"max_samples will be set to n_samples for estimation",
clf.fit, X)
assert_equal(clf.max_samples_, X.shape[0])

clf = IsolationForest(max_samples=0.4).fit(X)
assert_equal(clf.max_samples_, 0.4*X.shape[0])


def test_iforest_parallel_regression():
Expand Down

0 comments on commit 889e2d4

Please sign in to comment.