Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OLMo's November release #34497

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/transformers/models/olmo/configuration_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ class OlmoConfig(PretrainedConfig):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_type (`str`, *optional*, defaults to `"default"`):
Type of layer norm to use.
use_q_norm (`bool`, *optional*, defaults to `False`):
Whether to apply norm to the queries within the attention mechanism.
use_k_norm (`bool`, *optional*, defaults to `False`):
Whether to apply norm to the keys within the attention mechanism.
norm_after (`bool`, *optional*, defaults to `False`):
Whether to apply norm after the attention/feedforward layers rather than before, as introduced
in the Swin transformer paper (Liu et al).
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
Expand Down Expand Up @@ -118,6 +129,11 @@ def __init__(
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_type="default",
use_q_norm=False,
use_k_norm=False,
norm_after=False,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=1,
bos_token_id=None,
Expand All @@ -144,6 +160,11 @@ def __init__(
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_type = layer_norm_type
self.use_q_norm = use_q_norm
self.use_k_norm = use_k_norm
self.norm_after = norm_after
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
Expand Down
97 changes: 77 additions & 20 deletions src/transformers/models/olmo/convert_olmo_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import shutil
from pathlib import Path
from typing import Any, Dict

import torch
import yaml
Expand All @@ -28,21 +29,16 @@

"""
Sample usage:

```
python src/transformers/models/olmo/convert_olmo_weights_to_hf.py \
--input_dir /path/to/downloaded/olmo/weights --model_size 7B --output_dir /output/path
```

Thereafter, models can be loaded via:

```py
from transformers import OlmoForCausalLM, AutoTokenizer

model = OlmoForCausalLM.from_pretrained("/output/path")
tokenizer = AutoTokenizer.from_pretrained("/output/path")
```

Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
Expand All @@ -62,7 +58,15 @@ def write_json(text, path):
json.dump(text, f)


def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True):
def write_model(
model_path,
input_base_path,
include_tokenizer=True,
tokenizer_path=None,
safe_serialization=True,
fix_eos_token_id=True,
tmp_cleanup=True,
):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)
Expand All @@ -74,7 +78,7 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa
n_heads = olmo_config["n_heads"]
dim = olmo_config["d_model"]
dims_per_head = dim // n_heads
base = 10000.0
base = olmo_config.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
max_position_embeddings = olmo_config["max_sequence_length"]

Expand All @@ -94,7 +98,7 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa
loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu")

param_count = 0
index_dict = {"weight_map": {}}
index_dict: Dict[str, Any] = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
# Unsharded
Expand All @@ -112,14 +116,28 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa
f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight,
f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight,
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"],
f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded.get(
f"transformer.blocks.{layer_i}.q_norm.weight"
),
f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded.get(
f"transformer.blocks.{layer_i}.k_norm.weight"
),
f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight,
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight,
f"model.layers.{layer_i}.input_layernorm.weight": loaded.get(
f"transformer.blocks.{layer_i}.attn_norm.weight"
),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded.get(
f"transformer.blocks.{layer_i}.ff_norm.weight"
),
}

state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq

for k, v in state_dict.items():
if v is None:
continue
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
Expand All @@ -130,12 +148,15 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa
# TODO: Deal with weight-tying
state_dict = {
"model.embed_tokens.weight": loaded["transformer.wte.weight"],
"model.norm.weight": loaded.get("transformer.ln_f.weight"),
"lm_head.weight": loaded["transformer.ff_out.weight"]
if "transformer.ff_out.weight" in loaded
else loaded["transformer.wte.weight"],
}

for k, v in state_dict.items():
if v is None:
continue
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
Expand All @@ -149,6 +170,11 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa
else:
intermediate_size = (dim * olmo_config["mlp_ratio"]) // 2

if fix_eos_token_id and olmo_config["eos_token_id"] == 0:
# Fixing a bug in OLMo where eos token id was incorrectly set
print("Changing eos_token_id from 0 to 50279.")
olmo_config["eos_token_id"] = 50279

config = OlmoConfig(
vocab_size=vocab_size,
hidden_size=dim,
Expand All @@ -160,7 +186,12 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa
pad_token_id=olmo_config["pad_token_id"],
bos_token_id=None,
eos_token_id=olmo_config["eos_token_id"],
tie_word_embeddings=olmo_config["weight_tying"],
tie_word_embeddings=olmo_config.get("weight_tying", True),
layer_norm_type=olmo_config.get("layer_norm_type", "default"),
use_q_norm=olmo_config.get("attention_layer_norm", False),
use_k_norm=olmo_config.get("attention_layer_norm", False),
norm_after=olmo_config.get("norm_after", False),
rms_norm_eps=olmo_config.get("layer_norm_eps", 1e-5),
rope_theta=base,
clip_qkv=olmo_config.get("clip_qkv"),
)
Expand All @@ -171,33 +202,44 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa
del loaded
gc.collect()

if tokenizer_path is not None:
_write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id)
if include_tokenizer:
_write_tokenizer(model_path, config, input_base_path, tokenizer_path)

print("Loading the checkpoint in a OLMo model.")
model = OlmoForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
# Avoid saving this as part of the config.
del model.config._name_or_path
print("Saving in the Transformers format.")
model.save_pretrained(model_path, safe_serialization=safe_serialization)
shutil.rmtree(tmp_model_path)
if tmp_cleanup:
# Make cleanup optional; attempting to `rmtree` the `tmp_model_path` causes
# errors if using NFS.
shutil.rmtree(tmp_model_path)


def _write_tokenizer(
output_path: Path, config: OlmoConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True
output_path: Path,
config: OlmoConfig,
checkpoint_dir: str,
input_tokenizer_path: Path | None,
) -> None:
print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.")

base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
if input_tokenizer_path is not None:
base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path))
else:
config_path = Path(checkpoint_dir) / "config.yaml"
tokenizer_config = yaml.safe_load(config_path.read_text())["tokenizer"]

# Initialize tokenizer and validate vocab size.
if Path(tokenizer_config["identifier"]).is_file():
base_tokenizer = Tokenizer.from_file(tokenizer_config["identifier"])
else:
base_tokenizer = Tokenizer.from_pretrained(tokenizer_config["identifier"])

eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1
pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id

if fix_eos_token_id and eos_token_id == 0:
# Fixing a bug in OLMo where eos token id was incorrectly set
print("Changing eos_token_id from 0 to 50279.")
eos_token_id = 50279

tokenizer = GPTNeoXTokenizerFast(
tokenizer_object=base_tokenizer,
eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False),
Expand All @@ -216,10 +258,17 @@ def main():
required=True,
help="Location of OLMo weights, which contains config.yaml and model.pt.",
)
parser.add_argument(
"--no_tokenizer",
action="store_false",
dest="include_tokenizer",
help="If set, do not convert OLMo tokenizer to HF tokenizer.",
)
parser.add_argument(
"--tokenizer_json_path",
type=Path,
default=None,
help="Location of OLMo tokenizer json file.",
help="Location of OLMo tokenizer json file. Defaults to what is set in the config file.",
)
parser.add_argument(
"--output_dir",
Expand All @@ -232,15 +281,23 @@ def main():
dest="fix_eos_token_id",
help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.",
)
parser.add_argument(
"--no_tmp_cleanup",
action="store_false",
dest="tmp_cleanup",
help="If passed, don't remove temp dir at end of HF conversion.",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
# Different OLMo versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
safe_serialization=args.safe_serialization,
include_tokenizer=args.include_tokenizer,
tokenizer_path=args.tokenizer_json_path,
fix_eos_token_id=args.fix_eos_token_id,
tmp_cleanup=args.tmp_cleanup,
)


Expand Down
Loading
Loading