From a0a806e2066ab75204f3e7c1476563075b0ee1be Mon Sep 17 00:00:00 2001 From: Zhenzhong1 <109137058+Zhenzhong1@users.noreply.github.com> Date: Wed, 10 Jan 2024 14:16:17 +0800 Subject: [PATCH] [LLM Runtime] Support the GGUF format (#40) --- neural_speed/application/main_run.cpp | 2 +- neural_speed/convert/convert_chatglm.py | 202 ++++- neural_speed/convert/convert_llama.py | 169 +++- neural_speed/models/model_utils/gguf.h | 531 ++++++++++++ neural_speed/models/model_utils/model_files.h | 777 +++++++++++++++++- requirements.txt | 3 +- 6 files changed, 1646 insertions(+), 38 deletions(-) create mode 100644 neural_speed/models/model_utils/gguf.h diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp index 772382fa2..f2c3282b2 100644 --- a/neural_speed/application/main_run.cpp +++ b/neural_speed/application/main_run.cpp @@ -772,4 +772,4 @@ int main(int argc, char** argv) { // NOLINT model_free(ctx); return 0; -} // NOLINT +} // NOLINT \ No newline at end of file diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index d6cda001b..38f986809 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -17,10 +17,11 @@ import numpy as np from pathlib import Path import argparse -from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, - Union) +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, + TypeVar, Union) from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from sentencepiece import SentencePieceProcessor # type: ignore +import gguf # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py @@ -34,7 +35,10 @@ def bytes_to_unicode(): To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ - bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + bs = list(range(ord("!"), + ord("~") + 1)) + list(range(ord("¡"), + ord("¬") + 1)) + list(range(ord("®"), + ord("ÿ") + 1)) cs = bs[:] n = 0 for b in range(2**8): @@ -117,8 +121,7 @@ def load_vocab_for_glm1(path: Path) -> SentencePieceVocab: else: raise FileNotFoundError( f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \ - pass the directory as --vocab-dir" - ) + pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" print(f"Loading vocab file {path}") return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) @@ -139,13 +142,188 @@ def load_vocab_for_glm2(path: Path) -> SentencePieceVocab: else: raise FileNotFoundError( f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \ - pass the directory as --vocab-dir" - ) + pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" print(f"Loading vocab file {path}") return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) +def chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams): + print("ChatGLM-2.gguf converting: ") + list_vars = model.state_dict() + for name in list_vars.keys(): + print(name, list_vars[name].shape, list_vars[name].dtype) + + print(hparams) + fout = open(fname_out, "wb") + + gguf_file = fname_out + '.gguf' + gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm2") + + gguf_writer.add_uint32('magic', 0x67676d66) + gguf_writer.add_uint32('version', 1) + gguf_writer.add_uint32('n_vocab', hparams["padded_vocab_size"]) + gguf_writer.add_uint32('n_embd', hparams["hidden_size"]) + gguf_writer.add_uint32('n_mult', 0) + gguf_writer.add_uint32('n_head', hparams["num_attention_heads"]) + gguf_writer.add_uint32('n_head_kv', 0) + + gguf_writer.add_uint32('n_layer', hparams["num_layers"]) + gguf_writer.add_uint32('n_rot', 0) + gguf_writer.add_uint32('ftype', ftype) + gguf_writer.add_uint32('max_seq_len', hparams["seq_length"]) + gguf_writer.add_uint32('alibi_bias_max', 0) + gguf_writer.add_uint32('clip_qkv', 0) + gguf_writer.add_uint32('par_res', 0) + + gguf_writer.add_uint32('word_embed_proj_dim', 0) + gguf_writer.add_uint32('do_layer_norm_before', 0) + + gguf_writer.add_uint32('multi_query_group_num', hparams["multi_query_group_num"]) + gguf_writer.add_uint32('ffn_hidden_size', hparams["ffn_hidden_size"]) + gguf_writer.add_uint32('inner_hidden_size', 0) + + gguf_writer.add_int32('bos_token_id', tokenizer.bos_token_id if tokenizer.bos_token_id is not None else -1) + gguf_writer.add_int32('eos_token_id', tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1) + gguf_writer.add_int32('pad_token_id', tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1) + gguf_writer.add_int32('sep_token_id', tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1) + + def write_vocab_gguf(dir_model): + print("gguf: get tokenizer metadata") + + tokens: List[bytes] = [] + scores: List[float] = [] + toktypes: List[int] = [] + + if Path(dir_model + "/tokenizer.model").is_file(): + # vocab type sentencepiece + print("gguf: get sentencepiece tokenizer vocab, scores and token types") + + vocab = load_vocab_for_glm2(Path(dir_model)) + + # NOTE: `all_tokens` returns the base vocabulary and added tokens + for text, score in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + + if Path(dir_model + "/added_tokens.json").is_file(): + with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: + addtokens_json = json.load(f) + + print("gguf: get added tokens") + + for key in addtokens_json: + tokens.append(key.encode("utf-8")) + scores.append(-1000.0) + toktypes.append(4) # user-defined token type + + gguf_writer.add_tokenizer_model("chatglm2") + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + + print("gguf: get special token ids") + + if Path(dir_model + "/tokenizer.json").is_file(): + # Look for special tokens in tokenizer.json if it exists + + with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: + tokenizer = json.load(f) + + if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file(): + + with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + + if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["bos_token"]["content"]: + gguf_writer.add_bos_token_id(key["id"]) + + if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["eos_token"]["content"]: + gguf_writer.add_eos_token_id(key["id"]) + + if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["unk_token"]["content"]: + gguf_writer.add_unk_token_id(key["id"]) + + if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["sep_token"]["content"]: + gguf_writer.add_sep_token_id(key["id"]) + + if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["pad_token"]["content"]: + gguf_writer.add_pad_token_id(key["id"]) + else: + # If no tokenizer.json: Look for special tokens in config.json + + if "bos_token_id" in hparams and hparams["bos_token_id"] != None: + gguf_writer.add_bos_token_id(hparams["bos_token_id"]) + + if "eos_token_id" in hparams and hparams["eos_token_id"] != None: + gguf_writer.add_eos_token_id(hparams["eos_token_id"]) + + if "unk_token_id" in hparams and hparams["unk_token_id"] != None: + gguf_writer.add_unk_token_id(hparams["unk_token_id"]) + + if "sep_token_id" in hparams and hparams["sep_token_id"] != None: + gguf_writer.add_sep_token_id(hparams["sep_token_id"]) + + if "pad_token_id" in hparams and hparams["pad_token_id"] != None: + gguf_writer.add_pad_token_id(hparams["pad_token_id"]) + + write_vocab_gguf(dir_model) + + # tensor info + print("gguf: get tensor metadata") + for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + + print("Processing variable: " + name + " with shape: ", data.shape) + if 'inv_freq' in name: + continue + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + # print(f"[{i+1:{padi}d}/{len(model)}] + # Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4}") + + gguf_writer.add_tensor(name, data) + + print("gguf: write header") + gguf_writer.write_header_to_file() + print("gguf: write metadata") + gguf_writer.write_kv_data_to_file() + print("gguf: write tensors") + gguf_writer.write_tensors_to_file() + + gguf_writer.close() + + print("Done. Output file: " + fname_out) + print("") + + def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): print("ChatGLM-2 converting: ") list_vars = model.state_dict() @@ -341,6 +519,11 @@ def main(args_in: Optional[List[str]] = None) -> None: parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("model", type=Path, help="directory containing model file") + parser.add_argument("--format", + type=str, + default="NE", + choices=["NE", "GGUF"], + help="convert to the GGUF or NE format") args = parser.parse_args(args_in) dir_model = args.model.as_posix() @@ -360,7 +543,10 @@ def main(args_in: Optional[List[str]] = None) -> None: model = AutoModel.from_pretrained(dir_model, low_cpu_mem_usage=True, trust_remote_code=True) if hasattr(model.config, "multi_query_attention"): - chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) + if args.format == "GGUF": + chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams) + else: + chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) else: chatglm1_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) diff --git a/neural_speed/convert/convert_llama.py b/neural_speed/convert/convert_llama.py index 86f9a8aac..8b448c95d 100644 --- a/neural_speed/convert/convert_llama.py +++ b/neural_speed/convert/convert_llama.py @@ -31,10 +31,11 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, - Union) +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, + TypeVar, Union) import numpy as np from sentencepiece import SentencePieceProcessor # type: ignore +import gguf if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -153,8 +154,9 @@ class Params: @staticmethod def guessed(model: 'LazyModel') -> 'Params': - n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model[ - "tok_embeddings.weight"].shape + n_vocab, n_embd = model[ + "model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model[ + "tok_embeddings.weight"].shape return Params( n_vocab=n_vocab, @@ -192,7 +194,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: Path) -> 'Params': ) # LLaMA v2 70B params.json - # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, + # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, # "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} @staticmethod def loadOriginalParamsJson(model: 'LazyModel', config_path: Path) -> 'Params': @@ -420,7 +422,8 @@ def __init__(self, ndarray: NDArray, shape: List[int], data_type: DataType) -> N assert isinstance(data_type, QuantizedDataType) # redundant, but mypy complains without this assert columns % data_type.groupsize == 0 words_in_block = 6 if data_type == DT_Q4_1 else 5 - self.ndarray = ndarray.view(dtype=np.uint32).reshape((rows, columns // data_type.groupsize, words_in_block)) + self.ndarray = ndarray.view(dtype=np.uint32).reshape( + (rows, columns // data_type.groupsize, words_in_block)) self.shape = shape[:] self.data_type = data_type @@ -612,8 +615,7 @@ def validate_conversion_to(self, data_type: DataType) -> None: sys.stderr.write( "Error: Input uses the newer GPTQ-for-LLaMa format (using g_idx), which is not yet natively\ supported by NE. For now you can still convert this model by passing `--outtype f16` to \ - dequantize, but that will result in a much larger output file for no quality benefit.\n" - ) + dequantize, but that will result in a much larger output file for no quality benefit.\n") sys.exit(1) assert not data_type.have_g_idx and self.data_type.have_addends and data_type.have_addends @@ -803,6 +805,7 @@ def load(offset: int, elm_count: int) -> NDArray: description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' return LazyStorage(load=load, kind=pid[1], description=description) + # @staticmethod def lazy_rebuild_tensor_v2( @@ -1064,7 +1067,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None: self.fout.write(struct.pack("f", params.rms_norm_eps)) self.fout.write(struct.pack("f", params.rope_theta)) - # TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json + # TODO, bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json # but bos_token_id = 1 in llama.cpp self.fout.write(struct.pack("i", 1)) self.fout.write(struct.pack("i", 2)) @@ -1087,7 +1090,12 @@ def write_vocab(self, vocab: Vocab) -> None: @staticmethod def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0, file_type=NEFileType.AllF32) + params = Params(n_vocab=vocab.vocab_size, + n_embd=0, + n_mult=0, + n_head=1, + n_layer=0, + file_type=NEFileType.AllF32) of = OutputFile(fname_out) of.write_file_header(params) of.write_vocab(vocab) @@ -1107,17 +1115,130 @@ def do_item(item: Tuple[str, LazyTensor]) -> NDArray: ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) - print( - f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | \ - type {lazy_tensor.data_type}" + print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | \ + type {lazy_tensor.data_type}") + of.write_tensor_header( + name, + lazy_tensor.shape, + lazy_tensor.data_type, ) - of.write_tensor_header(name, lazy_tensor.shape, lazy_tensor.data_type) ndarray.tofile(of.fout) of.fout.close() +class OutputFile_GGUF: + def __init__(self, fname_out: Path) -> None: + self.fout = open(fname_out, "wb") + self.gguf_file = str(fname_out) + '.gguf' + self.gguf_writer = gguf.GGUFWriter(self.gguf_file, "llama_ITREX") + + def write_file_header(self, params: Params, file_type: NEFileType) -> None: + # # [1, 32000, 4096, 256, 32, 32, 32, 128, 0] + self.gguf_writer.add_uint32('magic', 0x67676d66) + self.gguf_writer.add_uint32('version', 1) + self.gguf_writer.add_uint32('n_vocab', params.n_vocab) + self.gguf_writer.add_uint32('n_embd', params.n_embd) + self.gguf_writer.add_uint32('n_mult', params.n_mult) + self.gguf_writer.add_uint32('n_head', params.n_head) + self.gguf_writer.add_uint32('n_head_kv', params.n_head_kv) + self.gguf_writer.add_uint32('n_layer', params.n_layer) + self.gguf_writer.add_uint32('n_rot', params.n_embd // params.n_head) + self.gguf_writer.add_uint32('ftype', file_type.value) + + self.gguf_writer.add_uint32('ffn_hidden_size', params.ffn_hidden_size) + self.gguf_writer.add_float32('rms_norm_eps', params.rms_norm_eps) + self.gguf_writer.add_float32('rope_theta', params.rope_theta) + + # TODO: + # bos_token_id = 0 in https://huggingface.co/decapoda-research/llama-7b-hf/blob/main/config.json + # but bos_token_id = 1 in llama.cpp + self.gguf_writer.add_int32('bos_token_id', 1) + self.gguf_writer.add_int32('eos_token_id', 2) + self.gguf_writer.add_int32('pad_token_id', 0) + self.gguf_writer.add_int32('sep_token_id', 0) + + def write_tensor_header_gguf(self, name: str, shape: Sequence[int], data_type: DataType, data) -> None: + # sname = name.encode('utf-8') + # self.fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[data_type])) + # self.fout.write(struct.pack("i" * len(shape), *shape[::-1])) + # self.fout.write(sname) + # self.fout.seek((self.fout.tell() + 31) & -32) + self.gguf_writer.add_tensor(name, data) + + def end(self): + + print("gguf: write header") + self.gguf_writer.write_header_to_file() + print("gguf: write metadata") + self.gguf_writer.write_kv_data_to_file() + print("gguf: write tensors") + self.gguf_writer.write_tensors_to_file() + + self.gguf_writer.close() + + def write_vocab_gguf(self, vocab: Vocab) -> None: + # for text, score in vocab.all_tokens(): + # self.fout.write(struct.pack("i", len(text))) + # self.fout.write(text) + # self.fout.write(struct.pack("f", score)) + + print("gguf: get tokenizer metadata") + + tokens: List[bytes] = [] + scores: List[float] = [] + toktypes: List[int] = [] + + for text, score in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + + print("gguf: get tokenizer metadata done") + + @staticmethod + def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: + of = OutputFile_GGUF(fname_out) + params = Params(n_vocab=vocab.vocab_size, + n_embd=0, + n_mult=0, + n_head=1, + n_layer=0, + file_type=NEFileType.AllF32) + of = OutputFile_GGUF(fname_out) + of.write_file_header(params) + of.write_vocab_gguf(vocab) + of.fout.close() + + @staticmethod + def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab, file_type: NEFileType) -> None: + check_vocab_size(params, vocab) + of = OutputFile_GGUF(fname_out) + of.write_file_header(params, file_type) + print("Writing vocab...") + of.write_vocab_gguf(vocab) + + def do_item(item: Tuple[str, LazyTensor]) -> NDArray: + name, lazy_tensor = item + return lazy_tensor.load().to_ne().ndarray + + ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | \ + type {lazy_tensor.data_type}") + of.write_tensor_header_gguf(name, lazy_tensor.shape, lazy_tensor.data_type, ndarray) + + of.end() + of.fout.close() + + def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> NEFileType: wq_type = model["layers.0.attention.wq.weight"].data_type if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): @@ -1239,8 +1360,7 @@ def load_vocab(path: Path) -> SentencePieceVocab: else: raise FileNotFoundError( f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \ - pass the directory as --vocab-dir" - ) + pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" print(f"Loading vocab file {path}") return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) @@ -1258,8 +1378,7 @@ def default_outfile(model_paths: List[Path], params: Params) -> Path: if ret in model_paths: sys.stderr.write( f"Error: Default output path ({ret}) would overwrite the input. Please explicitly specify\ - a path using --outfile.\n" - ) + a path using --outfile.\n") sys.exit(1) return ret @@ -1289,6 +1408,11 @@ def main(args_in: Optional[List[str]] = None) -> None: parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("--format", + type=str, + default="NE", + choices=["NE", "GGUF"], + help="convert to the GGUF or NE format") args = parser.parse_args(args_in) vocab: Vocab @@ -1306,18 +1430,25 @@ def main(args_in: Optional[List[str]] = None) -> None: if args.dump: do_dump_model(model_plus) return + if model_plus.vocab is not None and args.vocab_dir is None: vocab = model_plus.vocab else: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir) + model = model_plus.model params = Params.load(model_plus) model = do_necessary_conversions(model, params) output_type = pick_output_type(model, args.outtype) model = convert_to_output_type(model, output_type) outfile = args.outfile or default_outfile(model_plus.paths, params) - OutputFile.write_all(outfile, params, model, vocab, output_type) + + if args.format == "GGUF": + OutputFile_GGUF.write_all(outfile, params, model, vocab, output_type) + else: + OutputFile.write_all(outfile, params, model, vocab, output_type) + print(f"Wrote {outfile}") diff --git a/neural_speed/models/model_utils/gguf.h b/neural_speed/models/model_utils/gguf.h new file mode 100644 index 000000000..71cf1b86a --- /dev/null +++ b/neural_speed/models/model_utils/gguf.h @@ -0,0 +1,531 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Defines fileno on msys: + +#ifndef GGUF_H +#define GGUF_H + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#include +#include +#include +#endif + +#include "core/layers/bestla_common.hpp" +#include "core/ne_layers.h" +#include "models/model_utils/util.h" + +#define GGML_MAX_DIMS 4 +#define GGUF_MAGIC "GGUF" + +enum ggml_log_level { GGML_LOG_LEVEL_ERROR = 2, GGML_LOG_LEVEL_WARN = 3, GGML_LOG_LEVEL_INFO = 4 }; + +typedef void (*ggml_log_callback)(enum ggml_log_level level, const char* text, void* user_data); +static void llama_log_callback_default(ggml_log_level level, const char* text, void* user_data) { + (void)level; + (void)user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct llama_state { + llama_state() {} + + // We save the log callback globally + ggml_log_callback log_callback = llama_log_callback_default; + void* log_callback_user_data = nullptr; +}; + +static llama_state g_state; + +static void llama_log_internal(ggml_log_level level, const char* format, ...); + +#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__) +#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) + +static void llama_log_internal_v(ggml_log_level level, const char* format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len + 1]; + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +static void llama_log_internal(ggml_log_level level, const char* format, ...) { + va_list args; + va_start(args, format); + llama_log_internal_v(level, format, args); + va_end(args); +} + +struct gguf_str { + uint64_t n; // GGUFv2 + char* data; +}; + +enum model_format { GGUF = 0, NE = 1, UNKNOWN = 2 }; + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, + GGUF_FILE_VERSION_V3 = 3, +}; + +enum ggml_type { + GGML_TYPE_F32 = 0, + GGML_TYPE_F16 = 1, + GGML_TYPE_Q4_0 = 2, + GGML_TYPE_Q4_1 = 3, + // GGML_TYPE_Q4_2 = 4, support has been removed + // GGML_TYPE_Q4_3 (5) support has been removed + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, + // k-quantizations + GGML_TYPE_Q2_K = 10, + GGML_TYPE_Q3_K = 11, + GGML_TYPE_Q4_K = 12, + GGML_TYPE_Q5_K = 13, + GGML_TYPE_Q6_K = 14, + GGML_TYPE_Q8_K = 15, + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_COUNT, +}; + +enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_UINT64 = 10, + GGUF_TYPE_INT64 = 11, + GGUF_TYPE_FLOAT64 = 12, + GGUF_TYPE_COUNT, // marks the end of the enum +}; + +static const char* GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = "u8", [GGUF_TYPE_INT8] = "i8", [GGUF_TYPE_UINT16] = "u16", [GGUF_TYPE_INT16] = "i16", + [GGUF_TYPE_UINT32] = "u32", [GGUF_TYPE_INT32] = "i32", [GGUF_TYPE_FLOAT32] = "f32", [GGUF_TYPE_BOOL] = "bool", + [GGUF_TYPE_STRING] = "str", [GGUF_TYPE_ARRAY] = "arr", [GGUF_TYPE_UINT64] = "u64", [GGUF_TYPE_INT64] = "i64", + [GGUF_TYPE_FLOAT64] = "f64", +}; + +union gguf_value { + uint8_t uint8; + int8_t int8; + uint16_t uint16; + int16_t int16; + uint32_t uint32; + int32_t int32; + float float32; + uint64_t uint64; + int64_t int64; + double float64; + bool bool_; + + struct gguf_str str; + + struct { + enum gguf_type type; + + uint64_t n; // GGUFv2 + void* data; + } arr; +}; + +struct gguf_kv { + struct gguf_str key; + + enum gguf_type type; + union gguf_value value; +}; + +struct gguf_header { + char magic[4]; + uint32_t version; + uint64_t n_tensors; // GGUFv2 + uint64_t n_kv; // GGUFv2 +}; + +struct gguf_context { + struct gguf_header header; + + struct gguf_kv* kv; + struct gguf_tensor_info* infos; + + size_t alignment; + size_t offset; // offset of `data` from beginning of file + size_t size; // size of `data` in bytes + + // uint8_t * padding; + void* data; +}; + +#if UINTPTR_MAX == 0xFFFFFFFF +#define GGML_MEM_ALIGN 4 +#else +#define GGML_MEM_ALIGN 16 +#endif + +#define GGUF_DEFAULT_ALIGNMENT 32 + +static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = sizeof(uint8_t), + [GGUF_TYPE_INT8] = sizeof(int8_t), + [GGUF_TYPE_UINT16] = sizeof(uint16_t), + [GGUF_TYPE_INT16] = sizeof(int16_t), + [GGUF_TYPE_UINT32] = sizeof(uint32_t), + [GGUF_TYPE_INT32] = sizeof(int32_t), + [GGUF_TYPE_FLOAT32] = sizeof(float), + [GGUF_TYPE_BOOL] = sizeof(bool), + [GGUF_TYPE_STRING] = sizeof(struct gguf_str), + [GGUF_TYPE_ARRAY] = 0, // undefined + [GGUF_TYPE_UINT64] = sizeof(uint64_t), + [GGUF_TYPE_INT64] = sizeof(int64_t), + [GGUF_TYPE_FLOAT64] = sizeof(double), +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_STARCODER, + LLM_ARCH_PERSIMMON, + LLM_ARCH_REFACT, + LLM_ARCH_BLOOM, + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_CHATGLM2, + LLM_ARCH_UNKNOWN, +}; + +static std::map LLM_ARCH_NAMES = { + {LLM_ARCH_LLAMA, "llama"}, {LLM_ARCH_FALCON, "falcon"}, {LLM_ARCH_GPT2, "gpt2"}, + {LLM_ARCH_GPTJ, "gptj"}, {LLM_ARCH_GPTNEOX, "gptneox"}, {LLM_ARCH_MPT, "mpt"}, + {LLM_ARCH_BAICHUAN, "baichuan"}, {LLM_ARCH_STARCODER, "starcoder"}, {LLM_ARCH_PERSIMMON, "persimmon"}, + {LLM_ARCH_REFACT, "refact"}, {LLM_ARCH_BLOOM, "bloom"}, {LLM_ARCH_STABLELM, "stablelm"}, + {LLM_ARCH_QWEN, "qwen"}, {LLM_ARCH_CHATGLM2, "chatglm2"}, +}; + +struct gguf_tensor_info { + struct gguf_str name; + + uint32_t n_dims; + uint64_t ne[GGML_MAX_DIMS]; + + enum ggml_type type; + + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` + + // for writing API + const void* data; + size_t size; +}; + +static bool gguf_fread_el(FILE* file, void* dst, size_t size, size_t* offset) { + const size_t n = fread(dst, 1, size, file); + *offset += n; + return n == size; +} + +static bool gguf_fread_str(FILE* file, struct gguf_str* p, size_t* offset) { + p->n = 0; + p->data = NULL; + + bool ok = true; + + ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset); + p->data = reinterpret_cast(calloc(p->n + 1, 1)); + ok = ok && gguf_fread_el(file, p->data, p->n, offset); + + return ok; +} + +static const char* llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: + return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: + return "GGUF V2"; + case GGUF_FILE_VERSION_V3: + return "GGUF V3 (latest)"; + } + + return "unknown"; +} + +inline static void* ggml_aligned_malloc(size_t size) { + if (size == 0) { + printf("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); + return NULL; + } + void* aligned_memory = NULL; +#ifdef GGML_USE_CPU_HBM + int result = hbw_posix_memalign(&aligned_memory, 16, size); +#elif GGML_USE_METAL + int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size); +#else + int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size); +#endif + if (result != 0) { + // Handle allocation failure + const char* error_desc = "unknown allocation error"; + switch (result) { + case EINVAL: + error_desc = "invalid alignment value"; + break; + case ENOMEM: + error_desc = "insufficient memory"; + break; + } + printf("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size / (1024.0 * 1024.0)); + return NULL; + } + return aligned_memory; +} +#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) + +#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ + do { \ + const std::string skey(key); \ + const int kid = gguf_find_key(ctx, skey.c_str()); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \ + } \ + } while (0) + +static void replace_all(std::string& s, const std::string& search, const std::string& replace) { + std::string result; + for (size_t pos = 0;; pos += search.length()) { + auto new_pos = s.find(search, pos); + if (new_pos == std::string::npos) { + result += s.substr(pos, s.size() - pos); + break; + } + result += s.substr(pos, new_pos - pos) + replace; + pos = new_pos; + } + s = std::move(result); +} + +static uint32_t codepoint_from_utf8(const std::string& utf8, size_t& offset) { + assert(offset < utf8.size()); + if (!(utf8[offset + 0] & 0x80)) { + auto result = utf8[offset + 0]; + offset += 1; + return result; + } else if (!(utf8[offset + 0] & 0x40)) { + throw std::invalid_argument("invalid character"); + } else if (!(utf8[offset + 0] & 0x20)) { + if (offset + 1 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); + offset += 2; + return result; + } else if (!(utf8[offset + 0] & 0x10)) { + if (offset + 2 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); + offset += 3; + return result; + } else if (!(utf8[offset + 0] & 0x08)) { + if (offset + 3 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80) || + !((utf8[offset + 3] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | + ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); + offset += 4; + return result; + } + throw std::invalid_argument("invalid string"); +} + +static std::vector codepoints_from_utf8(const std::string& utf8) { + std::vector result; + size_t offset = 0; + while (offset < utf8.size()) { + result.push_back(codepoint_from_utf8(utf8, offset)); + } + return result; +} + +enum llm_kv { + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_ADD_BOS, + LLM_KV_TOKENIZER_ADD_EOS, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, +}; + +static std::map LLM_KV_NAMES = { + {LLM_KV_GENERAL_ARCHITECTURE, "general.architecture"}, + {LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version"}, + {LLM_KV_GENERAL_ALIGNMENT, "general.alignment"}, + {LLM_KV_GENERAL_NAME, "general.name"}, + {LLM_KV_GENERAL_AUTHOR, "general.author"}, + {LLM_KV_GENERAL_URL, "general.url"}, + {LLM_KV_GENERAL_DESCRIPTION, "general.description"}, + {LLM_KV_GENERAL_LICENSE, "general.license"}, + {LLM_KV_GENERAL_SOURCE_URL, "general.source.url"}, + {LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository"}, + + {LLM_KV_CONTEXT_LENGTH, "%s.context_length"}, + {LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length"}, + {LLM_KV_BLOCK_COUNT, "%s.block_count"}, + {LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length"}, + {LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual"}, + {LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout"}, + + {LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count"}, + {LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv"}, + {LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias"}, + {LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv"}, + {LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon"}, + {LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon"}, + + {LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count"}, + {LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base"}, + {LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear"}, + {LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type"}, + {LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor"}, + {LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length"}, + {LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned"}, + + {LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model"}, + {LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens"}, + {LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type"}, + {LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores"}, + {LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges"}, + {LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id"}, + {LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id"}, + {LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id"}, + {LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id"}, + {LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id"}, + {LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token"}, + {LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token"}, + {LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json"}, + {LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world"}, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + std::string operator()(llm_kv kv) const { return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str()); } +}; + +static std::string gguf_data_to_str(enum gguf_type type, const void* data, int i) { + switch (type) { + case GGUF_TYPE_UINT8: + return std::to_string(((const uint8_t*)data)[i]); + case GGUF_TYPE_INT8: + return std::to_string(((const int8_t*)data)[i]); + case GGUF_TYPE_UINT16: + return std::to_string(((const uint16_t*)data)[i]); + case GGUF_TYPE_INT16: + return std::to_string(((const int16_t*)data)[i]); + case GGUF_TYPE_UINT32: + return std::to_string(((const uint32_t*)data)[i]); + case GGUF_TYPE_INT32: + return std::to_string(((const int32_t*)data)[i]); + case GGUF_TYPE_UINT64: + return std::to_string(((const uint64_t*)data)[i]); + case GGUF_TYPE_INT64: + return std::to_string(((const int64_t*)data)[i]); + case GGUF_TYPE_FLOAT32: + return std::to_string(((const float*)data)[i]); + case GGUF_TYPE_FLOAT64: + return std::to_string(((const double*)data)[i]); + case GGUF_TYPE_BOOL: + return ((const bool*)data)[i] ? "true" : "false"; + default: + return format("unknown type %d", type); + } +} + +#endif // GGUF_H diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index aa154c998..30366104e 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -33,6 +33,8 @@ #include "core/ne_layers.h" #include "models/model_utils/util.h" #include "models/models.h" +#include "models/model_utils/gguf.h" +#include template static T checked_mul(T a, T b) { @@ -211,20 +213,759 @@ struct model_load_tensors_map { std::unordered_map name_to_idx; }; +struct gguf_loader { + FILE* gguf_file; + + gguf_loader(FILE* ne_file) : gguf_file(ne_file) {} + + const char* gguf_type_name(enum gguf_type type) { return GGUF_TYPE_NAME[type]; } + + int gguf_get_version(const struct gguf_context* ctx) { return ctx->header.version; } + + size_t gguf_get_alignment(const struct gguf_context* ctx) { return ctx->alignment; } + + size_t gguf_get_data_offset(const struct gguf_context* ctx) { return ctx->offset; } + + void* gguf_get_data(const struct gguf_context* ctx) { return ctx->data; } + + int gguf_get_n_kv(const struct gguf_context* ctx) { return ctx->header.n_kv; } + + int gguf_find_key(const struct gguf_context* ctx, const char* key) { + // return -1 if key not found + int keyfound = -1; + + const int n_kv = gguf_get_n_kv(ctx); + + for (int i = 0; i < n_kv; ++i) { + if (strcmp(key, gguf_get_key(ctx, i)) == 0) { + keyfound = i; + break; + } + } + + return keyfound; + } + + const char* gguf_get_key(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].key.data; + } + + enum gguf_type gguf_get_kv_type(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].type; + } + + enum gguf_type gguf_get_arr_type(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.type; + } + + const void* gguf_get_arr_data(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.data; + } + + const char* gguf_get_arr_str(const struct gguf_context* ctx, int key_id, int i) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + struct gguf_kv* kv = &ctx->kv[key_id]; + struct gguf_str* str = &((struct gguf_str*)kv->value.arr.data)[i]; + return str->data; + } + + int gguf_get_arr_n(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.n; + } + + uint8_t gguf_get_val_u8(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); + return ctx->kv[key_id].value.uint8; + } + + int8_t gguf_get_val_i8(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); + return ctx->kv[key_id].value.int8; + } + + uint16_t gguf_get_val_u16(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); + return ctx->kv[key_id].value.uint16; + } + + int16_t gguf_get_val_i16(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); + return ctx->kv[key_id].value.int16; + } + + uint32_t gguf_get_val_u32(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); + return ctx->kv[key_id].value.uint32; + } + + int32_t gguf_get_val_i32(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); + return ctx->kv[key_id].value.int32; + } + + float gguf_get_val_f32(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); + return ctx->kv[key_id].value.float32; + } + + uint64_t gguf_get_val_u64(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); + return ctx->kv[key_id].value.uint64; + } + + int64_t gguf_get_val_i64(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); + return ctx->kv[key_id].value.int64; + } + + double gguf_get_val_f64(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); + return ctx->kv[key_id].value.float64; + } + + bool gguf_get_val_bool(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); + return ctx->kv[key_id].value.bool_; + } + + const char* gguf_get_val_str(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); + return ctx->kv[key_id].value.str.data; + } + + const void* gguf_get_val_data(const struct gguf_context* ctx, int key_id) { + NE_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + NE_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY); + NE_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING); + return &ctx->kv[key_id].value; + } + + int gguf_get_n_tensors(const struct gguf_context* ctx) { return ctx->header.n_tensors; } + + int gguf_find_tensor(const struct gguf_context* ctx, const char* name) { + // return -1 if tensor not found + int tensorfound = -1; + + const int n_tensors = gguf_get_n_tensors(ctx); + + for (int i = 0; i < n_tensors; ++i) { + if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { + tensorfound = i; + break; + } + } + + return tensorfound; + } + + size_t gguf_get_tensor_offset(const struct gguf_context* ctx, int i) { return ctx->infos[i].offset; } + + char* gguf_get_tensor_name(const struct gguf_context* ctx, int i) { return ctx->infos[i].name.data; } + + // returns the index + // remove static + int gguf_get_or_add_key(struct gguf_context* ctx, const char* key) { + const int idx = gguf_find_key(ctx, key); + if (idx >= 0) { + return idx; + } + + const int n_kv = gguf_get_n_kv(ctx); + + ctx->kv = reinterpret_cast(realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv))); + ctx->kv[n_kv].key.n = strlen(key); + ctx->kv[n_kv].key.data = strdup(key); + ctx->header.n_kv++; + + return n_kv; + } + + // remove static + std::string gguf_kv_to_str(struct gguf_context* ctx_gguf, int i) { + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case GGUF_TYPE_STRING: + return gguf_get_val_str(ctx_gguf, i); + case GGUF_TYPE_ARRAY: { + const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); + int arr_n = gguf_get_arr_n(ctx_gguf, i); + const void* data = gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == GGUF_TYPE_STRING) { + std::string val = gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + replace_all(val, "\\", "\\\\"); + replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); + } + } + + size_t file_offset(const struct gguf_context* ctx_gguf, const char* name) { + const int idx = gguf_find_tensor(ctx_gguf, name); + + if (idx < 0) { + throw std::runtime_error(format("%s: tensor '%s' not found in the file", __func__, name)); + } + + size_t data_offset = gguf_get_data_offset(ctx_gguf); + size_t tensor_offset = gguf_get_tensor_offset(ctx_gguf, idx); + return data_offset + tensor_offset; + } + + void gguf_set_val_u8(struct gguf_context* ctx, const char* key, uint8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT8; + ctx->kv[idx].value.uint8 = val; + } + + void gguf_set_val_i8(struct gguf_context* ctx, const char* key, int8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT8; + ctx->kv[idx].value.int8 = val; + } + + void gguf_set_val_u16(struct gguf_context* ctx, const char* key, uint16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT16; + ctx->kv[idx].value.uint16 = val; + } + + void gguf_set_val_i16(struct gguf_context* ctx, const char* key, int16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT16; + ctx->kv[idx].value.int16 = val; + } + + void gguf_set_val_u32(struct gguf_context* ctx, const char* key, uint32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT32; + ctx->kv[idx].value.uint32 = val; + } + + void gguf_set_val_i32(struct gguf_context* ctx, const char* key, int32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT32; + ctx->kv[idx].value.int32 = val; + } + + void gguf_set_val_f32(struct gguf_context* ctx, const char* key, float val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_FLOAT32; + ctx->kv[idx].value.float32 = val; + } + + void gguf_set_val_u64(struct gguf_context* ctx, const char* key, uint64_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT64; + ctx->kv[idx].value.uint64 = val; + } + + void gguf_set_val_i64(struct gguf_context* ctx, const char* key, int64_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT64; + ctx->kv[idx].value.int64 = val; + } + + void gguf_set_val_f64(struct gguf_context* ctx, const char* key, double val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_FLOAT64; + ctx->kv[idx].value.float64 = val; + } + + void gguf_set_val_bool(struct gguf_context* ctx, const char* key, bool val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_BOOL; + ctx->kv[idx].value.bool_ = val; + } + + void gguf_set_val_str(struct gguf_context* ctx, const char* key, const char* val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_STRING; + ctx->kv[idx].value.str.n = strlen(val); + ctx->kv[idx].value.str.data = strdup(val); + } + + void gguf_set_arr_data(struct gguf_context* ctx, const char* key, enum gguf_type type, const void* data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = type; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n * GGUF_TYPE_SIZE[type]); + memcpy(ctx->kv[idx].value.arr.data, data, n * GGUF_TYPE_SIZE[type]); + } + + void gguf_set_arr_str(struct gguf_context* ctx, const char* key, const char** data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n * sizeof(struct gguf_str)); + for (int i = 0; i < n; i++) { + struct gguf_str* str = &((struct gguf_str*)ctx->kv[idx].value.arr.data)[i]; + str->n = strlen(data[i]); + str->data = strdup(data[i]); + } + } + + void gguf_free(struct gguf_context* ctx) { + if (ctx == NULL) { + return; + } + + if (ctx->kv) { + // free string memory - not great.. + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv* kv = &ctx->kv[i]; + + if (kv->key.data) { + free(kv->key.data); + } + + if (kv->type == GGUF_TYPE_STRING) { + if (kv->value.str.data) { + free(kv->value.str.data); + } + } + + if (kv->type == GGUF_TYPE_ARRAY) { + if (kv->value.arr.data) { + if (kv->value.arr.type == GGUF_TYPE_STRING) { + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + struct gguf_str* str = &((struct gguf_str*)kv->value.arr.data)[j]; + if (str->data) { + free(str->data); + } + } + } + free(kv->value.arr.data); + } + } + } + + free(ctx->kv); + } + + if (ctx->infos) { + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info* info = &ctx->infos[i]; + + if (info->name.data) { + free(info->name.data); + } + } + + free(ctx->infos); + } + } + + struct gguf_context* gguf_init_from_file(model_load_tensors_map& tensors_map, size_t& gguf_data_offset) { + if (!gguf_file) { + return nullptr; + } + + size_t offset = 0; + char magic[4]; + + gguf_fread_el(gguf_file, &magic, sizeof(magic), &offset); + + struct gguf_context* ctx = reinterpret_cast(GGML_ALIGNED_MALLOC(sizeof(struct gguf_context))); + ctx->offset = 0; + // read the header + strncpy(ctx->header.magic, magic, 4); + + bool ok = true; + ctx->kv = NULL; + ctx->infos = NULL; + ctx->data = NULL; + + ok = ok && gguf_fread_el(gguf_file, &ctx->header.version, sizeof(ctx->header.version), &offset); + ok = ok && gguf_fread_el(gguf_file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); + ok = ok && gguf_fread_el(gguf_file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); + + if (ctx->header.version == 1) { + fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); + fclose(gguf_file); + gguf_free(ctx); + return nullptr; + } + + if (!ok) { + fprintf(stderr, "%s: failed to read header\n", __func__); + fclose(gguf_file); + gguf_free(ctx); + return nullptr; + } + + // read the kv pairs + ctx->kv = reinterpret_cast(malloc(ctx->header.n_kv * sizeof(struct gguf_kv))); + + for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv* kv = &ctx->kv[i]; + + ok = ok && gguf_fread_str(gguf_file, &kv->key, &offset); + ok = ok && gguf_fread_el(gguf_file, &kv->type, sizeof(kv->type), &offset); + + switch (kv->type) { + case GGUF_TYPE_UINT8: + ok = ok && gguf_fread_el(gguf_file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); + break; + case GGUF_TYPE_INT8: + ok = ok && gguf_fread_el(gguf_file, &kv->value.int8, sizeof(kv->value.int8), &offset); + break; + case GGUF_TYPE_UINT16: + ok = ok && gguf_fread_el(gguf_file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); + break; + case GGUF_TYPE_INT16: + ok = ok && gguf_fread_el(gguf_file, &kv->value.int16, sizeof(kv->value.int16), &offset); + break; + case GGUF_TYPE_UINT32: + ok = ok && gguf_fread_el(gguf_file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); + break; + case GGUF_TYPE_INT32: + ok = ok && gguf_fread_el(gguf_file, &kv->value.int32, sizeof(kv->value.int32), &offset); + break; + case GGUF_TYPE_FLOAT32: + ok = ok && gguf_fread_el(gguf_file, &kv->value.float32, sizeof(kv->value.float32), &offset); + break; + case GGUF_TYPE_UINT64: + ok = ok && gguf_fread_el(gguf_file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); + break; + case GGUF_TYPE_INT64: + ok = ok && gguf_fread_el(gguf_file, &kv->value.int64, sizeof(kv->value.int64), &offset); + break; + case GGUF_TYPE_FLOAT64: + ok = ok && gguf_fread_el(gguf_file, &kv->value.float64, sizeof(kv->value.float64), &offset); + break; + case GGUF_TYPE_BOOL: + ok = ok && gguf_fread_el(gguf_file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); + break; + case GGUF_TYPE_STRING: + ok = ok && gguf_fread_str(gguf_file, &kv->value.str, &offset); + break; + case GGUF_TYPE_ARRAY: { + ok = ok && gguf_fread_el(gguf_file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); + ok = ok && gguf_fread_el(gguf_file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); + + switch (kv->value.arr.type) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: { + kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); + ok = ok && gguf_fread_el(gguf_file, kv->value.arr.data, + kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type], &offset); + } break; + case GGUF_TYPE_STRING: { + kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str)); + for (uint64_t j = 0; j < kv->value.arr.n; ++j) { + ok = ok && gguf_fread_str(gguf_file, &((struct gguf_str*)kv->value.arr.data)[j], &offset); + } + } break; + case GGUF_TYPE_ARRAY: + case GGUF_TYPE_COUNT: + printf("False && invalid type"); + break; // NE_ASSERT(false && "invalid type"); break; + } + } break; + case GGUF_TYPE_COUNT: + printf("False && invalid type"); // NE_ASSERT(false && "invalid type"); + } + + if (!ok) { + break; + } + } + + // read the tensor infos + ctx->infos = + reinterpret_cast(malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info))); + + for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info* info = &ctx->infos[i]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + info->ne[j] = 1; + } + + ok = ok && gguf_fread_str(gguf_file, &info->name, &offset); + ok = ok && gguf_fread_el(gguf_file, &info->n_dims, sizeof(info->n_dims), &offset); + for (uint32_t j = 0; j < info->n_dims; ++j) { + ok = ok && gguf_fread_el(gguf_file, &info->ne[j], sizeof(info->ne[j]), &offset); + } + ok = ok && gguf_fread_el(gguf_file, &info->type, sizeof(info->type), &offset); + ok = ok && gguf_fread_el(gguf_file, &info->offset, sizeof(info->offset), &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor info\n", __func__); + fclose(gguf_file); + gguf_free(ctx); + return nullptr; + } + + model_load_tensor_shard shard; + std::string name = gguf_get_tensor_name(ctx, i); + uint32_t name_len = name.length(); + shard.type = (enum ne_type)0; + + uint32_t n_dims = info->n_dims; + shard.ne.resize(n_dims); + for (uint32_t j = 0; j < info->n_dims; ++j) { + shard.ne[j] = info->ne[j]; + } + + if (n_dims < 1 || n_dims > 2) { + throw format("model.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims); + } + switch (shard.type) { + case NE_TYPE_F32: + case NE_TYPE_F16: + case NE_TYPE_Q4_0: + case NE_TYPE_Q4_1: + case NE_TYPE_Q5_0: + case NE_TYPE_Q5_1: + case NE_TYPE_Q8_0: + case NE_TYPE_BTLA: + break; + default: { + throw format("unrecognized tensor type %u\n", shard.type); + } + } + shard.file_idx = 0; + const size_t offs = file_offset(ctx, name.c_str()); + int length = info->ne[0] * info->ne[1] * info->ne[2] * info->ne[3] * 4; + + shard.file_off = offs; + + auto it = tensors_map.name_to_idx.find(name); + size_t idx; + if (it != tensors_map.name_to_idx.end()) { + idx = it->second; + } else { + tensors_map.tensors.emplace_back(name); + idx = tensors_map.tensors.size() - 1; + tensors_map.name_to_idx.emplace(name, idx); + } + tensors_map.tensors.at(idx).shards.push_back(shard); + } + + ctx->alignment = GGUF_DEFAULT_ALIGNMENT; + + int alignment_idx = gguf_find_key(ctx, "general.alignment"); + if (alignment_idx != -1) { + ctx->alignment = gguf_get_val_u32(ctx, alignment_idx); + } + + const size_t offset_pad = offset % ctx->alignment; + + if (offset_pad != 0) { + offset += ctx->alignment - offset_pad; + // fseek(file, offset, SEEK_SET); + } + + ctx->offset = offset; + gguf_data_offset = offset; + + return ctx; + } + + void gguf_load_from_file(struct gguf_context* ctx_gguf, model_hparams& hparams, model_vocab& vocab) { + int n_kv = 0; + n_kv = gguf_get_n_kv(ctx_gguf); + + int n_tensors = 0; + n_tensors = gguf_get_n_tensors(ctx_gguf); + + llama_fver fver; + fver = (enum llama_fver)gguf_get_version(ctx_gguf); + printf("%s: loaded meta data with %d key-value pairs and %d tensors (version %s)\n", __func__, n_kv, n_tensors, + llama_file_version_name(fver)); + + for (int i = 0; i < n_kv; i++) { + const char* name = gguf_get_key(ctx_gguf, i); + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + const std::string type_name = + type == GGUF_TYPE_ARRAY ? format("%s[%s,%d]", gguf_type_name(type), + gguf_type_name(gguf_get_arr_type(ctx_gguf, i)), gguf_get_arr_n(ctx_gguf, i)) + : gguf_type_name(type); + + std::string value = gguf_kv_to_str(ctx_gguf, i); + const size_t MAX_VALUE_LEN = 40; + if (value.size() > MAX_VALUE_LEN) { + value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); + } + replace_all(value, "\n", "\\n"); + + printf("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + } + + uint32_t magic = -1; + uint32_t version = -1; + std::string arch = "unknown"; + GGUF_GET_KEY(ctx_gguf, arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecuture"); + GGUF_GET_KEY(ctx_gguf, magic, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "magic"); + GGUF_GET_KEY(ctx_gguf, version, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "version"); + + // get hparams kv + GGUF_GET_KEY(ctx_gguf, hparams.n_vocab, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "n_vocab"); + GGUF_GET_KEY(ctx_gguf, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "n_embd"); + GGUF_GET_KEY(ctx_gguf, hparams.n_mult, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "n_mult"); + GGUF_GET_KEY(ctx_gguf, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "n_head"); + GGUF_GET_KEY(ctx_gguf, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "n_head_kv"); + GGUF_GET_KEY(ctx_gguf, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "n_layer"); + GGUF_GET_KEY(ctx_gguf, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "n_rot"); + + uint32_t ftype = 1; + GGUF_GET_KEY(ctx_gguf, ftype, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "ftype"); + hparams.ftype = (enum ne_ftype)ftype; + + GGUF_GET_KEY(ctx_gguf, hparams.max_seq_len, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "max_seq_len"); + GGUF_GET_KEY(ctx_gguf, hparams.alibi_bias_max, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "alibi_bias_max"); + GGUF_GET_KEY(ctx_gguf, hparams.clip_qkv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "clip_qkv"); + GGUF_GET_KEY(ctx_gguf, hparams.par_res, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "par_res"); + + GGUF_GET_KEY(ctx_gguf, hparams.word_embed_proj_dim, gguf_get_val_u32, GGUF_TYPE_UINT32, false, + "word_embed_proj_dim"); + GGUF_GET_KEY(ctx_gguf, hparams.do_layer_norm_before, gguf_get_val_u32, GGUF_TYPE_UINT32, false, + "do_layer_norm_before"); + + GGUF_GET_KEY(ctx_gguf, hparams.multi_query_group_num, gguf_get_val_u32, GGUF_TYPE_UINT32, false, + "multi_query_group_num"); + GGUF_GET_KEY(ctx_gguf, hparams.ffn_hidden_size, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "ffn_hidden_size"); + GGUF_GET_KEY(ctx_gguf, hparams.inner_hidden_size, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "inner_hidden_size"); + + GGUF_GET_KEY(ctx_gguf, vocab.bos_token_id, gguf_get_val_i32, GGUF_TYPE_INT32, false, "bos_token_id"); + GGUF_GET_KEY(ctx_gguf, vocab.eos_token_id, gguf_get_val_i32, GGUF_TYPE_INT32, false, "eos_token_id"); + GGUF_GET_KEY(ctx_gguf, vocab.pad_token_id, gguf_get_val_i32, GGUF_TYPE_INT32, false, "pad_token_id"); + GGUF_GET_KEY(ctx_gguf, vocab.sep_token_id, gguf_get_val_i32, GGUF_TYPE_INT32, false, "sep_token_id"); + + // load vocab + std::string tokens = "tokenizer.ggml.tokens"; + const int token_idx = gguf_find_key(ctx_gguf, tokens.c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const float* scores = nullptr; + std::string scores_name = "tokenizer.ggml.scores"; + const int score_idx = gguf_find_key(ctx_gguf, scores_name.c_str()); + if (score_idx != -1) { + scores = (const float*)gguf_get_arr_data(ctx_gguf, score_idx); + } + + const uint32_t n_vocab = gguf_get_arr_n(ctx_gguf, token_idx); + + vocab.id_to_token.resize(hparams.n_vocab); + for (uint32_t i = 0; i < n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx_gguf, token_idx, i); + // NE_ASSERT(codepoints_from_utf8(word).size() > 0); + + vocab.token_to_id[word] = i; + + auto& tok_score = vocab.id_to_token[i]; + tok_score.tok = std::move(word); + tok_score.score = scores ? scores[i] : 0.0f; + } + } +}; + struct model_file_loader { model_file file; model_file_version file_version; model_hparams hparams; model_vocab vocab; + size_t gguf_data_offset = 0; // offset of the GGUF tensor data from the beginning of the file. + enum model_format model_magic = UNKNOWN; + model_file_loader(const char* fname, size_t file_idx, model_load_tensors_map& tensors_map) : file(fname, "rb") { fprintf(stderr, "model.cpp: loading model from %s\n", fname); - read_magic(); - read_hparams(); - read_vocab(); - read_tensor_metadata(file_idx, tensors_map); + model_magic = read_file_magic(); + if (model_magic == GGUF) { + std::cout << "Loading the bin file with GGUF format..." << std::endl; + fseek(file.fp, 0, SEEK_SET); + model_magic = GGUF; + + gguf_loader gguf_loader(file.fp); + + struct gguf_context* ctx_gguf = NULL; + ctx_gguf = gguf_loader.gguf_init_from_file(tensors_map, gguf_data_offset); + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load model\n", __func__)); + } + + gguf_loader.gguf_load_from_file(ctx_gguf, hparams, vocab); + } else if (model_magic == NE) { + std::cout << "Loading the bin file with NE format..." << std::endl; + fseek(file.fp, 0, SEEK_SET); + read_ne_magic(); + read_ne_hparams(); + read_ne_vocab(); + read_tensor_metadata(file_idx, tensors_map); + } else { + throw format("unknown file format model_maigc = %d", model_magic); + } } - void read_magic() { + + void read_ne_magic() { uint32_t magic = file.read_u32(); if (magic == MODEL_FILE_MAGIC_NE) { @@ -258,7 +999,25 @@ struct model_file_loader { throw format("unknown (magic, version) combination: %08x, %08x; is this really a NE file?", magic, version); } - void read_hparams() { + + enum model_format read_file_magic() { + char gguf_magic[4]; + const size_t n = fread(&gguf_magic, 1, sizeof(gguf_magic), file.fp); + bool ok = true; + ok = ok & gguf_magic[0] == 'G'; + ok = ok & gguf_magic[1] == 'G'; + ok = ok & gguf_magic[2] == 'U'; + ok = ok & gguf_magic[3] == 'F'; + + if (ok) { + model_magic = GGUF; + } else { + model_magic = NE; + } + return model_magic; + } + + void read_ne_hparams() { hparams.n_vocab = file.read_u32(); hparams.n_embd = file.read_u32(); hparams.n_mult = file.read_u32(); @@ -286,13 +1045,13 @@ struct model_file_loader { file.read_raw(&hparams.freq_base, sizeof(float)); } - void read_vocab() { - vocab.id_to_token.resize(hparams.n_vocab); + void read_ne_vocab() { file.read_raw(&vocab.bos_token_id, sizeof(model_vocab::id)); file.read_raw(&vocab.eos_token_id, sizeof(model_vocab::id)); file.read_raw(&vocab.pad_token_id, sizeof(model_vocab::id)); file.read_raw(&vocab.sep_token_id, sizeof(model_vocab::id)); + vocab.id_to_token.resize(hparams.n_vocab); for (uint32_t i = 0; i < hparams.n_vocab; i++) { uint32_t len = file.read_u32(); std::string word = file.read_string(len); @@ -645,7 +1404,7 @@ struct model_model_loader { lt.data = (uint8_t*)mapping->addr + lt.shards.at(0).file_off; } else if (lt.split_type == SPLIT_NONE) { model_file& file = file_loaders.at(lt.shards.at(0).file_idx)->file; - file.seek(lt.shards.at(0).file_off, SEEK_SET); + file.seek(lt.shards.at(0).file_off + file_loaders.at(0)->gguf_data_offset, SEEK_SET); file.read_raw(lt.data, lt.size); } else if (lt.split_type == SPLIT_BY_ROWS) { size_t offset = 0; diff --git a/requirements.txt b/requirements.txt index f82674779..e30a99c1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ datasets transformers_stream_generator tiktoken py-cpuinfo -cmake \ No newline at end of file +gguf +cmake