diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 7514a146b5635..d689e6bc63c84 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1186,6 +1186,19 @@ def test_only_constant_features(): assert_equal(est.tree_.max_depth, 0) +def test_behaviour_constant_feature_after_splits(): + X = np.transpose(np.vstack(([[0, 0, 0, 0, 0, 1, 2, 4, 5, 6, 7]], + np.zeros((4, 11))))) + y = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3] + for name, TreeEstimator in ALL_TREES.items(): + # do not check extra random trees + if "ExtraTree" not in name: + est = TreeEstimator(random_state=0, max_features=1) + est.fit(X, y) + assert_equal(est.tree_.max_depth, 2) + assert_equal(est.tree_.node_count, 5) + + def test_with_only_one_non_constant_features(): X = np.hstack([np.array([[1.], [1.], [0.], [0.]]), np.zeros((4, 1000))])