diff --git a/py-forust/forust/__init__.py b/py-forust/forust/__init__.py index c10b84c..aa705ed 100644 --- a/py-forust/forust/__init__.py +++ b/py-forust/forust/__init__.py @@ -467,6 +467,10 @@ def fit( eval_w_, ) ) + if len(evaluation_data_) > 1: + warnings.warn( + "Multiple evaluation datasets past, only the first one will be used to determine early stopping." + ) else: evaluation_data_ = None diff --git a/py-forust/tests/test_booster.py b/py-forust/tests/test_booster.py index 536299c..e4f50ce 100644 --- a/py-forust/tests/test_booster.py +++ b/py-forust/tests/test_booster.py @@ -702,7 +702,41 @@ def test_early_stopping_with_dev(X_y): # Did we actually stop? n_trees = json.loads(model.json_dump())["trees"] - assert len(n_trees) == model.get_best_iteration() + 4 + # Didn't improve for 4 rounds, and one more triggered it. + assert len(n_trees) == model.get_best_iteration() + 5 + assert len(n_trees) == model.get_evaluation_history().shape[0] + assert model.get_best_iteration() < 99 + + +def test_early_stopping_with_dev_val(X_y): + X, y = X_y + + val = y.index.to_series().isin(y.sample(frac=0.25, random_state=0)) + + w = np.random.default_rng(0).uniform(0.5, 1.0, size=val.sum()) + + model = GradientBooster(log_iterations=1, early_stopping_rounds=4, iterations=100) + + with pytest.warns(UserWarning, match="Multiple evaluation datasets"): + model.fit( + X.loc[~val, :], + y.loc[~val], + evaluation_data=[ + (X.loc[val, :], y.loc[val]), + (X.loc[val, :], y.loc[val], w), + (X.loc[~val, :], y.loc[~val]), + ], + ) + + # the weight is being used. + history = model.get_evaluation_history() + assert history[2, 0] != history[2, 1] + + # Did we actually stop? + n_trees = json.loads(model.json_dump())["trees"] + # Didn't improve for 4 rounds, and one more triggered it. + assert len(n_trees) == model.get_best_iteration() + 5 + assert len(n_trees) == model.get_evaluation_history().shape[0] assert model.get_best_iteration() < 99 diff --git a/src/gradientbooster.rs b/src/gradientbooster.rs index da5bdae..7ddedfa 100644 --- a/src/gradientbooster.rs +++ b/src/gradientbooster.rs @@ -504,6 +504,9 @@ impl GradientBooster { let mut best_metric: Option = None; + // This will always be false, unless early stopping rounds are used. + let mut stop_early = false; + for i in 0..self.iterations { let verbose = if self.log_iterations == 0 { false @@ -531,9 +534,6 @@ impl GradientBooster { self.update_predictions_inplace(&mut yhat, &tree, data); - // This will always be false, unless early stopping rounds are used. - let mut stop_early = false; - // Update Evaluation data, if it's needed. if let Some(eval_sets) = &mut evaluation_sets { if self.evaluation_history.is_none() { @@ -592,11 +592,14 @@ impl GradientBooster { if let Some(history) = &mut self.evaluation_history { history.append_row(metrics); } - if stop_early { - break; - } } self.trees.push(tree); + + // Did we trigger the early stopping rounds criteria? + if stop_early { + break; + } + (grad, hess) = calc_grad_hess(y, &yhat, sample_weight); if verbose { info!("Completed iteration {} of {}", i, self.iterations);