Skip to content

Commit

Permalink
TST Remove Boston dataset in test_mlp (scikit-learn#17337)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucyleeow authored May 26, 2020
1 parent 6b68144 commit 6f33c5c
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions sklearn/neural_network/tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@

import numpy as np

from numpy.testing import assert_almost_equal, assert_array_equal
from numpy.testing import (
assert_almost_equal,
assert_array_equal,
assert_allclose,
)

from sklearn.datasets import load_digits, load_boston, load_iris
from sklearn.datasets import load_digits, load_iris
from sklearn.datasets import make_regression, make_multilabel_classification
from sklearn.exceptions import ConvergenceWarning
from io import StringIO
from sklearn.metrics import roc_auc_score
from sklearn.neural_network import MLPClassifier
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.preprocessing import MinMaxScaler, scale
from scipy.sparse import csr_matrix
from sklearn.utils._testing import ignore_warnings

Expand All @@ -42,12 +46,10 @@
classification_datasets = [(X_digits_multi, y_digits_multi),
(X_digits_binary, y_digits_binary)]

boston = load_boston()

Xboston = StandardScaler().fit_transform(boston.data)[: 200]
yboston = boston.target[:200]

regression_datasets = [(Xboston, yboston)]
X_reg, y_reg = make_regression(n_samples=200, n_features=10, bias=20.,
noise=100., random_state=7)
y_reg = scale(y_reg)
regression_datasets = [(X_reg, y_reg)]

iris = load_iris()

Expand Down Expand Up @@ -252,17 +254,17 @@ def test_lbfgs_classification(X, y):

@pytest.mark.parametrize('X,y', regression_datasets)
def test_lbfgs_regression(X, y):
# Test lbfgs on the boston dataset, a regression problems.
# Test lbfgs on the regression dataset.
for activation in ACTIVATION_TYPES:
mlp = MLPRegressor(solver='lbfgs', hidden_layer_sizes=50,
max_iter=150, shuffle=True, random_state=1,
activation=activation)
mlp.fit(X, y)
if activation == 'identity':
assert mlp.score(X, y) > 0.84
assert mlp.score(X, y) > 0.80
else:
# Non linear models perform much better than linear bottleneck:
assert mlp.score(X, y) > 0.95
assert mlp.score(X, y) > 0.98


@pytest.mark.parametrize('X,y', classification_datasets)
Expand All @@ -287,7 +289,7 @@ def test_lbfgs_regression_maxfun(X, y):
max_fun = 10
# regression tests
for activation in ACTIVATION_TYPES:
mlp = MLPRegressor(solver='lbfgs', hidden_layer_sizes=50,
mlp = MLPRegressor(solver='lbfgs', hidden_layer_sizes=50, tol=0.0,
max_iter=150, max_fun=max_fun, shuffle=True,
random_state=1, activation=activation)
with pytest.warns(ConvergenceWarning):
Expand Down Expand Up @@ -400,8 +402,8 @@ def test_partial_fit_unseen_classes():
def test_partial_fit_regression():
# Test partial_fit on regression.
# `partial_fit` should yield the same results as 'fit' for regression.
X = Xboston
y = yboston
X = X_reg
y = y_reg

for momentum in [0, .9]:
mlp = MLPRegressor(solver='sgd', max_iter=100, activation='relu',
Expand All @@ -418,9 +420,9 @@ def test_partial_fit_regression():
mlp.partial_fit(X, y)

pred2 = mlp.predict(X)
assert_almost_equal(pred1, pred2, decimal=2)
assert_allclose(pred1, pred2)
score = mlp.score(X, y)
assert score > 0.75
assert score > 0.65


def test_partial_fit_errors():
Expand Down

0 comments on commit 6f33c5c

Please sign in to comment.