Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: EnsembleModel accepts pretrained global models #1815

Merged
merged 59 commits into from
Sep 16, 2023

Conversation

madtoinou
Copy link
Collaborator

Fixes #1775 and fixes #1785.

Summary

  • EnsembleModel accept pre-trained models if they are all global and all already fitted.
  • RegressionEnsembleModel.fit() accepts an additional argument retrain_forecasting_model which is False by default to avoid catastrophic forgetting. retrain_forecasting_model=True is accepted only if all the forecasting models are global and pre-trained.

Other Information

  • If a NaiveEnsembleModel was instantiated with pre-trained global models, the constructor sets self._fit_called=True so that predict() can be called directly without having to call fit(). I don't know if it's the most intuitive behavior but in case of single ts training/prediction, it makes the workflow more natural.
  • Added some very basics tests

@codecov-commenter
Copy link

codecov-commenter commented Jun 5, 2023

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @madtoinou , this is a great start 🚀

I have a couple of points:

  • let's move retrain_forecasting_models to the constructor __init__() of the base EnsembleModel
  • retrain_forecasting_models should be True by default to maintain the same behavior as we had previously.
    • if retrain_forecasting_models=True, we use the forecasting_models = [model.untrained_model() for model in forecasting_models]. Like this we will not mess with any pre-trained model.
      • we can log an info if there were any pre-trained models in the original forecasting_models
  • I think it would be a good idea at this point to rename models to forecasting_models in all ensemble models (base, naive) to have a common naming. This is a breaking change, but I think it is acceptable along with all the other changes of the next release.
  • In theory NaiveEnsembleModel could be used without fit() if all the models were pretrained. Should we allow this?

…True), renamed models to forecasting_models in all ensemble, improved the logic around retraining
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice, thanks @madtoinou ! 🚀

Had some last suggestions mainly regarding documentation and some typos.

One thing I thought could be nice to add:

  • How about we allow -1 for regression_train_n_points when retrain_forecasting_models=False? Like this whenever the user calls fit(), the regression model would be trained on the entire input?

CHANGELOG.md Outdated Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/regression_ensemble_model.py Outdated Show resolved Hide resolved
@@ -172,9 +195,11 @@ def fit(
# prepare the forecasting models for further predicting by fitting them with the entire data

# Some models (incl. Neural-Network based models) may need to be 'reset' to allow being retrained from scratch
self.models = [model.untrained_model() for model in self.models]
self.forecasting_models = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to return here before when retrain=False

if not self.retrain_forecasting_models:
    return self

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to update the training_series attribute of the models to make the behavior of predict() more intuitive (otherwise, there might be discrepancies between the forecasting models and the regressor, especially in the prediction time indexes)

darts/tests/models/forecasting/test_ensemble_models.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, we're almost there 🚀

Had a couple of suggestions, the main ones revolve around the regression_train_n_points=-1 where I think we can just use the series as is for the input to generate the regression model training data.

CHANGELOG.md Outdated Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Show resolved Hide resolved
darts/models/forecasting/regression_ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/regression_ensemble_model.py Outdated Show resolved Hide resolved
else None,
)
# update training_series attribute to make predict() behave as expected
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the same comment as I added for regression ensemble model.

darts/models/forecasting/regression_ensemble_model.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great now, thank for this PR @madtoinou 🚀

@dennisbader dennisbader merged commit 1929d9f into master Sep 16, 2023
8 of 9 checks passed
@dennisbader dennisbader deleted the feat/pretrained-ensemble branch September 16, 2023 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Released
3 participants