Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learn rate scheduler #207

Merged
merged 18 commits into from
Nov 30, 2023
Merged

Learn rate scheduler #207

merged 18 commits into from
Nov 30, 2023

Conversation

jeffjennings
Copy link
Contributor

@jeffjennings jeffjennings commented Nov 29, 2023

NOTE: This PR should only be reviewed after #206 is merged into main and main merged into this branch (it was branched from there).

  • Optionally adds a scheduler to cross-val in order to dynamically adjust the learning rate.
  • Tracks the learning rate over the optimization loop for a given k-fold.
  • Adds test for training loop with scheduler.

Choice and defaults of scheduler:
Of the several torch.optim.lr_scheduler schedulers, ReduceLROnPlateau is one of the few that updates the learning rate according to some metric (doc here, not well-written), rather than just at some user-supplied number of epochs (which would not be at all general). I use this scheduler and give it the loss as the metric.

The scheduler has a threshold below which it judges the metric to no longer be changing, triggering a decrease in the learning rate; I keep this threshold at the default factor of 1e-4. For the factor to reduce the learning rate to, I found 0.995 is a good choice (reduces the learning rate to 99.5% of its previous value) -- the factor is an arg in the TrainTest class. This choice (for the 1 dsharp dataset I tested) keeps gradually reducing the brightness scale of the gradient image after the loss has plateaued by eye, while avoiding transient spikes in the loss at large iteration. The learning rate update is done at each iteration in the training loop after optimizer.step().

Because the scheduler gradually improves the gradient image even when the loss appears to plateau, I've also tested strengthening the convergence tolerance for the loss in the training loop. I've found the best result by setting the tolerance to 1 part in 10^5 (the loss must be changing by less than this for 10 iterations to be considered converged; previously 1 part in 10^3). This factor is an arg in TrainTest.

Copy link
Collaborator

@iancze iancze left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I wonder how robust the threshold settings will be to choice of ALMA dataset?

@jeffjennings
Copy link
Contributor Author

Yeah, right now I'm defining the torch.optim optimizer and torch.optim.lr_scheduler scheduler in CrossValidate to pass to TrainTest, but utimately it could be better to have the user pass in a template optimizer and scheduler that are re-initialized per kfold.

For now the scheduler has a threshold that's a relative factor of the loss, so I think it should be pretty flexible. The schedule factor I'm less certain will be flexible, but because of that I made it an arg of CrossValidate. I'll get a sense by running cross-val on more datasets.

@jeffjennings jeffjennings merged commit b6ba439 into main Nov 30, 2023
4 checks passed
@jeffjennings jeffjennings deleted the learn_rate_scheduler branch November 30, 2023 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants