From 35447054f5fb2d6f7f901361e74fdfcb761de06e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 5 Dec 2024 15:47:20 +0100 Subject: [PATCH] Update Mistral conversion script (#34829) * Update convert_mistral_weights_to_hf.py * Update convert_mistral_weights_to_hf.py * Update convert_mistral_weights_to_hf.py --- .../mistral/convert_mistral_weights_to_hf.py | 402 +++++++++--------- 1 file changed, 194 insertions(+), 208 deletions(-) diff --git a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py index 266812b3972dff..1a89ade8fa6dbd 100644 --- a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py +++ b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py @@ -12,20 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import gc import json import os -import shutil +import re import warnings import torch -from safetensors.torch import load_file as safe_load_file +from safetensors.torch import load_file -from transformers import ( - LlamaTokenizer, - MistralConfig, - MistralForCausalLM, -) +from transformers import LlamaTokenizer, MistralConfig, MistralForCausalLM try: @@ -39,32 +34,40 @@ ) tokenizer_class = LlamaTokenizer -""" -Sample usage: +# fmt: off +STATE_DICT_MAPPING = { + # CausalLM keys + r"^output.weight": r"lm_head.weight", -``` -python src/transformers/models/mistral/convert_mistral_weights_to_hf.py \ - --input_dir /path/to/downloaded/mistral/weights --model_size 7B --output_dir /output/path -``` + # Model keys + r"^norm.weight": r"model.norm.weight", + r"^tok_embeddings.weight": r"model.embed_tokens.weight", -Thereafter, models can be loaded via: + # Layers keys + r"^layers.(\d+).attention_norm.weight": r"model.layers.\1.input_layernorm.weight", + r"^layers.(\d+).ffn_norm.weight": r"model.layers.\1.post_attention_layernorm.weight", -```py -from transformers import MistralForCausalLM, LlamaTokenizer + # Attention keys + r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.layers.\1.self_attn.\2_proj.weight", -model = MistralForCausalLM.from_pretrained("/output/path") -tokenizer = LlamaTokenizer.from_pretrained("/output/path") -``` -Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions -come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). -""" + # MLP keys + r"^layers.(\d+).feed_forward.w1.weight": r"model.layers.\1.mlp.gate_proj.weight", + r"^layers.(\d+).feed_forward.w2.weight": r"model.layers.\1.mlp.down_proj.weight", + r"^layers.(\d+).feed_forward.w3.weight": r"model.layers.\1.mlp.up_proj.weight", +} +# fmt: on -NUM_SHARDS = {"7B": 1} +def map_old_key_to_new(old_key): + """Map of a key of the original state dict to the equivalent key in HF format""" + for pattern, replacement in STATE_DICT_MAPPING.items(): + new_key, n_replace = re.subn(pattern, replacement, old_key) + # Early exit of the loop + if n_replace > 0: + return new_key -def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): - return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).") def read_json(path): @@ -72,218 +75,201 @@ def read_json(path): return json.load(f) -def write_json(text, path): - with open(path, "w") as f: - json.dump(text, f) +def permute_for_rope(tensor, n_heads, dim1, dim2): + """Permute the weights for the ROPE formulation.""" + tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) + tensor = tensor.transpose(1, 2) + tensor = tensor.reshape(dim1, dim2) + return tensor -def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, is_v3=False): - # for backward compatibility, before you needed the repo to be called `my_repo/model_size` - if not os.path.isfile(os.path.join(input_base_path, "params.json")): - input_base_path = os.path.join(input_base_path, model_size) +def convert_state_dict(original_state_dict: dict, config: MistralConfig): + """Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case).""" + new_dict = {} - os.makedirs(model_path, exist_ok=True) - tmp_model_path = os.path.join(model_path, "tmp") - os.makedirs(tmp_model_path, exist_ok=True) + n_heads = config.num_attention_heads + dim = config.hidden_size + dims_per_head = dim // n_heads + num_key_value_heads = config.num_key_value_heads + key_value_dim = dims_per_head * num_key_value_heads - params = read_json(os.path.join(input_base_path, "params.json")) - num_shards = NUM_SHARDS[model_size] + for old_key, tensor in original_state_dict.items(): + new_key = map_old_key_to_new(old_key) - sliding_window = params.get("sliding_window", None) + if "q_proj" in new_key: + tensor = tensor.view(n_heads, dims_per_head, dim).reshape(dim, dim) + tensor = permute_for_rope(tensor, n_heads, dim, dim) + elif "k_proj" in new_key: + tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim) + tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim) + elif "v_proj" in new_key: + tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim) - # For some reason this is a string in the params.json - if sliding_window is not None: - sliding_window = int(sliding_window) + new_dict[new_key] = tensor + return new_dict - n_layers = params["n_layers"] - n_heads = params["n_heads"] - n_heads_per_shard = n_heads // num_shards - dim = params["dim"] + +def get_concat_dim(key): + """Return the dimension to concatenate the weights on.""" + concat_dim_1 = [ + r"model.embed_tokens.weight", + r"model.layers.(\d+).self_attn.o_proj.weight", + r"model.layers.(\d+).mlp.down_proj.weight", + ] + if any(re.search(pattern, key) for pattern in concat_dim_1): + return 1 + return 0 + + +def convert_state_dict_sharded(loaded_shards: list[dict], config: MistralConfig): + """Convert the state dict, when a single `nn.Module` is sharded accross different files.""" + new_dict = {} + + num_shards = len(loaded_shards) + + n_heads = config.num_attention_heads + dim = config.hidden_size dims_per_head = dim // n_heads - base = params.get("rope_theta", 10000.0) - inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - max_position_embeddings = 4096 * 8 - - if tokenizer_path is not None: - tokenizer = tokenizer_class(tokenizer_path + ".v3" if is_v3 else "") - tokenizer.save_pretrained(model_path) - vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 - - if "n_kv_heads" in params: - num_key_value_heads = params["n_kv_heads"] # for GQA / MQA - num_local_key_value_heads = num_key_value_heads // num_shards - key_value_dim = dims_per_head * num_local_key_value_heads - else: # compatibility with other checkpoints - num_key_value_heads = n_heads - num_local_key_value_heads = n_heads_per_shard - key_value_dim = dim - - # permute for sliced rotary - def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): - return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - print(f"Fetching all parameters from the checkpoint at {input_base_path}.") - - # Load weights - for v3 models the consolidated weights are in a single file format in safetensors - if is_v3: - loaded = [safe_load_file(os.path.join(input_base_path, "consolidated.safetensors"))] - else: - loaded = [ - torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") - for i in range(num_shards) - ] - param_count = 0 - index_dict = {"weight_map": {}} - for layer_i in range(n_layers): - filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" - - # Sharded - # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share - # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is - # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. - - state_dict = { - f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ - f"layers.{layer_i}.attention_norm.weight" - ].clone(), - f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ - f"layers.{layer_i}.ffn_norm.weight" - ].clone(), - } - state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) - for i in range(num_shards) - ], - dim=0, + num_key_value_heads = config.num_key_value_heads + n_heads_per_shard = n_heads // num_shards + num_local_key_value_heads = num_key_value_heads // num_shards + key_value_dim = dim if n_heads == num_key_value_heads else dims_per_head * num_local_key_value_heads + + original_keys = loaded_shards[0].keys() + for old_key in original_keys: + new_key = map_old_key_to_new(old_key) + cat_dim = get_concat_dim(new_key) + + if "q_proj" in new_key: + tensor = torch.cat( + [shard.pop(old_key).view(n_heads_per_shard, dims_per_head, dim) for shard in loaded_shards], + dim=cat_dim, ).reshape(dim, dim) - ) - state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( - torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( - num_local_key_value_heads, dims_per_head, dim - ) - for i in range(num_shards) - ], - dim=0, - ).reshape(key_value_dim, dim), - num_key_value_heads, - key_value_dim, - dim, - ) - state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( - [ - loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(num_local_key_value_heads, dims_per_head, dim) - for i in range(num_shards) - ], - dim=0, - ).reshape(key_value_dim, dim) - - state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 - ) - state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 - ) - state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 - ) - state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 - ) - - state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(tmp_model_path, filename)) - - filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" - state_dict = { - "model.norm.weight": loaded[0]["norm.weight"], - "model.embed_tokens.weight": torch.cat([loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1), - "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), + tensor = permute_for_rope(tensor, n_heads, dim, dim) + elif "k_proj" in new_key: + tensor = torch.cat( + [shard.pop(old_key).view(num_local_key_value_heads, dims_per_head, dim) for shard in loaded_shards], + dim=cat_dim, + ).reshape(key_value_dim, dim) + tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim) + elif "v_proj" in new_key: + tensor = torch.cat( + [shard.pop(old_key).view(num_local_key_value_heads, dims_per_head, dim) for shard in loaded_shards], + dim=cat_dim, + ).reshape(key_value_dim, dim) + elif "input_layernorm" in new_key or "post_attention_layernorm" in new_key: + tensor = loaded_shards[0][old_key].clone() + elif "model.norm.weight" in new_key: + tensor = loaded_shards[0][old_key] + else: + tensor = torch.cat([shard.pop(old_key) for shard in loaded_shards], dim=cat_dim) + + new_dict[new_key] = tensor + + return new_dict + + +def convert_config(original_config: dict, max_position_embeddings: int): + key_mapping = { + "hidden_size": "dim", + "num_hidden_layers": "n_layers", + "intermediate_size": "hidden_dim", + "num_attention_heads": "n_heads", + "rms_norm_eps": "norm_eps", } - - for k, v in state_dict.items(): - index_dict["weight_map"][k] = filename - param_count += v.numel() - torch.save(state_dict, os.path.join(tmp_model_path, filename)) - - # Write configs - index_dict["metadata"] = {"total_size": param_count * 2} - write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) - config = MistralConfig( - hidden_size=dim, - intermediate_size=params["hidden_dim"], - num_attention_heads=params["n_heads"], - num_hidden_layers=params["n_layers"], - rms_norm_eps=params["norm_eps"], - num_key_value_heads=num_key_value_heads, - vocab_size=vocab_size, - rope_theta=base, - max_position_embeddings=max_position_embeddings, - sliding_window=sliding_window, + similar_keys_to_keep = [ + "head_dim", + "vocab_size", + ] + + new_config_kwargs = {k: original_config[v] for k, v in key_mapping.items()} + new_config_kwargs.update({k: v for k, v in original_config.items() if k in similar_keys_to_keep}) + + # These are not always defined depending on `params.json` + new_config_kwargs["sliding_window"] = original_config.get("sliding_window", None) + new_config_kwargs["num_key_value_heads"] = original_config.get( + "n_kv_heads", new_config_kwargs["num_attention_heads"] ) - config.save_pretrained(tmp_model_path) + new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0) + + # This is never provided in `params.json`, we provide it manually + new_config_kwargs["max_position_embeddings"] = max_position_embeddings - # Make space so we can load the model properly now. - del state_dict - del loaded - gc.collect() + # This may sometimes be a string in `params.json` + if new_config_kwargs["sliding_window"] is not None: + new_config_kwargs["sliding_window"] = int(new_config_kwargs["sliding_window"]) - print("Loading the checkpoint in a Mistral model.") - model = MistralForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) - # Avoid saving this as part of the config. - del model.config._name_or_path - model.config.torch_dtype = torch.float16 - print("Saving in the Transformers format.") + new_config = MistralConfig(**new_config_kwargs) + return new_config - model.save_pretrained(model_path, safe_serialization=safe_serialization) - shutil.rmtree(tmp_model_path) +def convert_and_write_model(input_dir: str, output_dir: str, max_position_embeddings: int, modules_are_split: bool): + """Convert the model and save it (this implicitly save the config as well).""" + params = read_json(os.path.join(input_dir, "params.json")) + config = convert_config(params, max_position_embeddings) + + full_state_dict = {} + # The model may be split between different files, but a single nn.Module is always fully present in a single file + if not modules_are_split: + shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")] + for shard_file in shards: + original_state_dict = load_file(os.path.join(input_dir, shard_file)) + new_dict = convert_state_dict(original_state_dict, config) + full_state_dict.update(new_dict) + # A single nn.Module is split between different checkpoint files + else: + shards = [file for file in os.listdir(input_dir) if re.match(r"consolidated.\d+.pth", file)] + shards = sorted(shards, key=lambda x: int(x.split(".")[1])) + loaded_shards = [torch.load(os.path.join(input_dir, file), map_location="cpu") for file in shards] + full_state_dict = convert_state_dict_sharded(loaded_shards, config) + + # Load weights into model and resave them + with torch.device("meta"): + model = MistralForCausalLM(config) + model.load_state_dict(full_state_dict, strict=True, assign=True) + model.save_pretrained(output_dir) -def write_tokenizer(tokenizer_path, input_tokenizer_path): - # Initialize the tokenizer based on the `spm` model - print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") - tokenizer = tokenizer_class(input_tokenizer_path) - tokenizer.save_pretrained(tokenizer_path) + +def convert_and_write_tokenizer(input_dir: str, output_dir: str): + """Convert the tokenizer and save it.""" + # May have .v3 or .v7 at the end + tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0] + tokenizer = tokenizer_class(os.path.join(input_dir, tokenizer_file)) + tokenizer.save_pretrained(output_dir) def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--input_dir", + "input_dir", help="Location of Mistral weights, which contains tokenizer.model and model folders", ) parser.add_argument( - "--model_size", - choices=["7B", "tokenizer_only"], - help="'f' models correspond to the finetuned versions, and are specific to the Mistral2 official release. For more details on Mistral2, checkout the original repo: https://huggingface.co/meta-mistral", + "output_dir", + help="Location to write HF model and tokenizer", ) parser.add_argument( - "--output_dir", - help="Location to write HF model and tokenizer", + "--max_position_embeddings", + type=int, + default=32768, + help="`max_position_embeddings` field in the config. This needs to be manually passed (not present anywhere otherwise).", ) - parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") parser.add_argument( - "--is_v3", action="store_true", help="Whether the checkpoints correspond to the 3rd version or not." + "--modules_are_split", + action="store_true", + help="If passed, then the weights of a single `nn.Module` are assumed to be split between different files.", ) + parser.add_argument( + "--tokenizer_only", + action="store_true", + help="If passed, will only convert the tokenizer.", + ) + args = parser.parse_args() - spm_path = os.path.join(args.input_dir, "tokenizer.model") - if args.model_size != "tokenizer_only": - write_model( - model_path=args.output_dir, - input_base_path=args.input_dir, - model_size=args.model_size, - safe_serialization=args.safe_serialization, - tokenizer_path=spm_path, - is_v3=args.is_v3, - ) - else: - write_tokenizer(args.output_dir, spm_path) + + if not args.tokenizer_only: + convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.modules_are_split) + convert_and_write_tokenizer(args.input_dir, args.output_dir) if __name__ == "__main__":