-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Stateful Callbacks #29666
Introduce Stateful Callbacks #29666
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@amyeroberts let me know if we need more docs or if I should rewrite the code a bit if it feels to complex/magical :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for much for working on this - and for the detailed description in the PR!
I understand the choice to match with accelerate patterns, but I think we might want to iterate on that. Looking at the behaviour in the tests I found it surprising, especially in the case when one can pass in a callback with default and it's previous state is loaded in: it seems to go against needing to fully specify the same training to resume.
I'm also not sure on the loading of args vs attributes - it seems attributes we'd definitely want to store, but args we might want to be able to overwrite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on this! I like this state control + warnings
return | ||
# Callback states are stored in stateful_callbacks | ||
not_found = [] | ||
for stored_callback, data in self.state.stateful_callbacks.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we ever have more than one callback of the same type e.g. different metric loggers? Only concern with logic below is that we might load state from one class into another
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call Amy, just in case we can, I've gone ahead and added logic in + a test to verify :)
@amyeroberts I'd like one more review please, just to make sure the multiple-versions logic makes sense to you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating and for all the tests - looks great!
assert len(cbs) == 2 | ||
assert cbs[0].my_test_state == "first" | ||
assert cbs[1].my_test_state == "second" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :)
stateful_callbacks[name] = [stateful_callbacks[name]] | ||
stateful_callbacks[name].append(callback.state()) | ||
else: | ||
stateful_callbacks[name] = callback.state() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for not just always adding as a list directly. If we use a default dict we can just do
# Saveable callbacks get stored as dict of kwargs
stateful_callbacks = default_dict(list)
for callback in self.stateful_callbacks:
if not isinstance(callback, (ExportableState)):
raise TypeError(
f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
)
name = callback.__class__.__name__
stateful_callbacks[name].append(callback.state())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python serialization does not like defaultdict, come to find out 😢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥲
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
e46a42e
to
7e10394
Compare
* Introduce saveable callbacks * Add note * Test for non-present and flag * Support early stopping and refusing to train further * Update docstring * More saving * Import oopsie * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * Make it go through TrainerArguments * Document * Fix test * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * Rework to allow for duplicates * CLean * Fix failing tests --------- Co-authored-by: amyeroberts <[email protected]>
Hi, it looks like the state of stateful callbacks are not updated before saving state to trainer_state.json on _save_checkpoint method: transformers/src/transformers/trainer.py Lines 2917 to 2920 in fe008d6
Only TrainerControl state is saved, instead of getting all ExportableState callbacks like here: transformers/src/transformers/trainer.py Lines 674 to 675 in fe008d6
So it should be:
Do you agree? |
What does this PR do?
This PR builds a foundation for stateful callbacks inside
Trainer
. Right now these are isolated to the core callbacks that exist, but can be expanded upon as needed later.Specifically a user must enable
restore_callback_states
in theTrainingArguments
to enable this behavior.System Design:
To keep with not having to deal with pickles/interact with the
TrainerState
properly (where callback data should be stored, likely), we need a way to recreate the exact state of aCallback
. To do so callbacks should implement asave_state
function which will create adict
ofargs
andattributes
that should be set.From this we can then recreate the constructor of an existing callback with these new options. (this is done in case there is logic in a callback's
__init__
that could be needed).For example, here is the one for
EarlyStoppingCallback
:What about Callbacks that I forget to add back in?
As this relies on a similar state, I chose to keep aligned with how we deal with states/checkpointing in Accelerate, wherein users must recreate the exact same initial scenario when resuming training.
Or, in other words:
If you included the
EarlyStoppingCheckpoint
initially, but when you doresume_from_checkpoint
you did not include that callback, we do not magically init and resume those states. Instead we will simply keep going and give you a nice warning message saying that things went amiss/some callback states weren't resumed.Limitations
As we're maintaining that these exist in the
TrainerState
(which again, makes sense for us to use this here), items should be JSON-serializable.The aim here is most likely not many callbacks actually need this, so only a few as needed over time can be added to this without much complexity.
Fixes #28544
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@amyeroberts