Skip to content

Latest commit

 

History

History
508 lines (368 loc) · 20 KB

checkpointer.rst

File metadata and controls

508 lines (368 loc) · 20 KB

Checkpointing in torchtune

This deep-dive will walk you through the design and behavior of the checkpointer and associated utilities.

.. grid:: 1

    .. grid-item-card:: :octicon:`mortar-board;1em;` What this deep-dive will cover:

      * Checkpointer design for torchtune
      * Checkpoint formats and how we handle them
      * Checkpointing scenarios: Intermediate vs Final and LoRA vs Full-finetune


Overview

torchtune checkpointers are designed to be composable components which can be plugged into any recipe - training, evaluation or generation. Each checkpointer supports a set of models and scenarios making these easy to understand, debug and extend.

Before we dive into the checkpointer in torchtune, let's define some concepts.


Checkpoint Format

In this deep-dive, we'll talk about different checkpoint formats and how torchtune handles them. Let's take a close look at these different formats.

Very simply put, the format of a checkpoint is dictated by the state_dict and how this is stored in files on disk. Each weight is associated with a string key that identifies it in the state dict. If the string identifier of the keys in the stored checkpoints don't match up exactly with those in the model definition, you'll either run into explicit errors (loading the state dict will raise an exception) or worse - silent errors (loading will succeed but training or inference will not work as expected). In addition to the keys lining up, you also need the shapes of the weights (values in the state_dict) to match up exactly with those expected by the model definition.

Let's look at the two popular formats for Llama 3.2.

Meta Format

This is the format supported by the official Llama 3.2 implementation. When you download the Llama 3.2 3B model from the meta-llama website, you'll get access to a single .pth checkpoint file. You can inspect the contents of this checkpoint easily with torch.load

>>> import torch
>>> state_dict = torch.load('consolidated.00.pth', mmap=True, weights_only=True, map_location='cpu')
>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>    print(f'{key}: {value.shape}')

tok_embeddings.weight: torch.Size([128256, 3072])
...
...
>>> print(len(state_dict.keys()))
255

The state_dict contains 255 keys, including an input embedding table called tok_embeddings. The model definition for this state_dict expects an embedding layer with 128256 tokens each having a embedding with dim of 3072.

HF Format

This is the most popular format within the Hugging Face Model Hub and is the default format in every torchtune config. This is also the format you get when you download the llama3.2 model from the Llama-3.2-3B-Instruct repo.

The first big difference is that the state_dict is split across two .safetensors files. To correctly load the checkpoint, you'll need to piece these files together. Let's inspect one of the files.

>>> from safetensors import safe_open
>>> state_dict = {}
>>> with safe_open("model-00001-of-00002.safetensors", framework="pt", device="cpu") as f:
>>>     for k in f.keys():
>>>         state_dict[k] = f.get_tensor(k)

>>> # inspect the keys and the shapes of the associated tensors
>>> for key, value in state_dict.items():
>>>     print(f'{key}: {value.shape}')

model.embed_tokens.weight: torch.Size([128256, 3072])
...
...
>>> print(len(state_dict.keys()))
187

Not only does the state_dict contain fewer keys (expected since this is one of two files), but the embedding table is called model.embed_tokens instead of tok_embeddings. This mismatch in names will cause an exception when you try to load the state_dict. The size of this layer is the same between the two, which is as expected.


As you can see, if you're not careful you'll likely end up making a number of errors just during checkpoint load and save. The torchtune checkpointer makes this less error-prone by managing state dicts for you. torchtune is designed to be "state-dict invariant".

  • When loading, torchtune accepts checkpoints from multiple sources in multiple formats. You don't have to worry about explicitly converting checkpoints every time you run a recipe.
  • When saving, torchtune produces checkpoints in the same format as the source. This includes converting the state_dict back into the original form and splitting the keys and weights across the same number of files.

