Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint management in torchtune #2070

Open
joecummings opened this issue Nov 25, 2024 · 0 comments
Open

Checkpoint management in torchtune #2070

joecummings opened this issue Nov 25, 2024 · 0 comments
Labels
rfc Request for comments

Comments

@joecummings
Copy link
Contributor

joecummings commented Nov 25, 2024

Goal: Manage the checkpointing of models in torchtune

Current solution

Our current solution is based on the format in which the checkpoint arrives to us. See below for our three checkpointer classes. These handle the loading of pretrained weights and saving of trained weights during and after training. In addition, they also handle saving and loading the state of the recipe for use when resuming training. Notably, these classes expect pointers to local files and only write out to local files.

# Load and save a model from Hugging Face w/ option to save back
# into safetensors
FullModelHFCheckpointer()

# Load and save a model in the "xlformers" (Meta) format for Llama checkpoints
FullModelMetaCheckpointer()

# Internal format of the checkpoint that exactly matches torchtune's
# implementation of models
FullModelTorchTuneCheckpointer()

These checkpointers are instantiated in all of our recipes and users interact with them through the load_checkpoint() and save_checkpoint() APIs. You can read more in depth about our current checkpointer design in the deep dives section of our documentation.

Some background

When we were starting designing the checkpointing functionality in torchtune, we knew we wanted to support more checkpoints than those that came directly from the Hugging Face Hub. Actually, it was a requirement b/c at that point, Meta hadn't easily integrated with the Hub for their Llama checkpoints so users had to download those weights separately. Since there was wide support in the OSS community for the Llama models and we didn't want to give Meta's legal team a heart attack by hosting our own weights, we needed to support loading directly from files in the original format.

The decision to bundle the recipe state with the checkpointer largely followed the intuition that this information for restarting a finetuning run belonged in the same location as the model weights that were to be used. Therefore, bundling the recipe state with the checkpointer class made it logically simple in that there was only one object writing out information (weights, state, configs) that could be used later.

Aside: what's the deal with the 'FullModel' prefix? Again when we first started, we had wanted to support directly through the checkpointer interfaces a way to load both adapters from LoRA and full models. These checkpointers were only supposed to be for the full model. Since we never actually got around to loading the adapters and instead made it very simple to load your trained adapters into an existing Hugging Face model + merged all weights back in before saving, this name was a little overly-descriptive.

How is the design not meeting the user's needs?

  1. The completely modular concept for checkpointing is hard to understand. Checkpointing in many other libraries does not (in the user's mind) handle both loading and saving of checkpoints. The prime example of this is how Hugging Face's from_pretrained() API is used immediately to load pre-trained weights and "checkpointing" is referred to as simply a way to save weights. We actually inadvertently also utilize this design by requiring a path for the tokenizer model weights, causing more confusion for torchtune users. What's worse, this checkpointing abstraction is inconsistent with our design principles. As the library has grown quickly, the Checkpointer classes have diverged with: "modular building blocks over monolithic components". The implementations now contain messy, model-specific state_dict conversion logic that makes it difficult to understand and debug.
  2. We centered our checkpointing design around the concept of "epochs". We doubled down on this by only allowing checkpointing to occur at epoch boundaries, meaning that users had to iterate through their entire data source in order to checkpoint. With the size of datasets only ever increasing, this paradigm no longer makes sense. (No param to control save checkpoints every N steps ? #988, Save intermediate checkpoints during training #1107)
  3. Lack of attention-to-detail. There's many small things about checkpointing that are frustrating for torchtune users and because of bandwidth constraints, we've quite frankly cut some corners. These include config bloat from specifying every single file (partially mitigated by @ebsmothers's implementation for FormattedFiles), not utilizing the Hugging Face cache, which means we use excess memory on a user's system to store model weights, resume_from_checkpoint doesn't actually work in all cases b/c we don't checkpoint the learning rate and we only checkpoint the recipe state in "intermediate" checkpoints, and checkpointing takes a looooooooooooong time.

Looking forward

All of the above information has been synthesized into the following requirements, user-experience considerations, and restrictions. There are several aspects to a successful checkpoint management solution and therefore I'll be posting individual RFCs for each of those in order to collect more targeted feedback. This Issue will serve as general context for all of those RFCs. Feel free to leave comments on this specific RFC if you have comments on the problems identified and requirements collected.

Requirements

  • Easy to use from a config and from code
  • Support loading and saving in various formats (Transformers, Meta, DCP)
  • Easy to implement for new model architectures
  • Seamless end-to-end experience for tasks like generation and evaluation

Additional UX considerations

  • Post-training recipes like PPO and KD where multiple models are involved and therefore multiple checkpoints need to be handled

Restrictions

  • State dictionaries must follow the standard PyTorch format and utilize PyTorch core and distributed APIs.
  • Backward compatibility with existing torchtune functionality
  • Cannot delete torchtune and start from scratch (sad!)
@joecummings joecummings added the rfc Request for comments label Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc Request for comments
Projects
None yet
Development

No branches or pull requests

1 participant