diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d3e4651c749e5..c1e36ee28d202 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,10 +10,10 @@ on: push: branches: - master - paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu'] + paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift'] pull_request: types: [opened, synchronize, reopened] - paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu'] + paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift'] env: BRANCH_NAME: ${{ github.head_ref || github.ref_name }} @@ -258,7 +258,7 @@ jobs: strategy: matrix: - destination: ['platform=macOS,name=Any Mac', 'platform=iOS,name=Any iOS Device', 'platform=tvOS,name=Any tvOS Device'] + destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] steps: - name: Clone diff --git a/.gitignore b/.gitignore index 57209d3d3bf95..8eac3accec189 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ *.gcno *.gcda *.dot +*.metallib .DS_Store .build/ .cache/ diff --git a/Package.swift b/Package.swift index 3ee3b2a209f75..1ea414cc149fd 100644 --- a/Package.swift +++ b/Package.swift @@ -10,15 +10,18 @@ let platforms: [SupportedPlatform]? = [ .tvOS(.v14) ] let exclude: [String] = [] -let additionalSources: [String] = ["ggml-metal.m", "ggml-metal.metal"] +let resources: [Resource] = [ + .process("ggml-metal.metal") +] +let additionalSources: [String] = ["ggml-metal.m"] let additionalSettings: [CSetting] = [ .unsafeFlags(["-fno-objc-arc"]), - .define("GGML_SWIFT"), .define("GGML_USE_METAL") ] #else let platforms: [SupportedPlatform]? = nil let exclude: [String] = ["ggml-metal.metal"] +let resources: [Resource] = [] let additionalSources: [String] = [] let additionalSettings: [CSetting] = [] #endif @@ -40,6 +43,7 @@ let package = Package( "ggml-alloc.c", "k_quants.c", ] + additionalSources, + resources: resources, publicHeadersPath: "spm-headers", cSettings: [ .unsafeFlags(["-Wno-shorten-64-to-32"]), diff --git a/README.md b/README.md index e436818fa92c4..0562795620e69 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,7 @@ as the main playground for developing new features for the [ggml](https://github - [X] [Aquila-7B](https://huggingface.co/BAAI/Aquila-7B) / [AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) - [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187) - [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) +- [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim) **Bindings:** @@ -377,7 +378,7 @@ Building the program with BLAS support may lead to some performance improvements - #### cuBLAS - This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). + This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). - Using `make`: ```bash make LLAMA_CUBLAS=1 @@ -613,6 +614,18 @@ For more information, see [https://huggingface.co/docs/transformers/perplexity]( The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512. The time per token is measured on a MacBook M1 Pro 32GB RAM using 4 and 8 threads. +#### How to run + +1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research +2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw` +3. Output: +``` +perplexity : calculating perplexity over 655 chunks +24.43 seconds per pass - ETA 4.45 hours +[1]4.5970,[2]5.1807,[3]6.0382,... +``` +And after 4.45 hours, you will have the final perplexity. + ### Interactive mode If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter. @@ -775,18 +788,6 @@ If your issue is with model generation quality, then please at least scan the fo - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) -#### How to run - -1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research -2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw` -3. Output: -``` -perplexity : calculating perplexity over 655 chunks -24.43 seconds per pass - ETA 4.45 hours -[1]4.5970,[2]5.1807,[3]6.0382,... -``` -And after 4.45 hours, you will have the final perplexity. - ### Android #### Building the Project using Android NDK diff --git a/common/common.cpp b/common/common.cpp index 6b9b4695ca288..60b00b5fbb8f1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -167,6 +167,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } + // store the external file name in params + params.prompt_file = argv[i]; std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); if (params.prompt.back() == '\n') { params.prompt.pop_back(); @@ -1020,10 +1022,11 @@ llama_token llama_sample_token( id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling - llama_sample_top_k (ctx, &cur_p, top_k, 1); - llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); - llama_sample_typical (ctx, &cur_p, typical_p, 1); - llama_sample_top_p (ctx, &cur_p, top_p, 1); + size_t min_keep = std::max(1, params.n_probs); + llama_sample_top_k (ctx, &cur_p, top_k, min_keep); + llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); + llama_sample_typical (ctx, &cur_p, typical_p, min_keep); + llama_sample_top_p (ctx, &cur_p, top_p, min_keep); llama_sample_temp(ctx, &cur_p, temp); { diff --git a/common/common.h b/common/common.h index e095c56e309c2..c802152791797 100644 --- a/common/common.h +++ b/common/common.h @@ -79,6 +79,7 @@ struct gpt_params { std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string prompt = ""; + std::string prompt_file = ""; // store the external prompt file name std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with diff --git a/convert-persimmon-to-gguf.py b/convert-persimmon-to-gguf.py new file mode 100644 index 0000000000000..e022ffe46189e --- /dev/null +++ b/convert-persimmon-to-gguf.py @@ -0,0 +1,130 @@ +import torch +import os +from pprint import pprint +import sys +import argparse +from pathlib import Path +from sentencepiece import SentencePieceProcessor +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) +import gguf + +def _flatten_dict(dct, tensors, prefix=None): + assert isinstance(dct, dict) + for key in dct.keys(): + new_prefix = prefix + '.' + key if prefix is not None else key + if isinstance(dct[key], torch.Tensor): + tensors[new_prefix] = dct[key] + elif isinstance(dct[key], dict): + _flatten_dict(dct[key], tensors, new_prefix) + else: + raise ValueError(type(dct[key])) + return None + +def _get_sentencepiece_tokenizer_info(dir_model: Path): + tokenizer_path = dir_model / 'adept_vocab.model' + print('gguf: getting sentencepiece tokenizer from', tokenizer_path) + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + print('gguf: adding tokens') + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + for i in range(tokenizer.vocab_size()): + text: bytes + score: float + + piece = tokenizer.id_to_piece(i) + text = piece.encode("utf-8") + score = tokenizer.get_score(i) + + toktype = 1 + if tokenizer.is_unknown(i): + toktype = 2 + if tokenizer.is_control(i): + toktype = 3 + if tokenizer.is_unused(i): + toktype = 5 + if tokenizer.is_byte(i): + toktype = 6 + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + pass + return tokens, scores, toktypes + +def main(): + parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file") + parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release") + parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory") + args = parser.parse_args() + sys.path.append(str(args.adept_inference_dir)) + persimmon_model = torch.load(args.ckpt_path) + hparams = persimmon_model['args'] + pprint(hparams) + tensors = {} + _flatten_dict(persimmon_model['model'], tensors, None) + + arch = gguf.MODEL_ARCH.PERSIMMON + gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch]) + + block_count = hparams.num_layers + head_count = hparams.num_attention_heads + head_count_kv = head_count + ctx_length = hparams.seq_length + hidden_size = hparams.hidden_size + + gguf_writer.add_name('persimmon-8b-chat') + gguf_writer.add_context_length(ctx_length) + gguf_writer.add_embedding_length(hidden_size) + gguf_writer.add_block_count(block_count) + gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size) + gguf_writer.add_rope_dimension_count(hidden_size // head_count) + gguf_writer.add_head_count(head_count) + gguf_writer.add_head_count_kv(head_count_kv) + gguf_writer.add_rope_freq_base(hparams.rotary_emb_base) + gguf_writer.add_layer_norm_eps(hparams.layernorm_epsilon) + + tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir) + gguf_writer.add_tokenizer_model('llama') + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) + gguf_writer.add_bos_token_id(71013) + gguf_writer.add_eos_token_id(71013) + + tensor_map = gguf.get_tensor_name_map(arch, block_count) + print(tensor_map) + for name in tensors.keys(): + data = tensors[name] + if name.endswith(".self_attention.rotary_emb.inv_freq"): + continue + old_dtype = data.dtype + # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?) + data = data.to(torch.float32).squeeze().numpy() + new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) + if new_name is None: + print("Can not map tensor '" + name + "'") + sys.exit() + n_dims = len(data.shape) + print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + gguf_writer.add_tensor(new_name, data) + print("gguf: write header") + gguf_writer.write_header_to_file() + print("gguf: write metadata") + gguf_writer.write_kv_data_to_file() + print("gguf: write tensors") + gguf_writer.write_tensors_to_file() + + gguf_writer.close() + + print(f"gguf: model successfully exported to '{args.outfile}'") + print("") + + + +if __name__ == '__main__': + main() diff --git a/examples/jeopardy/README.md b/examples/jeopardy/README.md index 4c42e3cdbf526..ffa13cbf349b2 100644 --- a/examples/jeopardy/README.md +++ b/examples/jeopardy/README.md @@ -2,7 +2,7 @@ This is pretty much just a straight port of aigoopy/llm-jeopardy/ with an added graph viewer. -The jeopardy test can be used to compare the fact knowledge of different models and compare them to eachother. This is in contrast to some other tests, which test logical deduction, creativity, writing skills, etc. +The jeopardy test can be used to compare the fact knowledge of different models and compare them to each other. This is in contrast to some other tests, which test logical deduction, creativity, writing skills, etc. Step 1: Open jeopardy.sh and modify the following: diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index ffd7b1db4abdd..721888da7de94 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -10,6 +10,7 @@ #include #include #include +#include // trim whitespace from the beginning and end of a string static std::string trim(const std::string & str) { @@ -70,6 +71,26 @@ struct client { std::vector tokens_prev; }; +static void print_date_time() { + std::time_t current_time = std::time(nullptr); + std::tm* local_time = std::localtime(¤t_time); + char buffer[80]; + strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", local_time); + + printf("\n\033[35mrun parameters as at %s\033[0m\n", buffer); +} + +// Define a split string function to ... +static std::vector split_string(const std::string& input, char delimiter) { + std::vector tokens; + std::istringstream stream(input); + std::string token; + while (std::getline(stream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + int main(int argc, char ** argv) { srand(1234); @@ -104,6 +125,23 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); + // load the prompts from an external file if there are any + if (params.prompt.empty()) { + printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); + } else { + // Output each line of the input params.prompts vector and copy to k_prompts + int index = 0; + printf("\n\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str()); + + std::vector prompts = split_string(params.prompt, '\n'); + for (const auto& prompt : prompts) { + k_prompts.resize(index + 1); + k_prompts[index] = prompt; + index++; + printf("%3d prompt: %s\n", index, prompt.c_str()); + } + } + fprintf(stderr, "\n\n"); fflush(stderr); @@ -233,7 +271,7 @@ int main(int argc, char ** argv) { client.n_decoded = 0; client.i_batch = batch.n_tokens - 1; - LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); + LOG_TEE("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); g_seq_id += 1; @@ -336,8 +374,8 @@ int main(int argc, char ** argv) { const auto t_main_end = ggml_time_us(); - LOG_TEE("\033[1mClient %3d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\nResponse: %s\n\n", - client.id, client.seq_id, client.n_prompt, client.n_decoded, + LOG_TEE("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \nInput: %s\n\033[35mResponse: %s\033[0m\n\n", + client.id, client.seq_id, n_seq, client.n_prompt, client.n_decoded, (t_main_end - client.t_start_prompt) / 1e6, (double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6, n_cache_miss, @@ -357,13 +395,21 @@ int main(int argc, char ** argv) { const auto t_main_end = ggml_time_us(); - LOG_TEE("\n\n"); + print_date_time(); + + LOG_TEE("\n%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); + if (params.prompt_file.empty()) { + params.prompt_file = "used built-in defaults"; + } + LOG_TEE("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str()); + LOG_TEE("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str()); + LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6); LOG_TEE("Cache misses: %6d\n", n_cache_miss); - LOG_TEE("\n\n"); + LOG_TEE("\n"); llama_print_timings(ctx); diff --git a/examples/server/README.md b/examples/server/README.md index 9ee62d06aeab6..8a079ae261cfa 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -114,9 +114,9 @@ node index.js `top_k`: Limit the next token selection to the K most probable tokens (default: 40). - `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9). + `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.95). - `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: 128, -1 = infinity). + `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: -1, -1 = infinity). `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt. @@ -156,6 +156,8 @@ node index.js `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced (default: []). + `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) + - **POST** `/tokenize`: Tokenize a given text. *Options:* diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5f9cdecd548e2..c53a64867336f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -534,98 +534,20 @@ struct llama_server_context return result; } - // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(model) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - const int32_t n_probs = params.n_probs; - { - auto *logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(model); - - // Apply params.logit_bias map - for (const auto &it : params.logit_bias) - { - logits[it.first] += it.second; - } - + // out of user input, sample next token std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) - { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } + candidates.reserve(llama_n_vocab(model)); - llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; - - // Apply penalties - float nl_logit = logits[llama_token_nl(ctx)]; - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(ctx, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - if (!penalize_nl) - { - logits[llama_token_nl(ctx)] = nl_logit; - } + result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates); - if (grammar != nullptr) { - llama_sample_grammar(ctx, &candidates_p, grammar); - } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - if (temp <= 0) - { - // Greedy sampling - result.tok = llama_sample_token_greedy(ctx, &candidates_p); - if (n_probs > 0) - { - llama_sample_softmax(ctx, &candidates_p); - } - } - else + const int32_t n_probs = params.n_probs; + if (params.temp <= 0 && n_probs > 0) { - if (mirostat == 1) - { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temp(ctx, &candidates_p, temp); - result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } - else if (mirostat == 2) - { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temp(ctx, &candidates_p, temp); - result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } - else - { - // Temperature sampling - size_t min_keep = std::max(1, n_probs); - llama_sample_top_k(ctx, &candidates_p, top_k, min_keep); - llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); - llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); - llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); - llama_sample_temp(ctx, &candidates_p, temp); - result.tok = llama_sample_token(ctx, &candidates_p); - } - } - - if (grammar != nullptr) { - llama_grammar_accept_token(ctx, grammar, result.tok); + // For llama_sample_token_greedy we need to sort candidates + llama_sample_softmax(ctx, &candidates_p); } for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) diff --git a/ggml-metal.m b/ggml-metal.m index 866fed43442a8..f8fa05dd9f4a7 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -109,6 +109,8 @@ GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f16_f16); + GGML_METAL_DECL_KERNEL(concat); + GGML_METAL_DECL_KERNEL(sqr); #undef GGML_METAL_DECL_KERNEL }; @@ -183,56 +185,44 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); -#ifdef GGML_SWIFT - // load the default.metallib file + // load library { - NSError * error = nil; - - NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"]; - NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath]; - NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"]; - NSURL * libURL = [NSURL fileURLWithPath:libPath]; - - // Load the metallib file into a Metal library - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } + NSBundle * bundle = nil; +#ifdef SWIFT_PACKAGE + bundle = SWIFTPM_MODULE_BUNDLE; #else - UNUSED(msl_library_source); - - // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource - { + bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; +#endif NSError * error = nil; + NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"]; + if (libPath != nil) { + NSURL * libURL = [NSURL fileURLWithPath:libPath]; + GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); + ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + } else { + GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + + NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]); + NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } - //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; - NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; - GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]); - - NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - + MTLCompileOptions* options = nil; #ifdef GGML_QKK_64 - MTLCompileOptions* options = [MTLCompileOptions new]; - options.preprocessorMacros = @{ @"QK_K" : @(64) }; - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; -#else - ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; + options = [MTLCompileOptions new]; + options.preprocessorMacros = @{ @"QK_K" : @(64) }; #endif + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + } + if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } } -#endif // load kernels { @@ -300,6 +290,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f16_f16); + GGML_METAL_ADD_KERNEL(concat); + GGML_METAL_ADD_KERNEL(sqr); #undef GGML_METAL_ADD_KERNEL } @@ -375,6 +367,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f16_f16); + GGML_METAL_DEL_KERNEL(concat); + GGML_METAL_DEL_KERNEL(sqr); #undef GGML_METAL_DEL_KERNEL @@ -431,7 +425,7 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; - //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); + //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; @@ -766,6 +760,43 @@ void ggml_metal_graph_compute( { // noop } break; + case GGML_OP_CONCAT: + { + + int64_t nb = ne00; + [encoder setComputePipelineState:ctx->pipeline_concat]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + + const int nth = MIN(1024, ne0); + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ADD: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -903,6 +934,17 @@ void ggml_metal_graph_compute( GGML_ASSERT(false); } } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + [encoder setComputePipelineState:ctx->pipeline_sqr]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { const int nth = MIN(32, ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 5a860098f157c..9bd94e82bfe1e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -132,6 +132,13 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -1098,6 +1105,62 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i02 < ne02) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; + src0_ptr += ntg.x*nb00; + } else { + ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; + src1_ptr += ntg.x*nb10; + } + dst_ptr += ntg.x*nb0; + } +} + //============================================ k-quants ====================================================== #ifndef QK_K diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index a2c570d7ebf1b..fb677a6ed7283 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -85,6 +85,7 @@ class MODEL_ARCH(IntEnum): GPTNEOX : int = auto() MPT : int = auto() STARCODER : int = auto() + PERSIMMON : int = auto() REFACT : int = auto() BERT : int = auto() @@ -108,6 +109,8 @@ class MODEL_TENSOR(IntEnum): FFN_DOWN : int = auto() FFN_UP : int = auto() FFN_NORM : int = auto() + ATTN_Q_NORM : int = auto() + ATTN_K_NORM : int = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -119,6 +122,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GPTNEOX: "gptneox", MODEL_ARCH.MPT: "mpt", MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.PERSIMMON: "persimmon", MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", } @@ -130,7 +134,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ROPE_FREQS: "rope_freqs", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", @@ -139,6 +142,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", @@ -249,6 +254,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PERSIMMON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.REFACT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -279,6 +298,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.PERSIMMON: [ + MODEL_TENSOR.ROPE_FREQS, + ] } @@ -286,12 +308,13 @@ class TensorNameMap: mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { # Token embeddings MODEL_TENSOR.TOKEN_EMBD: ( - "gpt_neox.embed_in", # gptneox - "transformer.wte", # gpt2 gpt-j mpt refact - "transformer.word_embeddings", # falcon - "model.embed_tokens", # llama-hf - "tok_embeddings", # llama-pth - "embeddings.word_embeddings", # bert + "gpt_neox.embed_in", # gptneox + "transformer.wte", # gpt2 gpt-j mpt refact + "transformer.word_embeddings", # falcon + "model.embed_tokens", # llama-hf + "tok_embeddings", # llama-pth + "embeddings.word_embeddings", # bert + "language_model.embedding.word_embeddings", # persimmon ), # Token type embeddings @@ -307,20 +330,22 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( - "embed_out", # gptneox - "lm_head", # gpt2 gpt-j mpt falcon llama-hf baichuan - "output", # llama-pth + "embed_out", # gptneox + "lm_head", # gpt2 mpt falcon llama-hf baichuan + "output", # llama-pth + "word_embeddings_for_head", # persimmon ), # Output norm MODEL_TENSOR.OUTPUT_NORM: ( - "gpt_neox.final_layer_norm", # gptneox - "transformer.ln_f", # gpt2 gpt-j falcon - "model.norm", # llama-hf baichuan - "norm", # llama-pth - "embeddings.LayerNorm", # bert - "transformer.norm_f", # mpt - "ln_f", # refact + "gpt_neox.final_layer_norm", # gptneox + "transformer.ln_f", # gpt2 gpt-j falcon + "model.norm", # llama-hf baichuan + "norm", # llama-pth + "embeddings.LayerNorm", # bert + "transformer.norm_f", # mpt + "ln_f", # refact + "language_model.encoder.final_layernorm", # persimmon ), # Rope frequencies @@ -332,14 +357,15 @@ class TensorNameMap: block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { # Attention norm MODEL_TENSOR.ATTN_NORM: ( - "gpt_neox.layers.{bid}.input_layernorm", # gptneox - "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact - "transformer.blocks.{bid}.norm_1", # mpt - "transformer.h.{bid}.input_layernorm", # falcon7b - "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf - "layers.{bid}.attention_norm", # llama-pth - "encoder.layer.{bid}.attention.output.LayerNorm", # bert + "gpt_neox.layers.{bid}.input_layernorm", # gptneox + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact + "transformer.blocks.{bid}.norm_1", # mpt + "transformer.h.{bid}.input_layernorm", # falcon7b + "transformer.h.{bid}.ln_mlp", # falcon40b + "model.layers.{bid}.input_layernorm", # llama-hf + "layers.{bid}.attention_norm", # llama-pth + "encoder.layer.{bid}.attention.output.LayerNorm", # bert + "language_model.encoder.layers.{bid}.input_layernorm", # persimmon ), # Attention norm 2 @@ -349,10 +375,11 @@ class TensorNameMap: # Attention query-key-value MODEL_TENSOR.ATTN_QKV: ( - "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox - "transformer.h.{bid}.attn.c_attn", # gpt2 - "transformer.blocks.{bid}.attn.Wqkv", # mpt - "transformer.h.{bid}.self_attention.query_key_value", # falcon + "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox + "transformer.h.{bid}.attn.c_attn", # gpt2 + "transformer.blocks.{bid}.attn.Wqkv", # mpt + "transformer.h.{bid}.self_attention.query_key_value", # falcon + "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon ), # Attention query @@ -381,14 +408,15 @@ class TensorNameMap: # Attention output MODEL_TENSOR.ATTN_OUT: ( - "gpt_neox.layers.{bid}.attention.dense", # gptneox - "transformer.h.{bid}.attn.c_proj", # gpt2 refact - "transformer.blocks.{bid}.attn.out_proj", # mpt - "transformer.h.{bid}.self_attention.dense", # falcon - "model.layers.{bid}.self_attn.o_proj", # llama-hf - "layers.{bid}.attention.wo", # llama-pth - "encoder.layer.{bid}.attention.output.dense", # bert - "transformer.h.{bid}.attn.out_proj", # gpt-j + "gpt_neox.layers.{bid}.attention.dense", # gptneox + "transformer.h.{bid}.attn.c_proj", # gpt2 refact + "transformer.blocks.{bid}.attn.out_proj", # mpt + "transformer.h.{bid}.self_attention.dense", # falcon + "model.layers.{bid}.self_attn.o_proj", # llama-hf + "layers.{bid}.attention.wo", # llama-pth + "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.h.{bid}.attn.out_proj", # gpt-j + "language_model.encoder.layers.{bid}.self_attention.dense" # persimmon ), # Rotary embeddings @@ -399,24 +427,26 @@ class TensorNameMap: # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( - "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox - "transformer.h.{bid}.ln_2", # gpt2 refact - "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf - "layers.{bid}.ffn_norm", # llama-pth - "encoder.layer.{bid}.output.LayerNorm", # bert + "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox + "transformer.h.{bid}.ln_2", # gpt2 refact + "transformer.blocks.{bid}.norm_2", # mpt + "model.layers.{bid}.post_attention_layernorm", # llama-hf + "layers.{bid}.ffn_norm", # llama-pth + "encoder.layer.{bid}.output.LayerNorm", # bert + "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon ), # Feed-forward up MODEL_TENSOR.FFN_UP: ( - "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox - "transformer.h.{bid}.mlp.c_fc", # gpt2 - "transformer.blocks.{bid}.ffn.up_proj", # mpt - "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon - "model.layers.{bid}.mlp.up_proj", # llama-hf refact - "layers.{bid}.feed_forward.w3", # llama-pth - "encoder.layer.{bid}.intermediate.dense", # bert - "transformer.h.{bid}.mlp.fc_in", # gpt-j + "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox + "transformer.h.{bid}.mlp.c_fc", # gpt2 + "transformer.blocks.{bid}.ffn.up_proj", # mpt + "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon + "model.layers.{bid}.mlp.up_proj", # llama-hf refact + "layers.{bid}.feed_forward.w3", # llama-pth + "encoder.layer.{bid}.intermediate.dense", # bert + "transformer.h.{bid}.mlp.fc_in", # gpt-j + "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon ), # Feed-forward gate @@ -427,15 +457,28 @@ class TensorNameMap: # Feed-forward down MODEL_TENSOR.FFN_DOWN: ( - "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox - "transformer.h.{bid}.mlp.c_proj", # gpt2 refact - "transformer.blocks.{bid}.ffn.down_proj", # mpt - "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon - "model.layers.{bid}.mlp.down_proj", # llama-hf - "layers.{bid}.feed_forward.w2", # llama-pth - "encoder.layer.{bid}.output.dense", # bert - "transformer.h.{bid}.mlp.fc_out", # gpt-j + "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox + "transformer.h.{bid}.mlp.c_proj", # gpt2 refact + "transformer.blocks.{bid}.ffn.down_proj", # mpt + "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon + "model.layers.{bid}.mlp.down_proj", # llama-hf + "layers.{bid}.feed_forward.w2", # llama-pth + "encoder.layer.{bid}.output.dense", # bert + "transformer.h.{bid}.mlp.fc_out", # gpt-j + "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon + ), + + MODEL_TENSOR.ATTN_Q_NORM: ( + "language_model.encoder.layers.{bid}.self_attention.q_layernorm", ), + + MODEL_TENSOR.ATTN_K_NORM: ( + "language_model.encoder.layers.{bid}.self_attention.k_layernorm", + ), + + MODEL_TENSOR.ROPE_FREQS: ( + "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon + ) } mapping: dict[str, tuple[MODEL_TENSOR, str]] diff --git a/llama.cpp b/llama.cpp index 40e4ac586e7f3..2e868cfd79c8d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -125,6 +125,27 @@ static void replace_all(std::string & s, const std::string & search, const std:: } s = std::move(result); } + +static bool is_float_close(float a, float b, float abs_tol) { + // Check for non-negative tolerance + if (abs_tol < 0.0) { + throw std::invalid_argument("Tolerance must be non-negative"); + } + + // Exact equality check + if (a == b) { + return true; + } + + // Check for infinities + if (std::isinf(a) || std::isinf(b)) { + return false; + } + + // Regular comparison using the provided absolute tolerance + return std::fabs(b - a) <= abs_tol; +} + #ifdef GGML_USE_CPU_HBM #include #endif @@ -165,6 +186,7 @@ enum llm_arch { LLM_ARCH_GPTNEOX, LLM_ARCH_MPT, LLM_ARCH_STARCODER, + LLM_ARCH_PERSIMMON, LLM_ARCH_REFACT, LLM_ARCH_UNKNOWN, }; @@ -178,6 +200,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_MPT, "mpt" }, { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_PERSIMMON, "persimmon" }, { LLM_ARCH_REFACT, "refact" }, }; @@ -297,6 +320,8 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_NORM, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, }; static std::map> LLM_TENSOR_NAMES = { @@ -378,6 +403,23 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PERSIMMON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd"}, + { LLM_TENSOR_OUTPUT_NORM, "output_norm"}, + { LLM_TENSOR_OUTPUT, "output"}, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"}, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv"}, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"}, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"}, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"}, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"}, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd"}, + }, + }, { LLM_ARCH_MPT, { @@ -938,6 +980,7 @@ enum e_model { MODEL_1B, MODEL_3B, MODEL_7B, + MODEL_8B, MODEL_13B, MODEL_15B, MODEL_30B, @@ -969,7 +1012,24 @@ struct llama_hparams { float rope_freq_scale_train; bool operator!=(const llama_hparams & other) const { - return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + if (this->vocab_only != other.vocab_only) return true; + if (this->n_vocab != other.n_vocab) return true; + if (this->n_ctx_train != other.n_ctx_train) return true; + if (this->n_embd != other.n_embd) return true; + if (this->n_head != other.n_head) return true; + if (this->n_head_kv != other.n_head_kv) return true; + if (this->n_layer != other.n_layer) return true; + if (this->n_rot != other.n_rot) return true; + if (this->n_ff != other.n_ff) return true; + + const float EPSILON = 1e-9; + + if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; + if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; + if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + + return false; } uint32_t n_gqa() const { @@ -1003,6 +1063,10 @@ struct llama_layer { struct ggml_tensor * attn_norm_b; struct ggml_tensor * attn_norm_2; struct ggml_tensor * attn_norm_2_b; + struct ggml_tensor * attn_q_norm; + struct ggml_tensor * attn_q_norm_b; + struct ggml_tensor * attn_k_norm; + struct ggml_tensor * attn_k_norm_b; // attention struct ggml_tensor * wq; @@ -1044,6 +1108,9 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_internal also uses it, so it + // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; @@ -1301,6 +1368,8 @@ static bool llama_kv_cache_init( // find an empty slot of size "n_tokens" in the cache // updates the cache head +// Note: On success, it's important that cache.head points +// to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_batch & batch) { @@ -1316,8 +1385,8 @@ static bool llama_kv_cache_find_slot( while (true) { if (cache.head + n_tokens > n_ctx) { + n_tested += n_ctx - cache.head; cache.head = 0; - n_tested += n_ctx - cache.head; continue; } @@ -1368,6 +1437,9 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } + + // Searching for a free slot can start here since we know it will be empty. + cache.head = uint32_t(c0); } static void llama_kv_cache_seq_rm( @@ -1375,6 +1447,8 @@ static void llama_kv_cache_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = cache.size; + if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); @@ -1383,9 +1457,13 @@ static void llama_kv_cache_seq_rm( cache.cells[i].seq_id.erase(seq_id); if (cache.cells[i].seq_id.empty()) { cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; } } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; } static void llama_kv_cache_seq_cp( @@ -1397,6 +1475,8 @@ static void llama_kv_cache_seq_cp( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + cache.head = 0; + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].seq_id.insert(seq_id_dst); @@ -1405,12 +1485,18 @@ static void llama_kv_cache_seq_cp( } static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + uint32_t new_head = cache.size; + for (uint32_t i = 0; i < cache.size; ++i) { if (!cache.cells[i].has_seq_id(seq_id)) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; } static void llama_kv_cache_seq_shift( @@ -1419,6 +1505,8 @@ static void llama_kv_cache_seq_shift( llama_pos p0, llama_pos p1, llama_pos delta) { + uint32_t new_head = cache.size; + if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); @@ -1428,12 +1516,17 @@ static void llama_kv_cache_seq_shift( if (cache.cells[i].pos < 0) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; } else { cache.has_shift = true; cache.cells[i].delta = delta; } } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.head = new_head != cache.size ? new_head : 0; } // @@ -1834,6 +1927,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_1B: return "1B"; case MODEL_3B: return "3B"; case MODEL_7B: return "7B"; + case MODEL_8B: return "8B"; case MODEL_13B: return "13B"; case MODEL_15B: return "15B"; case MODEL_30B: return "30B"; @@ -1946,6 +2040,14 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_PERSIMMON: + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + switch (hparams.n_layer) { + case 36: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } case LLM_ARCH_REFACT: { GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); @@ -2482,6 +2584,67 @@ static void llm_load_tensors( } } } break; + case LLM_ARCH_PERSIMMON: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + { + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + const int i_gpu_start = n_layer - n_gpu_layers; + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; + auto & layer = model.layers[i]; + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend); + layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend); + layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend); + layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -2591,8 +2754,8 @@ static bool llama_model_load( } static struct ggml_cgraph * llm_build_llama( - llama_context & lctx, - const llama_batch & batch) { + llama_context & lctx, + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -2630,11 +2793,9 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -3018,11 +3179,9 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -3419,11 +3578,9 @@ static struct ggml_cgraph * llm_build_refact( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -3773,11 +3930,9 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -4133,11 +4288,9 @@ static struct ggml_cgraph * llm_build_starcoder( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -4348,6 +4501,404 @@ static struct ggml_cgraph * llm_build_starcoder( return gf; } + +static struct ggml_cgraph * llm_build_persimmon( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const auto & cparams = lctx.cparams; + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_head = hparams.n_head; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const size_t n_rot = n_embd_head / 2; + + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; + const float norm_eps = hparams.f_norm_eps; + + const int n_gpu_layers = model.n_gpu_layers; + + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + auto & buf_compute = lctx.buf_compute; + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + } + } + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + // we rotate only the first n_rot dimensions. + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_rot, n_head, n_ctx, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il) + ), + K_shift, n_rot, 2, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + } + for (int il=0; il < n_layer; ++il) { + struct ggml_tensor * residual = inpL; + offload_func_t offload_func = llama_nop; + { + cur = ggml_norm(ctx0, inpL, norm_eps); + offload_func(cur); + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); + offload_func(cur); + cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b); + offload_func(cur); + ggml_format_name(cur, "input_layernorm_%d", il); + } + // self attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + offload_func_kq(cur); + + // split qkv + GGML_ASSERT(n_head_kv == n_head); + ggml_set_name(cur, format("qkv_%d", il).c_str()); + struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens); + offload_func_kq(tmpqkv); + struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); + offload_func_kq(tmpqkv_perm); + ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il); + struct ggml_tensor * tmpq = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + 0 + ); + offload_func_kq(tmpq); + struct ggml_tensor * tmpk = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens + ); + offload_func_kq(tmpk); + // Q/K Layernorm + tmpq = ggml_norm(ctx0, tmpq, norm_eps); + offload_func_kq(tmpq); + tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); + offload_func_kq(tmpq); + tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); + offload_func_kq(tmpq); + + tmpk = ggml_norm(ctx0, tmpk, norm_eps); + offload_func_v(tmpk); + tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); + offload_func_v(tmpk); + tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); + offload_func_v(tmpk); + + // RoPE the first n_rot of q/k, pass the other half, and concat. + struct ggml_tensor * qrot = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + 0 + ); + offload_func_kq(qrot); + ggml_format_name(qrot, "qrot_%d", il); + struct ggml_tensor * krot = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + 0 + ); + offload_func_kq(krot); + ggml_format_name(krot, "krot_%d", il); + + // get the second half of tmpq, e.g tmpq[n_rot:, :, :] + struct ggml_tensor * qpass = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + ggml_element_size(tmpq) * n_rot + ); + offload_func_kq(qpass); + ggml_format_name(qpass, "qpass_%d", il); + struct ggml_tensor * kpass = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + ggml_element_size(tmpk) * n_rot + ); + offload_func_kq(kpass); + ggml_format_name(kpass, "kpass_%d", il); + + struct ggml_tensor * qrotated = ggml_rope_custom( + ctx0, qrot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale + ); + offload_func_kq(qrotated); + struct ggml_tensor * krotated = ggml_rope_custom( + ctx0, krot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale + ); + offload_func_kq(krotated); + // ggml currently only supports concatenation on dim=2 + // so we need to permute qrot, qpass, concat, then permute back. + qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); + offload_func_kq(qrotated); + krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); + offload_func_kq(krotated); + + qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); + offload_func_kq(qpass); + kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); + offload_func_kq(kpass); + + struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); + offload_func_kq(Qcur); + struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); + offload_func_kq(Kcur); + + struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3)); + offload_func_kq(Q); + + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); + offload_func_kq(Kcur); + { + struct ggml_tensor * tmpv = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 + ); + offload_func_v(tmpv); + // store K, V in cache + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); + offload_func_v(Vcur); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d( + ctx0, kv_self.k, n_tokens*n_embd_gqa, + (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head) + ); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + offload_func_v(v); + ggml_set_name(v, "v"); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + + offload_func_kq(K); + ggml_format_name(K, "K_%d", il); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + offload_func_kq(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); + cur = ggml_add(ctx0, cur, model.layers[il].bo); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + struct ggml_tensor * inpFF = ggml_add(ctx0, residual, cur); + offload_func(inpFF); + ggml_set_name(inpFF, "inpFF"); + { + // MLP + { + // Norm + cur = ggml_norm(ctx0, inpFF, norm_eps); + offload_func(cur); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].ffn_norm), + model.layers[il].ffn_norm_b + ); + ggml_set_name(cur, "ffn_norm"); + offload_func(cur); + } + cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); + offload_func(cur); + + cur = ggml_add(ctx0, cur, model.layers[il].b3); + offload_func(cur); + ggml_set_name(cur, "result_ffn_up"); + + cur = ggml_sqr(ctx0, ggml_relu(ctx0, cur)); + ggml_set_name(cur, "result_ffn_act"); + offload_func(cur); + offload_func(cur->src[0]); + + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); + cur = ggml_add(ctx0, + cur, + model.layers[il].b2); + offload_func(cur); + ggml_set_name(cur, "outFF"); + } + cur = ggml_add(ctx0, cur, inpFF); + offload_func(cur); + ggml_set_name(cur, "inpFF_+_outFF"); + inpL = cur; + } + cur = inpL; + { + cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); + cur = ggml_mul(ctx0, cur, model.output_norm); + offload_func_nr(cur); + + cur = ggml_add(ctx0, cur, model.output_norm_b); + // offload_func_nr(cur); + + ggml_set_name(cur, "result_norm"); + } + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); + return gf; +} + static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch) { @@ -4372,6 +4923,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_starcoder(lctx, batch); } break; + case LLM_ARCH_PERSIMMON: + { + result = llm_build_persimmon(lctx, batch); + } case LLM_ARCH_REFACT: { result = llm_build_refact(lctx, batch); @@ -4454,10 +5009,6 @@ static int llama_decode_internal( batch.seq_id = seq_id.data(); } - // we always start to search for a free slot from the start of the cache - // TODO: better strategies can be implemented - kv_self.head = 0; - if (!llama_kv_cache_find_slot(kv_self, batch)) { return 1; } @@ -4543,8 +5094,12 @@ static int llama_decode_internal( #endif // update the kv ring buffer - lctx.kv_self.head += n_tokens; lctx.kv_self.has_shift = false; + lctx.kv_self.head += n_tokens; + // Ensure kv cache head points to a valid index. + if (lctx.kv_self.head >= lctx.kv_self.size) { + lctx.kv_self.head = 0; + } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -6639,6 +7194,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } std::ofstream fout(fname_out, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors const size_t meta_size = gguf_get_meta_size(ctx_out); @@ -8166,8 +8722,9 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch buf[0] = llama_token_to_byte(model->vocab, token); return 1; } else { - // let's not crash the program on unknown tokens - //GGML_ASSERT(false); + // TODO: for now we accept all unsupported token types, + // suppressing them like CONTROL tokens. + // GGML_ASSERT(false); } break; } @@ -8183,15 +8740,14 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch } else if (llama_is_control_token(model->vocab, token)) { ; } else { - // let's not crash the program on unknown tokens - //GGML_ASSERT(false); + // TODO: for now we accept all unsupported token types, + // suppressing them like CONTROL tokens. + // GGML_ASSERT(false); } break; } default: - // let's not crash the program on unknown tokens - //GGML_ASSERT(false); - break; + GGML_ASSERT(false); } } return 0; @@ -8218,14 +8774,14 @@ void llama_print_timings(struct llama_context * ctx) { const llama_timings timings = llama_get_timings(ctx); LLAMA_LOG_INFO("\n"); - LLAMA_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); + LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); - LLAMA_LOG_INFO("%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); - LLAMA_LOG_INFO("%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); - LLAMA_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); + LLAMA_LOG_INFO("%s: total time = %10.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); } void llama_reset_timings(struct llama_context * ctx) { diff --git a/prompts/LLM-questions.txt b/prompts/LLM-questions.txt new file mode 100644 index 0000000000000..fdf3d52f4416a --- /dev/null +++ b/prompts/LLM-questions.txt @@ -0,0 +1,49 @@ +In the context of LLMs, what is "Attention"? +In the context of LLMs, what is a completion? +In the context of LLMs, what is a prompt? +In the context of LLMs, what is GELU? +In the context of LLMs, what is RELU? +In the context of LLMs, what is softmax? +In the context of LLMs, what is decoding? +In the context of LLMs, what is encoding? +In the context of LLMs, what is tokenizing? +In the context of LLMs, what is an embedding? +In the context of LLMs, what is quantization? +In the context of LLMs, what is a tensor? +In the context of LLMs, what is a sparse tensor? +In the context of LLMs, what is a vector? +In the context of LLMs, how is attention implemented? +In the context of LLMs, why is attention all you need? +In the context of LLMs, what is "RoPe" and what is it used for? +In the context of LLMs, what is "LoRA" and what is it used for? +In the context of LLMs, what are weights? +In the context of LLMs, what are biases? +In the context of LLMs, what are checkpoints? +In the context of LLMs, what is "perplexity"? +In the context of LLMs, what are models? +In the context of machine-learning, what is "catastrophic forgetting"? +In the context of machine-learning, what is "elastic weight consolidation (EWC)"? +In the context of neural nets, what is a hidden layer? +In the context of neural nets, what is a convolution? +In the context of neural nets, what is dropout? +In the context of neural nets, what is cross-entropy? +In the context of neural nets, what is over-fitting? +In the context of neural nets, what is under-fitting? +What is the difference between an interpreted computer language and a compiled computer language? +In the context of software development, what is a debugger? +When processing using a GPU, what is off-loading? +When processing using a GPU, what is a batch? +When processing using a GPU, what is a block? +When processing using a GPU, what is the difference between a batch and a block? +When processing using a GPU, what is a scratch tensor? +When processing using a GPU, what is a layer? +When processing using a GPU, what is a cache? +When processing using a GPU, what is unified memory? +When processing using a GPU, what is VRAM? +When processing using a GPU, what is a kernel? +When processing using a GPU, what is "metal"? +In the context of LLMs, what are "Zero-Shot", "One-Shot" and "Few-Shot" learning models? +In the context of LLMs, what is the "Transformer-model" architecture? +In the context of LLMs, what is "Multi-Head Attention"? +In the context of LLMs, what is "Self-Attention"? +In the context of transformer-model architectures, how do attention mechanisms use masks? \ No newline at end of file diff --git a/prompts/parallel-questions.txt b/prompts/parallel-questions.txt new file mode 100644 index 0000000000000..c9fc7b8b48418 --- /dev/null +++ b/prompts/parallel-questions.txt @@ -0,0 +1,43 @@ +What do you know about Hobbits? +What is quantum field theory? +Why did the chicken cross the road? +Who is the president of the United States? +How do I run CMake on MacOS? +Do you agree that C++ is a really finicky language compared with Python3? +Is it a good idea to invest in technology? +Do you like Wagner's Ring? +Do you think this file input option is really neat? +What should we all do about climate change? +Is time-travel possible within the laws of current physics? +Is it like anything to be a bat? +Once the chicken has crossed the road, does it try to go back? +Who is the greatest of all musical composers? +What is art? +Is there life elsewhere in the universe? +What is intelligence? +What is the difference between knowledge and intelligence? +Will religion ever die? +Do we understand ourselves? +What is the best way to cook eggs? +If you cannot see things, on what basis do you evaluate them? +Explain the role of the np junction in photovoltaic cells? +Is professional sport a good or bad influence on human behaviour? +Is capital punishment immoral? +Should we care about other people? +Who are you? +Which sense would you surrender if you could? +Was Henry Ford a hero or a villain? +Do we need leaders? +What is nucleosynthesis? +Who is the greatest scientist of all time? +Who first observed what came to be known as the photovoltaic effect? +What is nuclear fusion and why does it release energy? +Can you know that you exist? +What is an exoplanet? +Do you like cream? +What is the difference? +Can I know that I exist while I'm dreaming that I'm Descartes? +Who said "I didn't know I thought that until I heard myself saying it"? +Does anything really matter? +Can you explain the unreasonable effectiveness of mathematics? + diff --git a/requirements.txt b/requirements.txt index 7dc51edb14395..81c909d0ba7fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -numpy==1.24 +numpy==1.24.4 sentencepiece==0.1.98 gguf>=0.1.0