-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(training,rollout)!: Rollout Schedulers #46
base: develop
Are you sure you want to change the base?
Conversation
- Allow for complex incrementing setup
- Calculation based not step based
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Started to go through the PR and left some comments! I still need to understand better some of the functionality so hope te questions makes sense. Thanks for this Harrison!
@@ -405,6 +419,7 @@ def train(self) -> None: | |||
use_distributed_sampler=False, | |||
profiler=self.profiler, | |||
enable_progress_bar=self.config.diagnostics.enable_progress_bar, | |||
reload_dataloaders_every_n_epochs=self._need_to_reload_dataloaders, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why the type of reload_dataloaders_every_n_epochs
has been changes from int
to bool
? Wondering because looking at PTL docs the type of flag is int
(https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/trainer/trainer.html#Trainer.__init__) and that's used by the [data_connector](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/trainer/connectors/data_connector.py#L51)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to python duck typing I suppose a True
will evaluate to 1
when used as an int. So it works, but if it is clearer, I can change it to an int.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now, thanks for the clarification! all good then
""" | ||
self._update_rollout(trainer, pl_module, epoch=checkpoint["epoch"], step=checkpoint["global_step"]) | ||
|
||
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *_) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone sets the limit_batches for validation to 0, to skip validation this hook wouldn't be triggered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I'll need to take a look.
@@ -451,7 +451,7 @@ def rollout_step( | |||
) | |||
assert batch.shape[1] >= rollout + self.multi_step, msg | |||
|
|||
for rollout_step in range(rollout or self.rollout): | |||
for rollout_step in range(rollout or int(self.rollout)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to pass rollout as a value to this function is you are already making it a self variable of the GraphForecaster
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing it as an arg allows for override, which I think is used in some of the callbacks.
|
||
def on_train_epoch_start(self) -> None: | ||
# Sync the rollout at the start of each epoch | ||
# Cannot use stepping due to inconsistent behaviour with Pytorch Lightning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to understand it better, what do you mean here with the inconsistent behaviour with PTL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I use the step function, it gets triggered in sanity checking, and other places where I don't want it,
increment: | ||
step: | ||
0: 0 | ||
200000: 1 # After 200k steps, increment by 1 every 1 epoch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am probably just being slow but how does this interact with the limit batches
? What would be the difference between doing the above, and the 'old configuration' with a limit batches of 200000?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The limit_batches
ends the training, this will continue on, and then begin updating the rollout.
@@ -0,0 +1,8 @@ | |||
# (C) Copyright 2024 Anemoi contributors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor thing, but just a remainder that all of these headers would need to be updated to 2025 before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even the code I wrote in 2024? I honestly have no idea
@@ -377,6 +377,20 @@ def strategy(self) -> DDPGroupStrategy: | |||
static_graph=not self.config.training.accum_grad_batches > 1, | |||
) | |||
|
|||
@cached_property | |||
def _need_to_reload_dataloaders(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you seen any difference in terms of runtime from using this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, yes, I need to still quantify it, but it does slow down the transition between epochs.
Closes #14
Rollout Schedulers
Expand the ways to describe rollout, and provide an interface to schedule updates
New default rollout config
Can step by epoch, step, and control the increment based on either the step or epoch.
Additonally, formally add random steppers.
Todo