Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(training,rollout)!: Rollout Schedulers #46

Draft
wants to merge 14 commits into
base: develop
Choose a base branch
from

Conversation

HCookie
Copy link
Member

@HCookie HCookie commented Dec 20, 2024

Closes #14

Rollout Schedulers

Expand the ways to describe rollout, and provide an interface to schedule updates

New default rollout config

# length of the "rollout" window (see Keisler's paper)
rollout:
  _target_: anemoi.training.schedulers.rolllout.stepped.EpochStepped
  minimum: 1
  maximum: 12
  # increase rollout every n epochs
  every_n_epochs: 1
  # Control the incrementing of the rollout window
  increment:
    step:
      0: 0
      200000: 1 # After 200k steps, increment by 1 every 1 epoch

Can step by epoch, step, and control the increment based on either the step or epoch.

Additonally, formally add random steppers.

Todo

  • Integrate with the data loader
  • Ensure that the randomness is seeded appropriately
  • Randomness broadcast
  • Ensure restartability
  • Ability to change config

@HCookie HCookie self-assigned this Dec 20, 2024
@HCookie HCookie changed the title feat(rollout)!: Rollout Schedulers feat(training,rollout)!: Rollout Schedulers Dec 20, 2024
@HCookie HCookie added documentation Improvements or additions to documentation enhancement New feature or request labels Dec 20, 2024
Copy link
Contributor

@anaprietonem anaprietonem left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Started to go through the PR and left some comments! I still need to understand better some of the functionality so hope te questions makes sense. Thanks for this Harrison!

@@ -405,6 +419,7 @@ def train(self) -> None:
use_distributed_sampler=False,
profiler=self.profiler,
enable_progress_bar=self.config.diagnostics.enable_progress_bar,
reload_dataloaders_every_n_epochs=self._need_to_reload_dataloaders,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why the type of reload_dataloaders_every_n_epochs has been changes from int to bool? Wondering because looking at PTL docs the type of flag is int (https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/trainer/trainer.html#Trainer.__init__) and that's used by the [data_connector](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/trainer/connectors/data_connector.py#L51)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to python duck typing I suppose a True will evaluate to 1 when used as an int. So it works, but if it is clearer, I can change it to an int.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now, thanks for the clarification! all good then

"""
self._update_rollout(trainer, pl_module, epoch=checkpoint["epoch"], step=checkpoint["global_step"])

def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *_) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone sets the limit_batches for validation to 0, to skip validation this hook wouldn't be triggered?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll need to take a look.

@@ -451,7 +451,7 @@ def rollout_step(
)
assert batch.shape[1] >= rollout + self.multi_step, msg

for rollout_step in range(rollout or self.rollout):
for rollout_step in range(rollout or int(self.rollout)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to pass rollout as a value to this function is you are already making it a self variable of the GraphForecaster ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing it as an arg allows for override, which I think is used in some of the callbacks.


def on_train_epoch_start(self) -> None:
# Sync the rollout at the start of each epoch
# Cannot use stepping due to inconsistent behaviour with Pytorch Lightning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to understand it better, what do you mean here with the inconsistent behaviour with PTL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I use the step function, it gets triggered in sanity checking, and other places where I don't want it,

@FussyDuck
Copy link

FussyDuck commented Jan 9, 2025

CLA assistant check
All committers have signed the CLA.

increment:
step:
0: 0
200000: 1 # After 200k steps, increment by 1 every 1 epoch
Copy link
Contributor

@anaprietonem anaprietonem Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am probably just being slow but how does this interact with the limit batches? What would be the difference between doing the above, and the 'old configuration' with a limit batches of 200000?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The limit_batches ends the training, this will continue on, and then begin updating the rollout.

@@ -0,0 +1,8 @@
# (C) Copyright 2024 Anemoi contributors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing, but just a remainder that all of these headers would need to be updated to 2025 before merging

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even the code I wrote in 2024? I honestly have no idea

@@ -377,6 +377,20 @@ def strategy(self) -> DDPGroupStrategy:
static_graph=not self.config.training.accum_grad_batches > 1,
)

@cached_property
def _need_to_reload_dataloaders(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you seen any difference in terms of runtime from using this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, yes, I need to still quantify it, but it does slow down the transition between epochs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request training
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rollout Scheduling
3 participants