diff --git a/developer_document.md b/developer_document.md index 2bc7ac7f6..57ba9a085 100644 --- a/developer_document.md +++ b/developer_document.md @@ -1,6 +1,6 @@ ## Before you start -ITREX LLM C++ Runtime has already supported some popular models like `LLAMA`,`GPT-J`, `GPT-NEOX`, `DOLLY`, etc.These LLMs have similar architectures and some of them share the same architect (`DOLLY` and `GPT-NEOX`). Before adding a new model, you can checkout its architecture (from Huggingface `config.json`) whether is in our [supported list](./models/model_utils/model_types.h#L68). +ITREX LLM C++ Runtime (`Neural Speed`) has already supported some popular models like `LLAMA`,`GPT-J`, `GPT-NEOX`, `DOLLY`, etc.These LLMs have similar architectures and some of them share the same architect (`DOLLY` and `GPT-NEOX`). Before adding a new model, you can checkout its architecture (from Huggingface `config.json`) whether is in our [supported list](./neural_speed/models/model_utils/model_types.h#L68). However, LLM inference thing is complicated. It may have its own: 1. special tokenizer (or vocab); 2. architecture (or forward pipeline); 3. operators (or kernels). Generally speaking, the first and second points appear frequently for transformers-LLMs. I will show you how to run a new model as soon as possible when your model hasn't any problems like above or only the problem 1. The next sections will discuss about the problem 2 and the problem 3 is beyond the scope of this document. @@ -84,7 +84,7 @@ The term **"hyperparamters"** describes a value that is used to configure the be - n_vocab: the size of the model's vocabulary - n_embd: the size of the model's " embedding layer", which is used during prompt ingestion. - n_layer: the number of layers in the model; each layer represents a set of weights. -Here we will use [convert_gptneox.py](scripts/convert_gptneox.py#L96) as an example, +Here we will use [convert_gptneox.py](neural_speed/convert/convert_gptneox.py#L96) as an example, ```python fout.write(struct.pack("i", hparams["num_attention_heads"])) fout.write(struct.pack("i", hparams.get("n_head_kv", 0))) # multi-query attention @@ -96,7 +96,7 @@ The above `fout` is the file we need to get, and the `num_attention`, `n_head_kv As the name implies, a model's vocabulary comprises components that are used by the model to generate language (text). However, unlike the vocabulary of a human, which consists of words, the vocabulary of a large language model consists of "tokens". A token can be an entire word, but oftentimes they are word fragments. Just like humans can compose millions of words from just a dozen or two letters, large language models use tokens to express a large number of words from a relatively smaller number of components. Consider a vocabulary with the following tokens: `whi`, `ch`, `le`, `who`, and `a`; this vocabulary can be used to create the English words `"which"`, `"while"`, `"who"`, `"a"`, and `"leach"`. How would the behavior change if the model contained the following tokens: `wh`, `ich`, `ile`, `o`, and `leach`? Choices such as these allow model-creators to tune the behavior and performance of their models. As described above, the model's hyperparameters typically contain a value that specifies the number of tokens in the vocabulary. The vocabulary is encoded as a list of tokens, each of which includes a 32-bit integer that specifies the length of the token. If your model has some new tokenizers, we suggest using a python tokenizer from transformers and feeding the input_ids to model Python API (python example in scripts folder) -Here we will use [convert_gptneox.py](scripts/convert_gptneox.py#L122) as an example to processed the vocabulary of gptneox and written it into `fout`. +Here we will use [convert_gptneox.py](neural_speed/convert/convert_gptneox.py#L122) as an example to processed the vocabulary of gptneox and written it into `fout`. ```python encoder = tokenizer.vocab encoder.update(tokenizer.get_added_vocab()) @@ -105,10 +105,10 @@ byte_decoder = {v:k for k, v in byte_encoder.items()} ``` ## 1.3. Model weights -Finally, and largest, component of a ITREX GRAPH file is the weights of the LLM that the file represents. Abstractly, a large language model is software that is used to generate language - just like software that is used to generate images can be improved by increasing the number of colors with which images can be rendered, large language models can be improved by increasing the number of weights in the model. The total number of weights in a model is referred to as the "size" of that model. For example, the dolly-v2-3b implementation of the gpt-neox-20b language model architecture is available in several sizes, like 3B and 20B, which stand for 3 billion and 20 billion, respectively. These numbers refer to the total number of weights in that model. +Finally, and largest, component of a `Neural Speed` GRAPH file is the weights of the LLM that the file represents. Abstractly, a large language model is software that is used to generate language - just like software that is used to generate images can be improved by increasing the number of colors with which images can be rendered, large language models can be improved by increasing the number of weights in the model. The total number of weights in a model is referred to as the "size" of that model. For example, the dolly-v2-3b implementation of the gpt-neox-20b language model architecture is available in several sizes, like 3B and 20B, which stand for 3 billion and 20 billion, respectively. These numbers refer to the total number of weights in that model. As described in the hyperparameters section, weights are grouped in sets called "layers", which, like hyperparameters, have structures that are uniquely defined by the model architecture; within a layer, weights are grouped in structures called "tensors". So, for instance, both dolly-v2-3B and gpt-neox-20B use layers that comprise the same tensors, but dolly-v2-3B has relatively fewer layers when compared to gpt-neox-20B. -Here we will use [convert_gptneox.py](scripts/convert_gptneox.py#L149) as an example to convert model weights to `fout`. +Here we will use [convert_gptneox.py](neural_speed/convert/convert_gptneox.py#L149) as an example to convert model weights to `fout`. ```python fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) for i in range(n_dims): @@ -199,7 +199,7 @@ n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; ``` -The weights of the model in the ITREX Graph file will be loaded in [model load function](neural_speed/models/gptneox/gptneox_utils.cpp#L71). Here, we'll re-read some of the parameters and weights of the converted binary,include ffn, attention, and norm weight and bias, We'll use the mapping between the name and the weight to read the weight we need. It is shown below. +The weights of the model in the `Neural Speed` Graph file will be loaded in [model load function](neural_speed/models/gptneox/gptneox_utils.cpp#L71). Here, we'll re-read some of the parameters and weights of the converted binary,include ffn, attention, and norm weight and bias, We'll use the mapping between the name and the weight to read the weight we need. It is shown below. ```cpp model.others[0] = ml->get_tensor("gpt_neox.embed_in.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); model.others[1] = ml->get_tensor("gpt_neox.final_layer_norm.weight", {n_embd}, NE_BACKEND_CPU); @@ -227,75 +227,167 @@ The `inpL` in the code above is equivalent to the `hidden_states` in the pytorch When enabling a new model, we should implement the `new_model.cpp` of the new model. Most of our model examples only support single prompt processing. You need to add `batch-dim` for tensors and concat `KV cache` per-batch if you want to try multi-batch inference. + +We recommend to use continuous batching way since it has no padding effect and can boost throughput in both offline and server scenarios. Here is an [example](https://github.com/intel/neural-speed/pull/145/files#diff-54aa87c707bdc2d4c0145d612079fe976ecefb1dbf1efa75e1e86d14ea396185) of how to modify `LLAMA` [source cpp file](neural_speed/models/llama/llama.cpp). We will show the important modifications below. + ```diff -// copy batch inputs -- struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); -+ struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size); - ne_set_name(embd, "embd"); -- memcpy(embd->data, tokens, N * ne_element_size(embd)); +// do not forget to copy all sequences in and out +// 1. analyze model inputs and prepare information for inference (especially for splitting sequences) ++ // continuous batching (no padding) ++ // input shape will be [1, l_sum] ++ if (batch_size > 1) ++ MODEL_ASSERT( ++ ("llama arch only supports contiuous batching inference when giving multi prompts.", ++ lctx.cont_batching)); ++ const bool concat_multi_seqs = batch_size > 1 ? true : false; ++ std::vector n_tokens(batch_size); ++ std::vector n_pasts(batch_size); ++ std::vector n_totals(batch_size); ++ const int beam_size = lctx.beam_search ? lctx.beam_size : 1; ++ std::vector block_ids(batch_size); + for (int i = 0; i < batch_size; ++i) { -+ memcpy(static_cast(embd->data) + i * N, tokens + i * N, N * ne_element_size(embd)); ++ n_tokens[i] = inputs[i].n_tokens; ++ n_pasts[i] = inputs[i].n_past; ++ n_totals[i] = inputs[i].n_total; ++ block_ids[i] = inputs[i].request_idx * beam_size + inputs[i].beam_idx; ++ // enforce that the first token is BOS ++ if (n_totals[i] == 0 && inputs[i].tokens[0] != lctx.vocab.bos_token_id) { ++ fprintf(stderr, "%s: first token must be BOS (token id is %ld) in %dth prompt\n", __func__, ++ lctx.vocab.bos_token_id, i); ++ return false; ++ } + } - -// add batch-dim for tensors -- struct ne_tensor* Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, n_embd / n_head, n_head, N, cur->nb[1] / n_head, cur->nb[1], 0 * sizeof(float) * n_embd / n_head)); -- struct ne_tensor* Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, n_embd / n_head, n_head, N, cur->nb[1] / n_head, cur->nb[1], 1 * sizeof(float) * n_embd / n_head)); -- struct ne_tensor* Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, n_embd / n_head, n_head, N, cur->nb[1] / n_head, cur->nb[1], 2 * sizeof(float) * n_embd / n_head)); -+ struct ne_tensor* Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N * batch_size, cur->nb[1] / n_head, cur->nb[1], 0 * sizeof(float) * head_dim)); -+ struct ne_tensor* Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N * batch_size, cur->nb[1] / n_head, cur->nb[1], 1 * sizeof(float) * head_dim)); -+ struct ne_tensor* Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N * batch_size, cur->nb[1] / n_head, cur->nb[1], 2 * sizeof(float) * head_dim)); - -// concat kv cache per-batch -- struct ne_tensor* k = -- ne_view_1d(ctx0, kv_self.k, N * n_embd, (ne_element_size(kv_self.k) * n_embd) * (il * n_ctx + n_past)); ++ const int seq_len_sum = std::accumulate(n_tokens.begin(), n_tokens.end(), 0); ++ const int infer_bs = 1; ++ const int infer_seq_len = seq_len_sum; +// max batch num for a inference, usually it's larger than num of model inputs (beam search or dynamic batch size inference) ++ const int kv_n_ctx_block = lctx.kv_n_ctx_block; +// divide kv_n_ctx_block into server groups and each of them has same shape inside ++ const std::vector> infer_groups = split_inputs_into_groups(inputs, n_input); + +// 2. for-loop RoPE +// reshape 4d for Q K tensors (add batch dimension) +- Qcur = ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_size, n_head, N); +- Kcur = ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, N); ++ Qcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_size, n_head, ++ infer_seq_len, infer_bs); ++ Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head_kv, ++ infer_seq_len, infer_bs); +// per_request rope ++ for (int gi = 0; gi < infer_groups.size(); ++gi) { ++ const int qk_bs = infer_groups[gi].size(); ++ const int qk_sl = n_tokens[infer_groups[gi].front()]; ++ const int qk_n_past = n_pasts[infer_groups[gi].front()]; ++ struct ne_tensor* Qcur_req = ++ ne_view_4d(ctx0, Qcur, head_size, n_head, qk_sl, qk_bs, ne_element_size(Qcur) * ++ head_size,ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size ++ * n_head * qk_sl, off_sl * n_head * ne_element_size(Qcur)); ++ ne_build_forward_expand( ++ &gf, ne_rope_inplace(ctx0, Qcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, ++ hparams.freq_scale)); ++ struct ne_tensor* Kcur_req = ne_view_4d( ++ ctx0, Kcur, head_size, n_head_kv, qk_sl, qk_bs, ne_element_size(Kcur) * head_size, ++ ne_element_size(Kcur) * head_size * n_head_kv, ne_element_size(Kcur) * head_size * n_head_kv * ++ qk_sl, off_sl * n_head_kv * ne_element_size(Kcur)); ++ ne_build_forward_expand( ++ &gf, ne_rope_inplace(ctx0, Kcur_req, qk_n_past, n_rot, 0, 0, hparams.freq_base, ++ hparams.freq_scale)); ++ off_sl += head_size * qk_bs * qk_sl; ++ } + +// 3. for-loop kv cache concat +- struct ne_tensor* k = ne_view_1d(ctx0, kv_self.k, N * n_embd_gqa, +- (ne_element_size(kv_self.k) * n_embd_gqa) * (il * n_ctx + n_past)); - struct ne_tensor* v = -- ne_view_2d(ctx0, kv_self.v, N, n_embd, (n_ctx)*ne_element_size(kv_self.v), -- (il * n_ctx) * ne_element_size(kv_self.v) * n_embd + n_past * ne_element_size(kv_self.v)); - +- ne_view_2d(ctx0, kv_self.v, N, n_embd_gqa, (n_ctx)*ne_element_size(kv_self.v), +- (il * n_ctx) * ne_element_size(kv_self.v) * n_embd_gqa + n_past * ne_element_size(kv_self.v)); +- // important: storing RoPE-ed version of K in the KV cache! - ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k)); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v)); -+ std::vector Kcur_bs(batch_size); -+ std::vector Vcur_bs(batch_size); -+ std::vector k_bs(batch_size); -+ std::vector v_bs(batch_size); ++ struct ne_tensor* const k_cache = ++ ne_view_1d(ctx0, kv_self.k, n_ctx * n_embd_gqa * kv_n_ctx_block, ++ il * n_ctx * ne_element_size(kv_self.k) * n_embd_gqa * kv_n_ctx_block); ++ struct ne_tensor* const v_cache = ++ ne_view_1d(ctx0, kv_self.v, n_ctx * n_embd_gqa * kv_n_ctx_block, ++ il * n_ctx * ne_element_size(kv_self.v) * n_embd_gqa * kv_n_ctx_block); ++ // cache = [tokens, beams, requests, layers], ++ // tokens = [head_dim, head_num, n_ctx] (may different orders) ++ size_t off_N_i = 0; + for (int i = 0; i < batch_size; ++i) { - // batch K -+ Kcur_bs[i] = ne_permute(ctx0, -+ ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim, -+ ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N, -+ i * ne_element_size(Kcur) * n_embd * N), -+ 0, 2, 1, 3); -+ k_bs[i] = ne_view_4d( -+ ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim, -+ ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, -+ ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block + -+ i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); - - // batch V -+ Vcur_bs[i] = ne_permute(ctx0, -+ ne_reshape_4d(ctx0, -+ ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd, -+ i * ne_element_size(Vcur) * n_embd * N), -+ head_dim, n_head, N, 1), -+ 1, 2, 0, 3); -+ v_bs[i] = -+ ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v), -+ n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, -+ ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block + -+ i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); -+ ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); -+ ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); -+ } - -// copy batch output logits out - // return result for just the last token -+ size_t bs_stride = n_vocab * N; -- logits_out.resize(n_vocab); -- memcpy(logits_out.data(), (float*)ne_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); -+ logits_out.resize(n_vocab * batch_size); -+ for (int i = 0; i < batch_size; ++i) { -+ memcpy(logits_out.data() + (i * n_vocab), (float*)ne_get_data(inpL) + (i * bs_stride) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); -+ } ++ const int block_idx = block_ids[i]; ++ const int N_i = n_tokens[i]; ++ const int n_past_i = n_pasts[i]; ++ // batch K ++ struct ne_tensor* Kcur_bs_i = ++ ne_permute(ctx0, ++ ne_view_4d(ctx0, Kcur, head_size, n_head_kv, N_i, 1, ne_element_size(Kcur) * head_size, ++ ne_element_size(Kcur) * n_embd_gqa, ne_element_size(Kcur) * n_embd_gqa * N_i, ++ ne_element_size(Kcur) * off_N_i), ++ 0, 2, 1, 3); ++ struct ne_tensor* k_bs_i = ++ ne_view_4d(ctx0, k_cache, head_size, N_i, n_head_kv, 1, ne_element_size(k_cache) * head_size, ++ ne_element_size(k_cache) * head_size * n_ctx, ne_element_size(k_cache) * n_embd_gqa * n_ctx, ++ block_idx * n_ctx * n_embd_gqa * ne_element_size(k_cache) + ++ head_size * n_past_i * ne_element_size(k_cache)); ++ // batch V ++ struct ne_tensor* Vcur_bs_i = ++ ne_permute(ctx0, ++ ne_reshape_4d(ctx0, ++ ne_view_2d(ctx0, Vcur, n_embd_gqa, N_i, ne_element_size(Vcur) * n_embd_gqa, ++ ne_element_size(Vcur) * off_N_i), ++ head_size, n_head_kv, N_i, 1), ++ 1, 2, 0, 3); ++ struct ne_tensor* v_bs_i = ne_view_4d( ++ ctx0, v_cache, N_i, head_size, n_head_kv, 1, n_ctx * ne_element_size(v_cache), ++ n_ctx * ne_element_size(v_cache) * head_size, n_ctx * ne_element_size(v_cache) * n_embd_gqa, ++ block_idx * n_ctx * n_embd_gqa * ne_element_size(v_cache) + n_past_i * ne_element_size(v_cache)); ++ // concat ++ ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs_i, k_bs_i)); ++ ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs_i, v_bs_i)); ++ off_N_i += head_size * n_head_kv * N_i; + +// 4. for-loop attention +// prepare final QKV_merged tensor ++ struct ne_tensor* KQV_merged_contiguous = ++ ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); +// prepare Q K V tensors for each prompt ++ size_t off_sl = 0; ++ for (int gi = 0; gi < infer_groups.size(); ++gi) { ++ const int attn_bs = infer_groups[gi].size(); ++ const int attn_sl = n_tokens[infer_groups[gi].front()]; ++ const int attn_block_id = block_ids[infer_groups[gi].front()]; ++ const int attn_n_past = n_pasts[infer_groups[gi].front()]; ++ const int attn_n_total = n_totals[infer_groups[gi].front()]; ++ struct ne_tensor* Q = ++ ne_permute(ctx0, ++ ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, ++ ne_element_size(Qcur) * head_size * n_head, ++ ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), ++ 0, 2, 1, 3); ++ std::string suffix = std::to_string(gi); ++ ne_set_name(Q, std::string("Q_" + suffix).c_str()); ++ const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; ++ std::vector attn_block_ids(infer_groups[gi].size()); ++ for (int j = 0; j < infer_groups[gi].size(); ++j) { ++ attn_block_ids[j] = block_ids[infer_groups[gi][j]]; ++ } ++ struct ne_tensor* K = ++ model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head_kv, attn_bs, attn_block_ids, il); ++ // split cached V into n_head heads ++ struct ne_tensor* V = model_kv_cache_seq_concat(&gf, &lctx, ctx0, n_cached_gi, head_size, n_head_kv, attn_bs, ++ attn_block_ids, il, false); ++ ne_set_name(K, std::string("K_" + suffix).c_str()); ++ ne_set_name(V, std::string("V_" + suffix).c_str()); +// compute V * softmax(mask(KQ)) ++ .... +// copy each KQV_merged_i into KQV_merged ++ struct ne_tensor* KQV_merged_gi = ne_permute(ctx0, KQV, 0, 2, 1, 3); ++ ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); + ++ ne_build_forward_expand(&gf, ++ ne_cpy(ctx0, KQV_merged_gi, ++ ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_merged_contiguous), ne_element_size(KQV_merged_contiguous) * off_sl))); ++ off_sl += head_size * n_head * attn_sl * attn_bs; ``` ## 2.3. Application