Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support wrapped models in grid search #2133

Conversation

andresliszt
Copy link

@andresliszt andresliszt commented Dec 25, 2023

Hello!, thanks for this amazing project. I am experiencing the same issue as the one mentioned in here . This pull request aims to add a way to include the arguments of the wrapped model in the grid, instead of just passing a list of instances of the wrapped model to the grid constructor

Fixes #2104 .

Summary

When the model key is found in the parameters dictionary passed to the gridsearch classmethod (meaning we are in the context of a model that wraps another), the class method expects either a list of wrapped model instances , or it expects a dictionary with a special key called model_class whose value is the class of the model to be wrapped. The other keys in the dictionary are the parameters that will be used to construct the grid dedicated to the wrapped model. Example

from sklearn.ensemble import RandomForestRegressor

from darts.models import RegressionModel
from darts.utils import timeseries_generation as tg

parameters = {
    "model": {
        "model_class": RandomForestRegressor,
        "min_samples_split": [2,3],
        "min_samples_leaf": [1,2],
    },
    "lags": [1,2,3],
}
series = tg.sine_timeseries(length=100)

RegressionModel.gridsearch(
    parameters=parameters, series=series, forecast_horizon=1
)

@andresliszt
Copy link
Author

andresliszt commented Dec 29, 2023

Interesting the test darts/tests/models/forecasting/test_transformer_model.py::TestTransformerModel::test_fit is failing with :
FileNotFoundError: [Errno 2] No such file or directory: '/var/folders/0t/rk2xtsbd6ws0rz79rbxkjrv00000gr/T/dartse01ues80/unittest-model-transformer/checkpoints/last-epoch=1.ckpt'

In the master branch i get the same error, so i think something with torch save method might be broken and is not related with this PR

@dennisbader
Copy link
Collaborator

Hi @andresliszt, the failing unit tests come from PyTorch Lightning version 2.1.3.
The version changed saving the last checkpoint as a symlink which brakes our tests.
They already reverted this here and will be released in the next version.

For now I'll make a PR to set an an upper cap on pytorch_lightning<=2.1.2, and will relax it once they release the new version.

@codecov-commenter
Copy link

codecov-commenter commented Dec 29, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.91%. Comparing base (3115bb6) to head (dbc31d5).

Current head dbc31d5 differs from pull request most recent head 1ec6e12

Please upload reports for the commit 1ec6e12 to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2133      +/-   ##
==========================================
+ Coverage   93.77%   93.91%   +0.14%     
==========================================
  Files         138      135       -3     
  Lines       14647    13397    -1250     
==========================================
- Hits        13735    12582    -1153     
+ Misses        912      815      -97     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@madtoinou madtoinou self-requested a review as a code owner June 19, 2024 08:39
@madtoinou madtoinou mentioned this pull request Nov 12, 2024
3 tasks
@madtoinou
Copy link
Collaborator

Replaced by #2594, thanks a lot @andresliszt for this great PR! I pulled from your branch to apply the linting and attributed the changes to you of course :)

@madtoinou madtoinou closed this Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] gridsearch with RegressionModel
4 participants