One big advantage of being "state-dict invariant" is that you should be able to use fine-tuned checkpoints from torchtune with any post-training tool (quantization, eval, inference) which supports the source format, without any code changes OR conversion scripts. This is one of the ways in which torchtune interoperates with the surrounding ecosystem.

Note

To be state-dict "invariant" in this way, the load_checkpoint and save_checkpoint methods of each checkpointer make use of weight converters which correctly map weights between checkpoint formats. For example, when loading weights from Hugging Face, we apply a permutation to certain weights on load and save to ensure checkpoints behave exactly the same. To further illustrate this, the Llama family of models uses a generic weight converter function whilst some other models like Phi3 have their own conversion functions which can be found within their model folders.


Handling different Checkpoint Formats

torchtune supports three different :ref:`checkpointers<checkpointing_label>`, each of which supports a different checkpoint format.

This checkpointer reads and writes checkpoints in a format which is compatible with the transformers framework from Hugging Face. As mentioned above, this is the most popular format within the Hugging Face Model Hub and is the default format in every torchtune config.

For this checkpointer to work correctly, we assume that checkpoint_dir contains the necessary checkpoint and json files. The easiest way to make sure everything works correctly is to use the following flow:

  • Download the model from the HF repo using tune download. This will ignore the "pth" files, since we will be loading the "safetensors".


    tune download meta-llama/Llama-3.2-3B-Instruct \
    --output-dir /tmp/Llama-3.2-3B-Instruct \
    --ignore-patterns "original/consolidated.00.pth"
  • Use output_dir specified here as the checkpoint_dir argument for the checkpointer.


The following snippet explains how the HFCheckpointer is setup in torchtune config files.

checkpointer:

    # checkpointer to use
    _component_: torchtune.training.FullModelHFCheckpointer

    # directory with the checkpoint files
    # this should match the folder you used when downloading the model
    checkpoint_dir: /tmp/Llama-3.2-3B-Instruct

    # checkpoint files. For the Llama-3.2-3B-Instruct model we have
    # 2 .safetensor files. The checkpointer takes care of sorting
    # by id and so the order here does not matter
    checkpoint_files: [
        model-00001-of-00002.safetensors,
        model-00002-of-00002.safetensors,
    ]

    # dir for saving the output checkpoints
    output_dir: <output_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA3_2

# set to True if restarting training. More on that later.
resume_from_checkpoint: False

Note

Checkpoint conversion to and from HF's format requires access to model params which are read directly from the config.json file. This helps ensure we either load the weights correctly or error out in case of discrepancy between the HF checkpoint file and torchtune's model implementations. This json file is downloaded from the hub along with the model checkpoints.


This checkpointer reads and writes checkpoints in a format which is compatible with the original meta-llama github repository.

For this checkpointer to work correctly, we assume that checkpoint_dir contains the necessary checkpoint and json files. The easiest way to make sure everything works correctly is to use the following flow:

  • Download the model from the HF repo using tune download. By default, this will ignore the "safetensors" files.


    tune download meta-llama/Llama-3.2-3B-Instruct \
    --output-dir /tmp/Llama-3.2-3B-Instruct \
    --ignore-patterns "*.safetensors"
  • Use output_dir above as the checkpoint_dir for the checkpointer.


The following snippet explains how the MetaCheckpointer is setup in torchtune config files.

checkpointer:

    # checkpointer to use
    _component_: torchtune.training.FullModelMetaCheckpointer

    # directory with the checkpoint files
    # this should match the folder you used when downloading the model
    checkpoint_dir: <checkpoint_dir>

    # checkpoint files. For the llama3.2 3B model we have
    # a single .pth file
    checkpoint_files: [consolidated.00.pth]

    # dir for saving the output checkpoints.
    output_dir: <checkpoint_dir>

    # model_type which specifies how to convert the state_dict
    # into a format which torchtune understands
    model_type: LLAMA3_2

# set to True if restarting training. More on that later.
resume_from_checkpoint: False

