From 0bcc84044ee2bc91ccd748fd8260bddb7f7cd03f Mon Sep 17 00:00:00 2001 From: Pete Walsh Date: Thu, 21 Nov 2024 12:40:24 -0800 Subject: [PATCH] RoPE scaling, document how to convert HuggingFace checkpoints (#111) --- CHANGELOG.md | 2 + docs/source/conf.py | 1 + docs/source/examples/huggingface.rst | 19 +++ docs/source/examples/llama.rst | 24 ++-- docs/source/examples/ngpt.rst | 13 +- docs/source/index.rst | 1 + pyproject.toml | 9 +- src/examples/huggingface/__init__.py | 0 .../huggingface/convert_checkpoint.py | 123 ++++++++++++++++++ src/olmo_core/data/tokenizer.py | 25 +++- src/olmo_core/nn/layer_norm.py | 8 ++ src/olmo_core/nn/rope.py | 76 +++++++++-- src/olmo_core/nn/transformer/config.py | 21 ++- src/test/data/tokenizer_test.py | 4 + 14 files changed, 291 insertions(+), 35 deletions(-) create mode 100644 docs/source/examples/huggingface.rst create mode 100644 src/examples/huggingface/__init__.py create mode 100644 src/examples/huggingface/convert_checkpoint.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f17991ca..6bf6050a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 for loading checkpoints with different key names. - Added `load_key_mapping` field to the trainer, same idea as the new `key_mapping` argument above. - Added an implementation of nGPT called `NormalizedTransformer`. +- Added an example showing how to convert a HuggingFace Llama 3.2 checkpoint into the right format for OLMo-core. +- Added an API for scaling RoPE embeddings. ### Changed diff --git a/docs/source/conf.py b/docs/source/conf.py index 0366ec37..9118f6ae 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -41,6 +41,7 @@ "sphinx.ext.viewcode", "sphinx_copybutton", "sphinx_autodoc_typehints", + "sphinx_inline_tabs", ] # Tell myst-parser to assign header anchors for h1-h3. diff --git a/docs/source/examples/huggingface.rst b/docs/source/examples/huggingface.rst new file mode 100644 index 00000000..68d88446 --- /dev/null +++ b/docs/source/examples/huggingface.rst @@ -0,0 +1,19 @@ +HuggingFace models +================== + +The OLMo-core :class:`~olmo_core.train.Trainer` can be used to fine-tune language models from HuggingFace's ``transformers`` library. + +One way to do this would be to manually apply a data parallel wrapper (like DDP or FSDP) to your ``AutoModelForCausalLM`` and then pass that model directly to the trainer. The downside with this approach is that you won't be able to take advantage of all of the optimizations in this library. + +Instead we recommend converting your HuggingFace checkpoint into a format that can be loaded into an equivalent OLMo-core :class:`~olmo_core.nn.transformer.Transformer` model, when possible, using the functions provided by :mod:`olmo_core.distributed.checkpoint`. + +Below is an example that shows how to convert a Llama-3.2 checkpoint on HuggingFace into the right format for OLMo-core. +It would be straight forward to adapt this script to convert in the other direction as well. + +.. seealso:: + See the `train a Llama model `_ example to learn how to use OLMo-core's training API to pretrain or fine-tune any Llama-like language model. + +.. tab:: ``src/examples/huggingface/convert_checkpoint.py`` + + .. literalinclude:: ../../../src/examples/huggingface/convert_checkpoint.py + :language: py diff --git a/docs/source/examples/llama.rst b/docs/source/examples/llama.rst index 80a0e95d..5855ac20 100644 --- a/docs/source/examples/llama.rst +++ b/docs/source/examples/llama.rst @@ -1,19 +1,17 @@ -``Train a Llama model`` -======================= +Train a Llama model +=================== -The following snippet is the code from ``src/examples/llama/train.py``. -It's a script meant to be launched via ``torchrun``. +The following snippets can be found in `src/examples/llama/ `_. +The ``train.py`` script is meant to be launched via ``torchrun``. You can also use the :mod:`olmo_core.launch` API to quickly launch this script on Beaker. -See below for an example of that. +See the ``train_launch.py`` snippet for an example of that. -``src/examples/llama/train.py`` -------------------------------- +.. tab:: ``train.py`` -.. literalinclude:: ../../../src/examples/llama/train.py - :language: py + .. literalinclude:: ../../../src/examples/llama/train.py + :language: py -``src/examples/llama/train_launch.py`` --------------------------------------- +.. tab:: ``train_launch.py`` -.. literalinclude:: ../../../src/examples/llama/train_launch.py - :language: py + .. literalinclude:: ../../../src/examples/llama/train_launch.py + :language: py diff --git a/docs/source/examples/ngpt.rst b/docs/source/examples/ngpt.rst index 5fd1c1e3..c532092b 100644 --- a/docs/source/examples/ngpt.rst +++ b/docs/source/examples/ngpt.rst @@ -1,11 +1,10 @@ -``Train an nGPT model`` -======================= +Train an nGPT model +=================== -The following snippet is the code from ``src/examples/ngpt/train.py``. +The following snippet can be found in `src/examples/ngpt/ `_. It's a script meant to be launched via ``torchrun``. -``src/examples/ngpt/train.py`` ------------------------------- +.. tab:: ``train.py`` -.. literalinclude:: ../../../src/examples/ngpt/train.py - :language: py + .. literalinclude:: ../../../src/examples/ngpt/train.py + :language: py diff --git a/docs/source/index.rst b/docs/source/index.rst index 977f2227..e7b57038 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,7 @@ specific to your environment. Then you can install OLMo-core from PyPI with: :maxdepth: 2 :caption: Examples + examples/huggingface.rst examples/llama.rst examples/ngpt.rst diff --git a/pyproject.toml b/pyproject.toml index ff24285f..c04d2cc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,12 +46,13 @@ dev = [ "build", "boto3", "google-cloud-storage", - "Sphinx>=6.0,<7.0.2", + "Sphinx>=6.0,<9.0", "furo==2024.8.6", - "myst-parser>=1.0,<2.1", - "sphinx-copybutton==0.5.2", - "sphinx-autobuild==2021.3.14", + "myst-parser>=1.0", + "sphinx-copybutton", + "sphinx-autobuild", "sphinx-autodoc-typehints==1.23.3", + "sphinx-inline-tabs", ] beaker = [ "beaker-py>=1.32.0", diff --git a/src/examples/huggingface/__init__.py b/src/examples/huggingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/examples/huggingface/convert_checkpoint.py b/src/examples/huggingface/convert_checkpoint.py new file mode 100644 index 00000000..34ed07e1 --- /dev/null +++ b/src/examples/huggingface/convert_checkpoint.py @@ -0,0 +1,123 @@ +""" +Example script showing how you could convert model weights on HuggingFace for a Llama-3.2 model +into a format that can be loaded by OLMo-core for fine-tuning. + +Note that this script is architecture-dependent, meaning it may only work for Llama-3.2 models on +HuggingFace. +""" + +import logging + +import torch +from transformers import AutoModelForCausalLM + +from olmo_core.data.tokenizer import TokenizerConfig +from olmo_core.distributed.checkpoint import load_model_and_optim_state, save_state_dict +from olmo_core.io import clear_directory, dir_is_empty +from olmo_core.nn.rope import RoPEScalingConfig +from olmo_core.nn.transformer import TransformerConfig +from olmo_core.utils import get_default_device, prepare_cli_environment + +log = logging.getLogger(__name__) + +HF_MODEL = "meta-llama/Llama-3.2-1B" +SAVE_PATH = f"/tmp/checkpoints/{HF_MODEL}" +SAVE_OVERWRITE = False + +TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL) +MODEL_CONFIG = TransformerConfig.llama3_1B( + TOKENIZER_CONFIG.vocab_size, fused_ops=False, use_flash=False, rope_scaling=RoPEScalingConfig() +) + + +def convert_checkpoint() -> AutoModelForCausalLM: + log.info(f"Loading HF checkpoint '{HF_MODEL}'") + hf_model = AutoModelForCausalLM.from_pretrained(HF_MODEL) + print(hf_model) + + if not dir_is_empty(SAVE_PATH): + if SAVE_OVERWRITE: + log.warning(f"Clearing existing checkpoint at '{SAVE_PATH}'") + clear_directory(SAVE_PATH) + else: + log.warning(f"Using existing checkpoint at '{SAVE_PATH}'") + return hf_model + + n_layers = len(hf_model.model.layers) + state_dict = hf_model.state_dict() + + # Map old keys to OLMo-core keys. + new_state_dict = { + "embeddings.weight": state_dict.pop("model.embed_tokens.weight"), + "lm_head.norm.weight": state_dict.pop("model.norm.weight"), + "lm_head.w_out.weight": state_dict.pop("lm_head.weight"), + } + for block in range(n_layers): + # Attention. + new_state_dict[f"blocks.{block}.attention.w_q.weight"] = state_dict.pop( + f"model.layers.{block}.self_attn.q_proj.weight" + ) + new_state_dict[f"blocks.{block}.attention.w_k.weight"] = state_dict.pop( + f"model.layers.{block}.self_attn.k_proj.weight" + ) + new_state_dict[f"blocks.{block}.attention.w_v.weight"] = state_dict.pop( + f"model.layers.{block}.self_attn.v_proj.weight" + ) + new_state_dict[f"blocks.{block}.attention.w_out.weight"] = state_dict.pop( + f"model.layers.{block}.self_attn.o_proj.weight" + ) + + # MLP. + new_state_dict[f"blocks.{block}.feed_forward.w1.weight"] = state_dict.pop( + f"model.layers.{block}.mlp.gate_proj.weight" + ) + new_state_dict[f"blocks.{block}.feed_forward.w2.weight"] = state_dict.pop( + f"model.layers.{block}.mlp.down_proj.weight" + ) + new_state_dict[f"blocks.{block}.feed_forward.w3.weight"] = state_dict.pop( + f"model.layers.{block}.mlp.up_proj.weight" + ) + + # Attention layer norm. + new_state_dict[f"blocks.{block}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{block}.input_layernorm.weight" + ) + + # MLP layer norm. + new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop( + f"model.layers.{block}.post_attention_layernorm.weight" + ) + + assert len(state_dict) == 0 + + log.info(f"Saving converted model checkpoint '{SAVE_PATH}'...") + save_state_dict(SAVE_PATH, {"model": new_state_dict}) + + return hf_model + + +def validate_conversion(hf_model): + log.info("Loading converted checkpoint for validation...") + + device = get_default_device() + + model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval() + load_model_and_optim_state(SAVE_PATH, model) + + hf_model = hf_model.to(device).eval() + + B, T = 1, 120 + input_ids = torch.randint(0, TOKENIZER_CONFIG.vocab_size, (B, T)).to(device) + + with torch.no_grad(): + logits = model(input_ids=input_ids) + hf_logits, *_ = hf_model(input_ids=input_ids, return_dict=False) + torch.testing.assert_close(hf_logits, logits) + + log.info("Conversion successful") + + +if __name__ == "__main__": + prepare_cli_environment() + hf_model = convert_checkpoint() + validate_conversion(hf_model) diff --git a/src/olmo_core/data/tokenizer.py b/src/olmo_core/data/tokenizer.py index 4cf950a3..0100ec34 100644 --- a/src/olmo_core/data/tokenizer.py +++ b/src/olmo_core/data/tokenizer.py @@ -98,8 +98,31 @@ def gpt2(cls) -> "TokenizerConfig": Get a :data:`~TokenizerName.gpt2` tokenizer config. """ return cls( - vocab_size=50280, + vocab_size=50257, eos_token_id=50256, + bos_token_id=50256, pad_token_id=50256, identifier=TokenizerName.gpt2, ) + + @classmethod + def from_hf(cls, identifier: str) -> "TokenizerConfig": + """ + Initialize a tokenizer config from a model on HuggingFace. + + :param identifier: The HF model identifier, e.g. "meta-llama/Llama-3.2-1B". + """ + import json + + from cached_path import cached_path + + with cached_path(f"hf://{identifier}/config.json").open() as f: + config = json.load(f) + + return cls( + vocab_size=config["vocab_size"], + eos_token_id=config["eos_token_id"], + pad_token_id=config.get("pad_token_id", config["eos_token_id"]), + bos_token_id=config.get("bos_token_id"), + identifier=identifier, + ) diff --git a/src/olmo_core/nn/layer_norm.py b/src/olmo_core/nn/layer_norm.py index ba5d5261..92efd578 100644 --- a/src/olmo_core/nn/layer_norm.py +++ b/src/olmo_core/nn/layer_norm.py @@ -120,6 +120,14 @@ def __init__( self.register_parameter("bias", None) self.register_parameter("weight", None) + def extra_repr(self): + if self.weight is not None and self.bias is not None: + return f"{tuple(self.weight.shape)}, bias=True, eps={self.eps}" + elif self.weight is not None: + return f"{tuple(self.weight.shape)}, eps={self.eps}" + else: + return f"eps={self.eps}" + def reset_parameters(self): if self.weight is not None: torch.nn.init.ones_(self.weight) diff --git a/src/olmo_core/nn/rope.py b/src/olmo_core/nn/rope.py index d5e5c2bc..137ec889 100644 --- a/src/olmo_core/nn/rope.py +++ b/src/olmo_core/nn/rope.py @@ -1,6 +1,7 @@ +import math from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -11,6 +12,7 @@ __all__ = [ "RoPEType", "RoPEConfig", + "RoPEScalingConfig", "RotaryEmbeddingBase", "RotaryEmbedding", "FusedRotaryEmbedding", @@ -32,6 +34,37 @@ class RoPEType(StrEnum): complex = "complex" +@dataclass +class RoPEScalingConfig(Config): + """ + Defines how to scale RoPE to longer sequence lengths. + """ + + factor: float = 32.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + old_context_len: int = 8192 + + def scale_inv_freq( + self, + inv_freq: torch.Tensor, + ) -> torch.Tensor: + low_freq_wavelen = self.old_context_len / self.low_freq_factor + high_freq_wavelen = self.old_context_len / self.high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq = torch.where(wavelen > low_freq_wavelen, inv_freq / self.factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq / self.factor + smooth_factor * inv_freq + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + return torch.where(is_medium_freq, smoothed_inv_freq, inv_freq) + + @dataclass class RoPEConfig(Config): """ @@ -48,6 +81,7 @@ class RoPEConfig(Config): """ theta: int = 500_000 full_precision: bool = True + scaling: Optional[RoPEScalingConfig] = None def build( self, @@ -59,12 +93,10 @@ def build( See :class:`RotaryEmbedding` for a description of the parameters. """ - kwargs: Dict[str, Any] = dict( - head_shape=head_shape, - theta=self.theta, - full_precision=self.full_precision, - cache=cache, - ) + kwargs = self.as_dict(exclude_none=True, recurse=False) + kwargs.pop("name") + kwargs["head_shape"] = head_shape + kwargs["cache"] = cache if self.name == "default": return RotaryEmbedding(**kwargs) @@ -88,11 +120,13 @@ def __init__( theta: int = 500_000, full_precision: bool = True, cache: Optional[BufferCache] = None, + scaling: Optional[RoPEScalingConfig] = None, ): super().__init__() self.dim = head_shape self.theta = theta self.full_precision = full_precision + self.scaling = scaling self._cache = cache or BufferCache() @abstractmethod @@ -141,12 +175,16 @@ def _get_rotary_embedding( self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float) / self.dim) ) + if self.scaling is not None: + inv_freq = self.scaling.scale_inv_freq(inv_freq) seq = torch.arange(seq_len, device=device, dtype=torch.float) freqs = torch.einsum("i , j -> i j", seq, inv_freq) positions = torch.cat((freqs, freqs), dim=-1) pos_sin, pos_cos = positions.sin(), positions.cos() + self._cache["rope_pos_sin"] = pos_sin self._cache["rope_pos_cos"] = pos_cos + return pos_sin, pos_cos def _rotate_half(self, x: torch.Tensor) -> torch.Tensor: @@ -231,11 +269,16 @@ def __init__( theta: int = 500_000, full_precision: bool = True, cache: Optional[BufferCache] = None, + scaling: Optional[RoPEScalingConfig] = None, ): from flash_attn.layers.rotary import apply_rotary_emb_qkv_ # type: ignore super().__init__( - head_shape=head_shape, theta=theta, full_precision=full_precision, cache=cache + head_shape=head_shape, + theta=theta, + full_precision=full_precision, + cache=cache, + scaling=scaling, ) self._apply_rotary_emb_qkv_ = apply_rotary_emb_qkv_ @@ -264,6 +307,8 @@ def _get_rotary_embedding( self.theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float) / self.dim) ) + if self.scaling is not None: + inv_freq = self.scaling.scale_inv_freq(inv_freq) seq = torch.arange(seq_len, device=device, dtype=torch.float) freqs = torch.einsum("i , j -> i j", seq, inv_freq) pos_sin, pos_cos = freqs.sin(), freqs.cos() @@ -304,6 +349,21 @@ class ComplexRotaryEmbedding(RotaryEmbeddingBase): :param full_precision: Always apply RoPE in full precision regardless of the input data type. """ + def __init__( + self, + *, + head_shape: int, + theta: int = 500_000, + full_precision: bool = True, + cache: Optional[BufferCache] = None, + ): + super().__init__( + head_shape=head_shape, + theta=theta, + full_precision=full_precision, + cache=cache, + ) + def warmup_cache(self, max_seq_len: int, device: torch.device): self._get_rotary_embedding(max_seq_len, device) diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 49beff8e..5d8e1eef 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -16,7 +16,7 @@ from ..feed_forward import FeedForwardConfig, FeedForwardType from ..layer_norm import LayerNormConfig, LayerNormType from ..lm_head import LMHeadConfig, LMHeadType -from ..rope import RoPEConfig, RoPEType +from ..rope import RoPEConfig, RoPEScalingConfig, RoPEType from .block import TransformerBlockConfig, TransformerBlockType from .init import InitMethod from .model import ( @@ -499,6 +499,22 @@ def llama2_70B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": **kwargs, ) + @classmethod + def llama3_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + """ + A 1B Llama3-like model config. + """ + return cls.llama_like( + d_model=2048, + vocab_size=vocab_size, + n_layers=kwargs.pop("n_layers", 16), + n_heads=kwargs.pop("n_heads", 32), + n_kv_heads=kwargs.pop("n_kv_heads", 8), + rope_theta=kwargs.pop("rope_theta", 500_000), + hidden_size_multiplier=1.5, + **kwargs, + ) + @classmethod def llama3_8B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": """ @@ -574,6 +590,7 @@ def llama_like( block_name: TransformerBlockType = TransformerBlockType.default, dtype: DType = DType.float32, compile: bool = False, + rope_scaling: Optional[RoPEScalingConfig] = None, **kwargs, ) -> "TransformerConfig": """ @@ -624,7 +641,7 @@ def llama_like( n_heads=n_heads, n_kv_heads=n_kv_heads, bias=False, - rope=RoPEConfig(name=rope_type, theta=rope_theta), + rope=RoPEConfig(name=rope_type, theta=rope_theta, scaling=rope_scaling), qk_norm=layer_norm if qk_norm else None, use_flash=use_flash, dtype=dtype, diff --git a/src/test/data/tokenizer_test.py b/src/test/data/tokenizer_test.py index e444e1dc..158372bf 100644 --- a/src/test/data/tokenizer_test.py +++ b/src/test/data/tokenizer_test.py @@ -4,3 +4,7 @@ def test_padded_vocab_size(): assert TokenizerConfig.dolma2().padded_vocab_size() == 100352 assert TokenizerConfig.gpt_neox_olmo_dolma_v1_5().padded_vocab_size() == 50304 + + +def test_from_hf(): + assert TokenizerConfig.from_hf("gpt2") == TokenizerConfig.gpt2()