diff --git a/README.md b/README.md index 1d4e9d41742b7..72b5a2c881961 100644 --- a/README.md +++ b/README.md @@ -1117,7 +1117,9 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m - Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` - See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions - Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices -- Matrix multiplication is unconventional: [`z = ggml_mul_mat(ctx, x, y)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means `zT = x @ yT` +- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ + +![matmul](media/matmul.png) ### Docs diff --git a/common/common.cpp b/common/common.cpp index 6b41d1bfdcb45..02367a914c90e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } + // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context. params.seed = std::stoul(argv[i]); + sparams.seed = std::stoul(argv[i]); return true; } if (arg == "-t" || arg == "--threads") { @@ -2359,12 +2361,12 @@ std::vector llama_tokenize( return result; } -std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { +std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), true); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), true); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); diff --git a/common/common.h b/common/common.h index 78a1a34021b60..91f6b924bef81 100644 --- a/common/common.h +++ b/common/common.h @@ -86,8 +86,8 @@ struct gpt_params { ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; - llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; - llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings // // sampling parameters struct llama_sampling_params sparams; @@ -238,11 +238,12 @@ std::vector llama_tokenize( bool add_special, bool parse_special = false); -// tokenizes a token into a piece +// tokenizes a token into a piece, optionally renders special/control tokens // should work similar to Python's `tokenizer.id_to_piece` std::string llama_token_to_piece( const struct llama_context * ctx, - llama_token token); + llama_token token, + bool special = true); // TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function // that takes into account the tokenizer type and decides how to handle the leading space diff --git a/common/sampling.cpp b/common/sampling.cpp index c5b93d9932ed7..289c0215bf541 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,4 +1,6 @@ +#define LLAMA_API_INTERNAL #include "sampling.h" +#include struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + llama_sampling_set_rng_seed(result, params.seed); + return result; } @@ -66,6 +70,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->cur.clear(); } +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); @@ -207,7 +218,7 @@ static llama_token llama_sampling_sample_impl( sampler_queue(ctx_main, params, cur_p, min_keep); - id = llama_sample_token(ctx_main, &cur_p); + id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); //{ // const int n_top = 10; diff --git a/common/sampling.h b/common/sampling.h index 85bbd8d00b587..7de531cdf71e1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -4,9 +4,10 @@ #include "grammar-parser.h" +#include #include -#include #include +#include // sampler types enum class llama_sampler_type : char { @@ -20,25 +21,26 @@ enum class llama_sampler_type : char { // sampling parameters typedef struct llama_sampling_params { - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f float dry_base = 1.75f; int dry_allowed_length = 2; @@ -83,6 +85,8 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; + + std::mt19937 rng; }; #include "common.h" @@ -100,6 +104,9 @@ void llama_sampling_reset_grammar(struct llama_sampling_context * ctx); // - reset grammar void llama_sampling_reset(llama_sampling_context * ctx); +// Set the sampler seed +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); + // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 4fd916cba3ed7..5763b6664e832 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -363,6 +363,16 @@ def _set_vocab_sentencepiece(self): scores.append(-1000.0) toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + print( + f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]" + ) + for i in range(1, pad_count + 1): + tokens.append(f"[PAD{i}]") + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + assert len(tokens) == vocab_size self.gguf_writer.add_tokenizer_model("llama") @@ -1789,6 +1799,12 @@ def write_tensors(self): class Qwen2Model(Model): model_arch = gguf.MODEL_ARCH.QWEN2 + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): @@ -1979,6 +1995,91 @@ def set_gguf_parameters(self): self.gguf_writer.add_add_bos_token(False) +@Model.register("Phi3ForCausalLM") +class Phi3MiniModel(Model): + model_arch = gguf.MODEL_ARCH.PHI3 + + def set_vocab(self): + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + if not tokenizer_path.is_file(): + print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + sys.exit(1) + + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + + for key in added_tokens_json: + token_id = added_tokens_json[key] + if (token_id >= vocab_size): + print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) + + rot_pct = 1.0 + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + rms_eps = self.find_hparam(["rms_norm_eps"]) + + self.gguf_writer.add_name("Phi3") + self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"])) + + self.gguf_writer.add_embedding_length(n_embd) + self.gguf_writer.add_feed_forward_length(8192) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_layer_norm_rms_eps(rms_eps) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_file_type(self.ftype) + + @Model.register("PlamoForCausalLM") class PlamoModel(Model): model_arch = gguf.MODEL_ARCH.PLAMO diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 41b62c2fe9f76..87ecc0a4f1394 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -30,7 +30,6 @@ int main(int argc, char ** argv){ // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_set_rng_seed(ctx, params.seed); GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 9526e898fe763..eebbd00a58e66 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -38,7 +38,6 @@ int main(int argc, char ** argv){ // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_set_rng_seed(ctx, params.seed); GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ef803d5cecd9f..dc13e91a3d8d9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -240,7 +240,6 @@ int main(int argc, char ** argv) { return 1; } session_tokens.resize(n_token_count_out); - llama_set_rng_seed(ctx, params.seed); LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size()); } } diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 9fe61eb1b1b13..2961999f2451a 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -881,11 +881,11 @@ .replace(/&/g, '&') .replace(//g, '>') - .replace(/^#{1,6} (.*)$/gim, '

$1

') - .replace(/\*\*(.*?)\*\*/g, '$1') - .replace(/__(.*?)__/g, '$1') - .replace(/\*(.*?)\*/g, '$1') - .replace(/_(.*?)_/g, '$1') + .replace(/(^|\n)#{1,6} ([^\n]*)(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1

$2

') + .replace(/\*\*(.*?)\*\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') + .replace(/__(.*?)__(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') + .replace(/\*(.*?)\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') + .replace(/_(.*?)_(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1') .replace(/```.*?\n([\s\S]*?)```/g, '
$1
') .replace(/`(.*?)`/g, '$1') .replace(/\n/gim, '
'); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 048c680084176..cc7faf6cdd7f8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -854,7 +854,7 @@ struct server_context { slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.params.seed = json_value(data, "seed", default_params.seed); + slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); @@ -1028,7 +1028,6 @@ struct server_context { send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } - llama_set_rng_seed(ctx, slot.params.seed); } slot.command = SLOT_COMMAND_LOAD_PROMPT; @@ -1118,7 +1117,7 @@ struct server_context { bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = llama_token_to_piece(ctx, result.tok); + const std::string token_str = llama_token_to_piece(ctx, result.tok, false); slot.sampled = result.tok; // search stop word and delete it diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature new file mode 100644 index 0000000000000..f17120f7b9fbc --- /dev/null +++ b/examples/server/tests/features/results.feature @@ -0,0 +1,57 @@ +@llama.cpp +@results +Feature: Results + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models + And a model file test-model-00001-of-00003.gguf + And 128 as batch size + And 256 KV cache size + And 128 max tokens to predict + + Scenario Outline: Multi users completion + Given slots + And continuous batching + Then the server is starting + Then the server is healthy + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given 42 as seed + And a prompt: + """ + Write a very long story about AI. + """ + + Given concurrent completion requests + Then the server is busy + Then the server is idle + And all slots are idle + Then all predictions are equal + Examples: + | n_slots | + | 1 | + | 2 | diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index ca400efa41b9e..f71e0d706cca9 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port): context.server_metrics = False context.server_process = None context.seed = None + context.draft = None context.server_seed = None context.user_api_key = None context.response_format = None @@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl): context.n_gpu_layer = ngl +@step('{draft:d} as draft') +def step_draft(context, draft): + context.draft = draft + + @step('{n_ctx:d} KV cache size') def step_n_ctx(context, n_ctx): context.n_ctx = n_ctx @@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n): assert_n_tokens_predicted(context.completion, predicted_n) +@step('all predictions are equal') +@async_run_until_complete +async def step_predictions_equal(context): + n_completions = await gather_tasks_results(context) + assert n_completions >= 2, "need at least 2 completions" + assert_all_predictions_equal(context.tasks_result) + context.tasks_result = [] + + @step('the completion is truncated') def step_assert_completion_truncated(context): step_assert_completion_truncated(context, '') @@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' f' {n_predicted} <> {expected_predicted_n}') +def assert_all_predictions_equal(completion_responses): + content_0 = completion_responses[0]['content'] + + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': + print(f"content 0: {content_0}") + + i = 1 + for response in completion_responses[1:]: + content = response['content'] + + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': + print(f"content {i}: {content}") + + assert content == content_0, "contents not equal" + + i += 1 + async def gather_tasks_results(context): n_tasks = len(context.concurrent_tasks) @@ -1148,6 +1180,8 @@ def start_server_background(context): server_args.extend(['--ubatch-size', context.n_ubatch]) if context.n_gpu_layer: server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) + if context.draft is not None: + server_args.extend(['--draft', context.draft]) if context.server_continuous_batching: server_args.append('--cont-batching') if context.server_embeddings: diff --git a/ggml-impl.h b/ggml-impl.h index 0c997d3ed521f..2ffacc299e4f1 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -45,7 +45,7 @@ extern "C" { // 16-bit float // on Arm, we use __fp16 // on x86, we use uint16_t -#if defined(__ARM_NEON) && !defined(_MSC_VER) +#if defined(__ARM_NEON) // if YCM cannot find , make a symbolic link to it, for example: // @@ -53,8 +53,262 @@ extern "C" { // #include +#ifdef _MSC_VER + +typedef uint16_t ggml_fp16_internal_t; + +#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + typedef __fp16 ggml_fp16_internal_t; +#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif // _MSC_VER + +#if !defined(__aarch64__) + +// 32-bit ARM compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vpaddq_s32 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +// NOTE: not tested +inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + uint8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 +#define ggml_vqtbl1q_s8 vqtbl1q_s8 +#define ggml_vqtbl1q_u8 vqtbl1q_u8 + +#endif // !defined(__aarch64__) + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + +#endif // !defined(__ARM_FEATURE_DOTPROD) + +#endif // defined(__ARM_NEON) + +#if defined(__ARM_NEON) && !defined(__MSC_VER) + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) @@ -75,8 +329,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #else -typedef uint16_t ggml_fp16_internal_t; - #ifdef __wasm_simd128__ #include #else @@ -221,7 +473,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #endif // __F16C__ -#endif // __ARM_NEON +#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) // precomputed f32 table for f16 (256 KB) // defined in ggml.c, initialized in ggml_init() diff --git a/ggml-quants.c b/ggml-quants.c index 32360a1f1be50..11e11c2196c2c 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -14,41 +14,6 @@ #include // for qsort #include // for GGML_ASSERT -#ifdef __ARM_NEON - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#else - -#ifdef __wasm_simd128__ -#include -#else -#if defined(__POWER9_VECTOR__) || defined(__powerpc64__) -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - #undef MIN #undef MAX @@ -276,258 +241,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // __AVX__ || __AVX2__ || __AVX512F__ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -#if defined(__ARM_NEON) - -#ifdef _MSC_VER - -#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } - -#else - -#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } - -#endif - -#if !defined(__aarch64__) - -// 64-bit compatibility - -// vaddvq_s16 -// vpaddq_s16 -// vpaddq_s32 -// vaddvq_s32 -// vaddvq_f32 -// vmaxvq_f32 -// vcvtnq_s32_f32 -// vzip1_u8 -// vzip2_u8 - -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { - int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); - int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); - return vcombine_s32(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -inline static float vmaxvq_f32(float32x4_t v) { - return - MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - -inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { - int32x4_t res; - - res[0] = roundf(vgetq_lane_f32(v, 0)); - res[1] = roundf(vgetq_lane_f32(v, 1)); - res[2] = roundf(vgetq_lane_f32(v, 2)); - res[3] = roundf(vgetq_lane_f32(v, 3)); - - return res; -} - -inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[0]; res[1] = b[0]; - res[2] = a[1]; res[3] = b[1]; - res[4] = a[2]; res[5] = b[2]; - res[6] = a[3]; res[7] = b[3]; - - return res; -} - -inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[4]; res[1] = b[4]; - res[2] = a[5]; res[3] = b[5]; - res[4] = a[6]; res[5] = b[6]; - res[6] = a[7]; res[7] = b[7]; - - return res; -} - -// vld1q_s16_x2 -// vld1q_u8_x2 -// vld1q_u8_x4 -// vld1q_s8_x2 -// vld1q_s8_x4 -// TODO: double-check these work correctly - -typedef struct ggml_int16x8x2_t { - int16x8_t val[2]; -} ggml_int16x8x2_t; - -inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { - ggml_int16x8x2_t res; - - res.val[0] = vld1q_s16(ptr + 0); - res.val[1] = vld1q_s16(ptr + 8); - - return res; -} - -typedef struct ggml_uint8x16x2_t { - uint8x16_t val[2]; -} ggml_uint8x16x2_t; - -inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { - ggml_uint8x16x2_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - - return res; -} - -typedef struct ggml_uint8x16x4_t { - uint8x16_t val[4]; -} ggml_uint8x16x4_t; - -inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { - ggml_uint8x16x4_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - res.val[2] = vld1q_u8(ptr + 32); - res.val[3] = vld1q_u8(ptr + 48); - - return res; -} - -typedef struct ggml_int8x16x2_t { - int8x16_t val[2]; -} ggml_int8x16x2_t; - -inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { - ggml_int8x16x2_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - - return res; -} - -typedef struct ggml_int8x16x4_t { - int8x16_t val[4]; -} ggml_int8x16x4_t; - -inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { - ggml_int8x16x4_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - res.val[2] = vld1q_s8(ptr + 32); - res.val[3] = vld1q_s8(ptr + 48); - - return res; -} - -// NOTE: not tested -inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { - int8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -// NOTE: not tested -inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { - uint8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -#else - -#define ggml_int16x8x2_t int16x8x2_t -#define ggml_uint8x16x2_t uint8x16x2_t -#define ggml_uint8x16x4_t uint8x16x4_t -#define ggml_int8x16x2_t int8x16x2_t -#define ggml_int8x16x4_t int8x16x4_t - -#define ggml_vld1q_s16_x2 vld1q_s16_x2 -#define ggml_vld1q_u8_x2 vld1q_u8_x2 -#define ggml_vld1q_u8_x4 vld1q_u8_x4 -#define ggml_vld1q_s8_x2 vld1q_s8_x2 -#define ggml_vld1q_s8_x4 vld1q_s8_x4 -#define ggml_vqtbl1q_s8 vqtbl1q_s8 -#define ggml_vqtbl1q_u8 vqtbl1q_u8 - -#endif - -#if !defined(__ARM_FEATURE_DOTPROD) - -inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { - const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); - const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); - - return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); -} - -#else - -#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) - -#endif - -#endif - #if defined(__ARM_NEON) || defined(__wasm_simd128__) #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 06cb26a7d495b..d2f1de19831e3 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -124,6 +124,7 @@ class MODEL_ARCH(IntEnum): QWEN2 = auto() QWEN2MOE = auto() PHI2 = auto() + PHI3 = auto() PLAMO = auto() CODESHELL = auto() ORION = auto() @@ -200,6 +201,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.QWEN2MOE: "qwen2moe", MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PLAMO: "plamo", MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.ORION: "orion", @@ -550,6 +552,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PHI3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.CODESHELL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 10de36fa883ee..e5750d4191f6b 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -117,6 +117,7 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert + "model.layers.{bid}.self_attn.qkv_proj" # phi3 ), # Attention query @@ -234,6 +235,7 @@ class TensorNameMap: "h.{bid}.mlp.c_fc", # gpt2 "transformer.h.{bid}.mlp.fc1", # phi2 "model.layers.{bid}.mlp.fc1", # phi2 + "model.layers.{bid}.mlp.gate_up_proj", # phi3 "model.layers.layers.{bid}.mlp.up_proj", # plamo "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert diff --git a/llama.cpp b/llama.cpp index aae9f1a8adac7..b75742a9e35aa 100644 --- a/llama.cpp +++ b/llama.cpp @@ -210,6 +210,7 @@ enum llm_arch { LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, LLM_ARCH_PHI2, + LLM_ARCH_PHI3, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, @@ -245,6 +246,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, @@ -792,6 +794,23 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHI3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_PLAMO, { @@ -3955,6 +3974,16 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_PHI3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; @@ -4352,6 +4381,7 @@ static void llm_load_vocab( //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL && (t.first == "<|eot_id|>" || t.first == "<|im_end|>" || + t.first == "<|end|>" || t.first == "" ) ) { @@ -5375,6 +5405,33 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; + case LLM_ARCH_PHI3: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context* ctx_layer = ctx_for_layer(i); + ggml_context* ctx_split = ctx_for_layer_split(i); + + auto& layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); + } + } break; case LLM_ARCH_PLAMO: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6338,6 +6395,12 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + if (cparams.flash_attn) { GGML_UNUSED(model); GGML_UNUSED(n_ctx); @@ -9031,12 +9094,140 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); + ggml_build_forward_expand(gf, cur); + return gf; + } + + struct ggml_cgraph * build_phi3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + auto residual = inpL; + + // self-attention + { + struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(attn_norm_output, "attn_norm", il); + + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + } + else { + Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor* inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, residual); + residual = cur; + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // FF + // special-case: the up and gate tensors are merged into a single tensor + // TOOD: support into llm_build_ffn + { + struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0)); + auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2)); + + y = ggml_mul(ctx0, y, ggml_silu(ctx0, g)); + cb(y, "ffn_gate", il); + + auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y); + cb(down, "ffn_down", il); + + cur = down; + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, residual, cur); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); return gf; } + struct ggml_cgraph * build_plamo() { struct ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -10538,6 +10729,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_phi2(); } break; + case LLM_ARCH_PHI3: + { + result = llm.build_phi3(); + } break; case LLM_ARCH_PLAMO: { result = llm.build_plamo(); @@ -13628,7 +13823,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -13641,7 +13836,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra } std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto & rng = ctx->rng; int idx = dist(rng); llama_token result = candidates->data[idx].id; @@ -13651,6 +13845,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(ctx, candidates, ctx->rng); +} + void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); @@ -15549,6 +15747,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: case LLM_ARCH_PHI2: + case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_STARCODER2: return LLAMA_ROPE_TYPE_NEOX; @@ -15562,6 +15761,10 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } @@ -17413,6 +17616,15 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } + } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos )) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << trim(message->content) << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else { // template not supported return -1; diff --git a/llama.h b/llama.h index dc52d3e36f4d6..3e4a68465c091 100644 --- a/llama.h +++ b/llama.h @@ -391,8 +391,10 @@ extern "C" { LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); @@ -999,7 +1001,7 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities. + /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. LLAMA_API llama_token llama_sample_token( struct llama_context * ctx, llama_token_data_array * candidates); @@ -1086,8 +1088,9 @@ extern "C" { // Internal API to be implemented by llama.cpp and used by tests/benchmarks only #ifdef LLAMA_API_INTERNAL -#include +#include #include +#include struct ggml_tensor; @@ -1124,6 +1127,10 @@ std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start); +// Randomly selects a token from the candidates based on their probabilities using given std::mt19937. +// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); + #endif // LLAMA_API_INTERNAL #endif // LLAMA_H diff --git a/media/matmul.png b/media/matmul.png new file mode 100644 index 0000000000000..786a20492c02b Binary files /dev/null and b/media/matmul.png differ diff --git a/media/matmul.svg b/media/matmul.svg new file mode 100644 index 0000000000000..1d6cb4bb78a22 --- /dev/null +++ b/media/matmul.svg @@ -0,0 +1,1238 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARow-major + BTColumn-major + CT=ABTColumn-major + + ne00 + + ne01 + + ne1 + + ne0 + + ne10 + + ne11 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + BRow-major + ATColumn-major + C=BATRow-major + + ne10 + + ne11 + + ne0 + + ne1 + + ne00 + + ne01 + + + diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cddf86a4105ea..4fe9183b92cfd 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -49,6 +49,8 @@ int main(void) { "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", // Llama-3 "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", + // Phi-3 + "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + ' ' + message['content'] + '<|end|> ' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> ' }}{% else %}{{ eos_token }}{% endif %}" }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -77,6 +79,8 @@ int main(void) { "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", // Llama 3 "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + // Phi 3 + "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\nI am an assistant<|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", }; std::vector formatted_chat(1024); int32_t res;