This checkpointer reads and writes checkpoints in a format that is compatible with torchtune's model definition. This does not perform any state_dict conversions and is currently used either for testing or for loading quantized models for generation.


Checkpoint Output

Congrats for getting this far! Let's say you have followed our :ref:`End-to-End Workflow with torchtune <e2e_flow>` and trained a llama 3.2 3B using one of our LoRA recipes.

Now let's visualize the outputs. A simple way of doing this is by running tree -a path/to/outputdir, which should show something like the tree below. There are 3 types of folders:

  1. recipe_state: Holds recipe_state.pt with the information necessary to restart your training run from the last intermediate epoch. More on that later;
  2. logs: Outputs of your metric_logger, if any;
  3. epoch_{}: Contains your trained model weights plus model metadata. If running inference or pushing to a model hub, you should use this folder directly;

Note

For each epoch, we copy the contents of the original checkpoint folder, excluding the original checkpoints and large files. These files are lightweight, mostly configuration files, and make it easier for the user to use the epoch folders directly in downstream applications.

For more details about each file, please check the End-to-End tutorial mentioned above.

>>> tree -a /tmp/torchtune/llama3_2_3B/lora_single_device
/tmp/torchtune/llama3_2_3B/lora_single_device
├── epoch_0
│   ├── adapter_config.json
│   ├── adapter_model.pt
│   ├── adapter_model.safetensors
│   ├── config.json
│   ├── ft-model-00001-of-00002.safetensors
│   ├── ft-model-00002-of-00002.safetensors
│   ├── generation_config.json
│   ├── LICENSE.txt
│   ├── model.safetensors.index.json
│   ├── original
│   │   ├── orig_params.json
│   │   ├── params.json
│   │   └── tokenizer.model
│   ├── original_repo_id.json
│   ├── README.md
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── tokenizer.json
│   └── USE_POLICY.md
├── epoch_1
│   ├── adapter_config.json
│   ├── adapter_model.pt
│   ├── adapter_model.safetensors
│   ├── config.json
│   ├── ft-model-00001-of-00002.safetensors
│   ├── ft-model-00002-of-00002.safetensors
│   ├── generation_config.json
│   ├── LICENSE.txt
│   ├── model.safetensors.index.json
│   ├── original
│   │   ├── orig_params.json
│   │   ├── params.json
│   │   └── tokenizer.model
│   ├── original_repo_id.json
│   ├── README.md
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── tokenizer.json
│   └── USE_POLICY.md
├── logs
│   └── log_1734652101.txt
└── recipe_state
    └── recipe_state.pt

Intermediate vs Final Checkpoints

torchtune Checkpointers support two checkpointing scenarios:

End-of-training Checkpointing

The model weights at the end of a completed training run are written out to file. The checkpointer ensures that the output checkpoint files have the same keys as the input checkpoint file used to begin training. The checkpointer also ensures that the keys are partitioned across the same number of files as the original checkpoint. The output state dict has the following standard format:

{
    "key_1": weight_1,
    "key_2": weight_2,
    ...
}

Mid-training Chekpointing.

If checkpointing in the middle of training, the output checkpoint needs to store additional information to ensure that subsequent training runs can be correctly restarted. In addition to the model checkpoint files, we output a recipe_state.pt file for intermediate checkpoints. These are currently output at the end of each epoch, and contain information such as optimizer state, number of epochs completed etc.

To prevent us from flooding output_dir with checkpoint files, the recipe state is overwritten at the end of each epoch.

The output state dicts have the following formats:

Model:
    {
        "key_1": weight_1,
        "key_2": weight_2,
        ...
    }

Recipe State:
    {
        "optimizer": ...,
        "epoch": ...,
        ...
    }

Resuming from checkpoint - Full Finetuning

Sometimes our training is interrupted for some reason. To restart training from a previous checkpoint file, you'll need to update the following fields in your configs:

resume_from_checkpoint: Set it to True;

checkpoint_files: change the path to epoch_{YOUR_EPOCH}/ft-model={}-of-{}.safetensors;

