diff --git a/src/transformers/models/phi/convert_phi_weights_to_hf.py b/src/transformers/models/phi/convert_phi_weights_to_hf.py index 36d6eeb3e635a5..69ef4c5919ed9b 100644 --- a/src/transformers/models/phi/convert_phi_weights_to_hf.py +++ b/src/transformers/models/phi/convert_phi_weights_to_hf.py @@ -18,12 +18,15 @@ This script downloads both Phi-1 and Phi-1.5 checkpoints to "checkpoint_path" and then converts the weights to HugfgingFace model's format and saves them in "pytorch_dump_folder_path". + +Example : $python ./convert_phi_weights_to_hf.py --model_name "microsoft/phi-2" --pytorch_dump_folder ./dump_folder/ --checkpoint_path ./ckpt_path/ """ import argparse import gc import os +import safetensors import torch from huggingface_hub import hf_hub_download @@ -31,18 +34,21 @@ _MODELS = { - "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin", - "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin", + "microsoft/phi-1": ["https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin"], + "microsoft/phi-1_5": ["https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin"], + "microsoft/phi-2": [ + "https://huggingface.co/microsoft/phi-2/blob/main/model-00001-of-00002.safetensors", + "https://huggingface.co/microsoft/phi-2/blob/main/model-00002-of-00002.safetensors", + ], } - PHI_MAPPING = { - "layers.0.wte.weight": "model.embed_tokens.weight", - "layers.25.linear.bias": "lm_head.bias", - "layers.25.linear.weight": "lm_head.weight", - "layers.25.ln.bias": "model.final_layernorm.bias", - "layers.25.ln.weight": "model.final_layernorm.weight", + "transformer.embd.wte.weight": "model.embed_tokens.weight", + "lm_head.linear": "lm_head", + "lm_head.ln": "model.final_layernorm", "layers": "model.layers", + "transformer": "model", + ".h.": ".layers.", "ln": "input_layernorm", "mixer": "self_attn", "Wqkv": "query_key_value", @@ -54,14 +60,6 @@ def convert_weights(original_weights, mapping, config): converted_weights = {} original_weights_keys = sorted(original_weights.keys()) - # we change names (1-24) -> layers(0-23) for Phi model layers - range_change = { - f"layers.{k}.": f"layers.{v}." - for k, v in zip(range(1, config.num_hidden_layers + 1), range(0, config.num_hidden_layers)) - } - - mapping.update(**range_change) - for original_weights_key in original_weights_keys: new_key = original_weights_key @@ -104,27 +102,48 @@ def _download(url: str, root: str): ) -def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly): +def convert_phi_weights( + model_name, checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly, _MODELS +): + _MODELS = _MODELS if model_name not in _MODELS.keys() else {model_name: _MODELS.get(model_name)} device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" - for each_model_name, each_model_url in _MODELS.items(): + for model_name, model_url in _MODELS.items(): converted_checkpoint = {} - - model_path = os.path.join(checkpoint_path, each_model_name + "_" + each_model_url.split("/")[-1]) - if not os.path.exists(model_path): - print(f"\n{each_model_name} was not found! Downloading it to {model_path}") - _download(url=each_model_url, root=model_path) - model_checkpoint = torch.load(model_path, map_location=device) - model_type = each_model_name.split("/")[1] # phi-1 or phi-1_5 - config = PhiConfig.from_pretrained(f"susnato/{model_type}_dev") + model_checkpoint = {} + + # for phi-2 the weights are stored in 2 different safetensors file so we need to iterate over that list and download one at a time + for model_each_url in model_url: + model_path = os.path.join(checkpoint_path, model_name + "_" + model_each_url.split("/")[-1]) + if not os.path.exists(model_path): + print(f"\n{model_name} was not found! Downloading it to {model_path}") + _download(url=model_each_url, root=model_path) + + if model_path.endswith("safetensors"): + loaded_weights = safetensors.torch.load_file(model_path, device=device) + else: + loaded_weights = torch.load(model_path, map_location=device) + model_checkpoint.update(**loaded_weights) + + model_type = model_name.split("/")[1] # phi-1 or phi-1_5 or phi-2 + + # init the config for phi-1 and phi-1.5 + config = PhiConfig() + # if we are dealing with phi-2 then update the config + if model_type == "phi-2": + config.hidden_size = 2560 + config.intermediate_size = 10240 + config.num_hidden_layers = 32 + config.resid_pdrop = 0.1 + config.partial_rotary_factor = 0.4 + config.num_hidden_layers = 32 + config.torch_dtype = "float16" # Converting the weights converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config)) # Save either the whole model or the converted weights if save_weights_directly: - save_weights_path = os.path.join( - pytorch_dump_folder_path, each_model_name.split("/")[-1] + "_" + each_model_url.split("/")[-1] - ) + save_weights_path = os.path.join(pytorch_dump_folder_path, model_type + "_pytorch_model.bin") torch.save(converted_checkpoint, save_weights_path) print(f"Model weights saved at {save_weights_path}!") @@ -148,6 +167,12 @@ def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, sav if __name__ == "__main__": parser = argparse.ArgumentParser() # # Required parameters + parser.add_argument( + "--model_name", + type=str, + help="Name of the model to convert. (Please enter one of the following: phi-1, phi-1_5, phi-2). If nothing is provided, all models will be converted.", + default=None, + ) parser.add_argument( "--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)" ) @@ -172,4 +197,11 @@ def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, sav ) args = parser.parse_args() - convert_phi_weights(args.checkpoint_path, args.pytorch_dump_folder_path, args.use_cuda, args.save_weights_directly) + convert_phi_weights( + args.model_name, + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.use_cuda, + args.save_weights_directly, + _MODELS, + ) diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index 516dd1ee626e7f..f5fd51e98b9528 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -432,3 +432,38 @@ def test_model_phi_1_5_logits(self): EXPECTED_OUTPUT = torch.tensor([[12.2922, 13.3507, 8.6963, 9.1355, 9.3502, 9.2667, 14.2027, 13.1363, 13.5446, 11.1337, 9.9279, 16.7195, 13.0768, 14.9141, 11.9965, 8.0233, 10.3129, 10.6118, 10.0204, 9.3827, 8.8344, 8.2806, 8.0153, 8.0540, 7.0964, 16.5743, 11.1256, 9.6987, 11.4770, 10.5440], [12.3323, 14.6050, 8.9986, 8.1580, 9.5654, 6.6728, 12.5966, 12.6662, 12.2784, 11.7522, 8.2039, 16.3102, 11.2203, 13.6088, 12.0125, 9.1021, 9.8216, 10.0987, 9.0926, 8.4260, 8.8009, 7.6547, 6.8075, 7.7881, 7.4501, 15.7451, 10.5053, 8.3129, 10.0027, 9.2612]]).to(torch_device) # fmt: skip self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-4, rtol=1e-4)) + + def test_model_phi_2_logits(self): + input_ids = { + "input_ids": torch.tensor( + [[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device + ) + } + + model = PhiForCausalLM.from_pretrained("susnato/phi-2").to(torch_device) + model.eval() + + output = model(**input_ids).logits + + EXPECTED_OUTPUT = torch.tensor([[6.4830, 6.1644, 3.4055, 2.2848, 5.4654, 2.8360, 5.5975, 5.5391, 7.3101, 4.2498, 2.5913, 10.3885, 6.4359, 8.7982, 5.6534, 0.5150, 2.7498, 3.1930, 2.4334, 1.7781, 1.5613, 1.3067, 0.8291, 0.5633, 0.6522, 9.8191, 5.5771, 2.7987, 4.2845, 3.7030], [6.0642, 7.8242, 3.4634, 1.9259, 4.3169, 2.0913, 6.0446, 3.6804, 6.6736, 4.0727, 2.1791, 11.4139, 5.6795, 7.5652, 6.2039, 2.7174, 4.3266, 3.6930, 2.8058, 2.6721, 2.3047, 2.0848, 2.0972, 2.0441, 1.3160, 9.2085, 4.5557, 3.0296, 2.6045, 2.4059]]).to(torch_device) # fmt: skip + + self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3)) + + def test_phi_2_generation(self): + model = PhiForCausalLM.from_pretrained("susnato/phi-2") + tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2") + + inputs = tokenizer( + "Can you help me write a formal email to a potential business partner proposing a joint venture?", + return_tensors="pt", + return_attention_mask=False, + ) + + outputs = model.generate(**inputs, max_new_tokens=30) + output_text = tokenizer.batch_decode(outputs) + + EXPECTED_OUTPUT = [ + "Can you help me write a formal email to a potential business partner proposing a joint venture?\nInput: Company A: ABC Inc.\nCompany B: XYZ Ltd.\nJoint Venture: A new online platform for e-commerce" + ] + + self.assertListEqual(output_text, EXPECTED_OUTPUT)