Skip to content

Commit

Permalink
Merge pull request #68 from jinlow/feature/add-eval-warning
Browse files Browse the repository at this point in the history
Save last tree, and make multiple evals throw warning.
  • Loading branch information
jinlow authored Sep 7, 2023
2 parents 7faeb03 + 8787abc commit 5c0d30e
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "forust-ml"
version = "0.2.22"
version = "0.2.23"
edition = "2021"
authors = ["James Inlow <[email protected]>"]
homepage = "https://github.com/jinlow/forust"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install forust

To use in a rust project add the following to your Cargo.toml file.
```toml
forust-ml = "0.2.22"
forust-ml = "0.2.23"
```

## Usage
Expand Down
4 changes: 2 additions & 2 deletions py-forust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-forust"
version = "0.2.22"
version = "0.2.23"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -10,7 +10,7 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }
forust-ml = { version = "0.2.22", path = "../" }
forust-ml = { version = "0.2.23", path = "../" }
numpy = "0.19.0"
ndarray = "0.15.1"
serde_plain = { version = "1.0" }
Expand Down
4 changes: 4 additions & 0 deletions py-forust/forust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def fit(
eval_w_,
)
)
if len(evaluation_data_) > 1:
warnings.warn(
"Multiple evaluation datasets passed, only the first one will be used to determine early stopping."
)
else:
evaluation_data_ = None

Expand Down
36 changes: 35 additions & 1 deletion py-forust/tests/test_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,41 @@ def test_early_stopping_with_dev(X_y):

# Did we actually stop?
n_trees = json.loads(model.json_dump())["trees"]
assert len(n_trees) == model.get_best_iteration() + 4
# Didn't improve for 4 rounds, and one more triggered it.
assert len(n_trees) == model.get_best_iteration() + 5
assert len(n_trees) == model.get_evaluation_history().shape[0]
assert model.get_best_iteration() < 99


def test_early_stopping_with_dev_val(X_y):
X, y = X_y

val = y.index.to_series().isin(y.sample(frac=0.25, random_state=0))

w = np.random.default_rng(0).uniform(0.5, 1.0, size=val.sum())

model = GradientBooster(log_iterations=1, early_stopping_rounds=4, iterations=100)

with pytest.warns(UserWarning, match="Multiple evaluation datasets"):
model.fit(
X.loc[~val, :],
y.loc[~val],
evaluation_data=[
(X.loc[val, :], y.loc[val]),
(X.loc[val, :], y.loc[val], w),
(X.loc[~val, :], y.loc[~val]),
],
)

# the weight is being used.
history = model.get_evaluation_history()
assert history[2, 0] != history[2, 1]

# Did we actually stop?
n_trees = json.loads(model.json_dump())["trees"]
# Didn't improve for 4 rounds, and one more triggered it.
assert len(n_trees) == model.get_best_iteration() + 5
assert len(n_trees) == model.get_evaluation_history().shape[0]
assert model.get_best_iteration() < 99


Expand Down
2 changes: 1 addition & 1 deletion rs-example.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
To run this example, add the following code to your `Cargo.toml` file.
```toml
[dependencies]
forust-ml = "0.2.22"
forust-ml = "0.2.23"
polars = "0.28"
reqwest = { version = "0.11", features = ["blocking"] }
```
Expand Down
15 changes: 9 additions & 6 deletions src/gradientbooster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ impl GradientBooster {

let mut best_metric: Option<f64> = None;

// This will always be false, unless early stopping rounds are used.
let mut stop_early = false;

for i in 0..self.iterations {
let verbose = if self.log_iterations == 0 {
false
Expand Down Expand Up @@ -531,9 +534,6 @@ impl GradientBooster {

self.update_predictions_inplace(&mut yhat, &tree, data);

// This will always be false, unless early stopping rounds are used.
let mut stop_early = false;

// Update Evaluation data, if it's needed.
if let Some(eval_sets) = &mut evaluation_sets {
if self.evaluation_history.is_none() {
Expand Down Expand Up @@ -592,11 +592,14 @@ impl GradientBooster {
if let Some(history) = &mut self.evaluation_history {
history.append_row(metrics);
}
if stop_early {
break;
}
}
self.trees.push(tree);

// Did we trigger the early stopping rounds criteria?
if stop_early {
break;
}

(grad, hess) = calc_grad_hess(y, &yhat, sample_weight);
if verbose {
info!("Completed iteration {} of {}", i, self.iterations);
Expand Down

0 comments on commit 5c0d30e

Please sign in to comment.