Skip to content

Commit

Permalink
Save last tree, and make multiple evals throw warning.
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlow committed Sep 7, 2023
1 parent 7faeb03 commit 7b27c38
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
4 changes: 4 additions & 0 deletions py-forust/forust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def fit(
eval_w_,
)
)
if len(evaluation_data_) > 1:
warnings.warn(
"Multiple evaluation datasets past, only the first one will be used to determine early stopping."
)
else:
evaluation_data_ = None

Expand Down
36 changes: 35 additions & 1 deletion py-forust/tests/test_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,41 @@ def test_early_stopping_with_dev(X_y):

# Did we actually stop?
n_trees = json.loads(model.json_dump())["trees"]
assert len(n_trees) == model.get_best_iteration() + 4
# Didn't improve for 4 rounds, and one more triggered it.
assert len(n_trees) == model.get_best_iteration() + 5
assert len(n_trees) == model.get_evaluation_history().shape[0]
assert model.get_best_iteration() < 99


def test_early_stopping_with_dev_val(X_y):
X, y = X_y

val = y.index.to_series().isin(y.sample(frac=0.25, random_state=0))

w = np.random.default_rng(0).uniform(0.5, 1.0, size=val.sum())

model = GradientBooster(log_iterations=1, early_stopping_rounds=4, iterations=100)

with pytest.warns(UserWarning, match="Multiple evaluation datasets"):
model.fit(
X.loc[~val, :],
y.loc[~val],
evaluation_data=[
(X.loc[val, :], y.loc[val]),
(X.loc[val, :], y.loc[val], w),
(X.loc[~val, :], y.loc[~val]),
],
)

# the weight is being used.
history = model.get_evaluation_history()
assert history[2, 0] != history[2, 1]

# Did we actually stop?
n_trees = json.loads(model.json_dump())["trees"]
# Didn't improve for 4 rounds, and one more triggered it.
assert len(n_trees) == model.get_best_iteration() + 5
assert len(n_trees) == model.get_evaluation_history().shape[0]
assert model.get_best_iteration() < 99


Expand Down
15 changes: 9 additions & 6 deletions src/gradientbooster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ impl GradientBooster {

let mut best_metric: Option<f64> = None;

// This will always be false, unless early stopping rounds are used.
let mut stop_early = false;

for i in 0..self.iterations {
let verbose = if self.log_iterations == 0 {
false
Expand Down Expand Up @@ -531,9 +534,6 @@ impl GradientBooster {

self.update_predictions_inplace(&mut yhat, &tree, data);

// This will always be false, unless early stopping rounds are used.
let mut stop_early = false;

// Update Evaluation data, if it's needed.
if let Some(eval_sets) = &mut evaluation_sets {
if self.evaluation_history.is_none() {
Expand Down Expand Up @@ -592,11 +592,14 @@ impl GradientBooster {
if let Some(history) = &mut self.evaluation_history {
history.append_row(metrics);
}
if stop_early {
break;
}
}
self.trees.push(tree);

// Did we trigger the early stopping rounds criteria?
if stop_early {
break;
}

(grad, hess) = calc_grad_hess(y, &yhat, sample_weight);
if verbose {
info!("Completed iteration {} of {}", i, self.iterations);
Expand Down

0 comments on commit 7b27c38

Please sign in to comment.