From 52531fdff88764282c1b233174721aab8347252d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Jan 2024 11:18:32 +0200 Subject: [PATCH] main : add self-extend support (#4815) * examples : add passkey test * passkey : better prints * passkey : select pass key pos from CLI * passkey : simplify n_past logic * llama : "self-extend"-like context extension * passkey : add comment * main : add Self-Extend support * llama : add comment about llama_kv_cache_seq_div --- common/common.cpp | 18 +++++++++ common/common.h | 2 + examples/main/main.cpp | 83 +++++++++++++++++++++++++++++++----------- llama.h | 4 ++ 4 files changed, 85 insertions(+), 22 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index eacaee18e0907..6b4913a656573 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_ctx = std::stoi(argv[i]); + } else if (arg == "--grp-attn-n" || arg == "-gan") { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_n = std::stoi(argv[i]); + } else if (arg == "--grp-attn-w" || arg == "-gaw") { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_w = std::stoi(argv[i]); } else if (arg == "--rope-freq-base") { if (++i >= argc) { invalid_param = true; @@ -904,6 +918,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" Not recommended since this is both slower and uses more VRAM.\n"); #endif // GGML_USE_CUBLAS #endif + printf(" -gan N, --grp-attn-n N\n"); + printf(" group-attention factor (default: %d)\n", params.grp_attn_n); + printf(" -gat N, --grp-attn-w N\n"); + printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); printf(" --verbose-prompt print prompt before generation\n"); printf(" -dkvc, --dump-kv-cache\n"); printf(" verbose print of the KV cache\n"); diff --git a/common/common.h b/common/common.h index 9659aa0453ff8..e2bbfc258b646 100644 --- a/common/common.h +++ b/common/common.h @@ -62,6 +62,8 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_beams = 0; // if non-zero then use beam search of given width. + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c096f110b32c5..5ea67051f3654 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -439,6 +439,21 @@ int main(int argc, char ** argv) { LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + + // group-attention state + // number of grouped KV tokens so far (used only if params.grp_attn_n > 1) + int ga_i = 0; + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT + LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w); + } LOG_TEE("\n\n"); if (params.interactive) { @@ -500,37 +515,61 @@ int main(int argc, char ** argv) { fflush(stdout); } - // infinite text generation via context swapping - // if we run out of context: - // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { - if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); - break; - } + if (ga_n == 1) { + // infinite text generation via context shifting + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left/2; + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", - n_past, n_left, n_ctx, params.n_keep, n_discard); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - n_past -= n_discard; + n_past -= n_discard; - if (ctx_guidance) { - n_past_guidance -= n_discard; + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + LOG("clear session path\n"); + path_session.clear(); } + } else { + // context extension via Self-Extend + while (n_past >= ga_i + ga_w) { + const int ib = (ga_n*ga_i)/ga_w; + const int bd = (ga_w/ga_n)*(ga_n - 1); + const int dd = (ga_w/ga_n) - ib*bd - ga_w; - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("\n"); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); + LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); - LOG("clear session path\n"); - path_session.clear(); + n_past -= bd; + + ga_i += ga_w/ga_n; + + LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); + } } // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) diff --git a/llama.h b/llama.h index 5305de90be5c1..869ff0acf525a 100644 --- a/llama.h +++ b/llama.h @@ -484,6 +484,10 @@ extern "C" { llama_pos p1, llama_pos delta); + // Integer division of the positions by factor of `d > 1` + // If the KV cache is RoPEd, the KV data is updated accordingly + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id,