Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Stateful Callbacks #29666

Merged
merged 15 commits into from
Apr 25, 2024
Merged

Introduce Stateful Callbacks #29666

merged 15 commits into from
Apr 25, 2024

Conversation

muellerzr
Copy link
Contributor

@muellerzr muellerzr commented Mar 15, 2024

What does this PR do?

This PR builds a foundation for stateful callbacks inside Trainer. Right now these are isolated to the core callbacks that exist, but can be expanded upon as needed later.

Specifically a user must enable restore_callback_states in the TrainingArguments to enable this behavior.

System Design:

To keep with not having to deal with pickles/interact with the TrainerState properly (where callback data should be stored, likely), we need a way to recreate the exact state of a Callback. To do so callbacks should implement a save_state function which will create a dict of args and attributes that should be set.

From this we can then recreate the constructor of an existing callback with these new options. (this is done in case there is logic in a callback's __init__ that could be needed).

For example, here is the one for EarlyStoppingCallback:

        def save(self) -> dict:
            return {
                "args": {
                    "early_stopping_patience": self.early_stopping_patience,
                    "early_stopping_threshold": self.early_stopping_threshold,
                },
                "attributes": {
                    "early_stopping_patience_counter": self.early_stopping_patience_counter,
                }
            }

What about Callbacks that I forget to add back in?

As this relies on a similar state, I chose to keep aligned with how we deal with states/checkpointing in Accelerate, wherein users must recreate the exact same initial scenario when resuming training.

Or, in other words:

If you included the EarlyStoppingCheckpoint initially, but when you do resume_from_checkpoint you did not include that callback, we do not magically init and resume those states. Instead we will simply keep going and give you a nice warning message saying that things went amiss/some callback states weren't resumed.

Limitations

As we're maintaining that these exist in the TrainerState (which again, makes sense for us to use this here), items should be JSON-serializable.

The aim here is most likely not many callbacks actually need this, so only a few as needed over time can be added to this without much complexity.

Fixes #28544

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts

@muellerzr muellerzr requested a review from amyeroberts March 15, 2024 04:56
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Contributor Author

@amyeroberts let me know if we need more docs or if I should rewrite the code a bit if it feels to complex/magical :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for much for working on this - and for the detailed description in the PR!

I understand the choice to match with accelerate patterns, but I think we might want to iterate on that. Looking at the behaviour in the tests I found it surprising, especially in the case when one can pass in a callback with default and it's previous state is loaded in: it seems to go against needing to fully specify the same training to resume.

I'm also not sure on the loading of args vs attributes - it seems attributes we'd definitely want to store, but args we might want to be able to overwrite?

src/transformers/trainer_callback.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_callback.py Show resolved Hide resolved
tests/trainer/test_trainer_callback.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_callback.py Show resolved Hide resolved
tests/trainer/test_trainer_callback.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_callback.py Show resolved Hide resolved
@muellerzr muellerzr requested review from ArthurZucker and amyeroberts and removed request for ArthurZucker March 25, 2024 15:07
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this! I like this state control + warnings

src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
return
# Callback states are stored in stateful_callbacks
not_found = []
for stored_callback, data in self.state.stateful_callbacks.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we ever have more than one callback of the same type e.g. different metric loggers? Only concern with logic below is that we might load state from one class into another

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call Amy, just in case we can, I've gone ahead and added logic in + a test to verify :)

@muellerzr muellerzr requested a review from amyeroberts April 4, 2024 16:21
@muellerzr
Copy link
Contributor Author

@amyeroberts I'd like one more review please, just to make sure the multiple-versions logic makes sense to you!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating and for all the tests - looks great!

Comment on lines +375 to +377
assert len(cbs) == 2
assert cbs[0].my_test_state == "first"
assert cbs[1].my_test_state == "second"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

stateful_callbacks[name] = [stateful_callbacks[name]]
stateful_callbacks[name].append(callback.state())
else:
stateful_callbacks[name] = callback.state()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for not just always adding as a list directly. If we use a default dict we can just do

            # Saveable callbacks get stored as dict of kwargs
            stateful_callbacks = default_dict(list)
            for callback in self.stateful_callbacks:
                if not isinstance(callback, (ExportableState)):
                    raise TypeError(
                        f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
                    )
                name = callback.__class__.__name__
                stateful_callbacks[name].append(callback.state())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python serialization does not like defaultdict, come to find out 😢

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥲

@muellerzr muellerzr force-pushed the muellerzr-checkpoint-callbacks branch from e46a42e to 7e10394 Compare April 25, 2024 13:11
@muellerzr muellerzr merged commit ad697f1 into main Apr 25, 2024
21 checks passed
@muellerzr muellerzr deleted the muellerzr-checkpoint-callbacks branch April 25, 2024 15:00
itazap pushed a commit that referenced this pull request May 14, 2024
* Introduce saveable callbacks

* Add note

* Test for non-present and flag

* Support early stopping and refusing to train further

* Update docstring

* More saving

* Import oopsie

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Make it go through TrainerArguments

* Document

* Fix test

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Rework to allow for duplicates

* CLean

* Fix failing tests

---------

Co-authored-by: amyeroberts <[email protected]>
@pedrobrs
Copy link
Contributor

pedrobrs commented Jul 19, 2024

Hi, it looks like the state of stateful callbacks are not updated before saving state to trainer_state.json on _save_checkpoint method:

if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

Only TrainerControl state is saved, instead of getting all ExportableState callbacks like here:
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)

So it should be:

if self.args.should_save: 
    for cb in [cb for cb in self.callback_handler.callbacks  + [self.control] if isinstance(cb, ExportableState)]:
        self.state.stateful_callbacks[cb.__class__.__name__] = cb.state()
    self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))    

Do you agree?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Early stopping patience does not work when resuming from checkpoint
4 participants