Notice that we do not change our checkpoint_dir or output_dir. Since we are resuming from checkpoint, we know to look for it in the output_dir.

checkpointer:
    # checkpoint files. Note that you will need to update this
    # section of the config with the intermediate checkpoint files
    checkpoint_files: [
        epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
        epoch_{YOUR_EPOCH}/ft-model-00001-of-00002.safetensors,
    ]

# set to True if restarting training
resume_from_checkpoint: True

Resuming from checkpoint - LoRA Finetuning

Similarly to full finetuning, we will also only need to modify two fields: resume_from_checkpoint and adapter_checkpoint, which will be loaded from output_dir. We do not have to modify checkpoint_files, because the base model being loaded is still the same.

checkpointer:

    # adapter_checkpoint. Note that you will need to update this
    # section of the config with the intermediate checkpoint files
    adapter_checkpoint: epoch_{YOUR_EPOCH}/adapter_model.safetensors

# set to True if restarting training
resume_from_checkpoint: True

# set to True to save only the adapter weights
# it does not influence resuming_from_checkpointing
save_adapter_weights_only: False

Note

In torchtune, we output both the adapter weights and the full model merged weights for LoRA. The merged checkpoint is a convenience, since it can be used without having special tooling to handle the adapters. However, they should not be used when resuming training, as loading the merged weights + adapter would be an error. Therefore, when resuming for LoRA, we will take the original untrained weigths from checkpoint dir, and the trained adapters from output_dir. For more details, take a look at our :ref:`LoRA Finetuning Tutorial <lora_finetune_label>`.

Note

Additionally, by setting the option save_adapter_weights_only, you can choose to only save the adapter weights. This reduces the amount of storage and time needed to save the checkpoint, but has no influence over resuming from checkpoint.


Putting this all together

Let's now put all of this knowledge together! We'll load some checkpoints, create some models and run a simple forward.

For this section we'll use the Llama-3.2-3B-Instruct model in HF format.

import torch
from torchtune.models.llama3_2 import llama3_2_3b
from torchtune.training import FullModelHFCheckpointer

# Set the right directory and files
checkpoint_dir = "/tmp/Llama-3.2-3B-Instruct/"
output_dir = "/tmp/torchtune/llama3_2_3B/full_single_device"

pytorch_files = [
    "model-00001-of-00002.safetensors",
    "model-00002-of-00002.safetensors",
]

# Set up the checkpointer and load state dict
checkpointer = FullModelHFCheckpointer(
    checkpoint_dir=checkpoint_dir,
    checkpoint_files=pytorch_files,
    output_dir=output_dir,
    model_type="LLAMA3_2",
)
torchtune_sd = checkpointer.load_checkpoint()

# Setup the model and the input
model = llama3_2_3b()

# Model weights are stored with the key="model"
model.load_state_dict(torchtune_sd["model"])
model.to("cuda")

# We have 128256 vocab tokens; lets generate an input with 24 tokens
x = torch.randint(0, 128256, (1, 24), dtype=torch.long, device="cuda")

tensor([[[ 1.4299,  1.1658,  4.2459,  ..., -2.3259, -2.3262, -2.3259],
        [ 6.5942,  7.2284,  2.4090,  ..., -6.0129, -6.0121, -6.0127],
        [ 5.6462,  4.8787,  4.0950,  ..., -4.6460, -4.6455, -4.6457],
        ...,
        [-0.4156, -0.0626, -0.0362,  ..., -3.6432, -3.6437, -3.6427],
        [-0.5679, -0.6902,  0.5267,  ..., -2.6137, -2.6138, -2.6127],
        [ 0.3688, -0.1350,  1.1764,  ..., -3.4563, -3.4565, -3.4564]]],
    device='cuda:0')

You can do this with any model supported by torchtune. You can find a full list of models and model builders :ref:`here <models>`.

We hope this deep-dive provided a deeper insight into the checkpointer and associated utilities in torchtune. Happy tuning!