diff --git a/src/transformers/models/olmo/configuration_olmo.py b/src/transformers/models/olmo/configuration_olmo.py index 77a3b18e364ecf..8ce4fe76fce9df 100644 --- a/src/transformers/models/olmo/configuration_olmo.py +++ b/src/transformers/models/olmo/configuration_olmo.py @@ -62,6 +62,17 @@ class OlmoConfig(PretrainedConfig): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_type (`str`, *optional*, defaults to `"default"`): + Type of layer norm to use. + use_q_norm (`bool`, *optional*, defaults to `False`): + Whether to apply norm to the queries within the attention mechanism. + use_k_norm (`bool`, *optional*, defaults to `False`): + Whether to apply norm to the keys within the attention mechanism. + norm_after (`bool`, *optional*, defaults to `False`): + Whether to apply norm after the attention/feedforward layers rather than before, as introduced + in the Swin transformer paper (Liu et al). + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -118,6 +129,11 @@ def __init__( hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, + layer_norm_type="default", + use_q_norm=False, + use_k_norm=False, + norm_after=False, + rms_norm_eps=1e-6, use_cache=True, pad_token_id=1, bos_token_id=None, @@ -144,6 +160,11 @@ def __init__( self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range + self.layer_norm_type = layer_norm_type + self.use_q_norm = use_q_norm + self.use_k_norm = use_k_norm + self.norm_after = norm_after + self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling diff --git a/src/transformers/models/olmo/convert_olmo_weights_to_hf.py b/src/transformers/models/olmo/convert_olmo_weights_to_hf.py index 0e77bdc69e7a0c..98b330875c41f2 100644 --- a/src/transformers/models/olmo/convert_olmo_weights_to_hf.py +++ b/src/transformers/models/olmo/convert_olmo_weights_to_hf.py @@ -17,6 +17,7 @@ import os import shutil from pathlib import Path +from typing import Any, Dict import torch import yaml @@ -28,21 +29,16 @@ """ Sample usage: - ``` python src/transformers/models/olmo/convert_olmo_weights_to_hf.py \ --input_dir /path/to/downloaded/olmo/weights --model_size 7B --output_dir /output/path ``` - Thereafter, models can be loaded via: - ```py from transformers import OlmoForCausalLM, AutoTokenizer - model = OlmoForCausalLM.from_pretrained("/output/path") tokenizer = AutoTokenizer.from_pretrained("/output/path") ``` - Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). """ @@ -62,7 +58,15 @@ def write_json(text, path): json.dump(text, f) -def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True): +def write_model( + model_path, + input_base_path, + include_tokenizer=True, + tokenizer_path=None, + safe_serialization=True, + fix_eos_token_id=True, + tmp_cleanup=True, +): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") os.makedirs(tmp_model_path, exist_ok=True) @@ -74,7 +78,7 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa n_heads = olmo_config["n_heads"] dim = olmo_config["d_model"] dims_per_head = dim // n_heads - base = 10000.0 + base = olmo_config.get("rope_theta", 10000.0) inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) max_position_embeddings = olmo_config["max_sequence_length"] @@ -94,7 +98,7 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu") param_count = 0 - index_dict = {"weight_map": {}} + index_dict: Dict[str, Any] = {"weight_map": {}} for layer_i in range(n_layers): filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" # Unsharded @@ -112,14 +116,28 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"], + f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded.get( + f"transformer.blocks.{layer_i}.q_norm.weight" + ), + f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded.get( + f"transformer.blocks.{layer_i}.k_norm.weight" + ), f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, + f"model.layers.{layer_i}.input_layernorm.weight": loaded.get( + f"transformer.blocks.{layer_i}.attn_norm.weight" + ), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded.get( + f"transformer.blocks.{layer_i}.ff_norm.weight" + ), } state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq for k, v in state_dict.items(): + if v is None: + continue index_dict["weight_map"][k] = filename param_count += v.numel() torch.save(state_dict, os.path.join(tmp_model_path, filename)) @@ -130,12 +148,15 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa # TODO: Deal with weight-tying state_dict = { "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "model.norm.weight": loaded.get("transformer.ln_f.weight"), "lm_head.weight": loaded["transformer.ff_out.weight"] if "transformer.ff_out.weight" in loaded else loaded["transformer.wte.weight"], } for k, v in state_dict.items(): + if v is None: + continue index_dict["weight_map"][k] = filename param_count += v.numel() torch.save(state_dict, os.path.join(tmp_model_path, filename)) @@ -149,6 +170,11 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa else: intermediate_size = (dim * olmo_config["mlp_ratio"]) // 2 + if fix_eos_token_id and olmo_config["eos_token_id"] == 0: + # Fixing a bug in OLMo where eos token id was incorrectly set + print("Changing eos_token_id from 0 to 50279.") + olmo_config["eos_token_id"] = 50279 + config = OlmoConfig( vocab_size=vocab_size, hidden_size=dim, @@ -160,7 +186,12 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa pad_token_id=olmo_config["pad_token_id"], bos_token_id=None, eos_token_id=olmo_config["eos_token_id"], - tie_word_embeddings=olmo_config["weight_tying"], + tie_word_embeddings=olmo_config.get("weight_tying", True), + layer_norm_type=olmo_config.get("layer_norm_type", "default"), + use_q_norm=olmo_config.get("attention_layer_norm", False), + use_k_norm=olmo_config.get("attention_layer_norm", False), + norm_after=olmo_config.get("norm_after", False), + rms_norm_eps=olmo_config.get("layer_norm_eps", 1e-5), rope_theta=base, clip_qkv=olmo_config.get("clip_qkv"), ) @@ -171,8 +202,8 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa del loaded gc.collect() - if tokenizer_path is not None: - _write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id) + if include_tokenizer: + _write_tokenizer(model_path, config, input_base_path, tokenizer_path) print("Loading the checkpoint in a OLMo model.") model = OlmoForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True) @@ -180,24 +211,35 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa del model.config._name_or_path print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) - shutil.rmtree(tmp_model_path) + if tmp_cleanup: + # Make cleanup optional; attempting to `rmtree` the `tmp_model_path` causes + # errors if using NFS. + shutil.rmtree(tmp_model_path) def _write_tokenizer( - output_path: Path, config: OlmoConfig, input_tokenizer_path: Path, fix_eos_token_id: bool = True + output_path: Path, + config: OlmoConfig, + checkpoint_dir: str, + input_tokenizer_path: Path | None, ) -> None: print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.") - base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + if input_tokenizer_path is not None: + base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + else: + config_path = Path(checkpoint_dir) / "config.yaml" + tokenizer_config = yaml.safe_load(config_path.read_text())["tokenizer"] + + # Initialize tokenizer and validate vocab size. + if Path(tokenizer_config["identifier"]).is_file(): + base_tokenizer = Tokenizer.from_file(tokenizer_config["identifier"]) + else: + base_tokenizer = Tokenizer.from_pretrained(tokenizer_config["identifier"]) eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id - if fix_eos_token_id and eos_token_id == 0: - # Fixing a bug in OLMo where eos token id was incorrectly set - print("Changing eos_token_id from 0 to 50279.") - eos_token_id = 50279 - tokenizer = GPTNeoXTokenizerFast( tokenizer_object=base_tokenizer, eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), @@ -216,10 +258,17 @@ def main(): required=True, help="Location of OLMo weights, which contains config.yaml and model.pt.", ) + parser.add_argument( + "--no_tokenizer", + action="store_false", + dest="include_tokenizer", + help="If set, do not convert OLMo tokenizer to HF tokenizer.", + ) parser.add_argument( "--tokenizer_json_path", + type=Path, default=None, - help="Location of OLMo tokenizer json file.", + help="Location of OLMo tokenizer json file. Defaults to what is set in the config file.", ) parser.add_argument( "--output_dir", @@ -232,6 +281,12 @@ def main(): dest="fix_eos_token_id", help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.", ) + parser.add_argument( + "--no_tmp_cleanup", + action="store_false", + dest="tmp_cleanup", + help="If passed, don't remove temp dir at end of HF conversion.", + ) parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") # Different OLMo versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. args = parser.parse_args() @@ -239,8 +294,10 @@ def main(): model_path=args.output_dir, input_base_path=args.input_dir, safe_serialization=args.safe_serialization, + include_tokenizer=args.include_tokenizer, tokenizer_path=args.tokenizer_json_path, fix_eos_token_id=args.fix_eos_token_id, + tmp_cleanup=args.tmp_cleanup, ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 60225d4759c6ab..7c28bd1d4a1337 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -74,6 +74,38 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm) +# copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Olmo +class OlmoRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(OlmoRMSNorm) + + +def get_layer_norm(norm_type: str, hidden_size: int, eps: float = 1e-6) -> nn.Module: + if norm_type == "default": + return OlmoLayerNorm(hidden_size) + if norm_type == "rms": + return OlmoRMSNorm(hidden_size, eps=eps) + raise NotImplementedError(f"No OLMo layer norm implementation of given type: {norm_type}") + + # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo # TODO(joao): add me back asap :) class OlmoRotaryEmbedding(nn.Module): @@ -238,6 +270,16 @@ def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.q_norm = ( + get_layer_norm(config.layer_norm_type, self.num_heads * self.head_dim, config.rms_norm_eps) + if config.use_q_norm + else None + ) + self.k_norm = ( + get_layer_norm(config.layer_norm_type, self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + if config.use_k_norm + else None + ) self._init_rope() def _init_rope(self): @@ -289,6 +331,11 @@ def forward( key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + if self.q_norm is not None: + query_states = self.q_norm(query_states) + if self.k_norm is not None: + key_states = self.k_norm(key_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -373,6 +420,11 @@ def forward( key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + if self.q_norm is not None: + query_states = self.q_norm(query_states) + if self.k_norm is not None: + key_states = self.k_norm(key_states) + # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape @@ -488,6 +540,11 @@ def forward( key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + if self.q_norm is not None: + query_states = self.q_norm(query_states) + if self.k_norm is not None: + key_states = self.k_norm(key_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -551,8 +608,9 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.self_attn = OLMO_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = OlmoMLP(config) - self.input_layernorm = OlmoLayerNorm(config.hidden_size) - self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + self.input_layernorm = get_layer_norm(config.layer_norm_type, config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = get_layer_norm(config.layer_norm_type, config.hidden_size, config.rms_norm_eps) + self.norm_after = config.norm_after # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward # TODO(joao): add me back asap :) @@ -588,7 +646,8 @@ def forward( """ residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if not self.norm_after: + hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( @@ -601,12 +660,17 @@ def forward( cache_position=cache_position, **kwargs, ) + if self.norm_after: + hidden_states = self.input_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + if not self.norm_after: + hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + if self.norm_after: + hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -762,7 +826,7 @@ def __init__(self, config: OlmoConfig): self.layers = nn.ModuleList( [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = OlmoLayerNorm(config.hidden_size) + self.norm = get_layer_norm(config.layer_norm_type, config.hidden_size, config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing