You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Goal: Manage the checkpointing of models in torchtune
Current solution
Our current solution is based on the format in which the checkpoint arrives to us. See below for our three checkpointer classes. These handle the loading of pretrained weights and saving of trained weights during and after training. In addition, they also handle saving and loading the state of the recipe for use when resuming training. Notably, these classes expect pointers to local files and only write out to local files.
# Load and save a model from Hugging Face w/ option to save back# into safetensorsFullModelHFCheckpointer()
# Load and save a model in the "xlformers" (Meta) format for Llama checkpointsFullModelMetaCheckpointer()
# Internal format of the checkpoint that exactly matches torchtune's# implementation of modelsFullModelTorchTuneCheckpointer()
These checkpointers are instantiated in all of our recipes and users interact with them through the load_checkpoint() and save_checkpoint() APIs. You can read more in depth about our current checkpointer design in the deep dives section of our documentation.
Some background
When we were starting designing the checkpointing functionality in torchtune, we knew we wanted to support more checkpoints than those that came directly from the Hugging Face Hub. Actually, it was a requirement b/c at that point, Meta hadn't easily integrated with the Hub for their Llama checkpoints so users had to download those weights separately. Since there was wide support in the OSS community for the Llama models and we didn't want to give Meta's legal team a heart attack by hosting our own weights, we needed to support loading directly from files in the original format.
The decision to bundle the recipe state with the checkpointer largely followed the intuition that this information for restarting a finetuning run belonged in the same location as the model weights that were to be used. Therefore, bundling the recipe state with the checkpointer class made it logically simple in that there was only one object writing out information (weights, state, configs) that could be used later.
Aside: what's the deal with the 'FullModel' prefix? Again when we first started, we had wanted to support directly through the checkpointer interfaces a way to load both adapters from LoRA and full models. These checkpointers were only supposed to be for the full model. Since we never actually got around to loading the adapters and instead made it very simple to load your trained adapters into an existing Hugging Face model + merged all weights back in before saving, this name was a little overly-descriptive.
How is the design not meeting the user's needs?
The completely modular concept for checkpointing is hard to understand. Checkpointing in many other libraries does not (in the user's mind) handle both loading and saving of checkpoints. The prime example of this is how Hugging Face's from_pretrained() API is used immediately to load pre-trained weights and "checkpointing" is referred to as simply a way to save weights. We actually inadvertently also utilize this design by requiring a path for the tokenizer model weights, causing more confusion for torchtune users. What's worse, this checkpointing abstraction is inconsistent with our design principles. As the library has grown quickly, the Checkpointer classes have diverged with: "modular building blocks over monolithic components". The implementations now contain messy, model-specific state_dict conversion logic that makes it difficult to understand and debug.
We centered our checkpointing design around the concept of "epochs". We doubled down on this by only allowing checkpointing to occur at epoch boundaries, meaning that users had to iterate through their entire data source in order to checkpoint. With the size of datasets only ever increasing, this paradigm no longer makes sense. (No param to control save checkpoints every N steps ? #988, Save intermediate checkpoints during training #1107)
All of the above information has been synthesized into the following requirements, user-experience considerations, and restrictions. There are several aspects to a successful checkpoint management solution and therefore I'll be posting individual RFCs for each of those in order to collect more targeted feedback. This Issue will serve as general context for all of those RFCs. Feel free to leave comments on this specific RFC if you have comments on the problems identified and requirements collected.
Requirements
Easy to use from a config and from code
Support loading and saving in various formats (Transformers, Meta, DCP)
Easy to implement for new model architectures
Seamless end-to-end experience for tasks like generation and evaluation
Additional UX considerations
Post-training recipes like PPO and KD where multiple models are involved and therefore multiple checkpoints need to be handled
Restrictions
State dictionaries must follow the standard PyTorch format and utilize PyTorch core and distributed APIs.
Backward compatibility with existing torchtune functionality
Cannot delete torchtune and start from scratch (sad!)
The text was updated successfully, but these errors were encountered:
Goal: Manage the checkpointing of models in torchtune
Current solution
Our current solution is based on the format in which the checkpoint arrives to us. See below for our three checkpointer classes. These handle the loading of pretrained weights and saving of trained weights during and after training. In addition, they also handle saving and loading the state of the recipe for use when resuming training. Notably, these classes expect pointers to local files and only write out to local files.
These checkpointers are instantiated in all of our recipes and users interact with them through the
load_checkpoint()
andsave_checkpoint()
APIs. You can read more in depth about our current checkpointer design in the deep dives section of our documentation.Some background
When we were starting designing the checkpointing functionality in torchtune, we knew we wanted to support more checkpoints than those that came directly from the Hugging Face Hub. Actually, it was a requirement b/c at that point, Meta hadn't easily integrated with the Hub for their Llama checkpoints so users had to download those weights separately. Since there was wide support in the OSS community for the Llama models and we didn't want to give Meta's legal team a heart attack by hosting our own weights, we needed to support loading directly from files in the original format.
The decision to bundle the recipe state with the checkpointer largely followed the intuition that this information for restarting a finetuning run belonged in the same location as the model weights that were to be used. Therefore, bundling the recipe state with the checkpointer class made it logically simple in that there was only one object writing out information (weights, state, configs) that could be used later.
How is the design not meeting the user's needs?
from_pretrained()
API is used immediately to load pre-trained weights and "checkpointing" is referred to as simply a way to save weights. We actually inadvertently also utilize this design by requiring a path for the tokenizer model weights, causing more confusion for torchtune users. What's worse, this checkpointing abstraction is inconsistent with our design principles. As the library has grown quickly, the Checkpointer classes have diverged with: "modular building blocks over monolithic components". The implementations now contain messy, model-specific state_dict conversion logic that makes it difficult to understand and debug.Looking forward
All of the above information has been synthesized into the following requirements, user-experience considerations, and restrictions. There are several aspects to a successful checkpoint management solution and therefore I'll be posting individual RFCs for each of those in order to collect more targeted feedback. This Issue will serve as general context for all of those RFCs. Feel free to leave comments on this specific RFC if you have comments on the problems identified and requirements collected.
Requirements
Additional UX considerations
Restrictions
The text was updated successfully, but these errors were encountered: