-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RoPE scaling, document how to convert HuggingFace checkpoints (#111)
- Loading branch information
Showing
14 changed files
with
291 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
HuggingFace models | ||
================== | ||
|
||
The OLMo-core :class:`~olmo_core.train.Trainer` can be used to fine-tune language models from HuggingFace's ``transformers`` library. | ||
|
||
One way to do this would be to manually apply a data parallel wrapper (like DDP or FSDP) to your ``AutoModelForCausalLM`` and then pass that model directly to the trainer. The downside with this approach is that you won't be able to take advantage of all of the optimizations in this library. | ||
|
||
Instead we recommend converting your HuggingFace checkpoint into a format that can be loaded into an equivalent OLMo-core :class:`~olmo_core.nn.transformer.Transformer` model, when possible, using the functions provided by :mod:`olmo_core.distributed.checkpoint`. | ||
|
||
Below is an example that shows how to convert a Llama-3.2 checkpoint on HuggingFace into the right format for OLMo-core. | ||
It would be straight forward to adapt this script to convert in the other direction as well. | ||
|
||
.. seealso:: | ||
See the `train a Llama model <llama.html>`_ example to learn how to use OLMo-core's training API to pretrain or fine-tune any Llama-like language model. | ||
|
||
.. tab:: ``src/examples/huggingface/convert_checkpoint.py`` | ||
|
||
.. literalinclude:: ../../../src/examples/huggingface/convert_checkpoint.py | ||
:language: py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,17 @@ | ||
``Train a Llama model`` | ||
======================= | ||
Train a Llama model | ||
=================== | ||
|
||
The following snippet is the code from ``src/examples/llama/train.py``. | ||
It's a script meant to be launched via ``torchrun``. | ||
The following snippets can be found in `src/examples/llama/ <https://github.com/allenai/OLMo-core/tree/main/src/examples/llama>`_. | ||
The ``train.py`` script is meant to be launched via ``torchrun``. | ||
You can also use the :mod:`olmo_core.launch` API to quickly launch this script on Beaker. | ||
See below for an example of that. | ||
See the ``train_launch.py`` snippet for an example of that. | ||
|
||
``src/examples/llama/train.py`` | ||
------------------------------- | ||
.. tab:: ``train.py`` | ||
|
||
.. literalinclude:: ../../../src/examples/llama/train.py | ||
:language: py | ||
.. literalinclude:: ../../../src/examples/llama/train.py | ||
:language: py | ||
|
||
``src/examples/llama/train_launch.py`` | ||
-------------------------------------- | ||
.. tab:: ``train_launch.py`` | ||
|
||
.. literalinclude:: ../../../src/examples/llama/train_launch.py | ||
:language: py | ||
.. literalinclude:: ../../../src/examples/llama/train_launch.py | ||
:language: py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
``Train an nGPT model`` | ||
======================= | ||
Train an nGPT model | ||
=================== | ||
|
||
The following snippet is the code from ``src/examples/ngpt/train.py``. | ||
The following snippet can be found in `src/examples/ngpt/ <https://github.com/allenai/OLMo-core/tree/main/src/examples/ngpt>`_. | ||
It's a script meant to be launched via ``torchrun``. | ||
|
||
``src/examples/ngpt/train.py`` | ||
------------------------------ | ||
.. tab:: ``train.py`` | ||
|
||
.. literalinclude:: ../../../src/examples/ngpt/train.py | ||
:language: py | ||
.. literalinclude:: ../../../src/examples/ngpt/train.py | ||
:language: py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
""" | ||
Example script showing how you could convert model weights on HuggingFace for a Llama-3.2 model | ||
into a format that can be loaded by OLMo-core for fine-tuning. | ||
Note that this script is architecture-dependent, meaning it may only work for Llama-3.2 models on | ||
HuggingFace. | ||
""" | ||
|
||
import logging | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM | ||
|
||
from olmo_core.data.tokenizer import TokenizerConfig | ||
from olmo_core.distributed.checkpoint import load_model_and_optim_state, save_state_dict | ||
from olmo_core.io import clear_directory, dir_is_empty | ||
from olmo_core.nn.rope import RoPEScalingConfig | ||
from olmo_core.nn.transformer import TransformerConfig | ||
from olmo_core.utils import get_default_device, prepare_cli_environment | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
HF_MODEL = "meta-llama/Llama-3.2-1B" | ||
SAVE_PATH = f"/tmp/checkpoints/{HF_MODEL}" | ||
SAVE_OVERWRITE = False | ||
|
||
TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL) | ||
MODEL_CONFIG = TransformerConfig.llama3_1B( | ||
TOKENIZER_CONFIG.vocab_size, fused_ops=False, use_flash=False, rope_scaling=RoPEScalingConfig() | ||
) | ||
|
||
|
||
def convert_checkpoint() -> AutoModelForCausalLM: | ||
log.info(f"Loading HF checkpoint '{HF_MODEL}'") | ||
hf_model = AutoModelForCausalLM.from_pretrained(HF_MODEL) | ||
print(hf_model) | ||
|
||
if not dir_is_empty(SAVE_PATH): | ||
if SAVE_OVERWRITE: | ||
log.warning(f"Clearing existing checkpoint at '{SAVE_PATH}'") | ||
clear_directory(SAVE_PATH) | ||
else: | ||
log.warning(f"Using existing checkpoint at '{SAVE_PATH}'") | ||
return hf_model | ||
|
||
n_layers = len(hf_model.model.layers) | ||
state_dict = hf_model.state_dict() | ||
|
||
# Map old keys to OLMo-core keys. | ||
new_state_dict = { | ||
"embeddings.weight": state_dict.pop("model.embed_tokens.weight"), | ||
"lm_head.norm.weight": state_dict.pop("model.norm.weight"), | ||
"lm_head.w_out.weight": state_dict.pop("lm_head.weight"), | ||
} | ||
for block in range(n_layers): | ||
# Attention. | ||
new_state_dict[f"blocks.{block}.attention.w_q.weight"] = state_dict.pop( | ||
f"model.layers.{block}.self_attn.q_proj.weight" | ||
) | ||
new_state_dict[f"blocks.{block}.attention.w_k.weight"] = state_dict.pop( | ||
f"model.layers.{block}.self_attn.k_proj.weight" | ||
) | ||
new_state_dict[f"blocks.{block}.attention.w_v.weight"] = state_dict.pop( | ||
f"model.layers.{block}.self_attn.v_proj.weight" | ||
) | ||
new_state_dict[f"blocks.{block}.attention.w_out.weight"] = state_dict.pop( | ||
f"model.layers.{block}.self_attn.o_proj.weight" | ||
) | ||
|
||
# MLP. | ||
new_state_dict[f"blocks.{block}.feed_forward.w1.weight"] = state_dict.pop( | ||
f"model.layers.{block}.mlp.gate_proj.weight" | ||
) | ||
new_state_dict[f"blocks.{block}.feed_forward.w2.weight"] = state_dict.pop( | ||
f"model.layers.{block}.mlp.down_proj.weight" | ||
) | ||
new_state_dict[f"blocks.{block}.feed_forward.w3.weight"] = state_dict.pop( | ||
f"model.layers.{block}.mlp.up_proj.weight" | ||
) | ||
|
||
# Attention layer norm. | ||
new_state_dict[f"blocks.{block}.attention_norm.weight"] = state_dict.pop( | ||
f"model.layers.{block}.input_layernorm.weight" | ||
) | ||
|
||
# MLP layer norm. | ||
new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop( | ||
f"model.layers.{block}.post_attention_layernorm.weight" | ||
) | ||
|
||
assert len(state_dict) == 0 | ||
|
||
log.info(f"Saving converted model checkpoint '{SAVE_PATH}'...") | ||
save_state_dict(SAVE_PATH, {"model": new_state_dict}) | ||
|
||
return hf_model | ||
|
||
|
||
def validate_conversion(hf_model): | ||
log.info("Loading converted checkpoint for validation...") | ||
|
||
device = get_default_device() | ||
|
||
model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval() | ||
load_model_and_optim_state(SAVE_PATH, model) | ||
|
||
hf_model = hf_model.to(device).eval() | ||
|
||
B, T = 1, 120 | ||
input_ids = torch.randint(0, TOKENIZER_CONFIG.vocab_size, (B, T)).to(device) | ||
|
||
with torch.no_grad(): | ||
logits = model(input_ids=input_ids) | ||
hf_logits, *_ = hf_model(input_ids=input_ids, return_dict=False) | ||
torch.testing.assert_close(hf_logits, logits) | ||
|
||
log.info("Conversion successful") | ||
|
||
|
||
if __name__ == "__main__": | ||
prepare_cli_environment() | ||
hf_model = convert_checkpoint() | ||
validate_conversion(hf_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.