From 41a98c618e5f7418367256fd2d69c37bdfafc390 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Thu, 2 Nov 2023 22:40:28 -0400 Subject: [PATCH 1/9] WIP: Initial setup for GGUF writer configuration - Created the `initialize_writer` function to set up GGUF writer with model metadata - Included validation for file type and architecture - Default hyperparameter values sourced from MixFormerSequentialConfig - Function annotations and documentation added for clarity - Prepared groundwork for MixFormer architecture integration --- convert-phi-1-to-gguf.py | 300 +++++++++++++++++++++++++++++++++++++++ gguf-py/gguf/gguf.py | 5 + 2 files changed, 305 insertions(+) create mode 100755 convert-phi-1-to-gguf.py diff --git a/convert-phi-1-to-gguf.py b/convert-phi-1-to-gguf.py new file mode 100755 index 0000000000000..e3ca34ff3fb67 --- /dev/null +++ b/convert-phi-1-to-gguf.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import logging +import os +import struct +import sys +from json import JSONDecodeError +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import torch +from transformers import AutoTokenizer + +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf")) +import gguf + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def check_required_files(directory: Path, required_files: List[str]) -> None: + missing_files = [ + file_name + for file_name in required_files + if not (directory / file_name).exists() + ] + if missing_files: + raise FileNotFoundError(f"Missing required files: {', '.join(missing_files)}") + + +def get_json_map(file_path: Path) -> dict[str, Any]: + with open(file_path, "r") as source_file: + try: + return json.load(source_file) + except JSONDecodeError: + raise ValueError(f"Failed to decode {file_path}") + + +def load_hyper_params(directory: Path, architecture: str) -> dict: + config_path = directory / "config.json" + hparams = get_json_map(config_path) + + # Ensure the expected architecture is present + expected_architecture = architecture + if hparams["architectures"][0] != expected_architecture: + raise ValueError( + f"Model architecture not supported: {hparams['architectures'][0]}" + ) + + return hparams + + +def initialize_writer( + fname_out: str, architecture: str, ftype: str, hparams: Dict[str, Any] +) -> gguf.GGUFWriter: + """ + Initializes the GGUF writer with the model metadata. + + :param fname_out: The filename for the output model. + :param architecture: The model architecture enum name. + :param ftype: The data type for the model file (e.g., 'F32', 'F16'). + :param hparams: The hyperparameters loaded from the model's config file. + :return: An initialized GGUF writer object. + """ + # Validate the architecture name + if not hasattr(gguf.MODEL_ARCH, architecture): + raise ValueError(f"Unsupported architecture: {architecture}") + ARCH = getattr(gguf.MODEL_ARCH, architecture) + + # Validate the file type + if ftype not in ['F32', 'F16']: + raise ValueError(f"Unsupported file type: {ftype}") + + # Initialize the GGUF writer + gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + + # Set the writer with the hyperparameters from MixFormerSequentialConfig + gguf_writer.add_name(gguf.MODEL_ARCH_NAMES[ARCH]) + gguf_writer.add_context_length(hparams.get("n_positions", 2048)) + gguf_writer.add_embedding_length(hparams.get("n_embd", 1024)) + n_inner = hparams.get("n_inner", 4 * hparams.get("n_embd", 1024)) + gguf_writer.add_feed_forward_length(n_inner) + gguf_writer.add_block_count(hparams.get("n_layer", 20)) + gguf_writer.add_head_count(hparams.get("n_head", 16)) + n_head_kv = hparams.get("n_head_kv", hparams.get("n_head", 16)) + gguf_writer.add_head_count_kv(n_head_kv) # NOTE: arxiv:2203.11082 + gguf_writer.add_layer_norm_eps(hparams.get("layer_norm_epsilon", 1e-5)) + + # Add the file type + gguf_writer.add_file_type(ftype) + + return gguf_writer + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a Phi-1 model to a GGML compatible file" + ) + parser.add_argument( + "--vocab-only", action="store_true", help="extract only the vocab" + ) + parser.add_argument( + "--outfile", type=Path, help="path to write to; default: based on input" + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file, or model file itself (*.bin)", + ) + parser.add_argument( + "--ftype", + type=str, + choices=["f32", "f16"], + default="f16", # NOTE: Phi-1 is dtype float16. + help="output format - use 'float32' for 32-bit tensors, 'float16' for 16-bit tensors", + ) + return parser.parse_args() + + +def main(): + try: + args = parse_args() + + ftype = args.ftype + directory = args.model # Renamed for clarity + + if not directory.is_dir(): + raise NotADirectoryError(f"{directory} is not a directory.") + + required_files = ["pytorch_model.bin", "config.json", "tokenizer.json"] + check_required_files(directory, required_files) + + # Reference the actual model file + model = directory / "pytorch_model.bin" + if not model.exists(): + raise FileNotFoundError(f"Model file {model} does not exist.") + + hparams = load_hyper_params(directory, "MixFormerSequentialForCausalLM") + architecture = hparams["architectures"][0] + + if args.outfile is not None: + fname_out = args.outfile + else: + fname_out = directory / f"ggml-model-{ftype}.gguf" + + if not fname_out.parent.exists(): + logging.warning(f"Output directory {fname_out.parent} does not exist.") + + gguf_writer = initialize_writer(fname_out, architecture, ftype, hparams) + + # Proceed with the model processing using the 'model' path + # ... [rest of your existing code] ... + + except Exception as e: + logging.error(e) + sys.exit(1) + + +if __name__ == "__main__": + main() + + +# # TOKENIZATION + +# print("gguf: get tokenizer metadata") + +# tokens: list[bytearray] = [] +# scores: list[float] = [] +# toktypes: list[int] = [] + +# # gpt2 tokenizer +# gguf_writer.add_tokenizer_model("gpt2") + +# print("gguf: get gpt2 tokenizer vocab") + +# # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py +# tokenizer = AutoTokenizer.from_pretrained(dir_model) + +# # The number of tokens in tokenizer.json can differ from the expected vocab size. +# # This causes downstream issues with mismatched tensor sizes when running the inference +# vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) +# assert max(tokenizer.vocab.values()) < vocab_size + +# added_vocab = tokenizer.get_added_vocab() +# reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} + +# for i in range(vocab_size): +# if i not in reverse_vocab: +# tokens.append(f"[PAD{i}]") +# toktypes.append(gguf.TokenType.USER_DEFINED) +# elif reverse_vocab[i] in added_vocab: +# tokens.append(reverse_vocab[i]) +# if tokenizer.added_tokens_decoder[i].special: +# toktypes.append(gguf.TokenType.CONTROL) +# else: +# toktypes.append(gguf.TokenType.USER_DEFINED) +# else: +# tokens.append(reverse_vocab[i]) +# toktypes.append(gguf.TokenType.NORMAL) + +# gguf_writer.add_token_list(tokens) +# gguf_writer.add_token_types(toktypes) +# special_vocab = gguf.SpecialVocab(dir_model, load_merges=True, n_vocab=len(tokens)) +# special_vocab.add_to_gguf(gguf_writer) + +# # TENSORS + +# tensor_map = gguf.get_tensor_name_map(ARCH, block_count) + +# # params for qkv transform +# n_head = hparams["n_head"] +# n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 + +# head_dim = hparams["n_embd"] // n_head + +# # tensor info +# print("gguf: get tensor metadata") + +# if num_parts == 0: +# part_names = iter(("pytorch_model.bin",)) +# else: +# part_names = ( +# f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) +# ) + +# for part_name in part_names: +# if args.vocab_only: +# break +# print("gguf: loading model part '" + part_name + "'") +# model_part = torch.load(dir_model / part_name, map_location="cpu") + +# for name in model_part.keys(): +# data = model_part[name] + +# old_dtype = data.dtype + +# # convert any unsupported data types to float32 +# if data.dtype != torch.float16 and data.dtype != torch.float32: +# data = data.to(torch.float32) + +# data = data.squeeze().numpy() + +# # map tensor names +# new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) +# if new_name is None: +# print("Can not map tensor '" + name + "'") +# sys.exit() + +# n_dims = len(data.shape) +# data_dtype = data.dtype + +# # if f32 desired, convert any float16 to float32 +# if ftype == 0 and data_dtype == np.float16: +# data = data.astype(np.float32) + +# # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 +# if ftype == 1 and data_dtype == np.float16 and n_dims == 1: +# data = data.astype(np.float32) + +# # if f16 desired, convert any float32 2-dim weight tensors to float16 +# if ( +# ftype == 1 +# and data_dtype == np.float32 +# and name.endswith(".weight") +# and n_dims == 2 +# ): +# data = data.astype(np.float16) + +# print( +# name, +# "=>", +# new_name +# + ", shape = " +# + str(data.shape) +# + ", " +# + str(old_dtype) +# + " --> " +# + str(data.dtype), +# ) + +# gguf_writer.add_tensor(new_name, data) + + +# print("gguf: write header") +# gguf_writer.write_header_to_file() +# print("gguf: write metadata") +# gguf_writer.write_kv_data_to_file() +# if not args.vocab_only: +# print("gguf: write tensors") +# gguf_writer.write_tensors_to_file() + +# gguf_writer.close() + +# print(f"gguf: model successfully exported to '{fname_out}'") +# print("") diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 727b4e55495a7..e63efa4a7df8d 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -93,6 +93,7 @@ class MODEL_ARCH(IntEnum): REFACT : int = auto() BERT : int = auto() BLOOM : int = auto() + PHI_1 : int = auto() class MODEL_TENSOR(IntEnum): @@ -132,6 +133,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.PHI_1: "phi-1", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -302,6 +304,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PHI_1: [ + # TODO + ], MODEL_ARCH.GPT2: [ # TODO ], From 748f37674636bd384c03a1a3395b7305fcc03bdf Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Tue, 12 Dec 2023 23:34:47 -0500 Subject: [PATCH 2/9] fix: Apply phi to merged updates Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- gguf-py/gguf/constants.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 685c88f1a3397..fde3195b79381 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -93,6 +93,7 @@ class MODEL_ARCH(IntEnum): BLOOM = auto() STABLELM = auto() QWEN = auto() + PHI = auto() class MODEL_TENSOR(IntEnum): @@ -336,6 +337,9 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GPT2: [ # TODO ], + MODEL_ARCH.PHI: [ + # TODO + ], # TODO } From e53d44c0bbd1bc993ee5929653ad765f2347ce0f Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:15:35 -0500 Subject: [PATCH 3/9] Consolidate PHI and PHI2 architectures in gguf constants Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- gguf-py/gguf/constants.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 46cc6a1439d05..d99616d30f505 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -95,7 +95,6 @@ class MODEL_ARCH(IntEnum): BLOOM = auto() STABLELM = auto() QWEN = auto() - PHI2 = auto() PHI = auto() @@ -142,7 +141,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", - MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI: "phi", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -353,7 +352,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GPT2: [ # TODO ], - MODEL_ARCH.PHI2: [ + MODEL_ARCH.PHI: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, @@ -364,9 +363,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ] - MODEL_ARCH.PHI: [ - # TODO - ], # TODO } From e96f40bf99eab4607aaf76f7f70406cd1b4d2277 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:18:52 -0500 Subject: [PATCH 4/9] Update tensor mappings for Phi models (Phi-1, Phi-1.5, Phi-2) Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- gguf-py/gguf/tensor_mapping.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 6fcbdbc1c0d4c..cb86d3ee0e71f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -17,7 +17,7 @@ class TensorNameMap: "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert "language_model.embedding.word_embeddings", # persimmon - "transformer.embd.wte", # phi2 + "transformer.embd.wte", # phi1 phi1_5 phi2 ), # Token type embeddings @@ -42,7 +42,7 @@ class TensorNameMap: "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen "output", # llama-pth bloom "word_embeddings_for_head", # persimmon - "lm_head.linear", # phi2 + "lm_head.linear", # phi1 phi1_5 phi2 ), # Output norm @@ -55,7 +55,7 @@ class TensorNameMap: "transformer.norm_f", # mpt "ln_f", # refact bloom qwen "language_model.encoder.final_layernorm", # persimmon - "lm_head.ln", # phi2 + "lm_head.ln", # phi1 phi1_5 phi2 ), # Rope frequencies @@ -78,7 +78,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi - "transformer.h.{bid}.ln", # phi2 + "transformer.h.{bid}.ln", # phi1 phi1_5 phi2 ), # Attention norm 2 @@ -94,7 +94,7 @@ class TensorNameMap: "transformer.h.{bid}.self_attention.query_key_value", # falcon "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon - "transformer.h.{bid}.mixer.Wqkv", # phi2 + "transformer.h.{bid}.mixer.Wqkv", # phi1 phi1_5 phi2 ), # Attention query @@ -133,7 +133,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon - "transformer.h.{bid}.mixer.out_proj", # phi2 + "transformer.h.{bid}.mixer.out_proj", # phi1 phi1_5 phi2 ), # Rotary embeddings @@ -173,7 +173,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.fc_in", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "transformer.h.{bid}.mlp.w1", # qwen - "transformer.h.{bid}.mlp.fc1", # phi2 + "transformer.h.{bid}.mlp.fc1", # phi1 phi1_5 phi2 ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -205,7 +205,7 @@ class TensorNameMap: "encoder.layer.{bid}.output.dense", # bert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon - "transformer.h.{bid}.mlp.fc2", # phi2 + "transformer.h.{bid}.mlp.fc2", # phi1 phi1_5 phi2 ), MODEL_TENSOR.FFN_DOWN_EXP: ( From ea6ae8d04c43ad0133a514cee90a71dddc923d87 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:26:25 -0500 Subject: [PATCH 5/9] Consolidate Phi model conversion handling in convert-hf-to-gguf.py Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- convert-hf-to-gguf.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e71a96c483313..0f8e0bb0051a4 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -10,7 +10,7 @@ import sys from enum import IntEnum from pathlib import Path -from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional +from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Optional, cast import numpy as np import torch @@ -22,7 +22,6 @@ sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf - ###### MODEL DEFINITIONS ###### class SentencePieceTokenTypes(IntEnum): @@ -183,7 +182,7 @@ def from_model_architecture(model_architecture): if model_architecture == "MixtralForCausalLM": return MixtralModel if model_architecture == "PhiForCausalLM": - return Phi2Model + return PhiModel return Model def _is_model_safetensors(self) -> bool: @@ -224,7 +223,7 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH: if arch == "MixtralForCausalLM": return gguf.MODEL_ARCH.LLAMA if arch == "PhiForCausalLM": - return gguf.MODEL_ARCH.PHI2 + return gguf.MODEL_ARCH.PHI raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -985,11 +984,11 @@ def write_tensors(self): self.gguf_writer.add_tensor(new_name, data) -class Phi2Model(Model): +class PhiModel(Model): def set_gguf_parameters(self): block_count = self.hparams["n_layer"] - self.gguf_writer.add_name("Phi2") + self.gguf_writer.add_name("Phi") self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) From e290792ae4e98937f5b4f14db27cbfef4d4d8fcb Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:34:56 -0500 Subject: [PATCH 6/9] Consolidate Handling of Phi Models in llama.cpp - Replaced LLM_ARCH_PHI2 with LLM_ARCH_PHI to unify the handling of different Phi model variants (Phi-1, Phi-1.5, Phi-2). - Updated architecture names map to reflect the consolidated architecture name from "phi2" to "phi". - Adjusted the tensor names mapping to use the new architecture name "phi" for consistent tensor loading and processing. - Modified hyperparameter loading to include a case for 24 layers under LLM_ARCH_PHI, classifying it as MODEL_1B. This change accommodates different layer counts for various Phi model variants. - Updated tensor loading sections to use the new architecture enum, ensuring proper tensor creation based on the model architecture. - Renamed build_phi2() to build_phi() in the graph building section, aligning with the new architecture name and ensuring correct computational graph construction for Phi models. - Adjusted graph construction calls to use the renamed build_phi() function, ensuring seamless integration and functionality for different Phi model variants. These changes aim to streamline the handling of various Phi models within `llama.cpp`, enhancing the application's capability to work effectively with these models while maintaining code clarity and consistency. --- llama.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/llama.cpp b/llama.cpp index edd2910b3ad29..bec9136ad4d0e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -195,7 +195,7 @@ enum llm_arch { LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, - LLM_ARCH_PHI2, + LLM_ARCH_PHI, LLM_ARCH_UNKNOWN, }; @@ -213,7 +213,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI, "phi" }, }; enum llm_kv { @@ -553,7 +553,7 @@ static std::map> LLM_TENSOR_NAMES = }, }, { - LLM_ARCH_PHI2, + LLM_ARCH_PHI, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, @@ -2651,11 +2651,12 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; - case LLM_ARCH_PHI2: + case LLM_ARCH_PHI: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; default: model.type = e_model::MODEL_UNKNOWN; } @@ -3655,7 +3656,7 @@ static void llm_load_tensors( } } } break; - case LLM_ARCH_PHI2: + case LLM_ARCH_PHI: { model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); @@ -4117,7 +4118,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - if (model.arch == LLM_ARCH_PHI2) { + if (model.arch == LLM_ARCH_PHI) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -5523,7 +5524,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_phi2() { + struct ggml_cgraph * build_phi() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_tensor * cur; @@ -5924,8 +5925,8 @@ static struct ggml_cgraph * llama_build_graph( if (!ggml_allocr_is_measure(lctx.alloc)) { const int64_t n_embd_head = model.hparams.n_embd_head(); - if (model.arch == LLM_ARCH_PHI2) { - // with phi2, we scale the Q to avoid precision issues + if (model.arch == LLM_ARCH_PHI) { + // with phi, we scale the Q to avoid precision issues // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 ggml_set_f32(cur, 1.0f); } else { @@ -6157,9 +6158,9 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen(); } break; - case LLM_ARCH_PHI2: + case LLM_ARCH_PHI: { - result = llm.build_phi2(); + result = llm.build_phi(); } break; default: GGML_ASSERT(false); From 6becb1f943b5e55dd381607883c127923e95f039 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:44:16 -0500 Subject: [PATCH 7/9] Remove deprecated conversion script Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- convert-phi-1-to-gguf.py | 300 --------------------------------------- 1 file changed, 300 deletions(-) delete mode 100755 convert-phi-1-to-gguf.py diff --git a/convert-phi-1-to-gguf.py b/convert-phi-1-to-gguf.py deleted file mode 100755 index e3ca34ff3fb67..0000000000000 --- a/convert-phi-1-to-gguf.py +++ /dev/null @@ -1,300 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import json -import logging -import os -import struct -import sys -from json import JSONDecodeError -from pathlib import Path -from typing import Any, Dict, List - -import numpy as np -import torch -from transformers import AutoTokenizer - -if "NO_LOCAL_GGUF" not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf")) -import gguf - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def check_required_files(directory: Path, required_files: List[str]) -> None: - missing_files = [ - file_name - for file_name in required_files - if not (directory / file_name).exists() - ] - if missing_files: - raise FileNotFoundError(f"Missing required files: {', '.join(missing_files)}") - - -def get_json_map(file_path: Path) -> dict[str, Any]: - with open(file_path, "r") as source_file: - try: - return json.load(source_file) - except JSONDecodeError: - raise ValueError(f"Failed to decode {file_path}") - - -def load_hyper_params(directory: Path, architecture: str) -> dict: - config_path = directory / "config.json" - hparams = get_json_map(config_path) - - # Ensure the expected architecture is present - expected_architecture = architecture - if hparams["architectures"][0] != expected_architecture: - raise ValueError( - f"Model architecture not supported: {hparams['architectures'][0]}" - ) - - return hparams - - -def initialize_writer( - fname_out: str, architecture: str, ftype: str, hparams: Dict[str, Any] -) -> gguf.GGUFWriter: - """ - Initializes the GGUF writer with the model metadata. - - :param fname_out: The filename for the output model. - :param architecture: The model architecture enum name. - :param ftype: The data type for the model file (e.g., 'F32', 'F16'). - :param hparams: The hyperparameters loaded from the model's config file. - :return: An initialized GGUF writer object. - """ - # Validate the architecture name - if not hasattr(gguf.MODEL_ARCH, architecture): - raise ValueError(f"Unsupported architecture: {architecture}") - ARCH = getattr(gguf.MODEL_ARCH, architecture) - - # Validate the file type - if ftype not in ['F32', 'F16']: - raise ValueError(f"Unsupported file type: {ftype}") - - # Initialize the GGUF writer - gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) - - # Set the writer with the hyperparameters from MixFormerSequentialConfig - gguf_writer.add_name(gguf.MODEL_ARCH_NAMES[ARCH]) - gguf_writer.add_context_length(hparams.get("n_positions", 2048)) - gguf_writer.add_embedding_length(hparams.get("n_embd", 1024)) - n_inner = hparams.get("n_inner", 4 * hparams.get("n_embd", 1024)) - gguf_writer.add_feed_forward_length(n_inner) - gguf_writer.add_block_count(hparams.get("n_layer", 20)) - gguf_writer.add_head_count(hparams.get("n_head", 16)) - n_head_kv = hparams.get("n_head_kv", hparams.get("n_head", 16)) - gguf_writer.add_head_count_kv(n_head_kv) # NOTE: arxiv:2203.11082 - gguf_writer.add_layer_norm_eps(hparams.get("layer_norm_epsilon", 1e-5)) - - # Add the file type - gguf_writer.add_file_type(ftype) - - return gguf_writer - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Convert a Phi-1 model to a GGML compatible file" - ) - parser.add_argument( - "--vocab-only", action="store_true", help="extract only the vocab" - ) - parser.add_argument( - "--outfile", type=Path, help="path to write to; default: based on input" - ) - parser.add_argument( - "model", - type=Path, - help="directory containing model file, or model file itself (*.bin)", - ) - parser.add_argument( - "--ftype", - type=str, - choices=["f32", "f16"], - default="f16", # NOTE: Phi-1 is dtype float16. - help="output format - use 'float32' for 32-bit tensors, 'float16' for 16-bit tensors", - ) - return parser.parse_args() - - -def main(): - try: - args = parse_args() - - ftype = args.ftype - directory = args.model # Renamed for clarity - - if not directory.is_dir(): - raise NotADirectoryError(f"{directory} is not a directory.") - - required_files = ["pytorch_model.bin", "config.json", "tokenizer.json"] - check_required_files(directory, required_files) - - # Reference the actual model file - model = directory / "pytorch_model.bin" - if not model.exists(): - raise FileNotFoundError(f"Model file {model} does not exist.") - - hparams = load_hyper_params(directory, "MixFormerSequentialForCausalLM") - architecture = hparams["architectures"][0] - - if args.outfile is not None: - fname_out = args.outfile - else: - fname_out = directory / f"ggml-model-{ftype}.gguf" - - if not fname_out.parent.exists(): - logging.warning(f"Output directory {fname_out.parent} does not exist.") - - gguf_writer = initialize_writer(fname_out, architecture, ftype, hparams) - - # Proceed with the model processing using the 'model' path - # ... [rest of your existing code] ... - - except Exception as e: - logging.error(e) - sys.exit(1) - - -if __name__ == "__main__": - main() - - -# # TOKENIZATION - -# print("gguf: get tokenizer metadata") - -# tokens: list[bytearray] = [] -# scores: list[float] = [] -# toktypes: list[int] = [] - -# # gpt2 tokenizer -# gguf_writer.add_tokenizer_model("gpt2") - -# print("gguf: get gpt2 tokenizer vocab") - -# # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py -# tokenizer = AutoTokenizer.from_pretrained(dir_model) - -# # The number of tokens in tokenizer.json can differ from the expected vocab size. -# # This causes downstream issues with mismatched tensor sizes when running the inference -# vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) -# assert max(tokenizer.vocab.values()) < vocab_size - -# added_vocab = tokenizer.get_added_vocab() -# reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} - -# for i in range(vocab_size): -# if i not in reverse_vocab: -# tokens.append(f"[PAD{i}]") -# toktypes.append(gguf.TokenType.USER_DEFINED) -# elif reverse_vocab[i] in added_vocab: -# tokens.append(reverse_vocab[i]) -# if tokenizer.added_tokens_decoder[i].special: -# toktypes.append(gguf.TokenType.CONTROL) -# else: -# toktypes.append(gguf.TokenType.USER_DEFINED) -# else: -# tokens.append(reverse_vocab[i]) -# toktypes.append(gguf.TokenType.NORMAL) - -# gguf_writer.add_token_list(tokens) -# gguf_writer.add_token_types(toktypes) -# special_vocab = gguf.SpecialVocab(dir_model, load_merges=True, n_vocab=len(tokens)) -# special_vocab.add_to_gguf(gguf_writer) - -# # TENSORS - -# tensor_map = gguf.get_tensor_name_map(ARCH, block_count) - -# # params for qkv transform -# n_head = hparams["n_head"] -# n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 - -# head_dim = hparams["n_embd"] // n_head - -# # tensor info -# print("gguf: get tensor metadata") - -# if num_parts == 0: -# part_names = iter(("pytorch_model.bin",)) -# else: -# part_names = ( -# f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) -# ) - -# for part_name in part_names: -# if args.vocab_only: -# break -# print("gguf: loading model part '" + part_name + "'") -# model_part = torch.load(dir_model / part_name, map_location="cpu") - -# for name in model_part.keys(): -# data = model_part[name] - -# old_dtype = data.dtype - -# # convert any unsupported data types to float32 -# if data.dtype != torch.float16 and data.dtype != torch.float32: -# data = data.to(torch.float32) - -# data = data.squeeze().numpy() - -# # map tensor names -# new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) -# if new_name is None: -# print("Can not map tensor '" + name + "'") -# sys.exit() - -# n_dims = len(data.shape) -# data_dtype = data.dtype - -# # if f32 desired, convert any float16 to float32 -# if ftype == 0 and data_dtype == np.float16: -# data = data.astype(np.float32) - -# # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 -# if ftype == 1 and data_dtype == np.float16 and n_dims == 1: -# data = data.astype(np.float32) - -# # if f16 desired, convert any float32 2-dim weight tensors to float16 -# if ( -# ftype == 1 -# and data_dtype == np.float32 -# and name.endswith(".weight") -# and n_dims == 2 -# ): -# data = data.astype(np.float16) - -# print( -# name, -# "=>", -# new_name -# + ", shape = " -# + str(data.shape) -# + ", " -# + str(old_dtype) -# + " --> " -# + str(data.dtype), -# ) - -# gguf_writer.add_tensor(new_name, data) - - -# print("gguf: write header") -# gguf_writer.write_header_to_file() -# print("gguf: write metadata") -# gguf_writer.write_kv_data_to_file() -# if not args.vocab_only: -# print("gguf: write tensors") -# gguf_writer.write_tensors_to_file() - -# gguf_writer.close() - -# print(f"gguf: model successfully exported to '{fname_out}'") -# print("") From af9cd934137340d9a527b0e90da15613a75e04da Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Wed, 20 Dec 2023 21:44:41 -0500 Subject: [PATCH 8/9] Ignore local content Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 76b3d2861826e..65542f0a804b0 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ build*/ out/ tmp/ +local/ models/* models-mnt From 2f2b3e443a3a7de0b0d3629809b1eb7136bcd3b7 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:01:26 -0500 Subject: [PATCH 9/9] chore: Apply flake8 formatting rules Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> --- convert-hf-to-gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 9892a17e7f2cc..32cad0987c80d 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -24,6 +24,7 @@ ###### MODEL DEFINITIONS ###### + class SentencePieceTokenTypes(IntEnum): NORMAL = 1 UNKNOWN = 2