diff --git a/developer_document.md b/developer_document.md index 2bc7ac7f6..3b6e41cde 100644 --- a/developer_document.md +++ b/developer_document.md @@ -1,6 +1,6 @@ ## Before you start -ITREX LLM C++ Runtime has already supported some popular models like `LLAMA`,`GPT-J`, `GPT-NEOX`, `DOLLY`, etc.These LLMs have similar architectures and some of them share the same architect (`DOLLY` and `GPT-NEOX`). Before adding a new model, you can checkout its architecture (from Huggingface `config.json`) whether is in our [supported list](./models/model_utils/model_types.h#L68). +ITREX LLM C++ Runtime (`Neural Speed`) has already supported some popular models like `LLAMA`,`GPT-J`, `GPT-NEOX`, `DOLLY`, etc.These LLMs have similar architectures and some of them share the same architect (`DOLLY` and `GPT-NEOX`). Before adding a new model, you can checkout its architecture (from Huggingface `config.json`) whether is in our [supported list](./neural_speed/models/model_utils/model_types.h#L68). However, LLM inference thing is complicated. It may have its own: 1. special tokenizer (or vocab); 2. architecture (or forward pipeline); 3. operators (or kernels). Generally speaking, the first and second points appear frequently for transformers-LLMs. I will show you how to run a new model as soon as possible when your model hasn't any problems like above or only the problem 1. The next sections will discuss about the problem 2 and the problem 3 is beyond the scope of this document. @@ -84,7 +84,7 @@ The term **"hyperparamters"** describes a value that is used to configure the be - n_vocab: the size of the model's vocabulary - n_embd: the size of the model's " embedding layer", which is used during prompt ingestion. - n_layer: the number of layers in the model; each layer represents a set of weights. -Here we will use [convert_gptneox.py](scripts/convert_gptneox.py#L96) as an example, +Here we will use [convert_gptneox.py](neural_speed/convert/convert_gptneox.py#L96) as an example, ```python fout.write(struct.pack("i", hparams["num_attention_heads"])) fout.write(struct.pack("i", hparams.get("n_head_kv", 0))) # multi-query attention @@ -96,7 +96,7 @@ The above `fout` is the file we need to get, and the `num_attention`, `n_head_kv As the name implies, a model's vocabulary comprises components that are used by the model to generate language (text). However, unlike the vocabulary of a human, which consists of words, the vocabulary of a large language model consists of "tokens". A token can be an entire word, but oftentimes they are word fragments. Just like humans can compose millions of words from just a dozen or two letters, large language models use tokens to express a large number of words from a relatively smaller number of components. Consider a vocabulary with the following tokens: `whi`, `ch`, `le`, `who`, and `a`; this vocabulary can be used to create the English words `"which"`, `"while"`, `"who"`, `"a"`, and `"leach"`. How would the behavior change if the model contained the following tokens: `wh`, `ich`, `ile`, `o`, and `leach`? Choices such as these allow model-creators to tune the behavior and performance of their models. As described above, the model's hyperparameters typically contain a value that specifies the number of tokens in the vocabulary. The vocabulary is encoded as a list of tokens, each of which includes a 32-bit integer that specifies the length of the token. If your model has some new tokenizers, we suggest using a python tokenizer from transformers and feeding the input_ids to model Python API (python example in scripts folder) -Here we will use [convert_gptneox.py](scripts/convert_gptneox.py#L122) as an example to processed the vocabulary of gptneox and written it into `fout`. +Here we will use [convert_gptneox.py](neural_speed/convert/convert_gptneox.py#L122) as an example to processed the vocabulary of gptneox and written it into `fout`. ```python encoder = tokenizer.vocab encoder.update(tokenizer.get_added_vocab()) @@ -105,10 +105,10 @@ byte_decoder = {v:k for k, v in byte_encoder.items()} ``` ## 1.3. Model weights -Finally, and largest, component of a ITREX GRAPH file is the weights of the LLM that the file represents. Abstractly, a large language model is software that is used to generate language - just like software that is used to generate images can be improved by increasing the number of colors with which images can be rendered, large language models can be improved by increasing the number of weights in the model. The total number of weights in a model is referred to as the "size" of that model. For example, the dolly-v2-3b implementation of the gpt-neox-20b language model architecture is available in several sizes, like 3B and 20B, which stand for 3 billion and 20 billion, respectively. These numbers refer to the total number of weights in that model. +Finally, and largest, component of a `Neural Speed` GRAPH file is the weights of the LLM that the file represents. Abstractly, a large language model is software that is used to generate language - just like software that is used to generate images can be improved by increasing the number of colors with which images can be rendered, large language models can be improved by increasing the number of weights in the model. The total number of weights in a model is referred to as the "size" of that model. For example, the dolly-v2-3b implementation of the gpt-neox-20b language model architecture is available in several sizes, like 3B and 20B, which stand for 3 billion and 20 billion, respectively. These numbers refer to the total number of weights in that model. As described in the hyperparameters section, weights are grouped in sets called "layers", which, like hyperparameters, have structures that are uniquely defined by the model architecture; within a layer, weights are grouped in structures called "tensors". So, for instance, both dolly-v2-3B and gpt-neox-20B use layers that comprise the same tensors, but dolly-v2-3B has relatively fewer layers when compared to gpt-neox-20B. -Here we will use [convert_gptneox.py](scripts/convert_gptneox.py#L149) as an example to convert model weights to `fout`. +Here we will use [convert_gptneox.py](neural_speed/convert/convert_gptneox.py#L149) as an example to convert model weights to `fout`. ```python fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) for i in range(n_dims): @@ -199,7 +199,7 @@ n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; ``` -The weights of the model in the ITREX Graph file will be loaded in [model load function](neural_speed/models/gptneox/gptneox_utils.cpp#L71). Here, we'll re-read some of the parameters and weights of the converted binary,include ffn, attention, and norm weight and bias, We'll use the mapping between the name and the weight to read the weight we need. It is shown below. +The weights of the model in the `Neural Speed` Graph file will be loaded in [model load function](neural_speed/models/gptneox/gptneox_utils.cpp#L71). Here, we'll re-read some of the parameters and weights of the converted binary,include ffn, attention, and norm weight and bias, We'll use the mapping between the name and the weight to read the weight we need. It is shown below. ```cpp model.others[0] = ml->get_tensor("gpt_neox.embed_in.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); model.others[1] = ml->get_tensor("gpt_neox.final_layer_norm.weight", {n_embd}, NE_BACKEND_CPU); @@ -227,76 +227,170 @@ The `inpL` in the code above is equivalent to the `hidden_states` in the pytorch When enabling a new model, we should implement the `new_model.cpp` of the new model. Most of our model examples only support single prompt processing. You need to add `batch-dim` for tensors and concat `KV cache` per-batch if you want to try multi-batch inference. + +We recommend to use continuous batching way since it has no padding effect and can boost throughput in both offline and server scenarios. Here is an [example](https://github.com/intel/neural-speed/pull/145/files) of how to modify `LLAMA` [source cpp file](neural_speed/models/llama/llama.cpp). We will show the important modifications below. + ```diff -// copy batch inputs -- struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); -+ struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size); - ne_set_name(embd, "embd"); -- memcpy(embd->data, tokens, N * ne_element_size(embd)); +// do not forget to copy all sequences in and out +// 1. analyze model inputs and prepare information for inference (especially for splitting sequences) ++ // continuous batching (no padding) ++ // input shape will be [1, l_sum] ++ if (batch_size > 1) ++ MODEL_ASSERT( ++ ("llama arch only supports continuous batching inference when giving multi prompts.", ++ lctx.cont_batching)); ++ const bool concat_multi_seqs = batch_size > 1 ? true : false; ++ std::vector n_tokens(batch_size); ++ std::vector n_pasts(batch_size); ++ std::vector n_totals(batch_size); ++ const int beam_size = lctx.beam_search ? lctx.beam_size : 1; ++ std::vector block_ids(batch_size); + for (int i = 0; i < batch_size; ++i) { -+ memcpy(static_cast(embd->data) + i * N, tokens + i * N, N * ne_element_size(embd)); ++ n_tokens[i] = inputs[i].n_tokens; ++ n_pasts[i] = inputs[i].n_past; ++ n_totals[i] = inputs[i].n_total; ++ block_ids[i] = inputs[i].request_idx * beam_size + inputs[i].beam_idx; ++ // enforce that the first token is BOS ++ if (n_totals[i] == 0 && inputs[i].tokens[0] != lctx.vocab.bos_token_id) { ++ fprintf(stderr, "%s: first token must be BOS (token id is %ld) in %dth prompt\n", __func__, ++ lctx.vocab.bos_token_id, i); ++ return false; ++ } + } - -// add batch-dim for tensors -- struct ne_tensor* Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, n_embd / n_head, n_head, N, cur->nb[1] / n_head, cur->nb[1], 0 * sizeof(float) * n_embd / n_head)); -- struct ne_tensor* Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, n_embd / n_head, n_head, N, cur->nb[1] / n_head, cur->nb[1], 1 * sizeof(float) * n_embd / n_head)); -- struct ne_tensor* Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, n_embd / n_head, n_head, N, cur->nb[1] / n_head, cur->nb[1], 2 * sizeof(float) * n_embd / n_head)); -+ struct ne_tensor* Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N * batch_size, cur->nb[1] / n_head, cur->nb[1], 0 * sizeof(float) * head_dim)); -+ struct ne_tensor* Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N * batch_size, cur->nb[1] / n_head, cur->nb[1], 1 * sizeof(float) * head_dim)); -+ struct ne_tensor* Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N * batch_size, cur->nb[1] / n_head, cur->nb[1], 2 * sizeof(float) * head_dim)); - -// concat kv cache per-batch -- struct ne_tensor* k = -- ne_view_1d(ctx0, kv_self.k, N * n_embd, (ne_element_size(kv_self.k) * n_embd) * (il * n_ctx + n_past)); ++ const int seq_len_sum = std::accumulate(n_tokens.begin(), n_tokens.end(), 0); ++ const int infer_bs = 1; ++ const int infer_seq_len = seq_len_sum; +// max batch num for a inference, usually it's larger than num of model inputs (beam search or dynamic batch size inference) ++ const int kv_n_ctx_block = lctx.kv_n_ctx_block; +// divide kv_n_ctx_block into server groups and each of them has same shape inside ++ const std::vector> infer_groups = split_inputs_into_groups(inputs, n_input); + +// 2. for-loop RoPE +// reshape 4d for Q K tensors (add batch dimension) +- Qcur = ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_size, n_head, N); +- Kcur = ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, N); ++ Qcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_size, n_head, ++ infer_seq_len, infer_bs); ++ Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, ++ infer_seq_len, infer_bs); +// per_request rope ++ for (int gi = 0; gi < infer_groups.size(); ++gi) { ++ const int qk_bs = infer_groups[gi].size(); ++ const int qk_sl = n_tokens[infer_groups[gi].front()]; ++ const int qk_n_past = n_pasts[infer_groups[gi].front()]; ++ struct ne_tensor* Qcur_req = ++ ne_view_4d(ctx0, Qcur, head_size, n_head, qk_sl, qk_bs, ne_element_size(Qcur) * ++ head_size,ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size ++ * n_head * qk_sl, off_sl * n_head * ne_element_size(Qcur)); ++ ne_build_forward_expand( ++ &gf, ne_rope_inplace(ctx0, Qcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, ++ hparams.freq_scale)); ++ struct ne_tensor* Kcur_req = ne_view_4d( ++ ctx0, Kcur, head_size, n_head_kv, qk_sl, qk_bs, ne_element_size(Kcur) * head_size, ++ ne_element_size(Kcur) * head_size * n_head_kv, ne_element_size(Kcur) * head_size * n_head_kv * ++ qk_sl, off_sl * n_head_kv * ne_element_size(Kcur)); ++ ne_build_forward_expand( ++ &gf, ne_rope_inplace(ctx0, Kcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, ++ hparams.freq_scale)); ++ off_sl += head_size * qk_bs * qk_sl; ++ } + +// 3. for-loop kv cache concat +// we suggest storing permuted kv tensors for unified kv cache operations without MHA fusion +- struct ne_tensor* k = ne_view_1d(ctx0, kv_self.k, N * n_embd_gqa, +- (ne_element_size(kv_self.k) * n_embd_gqa) * (il * n_ctx + n_past)); - struct ne_tensor* v = -- ne_view_2d(ctx0, kv_self.v, N, n_embd, (n_ctx)*ne_element_size(kv_self.v), -- (il * n_ctx) * ne_element_size(kv_self.v) * n_embd + n_past * ne_element_size(kv_self.v)); - +- ne_view_2d(ctx0, kv_self.v, N, n_embd_gqa, (n_ctx)*ne_element_size(kv_self.v), +- (il * n_ctx) * ne_element_size(kv_self.v) * n_embd_gqa + n_past * ne_element_size(kv_self.v)); +- // important: storing RoPE-ed version of K in the KV cache! - ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k)); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v)); -+ std::vector Kcur_bs(batch_size); -+ std::vector Vcur_bs(batch_size); -+ std::vector k_bs(batch_size); -+ std::vector v_bs(batch_size); ++ struct ne_tensor* const k_cache = ++ ne_view_1d(ctx0, kv_self.k, n_ctx * n_embd_gqa * kv_n_ctx_block, ++ il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block); ++ struct ne_tensor* const v_cache = ++ ne_view_1d(ctx0, kv_self.v, n_ctx * n_embd_gqa * kv_n_ctx_block, ++ il * n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block); ++ // cache = [tokens, beams, requests, layers], ++ // tokens = [head_dim, head_num, n_ctx] (may different orders) ++ size_t off_N_i = 0; + for (int i = 0; i < batch_size; ++i) { - // batch K -+ Kcur_bs[i] = ne_permute(ctx0, -+ ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim, -+ ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N, -+ i * ne_element_size(Kcur) * n_embd * N), -+ 0, 2, 1, 3); -+ k_bs[i] = ne_view_4d( -+ ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim, -+ ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, -+ ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block + -+ i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); - - // batch V -+ Vcur_bs[i] = ne_permute(ctx0, -+ ne_reshape_4d(ctx0, -+ ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd, -+ i * ne_element_size(Vcur) * n_embd * N), -+ head_dim, n_head, N, 1), -+ 1, 2, 0, 3); -+ v_bs[i] = -+ ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v), -+ n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, -+ ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block + -+ i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); -+ ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); -+ ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); -+ } - -// copy batch output logits out - // return result for just the last token -+ size_t bs_stride = n_vocab * N; -- logits_out.resize(n_vocab); -- memcpy(logits_out.data(), (float*)ne_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); -+ logits_out.resize(n_vocab * batch_size); -+ for (int i = 0; i < batch_size; ++i) { -+ memcpy(logits_out.data() + (i * n_vocab), (float*)ne_get_data(inpL) + (i * bs_stride) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); -+ } ++ const int block_idx = block_ids[i]; ++ const int N_i = n_tokens[i]; ++ const int n_past_i = n_pasts[i]; ++ // batch K ++ struct ne_tensor* Kcur_bs_i = ++ ne_permute(ctx0, ++ ne_view_4d(ctx0, Kcur, head_size, n_head_kv, N_i, 1, ne_element_size(Kcur) * head_size, ++ ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * N_i, ++ ne_element_size(Kcur) * off_N_i), ++ 0, 2, 1, 3); ++ struct ne_tensor* k_bs_i = ++ ne_view_4d(ctx0, k_cache, head_size, N_i, n_head_kv, 1, ne_element_size(k_cache) * head_size, ++ ne_element_size(k_cache) * head_size * n_ctx, ne_element_size(k_cache) * n_embd_gqa * n_ctx, ++ block_idx * n_ctx * n_embd_gqa * ne_element_size(k_cache) + ++ head_size * n_past_i * ne_element_size(k_cache)); ++ // batch V ++ struct ne_tensor* Vcur_bs_i = ++ ne_permute(ctx0, ++ ne_reshape_4d(ctx0, ++ ne_view_2d(ctx0, Vcur, n_embd_gqa, N_i, ne_element_size(Vcur) * n_embd_gqa, ++ ne_element_size(Vcur) * off_N_i), ++ head_size, n_head_kv, N_i, 1), ++ 1, 2, 0, 3); ++ struct ne_tensor* v_bs_i = ne_view_4d( ++ ctx0, v_cache, N_i, head_size, n_head_kv, 1, n_ctx * ne_element_size(v_cache), ++ n_ctx * ne_element_size(v_cache) * head_size, n_ctx * ne_element_size(v_cache) * n_embd_gqa, ++ block_idx * n_ctx * n_embd_gqa * ne_element_size(v_cache) + n_past_i * ne_element_size(v_cache)); ++ // concat ++ ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs_i, k_bs_i)); ++ ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs_i, v_bs_i)); ++ off_N_i += head_size * n_head_kv * N_i; + +// 4. for-loop attention +// prepare final QKV_merged tensor ++ struct ne_tensor* KQV_merged_contiguous = ++ ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); +// prepare Q K V tensors for each prompt ++ size_t off_sl = 0; ++ for (int gi = 0; gi < infer_groups.size(); ++gi) { ++ const int attn_bs = infer_groups[gi].size(); ++ const int attn_sl = n_tokens[infer_groups[gi].front()]; ++ const int attn_block_id = block_ids[infer_groups[gi].front()]; ++ const int attn_n_past = n_pasts[infer_groups[gi].front()]; ++ const int attn_n_total = n_totals[infer_groups[gi].front()]; ++ struct ne_tensor* Q = ++ ne_permute(ctx0, ++ ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, ++ ne_element_size(Qcur) * head_size * n_head, ++ ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), ++ 0, 2, 1, 3); ++ std::string suffix = std::to_string(gi); ++ ne_set_name(Q, std::string("Q_" + suffix).c_str()); ++ const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; ++ std::vector attn_block_ids(infer_groups[gi].size()); ++ for (int j = 0; j < infer_groups[gi].size(); ++j) { ++ attn_block_ids[j] = block_ids[infer_groups[gi][j]]; ++ } ++ struct ne_tensor* K = ++ model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head_kv, attn_bs, attn_block_ids, il); ++ // split cached V into n_head heads ++ struct ne_tensor* V = model_kv_cache_seq_concat(&gf, &lctx, ctx0, n_cached_gi, head_size, n_head_kv, attn_bs, ++ attn_block_ids, il, false); ++ ne_set_name(K, std::string("K_" + suffix).c_str()); ++ ne_set_name(V, std::string("V_" + suffix).c_str()); +// compute V * softmax(mask(KQ)) ++ .... +// copy each KQV_merged_i into KQV_merged ++ struct ne_tensor* KQV_merged_gi = ne_permute(ctx0, KQV, 0, 2, 1, 3); ++ ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); + ++ ne_build_forward_expand(&gf, ++ ne_cpy(ctx0, KQV_merged_gi, ++ ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_merged_contiguous), ne_element_size(KQV_merged_contiguous) * off_sl))); ++ off_sl += head_size * n_head * attn_sl * attn_bs; ``` +>Note: You can set larger [`NE_MAX_NODES`](neural_speed/core/ne.h#43) and [`model_scratch_enlarge_scale`](neural_speed/models/llama/llama.h#29) values if out of memory when the inputs' batch size becomes larger. ## 2.3. Application - Q4_0 quant : We can quantize the model generated by convert by adding a quant layer class to quantize it into an int4 low-bit file, so as to obtain better inference performance. Register quant layer class in your new_model_utils.cpp, just like [gptneox_utils.cpp](neural_speed/models/gptneox/gptneox_utils.cpp#L163), replace `gptneox_quant_layer` to your `new_model_quant_layer`. diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index c8ecab1f3..7bb39ce16 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -212,7 +212,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa out_count = 0 input_list = None pad_token_id = generate_kwargs.get("pad_token", None) - if generate_kwargs.get("continuous_batching", False): + if input_ids.shape[0] > 1 and generate_kwargs.get("continuous_batching", True): input_list = self._cont_batching_input(input_ids, pad_token_id) else: input_list = input_ids.tolist() diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index 4d064854e..f82fd803b 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -76,13 +76,14 @@ using Response = Query; using ResponseCallback = std::function, int)>; } // namespace +static std::set cont_batching_model_archs = {MODEL_GPTJ, MODEL_LLAMA}; void init_gpt_params(gpt_params* params, const std::string& model_path, int max_new_tokens = -1, int n_batch = 512, int ctx_size = 512, int seed = -1, int threads = 8, float repetition_penalty = 1.1f, int num_beams = 1, bool do_sample = false, int top_k = 40, float top_p = 0.95, float temperature = 0.8, int min_new_tokens = 0, float length_penalty = 1.0f, bool early_stopping = false, int n_keep = 0, int n_discard = -1, bool shift_roped_k = false, int batch_size = 1, model_vocab::id pad_token = -1, const std::string& memory_dtype = "auto", - const bool& continuous_batching = false, const int& max_request_num = MODEL_MAX_REQUEST_NUM, + bool continuous_batching = true, const int& max_request_num = MODEL_MAX_REQUEST_NUM, const float& model_scratch_enlarge_scale = 1.0f) { MODEL_ASSERT(params != nullptr); #ifdef MODEL_NAME @@ -114,10 +115,13 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_ params->memory_type = KV_MEM_TYPE_AUTO; else fprintf(stderr, "Unexpected memory dtype %s!", memory_dtype.c_str()); - if (batch_size > 1 && (!continuous_batching || params->model_arch != model_archs::MODEL_GPTJ)) { - params->memory_type = KV_MEM_TYPE_F16; // TODO(Yi & YZT): MHA IN MULTI-BATCH For More Model Archs - } + // TODO(Yi & YZT): MHA IN MULTI-BATCH For More Model Archs params->cont_batching = continuous_batching; + if (params->shift_roped_k) params->cont_batching = false; + if (cont_batching_model_archs.count(params->model_arch) == 0) params->cont_batching = false; + if (batch_size > 1 && !continuous_batching) { + params->memory_type = KV_MEM_TYPE_F16; + } params->max_request_num = std::max(batch_size, max_request_num); params->min_new_tokens = min_new_tokens; params->length_penalty = length_penalty; @@ -137,8 +141,8 @@ class ModelServer { int n_batch, int ctx_size, int seed, int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, - const std::string& memory_dtype, const bool& continuous_batching, const int& max_request_num, - const float& model_scratch_enlarge_scale, const std::string& policy, const bool& print_log, + const std::string& memory_dtype, bool continuous_batching, const int& max_request_num, + const float& model_scratch_enlarge_scale, const std::string& policy, bool print_log, const std::function& init_cb) : response(response), waiting(), @@ -258,12 +262,16 @@ class ModelServer { int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, - const std::string& memory_dtype, const bool& continuous_batching, const int& max_request_num, + const std::string& memory_dtype, bool continuous_batching, const int& max_request_num, const float& model_scratch_enlarge_scale) { init_gpt_params(¶ms, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty, num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, continuous_batching, max_request_num, model_scratch_enlarge_scale); + if (cont_batching_model_archs.count(params.model_arch) == 0) { + fprintf(stderr, "\nERROR: ModelServer only supports gpt-j, llama!\n"); + running = false; + } } ~ModelServer() { @@ -317,8 +325,7 @@ class Model { float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype, - const bool& continuous_batching, const int& max_request_num, - const float& model_scratch_enlarge_scale); + bool continuous_batching, const int& max_request_num, const float& model_scratch_enlarge_scale); void reinit(); std::vector> generate(const std::vector>& input_ids); // deprecated API @@ -411,7 +418,7 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_ int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, - const std::string& memory_dtype, const bool& continuous_batching, const int& max_request_num, + const std::string& memory_dtype, bool continuous_batching, const int& max_request_num, const float& model_scratch_enlarge_scale) { init_gpt_params(¶ms, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty, num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep, @@ -466,9 +473,9 @@ bool Model::check_input_and_count_padding(const std::vectorbatch_size = input_ids.size(); MODEL_ASSERT(input_ids.size() <= ctx->max_request_num); - static std::set batched_model_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_CHATGLM}; + static std::set batched_model_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_CHATGLM, MODEL_LLAMA}; if (batched_model_archs.count(params.model_arch) == 0) { - fprintf(stderr, "\nERROR: Only gpt-j, gpt-neox, chatglm support multi-batch generation!\n"); + fprintf(stderr, "\nERROR: Only gpt-j, gpt-neox, chatglm, llama support multi-batch generation!\n"); return false; } if (ctx->vocab.pad_token_id == -1) { @@ -738,7 +745,7 @@ std::vector> Model::post_beam_search(model_context* lct const std::vector& inputs, const int& n_threads) { // TODO(Zhentao): to implement - static std::set supported_archs = {MODEL_GPTJ, MODEL_GPTNEOX}; + static std::set supported_archs = {MODEL_GPTJ, MODEL_GPTNEOX, MODEL_LLAMA}; if (supported_archs.count(params.model_arch) != 0) { return beam_search(lctx, n_predict, inputs, n_threads); } else { @@ -914,7 +921,7 @@ PYBIND11_MODULE(mixtral_cpp, m) py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false, py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false, py::arg("batch_size") = 1, py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto", - py::arg("continuous_batching") = false, py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM, + py::arg("continuous_batching") = true, py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM, py::arg("model_scratch_enlarge_scale") = 1.0f) .def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids")) .def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits", @@ -946,9 +953,8 @@ PYBIND11_MODULE(mixtral_cpp, m) .def_readwrite("token_ids", &Query::token_ids); py::class_(m, "ModelServer", py::module_local()) .def(py::init&>(), + float, float, int, float, bool, int, int, bool, int, model_vocab::id, const std::string&, bool, + const int&, const float&, const std::string&, bool, const std::function&>(), py::arg("response"), py::arg("model_path"), py::arg("return_prompt") = false, py::arg("max_new_tokens") = -1, py::arg("n_batch") = 512, py::arg("ctx_size") = 512, py::arg("seed") = -1, py::arg("threads") = 8, py::arg("repetition_penalty") = 1.1f, py::arg("num_beams") = 1, py::arg("do_sample") = false, diff --git a/neural_speed/core/ne.h b/neural_speed/core/ne.h index 33bf4f0b6..9e029b6d9 100644 --- a/neural_speed/core/ne.h +++ b/neural_speed/core/ne.h @@ -40,7 +40,7 @@ #define NE_FILE_VERSION 1 #define NE_MAX_DIMS 4 -#define NE_MAX_NODES 16384 +#define NE_MAX_NODES 40960 #define NE_MAX_PARAMS 256 #define NE_MAX_CONTEXTS 64 #define NE_MAX_OPT 36 diff --git a/neural_speed/models/gptj/gptj.cpp b/neural_speed/models/gptj/gptj.cpp index 2dd1559b1..b907ed472 100644 --- a/neural_speed/models/gptj/gptj.cpp +++ b/neural_speed/models/gptj/gptj.cpp @@ -63,9 +63,10 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu const int N = inputs->n_tokens; const int n_past = inputs->n_past; const int n_total = inputs->n_total; - // continuous batching + // continuous batching (no padding) // if each sequence length l_i ! = l_k // input shape will be [1, l_sum] + const bool concat_multi_seqs = (batch_size > 1 && lctx.cont_batching) ? true : false; std::vector n_tokens(batch_size); std::vector n_pasts(batch_size); std::vector n_totals(batch_size); @@ -78,15 +79,15 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu n_pasts[i] = inputs[i].n_past; n_totals[i] = inputs[i].n_total; block_ids[i] = inputs[i].request_idx * beam_size + inputs[i].beam_idx; - if (!lctx.cont_batching) { + if (!concat_multi_seqs) { n_padding.push_back(inputs[i].n_padding); if (no_padding && inputs[i].n_padding != 0) no_padding = false; } } const int seq_len_sum = std::accumulate(n_tokens.begin(), n_tokens.end(), 0); - if (!lctx.cont_batching) MODEL_ASSERT(seq_len_sum == N * batch_size); - const int infer_bs = lctx.cont_batching ? 1 : batch_size; - const int infer_seq_len = lctx.cont_batching ? seq_len_sum : N; + if (!concat_multi_seqs) MODEL_ASSERT(seq_len_sum == N * batch_size); + const int infer_bs = concat_multi_seqs ? 1 : batch_size; + const int infer_seq_len = concat_multi_seqs ? seq_len_sum : N; const std::vector> infer_groups = split_inputs_into_groups(inputs, n_input); const auto& model = lctx.model; const auto& hparams = model.hparams; @@ -100,7 +101,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu const int n_ctx = lctx.n_ctx; // max number fo tokens to keep in the kv-cache const int n_keep = lctx.n_keep; const bool shift_roped_k = lctx.shift_roped_k; - MODEL_ASSERT(("continuous batching mechanism doesn't support shift rope.\n", !(lctx.cont_batching && shift_roped_k))); + MODEL_ASSERT(("continuous batching mechanism doesn't support shift rope.\n", !(concat_multi_seqs && shift_roped_k))); const bool is_ring_full = shift_roped_k && n_total > n_past; const int n_cached = shift_roped_k ? std::min(n_total + N, n_ctx) : (n_past + N); // #tokens cached after kv-append int n_head = hparams.n_head; @@ -120,7 +121,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu } #endif - MODEL_ASSERT(("continuous batching mechanism doesn't support TP.\n", !(lctx.cont_batching && enable_tp))); + MODEL_ASSERT(("continuous batching mechanism doesn't support TP.\n", !(concat_multi_seqs && enable_tp))); auto& mem_per_token = lctx.mem_per_token; auto& buf_compute = lctx.buf_compute; @@ -208,7 +209,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu infer_bs); Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); } - if (lctx.cont_batching) { + if (concat_multi_seqs) { size_t off_sl = 0; // per_request rope for (int gi = 0; gi < infer_groups.size(); ++gi) { @@ -414,9 +415,9 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu K = ne_permute(ctx0, K, 0, 2, 1, 3); } } else { - std::vector attn_block_ids; - for (const auto& bsi : infer_groups[gi]) { - attn_block_ids.push_back(block_ids[bsi]); + std::vector attn_block_ids(infer_groups[gi].size()); + for (int j = 0; j < infer_groups[gi].size(); ++j) { + attn_block_ids[j] = block_ids[infer_groups[gi][j]]; } K = model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head, attn_bs, attn_block_ids, il); if (is_ring_full) { diff --git a/neural_speed/models/llama/llama.cpp b/neural_speed/models/llama/llama.cpp index 41aedf08d..2f4a1d65f 100644 --- a/neural_speed/models/llama/llama.cpp +++ b/neural_speed/models/llama/llama.cpp @@ -54,21 +54,43 @@ static const bool NE_ATTN_PREFER_FP32 = // static bool llama_model_eval_internal(model_context* ctx, const model_input* inputs, const int n_input, const int n_threads) { + const int64_t t_start_us = ne_time_us(); model_context& lctx = *ctx; - // static batching for now + // single prompt const int N = inputs->n_tokens; const int n_past = inputs->n_past; const int n_total = inputs->n_total; - // enforce that the first token is BOS - if (n_total == 0 && inputs->tokens[0] != lctx.vocab.bos_token_id) { - fprintf(stderr, "%s: first token must be BOS\n", __func__); - return false; - } const int batch_size = lctx.batch_size; MODEL_ASSERT(batch_size == n_input); - - const int64_t t_start_us = ne_time_us(); + // continuous batching (no padding) + // input shape will be [1, l_sum] + if (batch_size > 1) + MODEL_ASSERT( + ("llama arch only supports continuous batching inference when giving multi prompts.", lctx.cont_batching)); + const bool concat_multi_seqs = batch_size > 1 ? true : false; + std::vector n_tokens(batch_size); + std::vector n_pasts(batch_size); + std::vector n_totals(batch_size); + const int beam_size = lctx.beam_search ? lctx.beam_size : 1; + std::vector block_ids(batch_size); + for (int i = 0; i < batch_size; ++i) { + n_tokens[i] = inputs[i].n_tokens; + n_pasts[i] = inputs[i].n_past; + n_totals[i] = inputs[i].n_total; + block_ids[i] = inputs[i].request_idx * beam_size + inputs[i].beam_idx; + // enforce that the first token is BOS + if (n_totals[i] == 0 && inputs[i].tokens[0] != lctx.vocab.bos_token_id) { + fprintf(stderr, "%s: first token must be BOS (token id is %d) in %dth prompt\n", __func__, + lctx.vocab.bos_token_id, i); + return false; + } + } + const int seq_len_sum = std::accumulate(n_tokens.begin(), n_tokens.end(), 0); + const int infer_bs = 1; + const int infer_seq_len = seq_len_sum; + const int kv_n_ctx_block = lctx.kv_n_ctx_block; + const std::vector> infer_groups = split_inputs_into_groups(inputs, n_input); const auto& model = lctx.model; const auto& hparams = model.hparams; @@ -82,6 +104,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp const int n_ctx = lctx.n_ctx; // max number fo tokens to keep in the kv-cache const int n_keep = lctx.n_keep; const bool shift_roped_k = lctx.shift_roped_k; + MODEL_ASSERT(("continuous batching mechanism doesn't support shift rope.\n", !(concat_multi_seqs && shift_roped_k))); // Whether kv-cache uses ring-buffer and is already full in the current run of _model_eval const bool is_ring_full = shift_roped_k && n_total > n_past; const int n_cached = shift_roped_k ? std::min(n_total + N, n_ctx) : (n_past + N); // #tokens cached after kv-append @@ -104,6 +127,8 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp n_head_kv /= world_size; } #endif + MODEL_ASSERT(("continuous batching mechanism doesn't support TP.\n", !(concat_multi_seqs && enable_tp))); + const int n_vocab = hparams.n_vocab; const int n_rot = head_size; const int n_embd_gqa = head_size * n_head_kv; @@ -129,7 +154,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp if (run_mha_reordered) { NE_ASSERT(kv_self.v->type == NE_TYPE_BTLA); // kv type should be the same attn_shape_t attn_shape = { - /* .batch_size = */ 1, + /* .batch_size = */ batch_size, /* .head_num = */ n_head, /* .heads_kv = */ n_head_kv, /* .head_size = */ head_size, @@ -147,11 +172,12 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); } - struct ne_tensor* embd = ne_new_tensor_1d(ctx0, NE_TYPE_I32, N, NE_SIZE_CALC); + struct ne_tensor* embd = ne_new_tensor_1d(ctx0, NE_TYPE_I32, seq_len_sum, NE_SIZE_CALC); ne_set_name(embd, "embd"); - + int cpy_off = 0; for (int i = 0; i < batch_size; ++i) { - memcpy(static_cast(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd)); + memcpy(static_cast(embd->data) + cpy_off, inputs[i].tokens, n_tokens[i] * ne_element_size(embd)); + cpy_off += n_tokens[i]; } #ifdef NS_TP_MODEL @@ -178,160 +204,279 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp } ne_tensor *Qcur, *Kcur, *Vcur; if (bestla_fusion_QKV_f32f32_support(model.layers[il].attn[0]->data, model.layers[il].attn[1]->data, - model.layers[il].attn[2]->data, N, model.layers[il].attn[0]->ne[1], + model.layers[il].attn[2]->data, seq_len_sum, model.layers[il].attn[0]->ne[1], model.layers[il].attn[0]->ne[0]) && n_head == n_head_kv) { // fused execution of QKV struct ne_tensor* QKVcur = ne_mul_qkv(ctx0, model.layers[il].attn[0], model.layers[il].attn[1], model.layers[il].attn[2], cur); - const size_t qkv_size = head_size * n_head * N; + const size_t qkv_size = head_size * n_head * seq_len_sum; const size_t qkv_bytes = qkv_size * ne_element_size(QKVcur); - Qcur = ne_reshape_3d(ctx0, ne_view_1d(ctx0, QKVcur, qkv_size, 0 * qkv_bytes), head_size, n_head, N); - Kcur = ne_reshape_3d(ctx0, ne_view_1d(ctx0, QKVcur, qkv_size, 1 * qkv_bytes), head_size, n_head_kv, N); + Qcur = ne_reshape_4d(ctx0, ne_view_1d(ctx0, QKVcur, qkv_size, 0 * qkv_bytes), head_size, n_head, infer_seq_len, + infer_bs); + Kcur = ne_reshape_4d(ctx0, ne_view_1d(ctx0, QKVcur, qkv_size, 1 * qkv_bytes), head_size, n_head_kv, infer_seq_len, + infer_bs); Vcur = ne_view_1d(ctx0, QKVcur, qkv_size, 2 * qkv_bytes); } else { - Qcur = ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_size, n_head, N); - Kcur = ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, N); + Qcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_size, n_head, infer_seq_len, + infer_bs); + Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, infer_seq_len, + infer_bs); Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); } - Qcur = - ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); + if (concat_multi_seqs) { + size_t off_sl = 0; + // per_request rope + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int qk_bs = infer_groups[gi].size(); + const int qk_sl = n_tokens[infer_groups[gi].front()]; + const int qk_n_past = n_pasts[infer_groups[gi].front()]; + struct ne_tensor* Qcur_req = + ne_view_4d(ctx0, Qcur, head_size, n_head, qk_sl, qk_bs, ne_element_size(Qcur) * head_size, + ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size * n_head * qk_sl, + off_sl * n_head * ne_element_size(Qcur)); + ne_build_forward_expand( + &gf, ne_rope_inplace(ctx0, Qcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale)); + struct ne_tensor* Kcur_req = ne_view_4d( + ctx0, Kcur, head_size, n_head_kv, qk_sl, qk_bs, ne_element_size(Kcur) * head_size, + ne_element_size(Kcur) * head_size * n_head_kv, ne_element_size(Kcur) * head_size * n_head_kv * qk_sl, + off_sl * n_head_kv * ne_element_size(Kcur)); + ne_build_forward_expand( + &gf, ne_rope_inplace(ctx0, Kcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, hparams.freq_scale)); + off_sl += head_size * qk_bs * qk_sl; + } + } else { + Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base, + hparams.freq_scale); + Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K + ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); + // Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, Vcur, head_size * n_head_kv, N)); + } ne_set_name(Qcur, "Qcur"); - Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K - ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base, hparams.freq_scale); ne_set_name(Kcur, "Kcur"); - Vcur = ne_transpose(ctx0, ne_reshape_2d(ctx0, Vcur, head_size * n_head_kv, N)); ne_set_name(Vcur, "Vcur"); // self-attention const float attn_scale = 1.0f / sqrtf(static_cast(head_size)); + struct ne_tensor* KQV_merged_contiguous = + ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); if (!run_mha_reordered) { // store key and value to memory + // important: + // 1. storing RoPE-ed version of K in the KV cache! + // 2. for loop self-attention in multi seqs infer (num_request > 1) { - struct ne_tensor* k = ne_view_1d(ctx0, kv_self.k, N * n_embd_gqa, - (ne_element_size(kv_self.k) * n_embd_gqa) * (il * n_ctx + n_past)); - struct ne_tensor* v = - ne_view_2d(ctx0, kv_self.v, N, n_embd_gqa, (n_ctx)*ne_element_size(kv_self.v), - (il * n_ctx) * ne_element_size(kv_self.v) * n_embd_gqa + n_past * ne_element_size(kv_self.v)); - // important: storing RoPE-ed version of K in the KV cache! - ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k)); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v)); + struct ne_tensor* const k_cache = + ne_view_1d(ctx0, kv_self.k, n_ctx * n_embd_gqa * kv_n_ctx_block, + il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block); + struct ne_tensor* const v_cache = + ne_view_1d(ctx0, kv_self.v, n_ctx * n_embd_gqa * kv_n_ctx_block, + il * n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block); + // cache = [tokens, beams, requests, layers], + // tokens = [head_dim, head_num, n_ctx] (may different orders) + size_t off_N_i = 0; + for (int i = 0; i < batch_size; ++i) { + const int block_idx = block_ids[i]; + const int N_i = n_tokens[i]; + const int n_past_i = n_pasts[i]; + // batch K + struct ne_tensor* Kcur_bs_i = + ne_permute(ctx0, + ne_view_4d(ctx0, Kcur, head_size, n_head_kv, N_i, 1, ne_element_size(Kcur) * head_size, + ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * N_i, + ne_element_size(Kcur) * off_N_i), + 0, 2, 1, 3); + struct ne_tensor* k_bs_i = + ne_view_4d(ctx0, k_cache, head_size, N_i, n_head_kv, 1, ne_element_size(k_cache) * head_size, + ne_element_size(k_cache) * head_size * n_ctx, ne_element_size(k_cache) * n_embd_gqa * n_ctx, + block_idx * n_ctx * n_embd_gqa * ne_element_size(k_cache) + + head_size * n_past_i * ne_element_size(k_cache)); + // batch V + struct ne_tensor* Vcur_bs_i = + ne_permute(ctx0, + ne_reshape_4d(ctx0, + ne_view_2d(ctx0, Vcur, n_embd_gqa, N_i, ne_element_size(Vcur) * n_embd_gqa, + ne_element_size(Vcur) * off_N_i), + head_size, n_head_kv, N_i, 1), + 1, 2, 0, 3); + struct ne_tensor* v_bs_i = ne_view_4d( + ctx0, v_cache, N_i, head_size, n_head_kv, 1, n_ctx * ne_element_size(v_cache), + n_ctx * ne_element_size(v_cache) * head_size, n_ctx * ne_element_size(v_cache) * n_embd_gqa, + block_idx * n_ctx * n_embd_gqa * ne_element_size(v_cache) + n_past_i * ne_element_size(v_cache)); + // concat + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs_i, k_bs_i)); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs_i, v_bs_i)); + off_N_i += head_size * n_head_kv * N_i; + } } - struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); - ne_set_name(Q, "Q"); - - struct ne_tensor* K = ne_reshape_3d( - ctx0, - ne_view_1d(ctx0, kv_self.k, n_cached * n_embd_gqa, il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa), - n_embd_gqa / n_head_kv, n_head_kv, n_cached); - if (is_ring_full) { - struct ne_tensor* cossin_cache = nullptr; - // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in - // a single eval execution - if (N == 1) { - cossin_cache = kv_self.cossin; + // for-loop attention + size_t off_sl = 0; + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int attn_bs = infer_groups[gi].size(); + const int attn_sl = n_tokens[infer_groups[gi].front()]; + const int attn_block_id = block_ids[infer_groups[gi].front()]; + const int attn_n_past = n_pasts[infer_groups[gi].front()]; + const int attn_n_total = n_totals[infer_groups[gi].front()]; + struct ne_tensor* Q = + ne_permute(ctx0, + ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, + ne_element_size(Qcur) * head_size * n_head, + ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), + 0, 2, 1, 3); + std::string suffix = std::to_string(gi); + ne_set_name(Q, std::string("Q_" + suffix).c_str()); + const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; + std::vector attn_block_ids(infer_groups[gi].size()); + for (int j = 0; j < infer_groups[gi].size(); ++j) { + attn_block_ids[j] = block_ids[infer_groups[gi][j]]; } - K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, - hparams.freq_scale); - } - K = ne_permute(ctx0, K, 0, 2, 1, 3); - ne_set_name(K, "K"); - // K * Q - struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); - ne_set_name(KQ, "KQ"); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - struct ne_tensor* KQ_scale = ne_new_f32(ctx0, attn_scale); - ne_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); - - // KQ_scaled shape [n_cached, N, n_head, 1] - struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, KQ_scale); - ne_set_name(KQ_scaled, "KQ_scaled"); - - // KQ_masked = mask_past(KQ_scaled) - if (N > 1 || !shift_roped_k) { // TODO(Yi): shift roped-k with N > 1 next-token - KQ_scaled = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - ne_set_name(KQ_scaled, "KQ_masked"); - } + struct ne_tensor* K = + model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head_kv, attn_bs, attn_block_ids, il); + if (is_ring_full) { + K = ne_permute(ctx0, K, 0, 2, 1, 3); + struct ne_tensor* cossin_cache = nullptr; + // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N + // in a single eval execution + if (N == 1) cossin_cache = kv_self.cossin; + K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, + hparams.freq_scale); + K = ne_permute(ctx0, K, 0, 2, 1, 3); + } + + // split cached V into n_head heads + struct ne_tensor* V = model_kv_cache_seq_concat(&gf, &lctx, ctx0, n_cached_gi, head_size, n_head_kv, attn_bs, + attn_block_ids, il, false); + ne_set_name(K, std::string("K_" + suffix).c_str()); + ne_set_name(V, std::string("V_" + suffix).c_str()); + + // K * Q + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); + ne_set_name(KQ, std::string("KQ_" + suffix).c_str()); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ne_tensor* KQ_scale = ne_new_f32(ctx0, attn_scale); + ne_set_name(KQ_scale, std::string("1/sqrt(n_embd/n_head)_" + suffix).c_str()); - // KQ = soft_max(KQ_masked) - struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_scaled); - ne_set_name(KQ_soft_max, "KQ_soft_max"); + // KQ_scaled shape [n_cached, N, n_head, 1] + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, KQ_scale); + ne_set_name(KQ_scaled, std::string("KQ_scaled_" + suffix).c_str()); - // split cached V into n_head heads - struct ne_tensor* V = - ne_view_3d(ctx0, kv_self.v, n_cached, n_embd_gqa / n_head_kv, n_head_kv, n_ctx * ne_element_size(kv_self.v), - n_ctx * ne_element_size(kv_self.v) * n_embd_gqa / n_head_kv, - n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * il); - ne_set_name(V, "V"); + // KQ_masked = mask_past(KQ_scaled) + if (N > 1 || !shift_roped_k || attn_n_total == 0) { // TODO(Yi): shift roped-k with N > 1 next-token + KQ_scaled = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, attn_n_past); + ne_set_name(KQ_scaled, std::string("KQ_masked_" + suffix).c_str()); + } - struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); - ne_set_name(KQV, "KQV"); + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_scaled); + ne_set_name(KQ_soft_max, std::string("KQ_soft_max_" + suffix).c_str()); - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); - ne_set_name(KQV_merged, "KQV_merged"); + struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); + ne_set_name(KQV, std::string("KQV_" + suffix).c_str()); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, N, NE_SIZE_CALC)); - ne_set_name(cur, "KQV_merged_contiguous"); + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ne_tensor* KQV_merged_gi = ne_permute(ctx0, KQV, 0, 2, 1, 3); + ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); + ne_build_forward_expand(&gf, + ne_cpy(ctx0, KQV_merged_gi, + ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, + head_size * n_head * ne_element_size(KQV_merged_contiguous), + ne_element_size(KQV_merged_contiguous) * off_sl))); + off_sl += head_size * n_head * attn_sl * attn_bs; + } + ne_set_name(KQV_merged_contiguous, "KQV_merged_contiguous"); // projection (no bias) - cur = ne_mul_mat(ctx0, model.layers[il].attn[3], cur); + cur = ne_mul_mat(ctx0, model.layers[il].attn[3], KQV_merged_contiguous); } else { const auto k_size = kv_cache_info.k_bytes; const auto v_size = kv_cache_info.v_bytes; // store key and value to memory { - const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor - head_size, n_ctx, n_head_kv, // ne - 0, 0, // nb (bestla managed) - il * k_size); // offset - ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, is_ring_full)); - const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor - head_size, n_ctx, n_head_kv, // ne - 0, 0, // nb (bestla managed) - il * v_size); // offset - // bestla always view V as (D, n_head, seq) - const auto Vcur_plain = - ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd_gqa * N, 0), n_embd_gqa / n_head_kv, n_head_kv, N); - ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur_plain, n_past, is_ring_full)); + size_t off_sl = 0; + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int update_bs = infer_groups[gi].size(); + const int update_sl = n_tokens[infer_groups[gi].front()]; + const int update_block_id = block_ids[infer_groups[gi].front()]; + const int update_n_past = n_pasts[infer_groups[gi].front()]; + const auto k_cache_g = ne_view_4d(ctx0, kv_self.k, // tensor + head_size, n_ctx, n_head_kv, update_bs, // ne + 0, 0, k_size, // nb (bestla managed) + il * kv_n_ctx_block * k_size + update_block_id * k_size); // offset + const auto k_cur_g = + ne_view_4d(ctx0, Kcur, head_size, n_head_kv, update_sl, update_bs, ne_element_size(Kcur) * head_size, + ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * update_sl, + ne_element_size(Kcur) * off_sl); + ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache_g, k_cur_g, update_n_past, is_ring_full)); + struct ne_tensor* v_cache_g = + ne_view_4d(ctx0, kv_self.v, // tensor + head_size, n_ctx, n_head_kv, update_bs, // ne + 0, 0, v_size, // nb (bestla managed) + il * kv_n_ctx_block * v_size + update_block_id * v_size); // offset); + // bestla always view V as (D, n_head, seq, bs) + const auto v_cur_g = + ne_view_4d(ctx0, Vcur, head_size, n_head_kv, update_sl, update_bs, ne_element_size(Vcur) * head_size, + ne_element_size(Vcur) * n_embd_gqa, ne_element_size(Vcur) * n_embd_gqa * update_sl, + ne_element_size(Vcur) * off_sl); + ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache_g, v_cur_g, update_n_past, is_ring_full)); + off_sl += n_embd_gqa * update_sl * update_bs; + } } - struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); - ne_set_name(Q, "Q"); - - struct ne_tensor* K = - ne_view_3d(ctx0, kv_self.k, // tensor - head_size, n_cached, n_head_kv, // ne - kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed) - il * k_size); // offset - *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout - if (is_ring_full) { - struct ne_tensor* cossin_cache = nullptr; - // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in - // a single eval execution - if (N == 1) cossin_cache = kv_self.cossin; - K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, - hparams.freq_scale); + // for-loop attention + size_t off_sl = 0; + for (int gi = 0; gi < infer_groups.size(); ++gi) { + const int attn_bs = infer_groups[gi].size(); + const int attn_sl = n_tokens[infer_groups[gi].front()]; + const int attn_block_id = block_ids[infer_groups[gi].front()]; + const int attn_n_past = n_pasts[infer_groups[gi].front()]; + const int attn_n_total = n_totals[infer_groups[gi].front()]; + struct ne_tensor* Q = + ne_permute(ctx0, + ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, + ne_element_size(Qcur) * head_size * n_head, + ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), + 0, 2, 1, 3); + std::string suffix = std::to_string(gi); + ne_set_name(Q, std::string("Q_" + suffix).c_str()); + const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; + struct ne_tensor* K = + ne_view_4d(ctx0, kv_self.k, // tensor + head_size, n_cached_gi, n_head_kv, attn_bs, // ne + kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, k_size, // nb (bestla managed) + il * kv_n_ctx_block * k_size + attn_block_id * k_size); // offset + *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // use nb0 for layout + if (is_ring_full) { + struct ne_tensor* cossin_cache = nullptr; + // Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N + // in a single eval execution + if (N == 1) cossin_cache = kv_self.cossin; + K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base, + hparams.freq_scale); + } + struct ne_tensor* V = ne_view_4d(ctx0, kv_self.v, // tensor + n_cached_gi, head_size, n_head_kv, attn_bs, // ne + kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, + v_size, // nb (bestla managed) + il * kv_n_ctx_block * v_size + attn_block_id * v_size); // use nb0 for layout + *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; + ne_set_name(K, std::string("K_" + suffix).c_str()); + ne_set_name(V, std::string("V_" + suffix).c_str()); + + ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE; + if (NE_ATTN_PREFER_FP32) attn_flags |= NE_ATTN_FLAG_PREFER_FP32; + if (n_total == 0 || !shift_roped_k) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases + struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + struct ne_tensor* KQV_merged_gi = ne_view_2d(ctx0, KQV_Out, head_size * n_head, attn_sl * attn_bs, + head_size * n_head * ne_element_size(KQV_Out), 0); + ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); + ne_build_forward_expand(&gf, + ne_cpy(ctx0, KQV_merged_gi, + ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, + head_size * n_head * ne_element_size(KQV_merged_contiguous), + ne_element_size(KQV_merged_contiguous) * off_sl))); + off_sl += head_size * n_head * attn_sl * attn_bs; } - ne_set_name(K, "K"); - - struct ne_tensor* V = - ne_view_3d(ctx0, kv_self.v, // tensor - n_cached, head_size, n_head_kv, // ne - kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed) - il * v_size); // offset - *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout - ne_set_name(V, "V"); - - ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE; - if (NE_ATTN_PREFER_FP32) attn_flags |= NE_ATTN_FLAG_PREFER_FP32; - if (n_total == 0 || !shift_roped_k) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases - struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); - struct ne_tensor* KQV_merged_contiguous = - ne_view_2d(ctx0, KQV_Out, head_size * n_head, N, head_size * n_head * ne_element_size(KQV_Out), 0); ne_set_name(KQV_merged_contiguous, "KQV_merged_contiguous"); - // projection (no bias) cur = ne_mul_mat(ctx0, model.layers[il].attn[3], KQV_merged_contiguous); } @@ -356,7 +501,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp } if (n_expert == 0) { if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data, - model.layers[il].ffn[2]->data, N, cur->ne[0], + model.layers[il].ffn[2]->data, seq_len_sum, cur->ne[0], model.layers[il].ffn[0]->ne[1], model.layers[il].ffn[1]->ne[1])) { cur = ne_ffn_silu(ctx0, model.layers[il].ffn[0], model.layers[il].ffn[1], model.layers[il].ffn[2], cur); } else { @@ -367,54 +512,70 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp cur = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur); } } else { - ne_tensor* logits = ne_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] - ne_tensor* probs = ne_soft_max_inplace(ctx0, logits); - ne_tensor* selected_experts = ne_top_k(ctx0, probs, n_expert_used); - ne_tensor* weights = ne_get_rows(ctx0, ne_reshape_3d(ctx0, probs, 1, n_expert, N), selected_experts); - weights = ne_reshape_2d(ctx0, weights, n_expert_used, N); - ne_tensor* weights_sum = ne_sum_rows(ctx0, weights); - weights_sum = ne_repeat(ctx0, weights_sum, weights); - weights = ne_div(ctx0, weights, weights_sum); - ne_tensor* moe_out = nullptr; - - for (int i = 0; i < n_expert_used; ++i) { - ne_tensor* cur_expert; - if (N == 1 && bestla_fusion_FFN_SiLu_f32f32_support( - model.layers[il].ffn_gate_exp[0]->data, model.layers[il].ffn_down_exp[0]->data, - model.layers[il].ffn_up_exp[0]->data, N, cur->ne[0], - model.layers[il].ffn_gate_exp[0]->ne[1], model.layers[il].ffn_down_exp[0]->ne[1])) { - cur_expert = ne_mul_id_ffn_silu(ctx0, model.layers[il].ffn_down_exp, model.layers[il].ffn_gate_exp, - model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); - } else { - ne_tensor* cur_up = ne_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); - ne_set_name(cur_up, "ffn_moe_up"); - - ne_tensor* cur_gate = - ne_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); - ne_set_name(cur_gate, "ffn_moe_gate"); - - cur_gate = ne_silu(ctx0, cur_gate); - ne_set_name(cur_gate, "ffn_moe_silu"); - - cur_expert = ne_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] - ne_set_name(cur_expert, "ffn_moe_gate_par"); - - cur_expert = ne_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, - cur_expert); // [n_tokens, n_embd] - ne_set_name(cur_expert, "ffn_moe_down"); - } - - cur_expert = - ne_mul(ctx0, cur_expert, - ne_repeat(ctx0, ne_view_2d(ctx0, weights, 1, N, weights->nb[1], i * weights->nb[0]), cur_expert)); - ne_set_name(cur_expert, "ffn_moe_weighted"); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ne_add(ctx0, moe_out, cur_expert); - ne_set_name(moe_out, "ffn_moe_out"); + // for-loop MOE (deal with sequence one by one) + struct ne_tensor* moe_out = ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); + size_t off_sl = 0; + for (int bi = 0; bi < batch_size; ++bi) { + const int moe_sl = n_tokens[bi]; + struct ne_tensor* cur_seq = + ne_view_2d(ctx0, cur, head_size * n_head, moe_sl, head_size * n_head * ne_element_size(cur), + ne_element_size(cur) * off_sl); + std::string suffix = std::to_string(bi); + ne_tensor* logits = ne_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur_seq); // [n_tokens, num_experts] + ne_tensor* probs = ne_soft_max_inplace(ctx0, logits); + ne_tensor* selected_experts = ne_top_k(ctx0, probs, n_expert_used); + ne_tensor* weights = ne_get_rows(ctx0, ne_reshape_3d(ctx0, probs, 1, n_expert, moe_sl), selected_experts); + weights = ne_reshape_2d(ctx0, weights, n_expert_used, moe_sl); + ne_tensor* weights_sum = ne_sum_rows(ctx0, weights); + weights_sum = ne_repeat(ctx0, weights_sum, weights); + weights = ne_div(ctx0, weights, weights_sum); + ne_tensor* moe_out_i = nullptr; + + for (int i = 0; i < n_expert_used; ++i) { + ne_tensor* cur_expert; + if (moe_sl == 1 && bestla_fusion_FFN_SiLu_f32f32_support( + model.layers[il].ffn_gate_exp[0]->data, model.layers[il].ffn_down_exp[0]->data, + model.layers[il].ffn_up_exp[0]->data, moe_sl, cur_seq->ne[0], + model.layers[il].ffn_gate_exp[0]->ne[1], model.layers[il].ffn_down_exp[0]->ne[1])) { + cur_expert = ne_mul_id_ffn_silu(ctx0, model.layers[il].ffn_down_exp, model.layers[il].ffn_gate_exp, + model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur_seq); + } else { + ne_tensor* cur_up = + ne_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur_seq); + ne_set_name(cur_up, std::string("ffn_moe_up_" + suffix).c_str()); + + ne_tensor* cur_gate = + ne_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur_seq); + ne_set_name(cur_gate, std::string("ffn_moe_gate_" + suffix).c_str()); + + cur_gate = ne_silu(ctx0, cur_gate); + ne_set_name(cur_gate, std::string("ffn_moe_silu_" + suffix).c_str()); + + cur_expert = ne_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] + ne_set_name(cur_expert, std::string("ffn_moe_gate_par_" + suffix).c_str()); + + cur_expert = ne_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, + cur_expert); // [n_tokens, n_embd] + ne_set_name(cur_expert, std::string("ffn_moe_down_" + suffix).c_str()); + } + + cur_expert = ne_mul( + ctx0, cur_expert, + ne_repeat(ctx0, ne_view_2d(ctx0, weights, 1, moe_sl, weights->nb[1], i * weights->nb[0]), cur_expert)); + ne_set_name(cur_expert, std::string("ffn_moe_weighted_" + suffix).c_str()); + + if (i == 0) { + moe_out_i = cur_expert; + } else { + moe_out_i = ne_add(ctx0, moe_out_i, cur_expert); + ne_set_name(moe_out_i, std::string("ffn_moe_out_" + suffix).c_str()); + } } + ne_build_forward_expand(&gf, ne_cpy(ctx0, moe_out_i, + ne_view_2d(ctx0, moe_out, head_size * n_head, moe_sl, + head_size * n_head * ne_element_size(moe_out), + ne_element_size(moe_out) * off_sl))); + off_sl += head_size * n_head * moe_sl; } cur = moe_out; @@ -471,13 +632,18 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp auto& logits_out = lctx.logits; if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), reinterpret_cast(ne_get_data(inpL)), sizeof(float) * n_vocab * N); + logits_out.resize(n_vocab * seq_len_sum); + memcpy(logits_out.data(), reinterpret_cast(ne_get_data(inpL)), sizeof(float) * n_vocab * seq_len_sum); } else { // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), reinterpret_cast(ne_get_data(inpL)) + (n_vocab * (N - 1)), - sizeof(float) * n_vocab); + logits_out.resize(n_vocab * batch_size); +#pragma omp parallel for + for (int i = 0; i < batch_size; ++i) { + size_t bs_off = std::accumulate(n_tokens.begin(), n_tokens.begin() + i, 0) * n_vocab; + memcpy(logits_out.data() + (i * n_vocab), + reinterpret_cast(ne_get_data(inpL)) + bs_off + (n_vocab * (n_tokens[i] - 1)), + sizeof(float) * n_vocab); + } } } // extract embeddings diff --git a/neural_speed/models/llama/llama.h b/neural_speed/models/llama/llama.h index 2cf7bdd08..5c9f07e58 100644 --- a/neural_speed/models/llama/llama.h +++ b/neural_speed/models/llama/llama.h @@ -26,18 +26,38 @@ enum llama_model { LLAMA_65B, }; -static const model_scratch llama_mem_req(int n_layers) { +static const model_scratch llama_mem_req(int n_layers, float enlarge_scale = 1.0f) { switch (n_layers) { case 32: - return {1024ull * MB, 1024ull * MB, 1608ull * MB}; + return { + static_cast(enlarge_scale * 1024) * MB, + static_cast(enlarge_scale * 1024) * MB, + static_cast(enlarge_scale * 1608) * MB, + }; case 40: - return {512ull * MB, 512ull * MB, 1608ull * MB}; + return { + static_cast(enlarge_scale * 512) * MB, + static_cast(enlarge_scale * 512) * MB, + static_cast(enlarge_scale * 1608) * MB, + }; case 48: - return {512ull * MB, 512ull * MB, 2366ull * MB}; + return { + static_cast(enlarge_scale * 512) * MB, + static_cast(enlarge_scale * 512) * MB, + static_cast(enlarge_scale * 2366) * MB, + }; case 60: - return {512ull * MB, 512ull * MB, 3124ull * MB}; + return { + static_cast(enlarge_scale * 512) * MB, + static_cast(enlarge_scale * 512) * MB, + static_cast(enlarge_scale * 3124) * MB, + }; case 80: - return {2048ull * MB, 2048ull * MB, 10240ull * MB}; + return { + static_cast(enlarge_scale * 2048) * MB, + static_cast(enlarge_scale * 2048) * MB, + static_cast(enlarge_scale * 10240) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp index 128f249a9..2bca0673c 100644 --- a/neural_speed/models/llama/llama_utils.cpp +++ b/neural_speed/models/llama/llama_utils.cpp @@ -81,7 +81,7 @@ void Llama::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b n_head = hparams.n_head; n_expert = hparams.n_experts; n_expert_used = hparams.n_experts_used; - scratch = llama_mem_req(n_layer); + scratch = llama_mem_req(n_layer, lctx.model_scratch_enlarge_scale); model.scratchs = scratch; } diff --git a/neural_speed/models/model_utils/model_config.h b/neural_speed/models/model_utils/model_config.h index 16879ec5b..816780e32 100644 --- a/neural_speed/models/model_utils/model_config.h +++ b/neural_speed/models/model_utils/model_config.h @@ -96,7 +96,7 @@ struct gpt_params { int batch_size = 1; // number batch of prompt bool beam_search = false; // use beam_search or not int beam_size = 1; // only valid if use beam search - int cont_batching = false; // whether to use continuous batching (concat multi sequences) + int cont_batching = true; // whether to use continuous batching (concat multi sequences) int max_request_num = MODEL_MAX_REQUEST_NUM; // maximum num of bearable requests in current env uint32_t min_new_tokens = 0; // min new tokens for beam search generation diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index 8f5bc43f1..2b05b216a 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -313,7 +313,7 @@ struct model_context { bool support_bestla_kv = false; // whether the model graph supports bestla-kvcache int beam_size = 1; int kv_n_ctx_block = 1; - bool cont_batching = false; + bool cont_batching = true; generation_config generation_conf; // global generation config std::shared_ptr bs_kv_reorder; std::vector> tensors_name; diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp index c89e74f74..896ed4231 100644 --- a/neural_speed/models/model_utils/model_utils.cpp +++ b/neural_speed/models/model_utils/model_utils.cpp @@ -185,7 +185,7 @@ struct model_context_params model_context_default_params() { /*.beam_search =*/false, /*.beam_size =*/1, /*.shift_roped_k =*/false, - /*cont_batching =*/false, + /*cont_batching =*/true, /*.max_request_num =*/1, /*.gen_conf =*/generation_config(), /*model_scratch_enlarge_scale =*/1.0f, diff --git a/scripts/python_api_example_for_model_server.py b/scripts/python_api_example_for_model_server.py index bd3ab34b0..c9889ef70 100644 --- a/scripts/python_api_example_for_model_server.py +++ b/scripts/python_api_example_for_model_server.py @@ -1,63 +1,106 @@ import time -import neural_speed.gptj_cpp as cpp +import argparse +from pathlib import Path +from typing import List, Optional +import neural_speed.llama_cpp as cpp from transformers import AutoTokenizer -prompts = [ - "she opened the door and see", - "tell me 10 things about jazz music", - "What is the meaning of life?", - "To be, or not to be, that is the question: Whether 'tis nobler in the mind to suffer"\ - " The slings and arrows of outrageous fortune, "\ - "Or to take arms against a sea of troubles."\ - "And by opposing end them. To die—to sleep,", - "Tell me an interesting fact about llamas.", - "What is the best way to cook a steak?", - "Are you familiar with the Special Theory of Relativity and can you explain it to me?", - "Recommend some interesting books to read.", - "What is the best way to learn a new language?", - "How to get a job at Intel?", - "If you could have any superpower, what would it be?", - "I want to learn how to play the piano.", + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser(description="example program llm model server") + parser.add_argument("--model_name", type=str, + help="model_name from huggingface or local model path: String", + required=True) + parser.add_argument("--model_path", type=Path, + help="Path to the local neural_speed low-bits model file: String", + required=True) + parser.add_argument("--max_new_tokens", type=int, + help="global query max generation token length: Int", required=False, + default=128) + parser.add_argument("--min_new_tokens", type=int, + help="global min new tokens for generation (only works in beam search): Int", + required=False, default=30) + parser.add_argument("--num_beams", type=int, + help="global num beams for beam search generation: Int", required=False, + default=4) + parser.add_argument("--do_sample", action="store_true", help="do sample for generation") + parser.add_argument("--early_stopping", action="store_true", + help="do early_stopping for beam search generation") + parser.add_argument("--return_prompt", action="store_true", + help="add prompt token ids in generation results") + parser.add_argument("--threads", type=int, help="num threads for model inference: Int", + required=False, default=8) + parser.add_argument("--max_request_num", type=int, + help="maximum number of running requests (or queries) for model inference: Int", + required=False, default=8) + parser.add_argument("--print_log", action="store_true", help="print server running logs") + parser.add_argument("--model_scratch_enlarge_scale", type=float, + help="scale for enlarge memory for model inference: Float", + required=False, default=1.0) + parser.add_argument("--memory_dtype", type=str, help="KV cache memory dtype: String", + required=False, default="auto") + args = parser.parse_args(args_in) + print(args) + + prompts = [ + "she opened the door and see", + "tell me 10 things about jazz music", + "What is the meaning of life?", + "To be, or not to be, that is the question: Whether 'tis nobler in the mind to suffer"\ + " The slings and arrows of outrageous fortune, "\ + "Or to take arms against a sea of troubles."\ + "And by opposing end them. To die—to sleep,", + "Tell me an interesting fact about llamas.", + "What is the best way to cook a steak?", + "Are you familiar with the Special Theory of Relativity and can you explain it to me?", + "Recommend some interesting books to read.", + "What is the best way to learn a new language?", + "How to get a job at Intel?", + "If you could have any superpower, what would it be?", + "I want to learn how to play the piano.", ] + tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) -model_name = "EleutherAI/gpt-j-6b" # model_name from huggingface or local model path -tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + res_collect = [] + # response function (deliver generation results and current remain working size in server) + def f_response(res, working): + ret_token_ids = [r.token_ids for r in res] + res_collect.extend(ret_token_ids) + ans = tokenizer.batch_decode(ret_token_ids, skip_special_tokens=True, + clean_up_tokenization_spaces=False) + print(f"working_size: {working}, ans:", flush=True) + for a in ans: + print(a) + print("=====================================") -res_collect = [] -def f_response(res, working): - ret_token_ids = [r.token_ids for r in res] - res_collect.extend(ret_token_ids) - ans = tokenizer.batch_decode(ret_token_ids, skip_special_tokens=True, - clean_up_tokenization_spaces=False) - print(f"working_size: {working}, ans:", flush=True) - for a in ans: - print(a) - print("=====================================") + added_count = 0 + s = cpp.ModelServer(f_response, + str(args.model_path), + max_new_tokens=args.max_new_tokens, + num_beams=args.num_beams, + min_new_tokens=args.min_new_tokens, + early_stopping=args.early_stopping, + do_sample=args.do_sample, + continuous_batching=True, + return_prompt=args.return_prompt, + threads=args.threads, + max_request_num=args.max_request_num, + print_log=args.print_log, + model_scratch_enlarge_scale = args.model_scratch_enlarge_scale, + memory_dtype= args.memory_dtype, + ) + for i in range(len(prompts)): + p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist() + s.issueQuery([cpp.Query(i, p_token_ids)]) + added_count += 1 + time.sleep(2) # adjust query sending time interval -model_path = "gptj-q4.bin" # please set your corresponding local neural_speed low-bits model file -added_count = 0 -s = cpp.ModelServer(f_response, # response function (deliver generation results and current remain working size in server) - model_path, # model_path - max_new_tokens=128, # global query max generation token length - num_beams=4, # global beam search related generation parameters - min_new_tokens=30, # global beam search related generation parameters (default: 0) - early_stopping=True, # global beam search related generation parameters (default: False) - continuous_batching=True, # turn on continuous batching mechanism (default: True) - return_prompt=True, # also return prompt token ids in generation results (default: False) - threads=56, # number of threads in model evaluate process (please bind cores if need) - max_request_num=8, # maximum number of running requests (or queries, default: 8) - print_log=True, # print server running logs (default: False) - model_scratch_enlarge_scale = 1, # model memory scratch enlarge scale (default: 1) - ) -for i in range(len(prompts)): - p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist() - s.issueQuery([cpp.Query(i, p_token_ids)]) - added_count += 1 - time.sleep(2) # adjust query sending time interval + # recommend to use time.sleep in while loop to exit program + # let cpp server owns more resources + while (added_count != len(prompts) or not s.Empty()): + time.sleep(1) + del s + print("should finished") -# recommend to use time.sleep in while loop to exit program -# let cpp server owns more resources -while (added_count != len(prompts) or not s.Empty()): - time.sleep(1) -del s -print("should finished") +if __name__ == "__main__": + main() diff --git a/tests/test_model_server.py b/tests/test_model_server.py new file mode 100644 index 000000000..d7e7fd102 --- /dev/null +++ b/tests/test_model_server.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest +import shutil +from neural_speed import Model +import neural_speed.llama_cpp as cpp +from transformers import AutoTokenizer + +class TestModelServer(unittest.TestCase): + + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree("./runtime_outs", ignore_errors=True) + + def test_model_server(self): + prompts = [ + "she opened the door and see", + "tell me 10 things about jazz music", + "What is the meaning of life?", + "To be, or not to be, that is the question: Whether 'tis nobler in the mind to suffer"\ + " The slings and arrows of outrageous fortune, "\ + "Or to take arms against a sea of troubles."\ + "And by opposing end them. To die—to sleep,", + "Tell me an interesting fact about llamas.", + "What is the best way to cook a steak?", + "Are you familiar with the Special Theory of Relativity and can you explain it to me?", + "Recommend some interesting books to read.", + "What is the best way to learn a new language?", + "How to get a job at Intel?", + "If you could have any superpower, what would it be?", + "I want to learn how to play the piano.", + ] + model_name = "/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = Model() + # get quantized model + model.init(model_name, use_quant=True, weight_dtype="int4", compute_dtype="int8") + del model + model_path = "./runtime_outs/ne_llama_q_int4_bestla_cint8_g32.bin" + + res_collect = [] + # response function (deliver generation results and current remain working size in server) + def f_response(res, working): + ret_token_ids = [r.token_ids for r in res] + res_collect.extend(ret_token_ids) + ans = tokenizer.batch_decode(ret_token_ids, skip_special_tokens=True, + clean_up_tokenization_spaces=False) + print(f"working_size: {working}, ans:", flush=True) + for a in ans: + print(a) + print("=====================================") + + for md in ["auto", "f16"]: + if md == "auto": + print("=======MHA MODEL SERVER TESTING=========") + else: + print("=======NON-MHA MODEL SERVER TESTING=========") + added_count = 0 + s = cpp.ModelServer(f_response, + model_path, + max_new_tokens=128, + num_beams=4, + min_new_tokens=30, + early_stopping=True, + do_sample=False, + continuous_batching=True, + return_prompt=True, + max_request_num=8, + threads=56, + print_log=False, + model_scratch_enlarge_scale = 1.0, + memory_dtype= md, + ) + for i in range(len(prompts)): + p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist() + s.issueQuery([cpp.Query(i, p_token_ids)]) + added_count += 1 + time.sleep(2) # adjust query sending time interval + + # recommend to use time.sleep in while loop to exit program + # let cpp server owns more resources + while (added_count != len(prompts) or not s.Empty()): + time.sleep(1) + del s + print("should finished") + +if __name__ == "__main__": + unittest.main()