diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 5a3d30b888d03..5687476cdcf92 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -248,7 +248,7 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { +static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); @@ -269,10 +269,12 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { } std::vector logit_history; - logit_history.resize(tokens.size()); - std::vector prob_history; - prob_history.resize(tokens.size()); + + if (compute_ppl) { + logit_history.resize(tokens.size()); + prob_history.resize(tokens.size()); + } const int n_chunk_max = tokens.size() / n_ctx; @@ -288,12 +290,17 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { std::vector workers(std::thread::hardware_concurrency() - 1); + const int num_batches = (n_ctx + n_batch - 1) / n_batch; + + std::vector logits; + if (compute_ppl && num_batches > 1) { + logits.reserve((size_t)n_ctx * n_vocab); + } + for (int i = 0; i < n_chunk; ++i) { const int start = i * n_ctx; const int end = start + n_ctx; - const int num_batches = (n_ctx + n_batch - 1) / n_batch; - std::vector logits; const auto t_start = std::chrono::high_resolution_clock::now(); @@ -321,8 +328,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { // restore the original token in case it was set to BOS tokens[batch_start] = token_org; - const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + if (compute_ppl && num_batches > 1) { + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + } } const auto t_end = std::chrono::high_resolution_clock::now(); @@ -338,25 +347,32 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); } - const int first = n_ctx/2; - process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); - count += n_ctx - first - 1; + if (compute_ppl) { + const int first = n_ctx/2; + const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); + process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); + count += n_ctx - first - 1; + + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + fflush(stdout); - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); - fflush(stdout); + logits.clear(); + } } printf("\n"); - nll2 /= count; - nll /= count; - const double ppl = exp(nll); - nll2 -= nll * nll; - if (nll2 > 0) { - nll2 = sqrt(nll2/(count-1)); - printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); - } else { - printf("Unexpected negative standard deviation of log(prob)\n"); + if (compute_ppl) { + nll2 /= count; + nll /= count; + const double ppl = exp(nll); + nll2 -= nll * nll; + if (nll2 > 0) { + nll2 = sqrt(nll2/(count-1)); + printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); + } else { + printf("Unexpected negative standard deviation of log(prob)\n"); + } } return true; @@ -365,6 +381,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { StatParams sparams; + bool compute_ppl = true; std::vector args; args.push_back(argv[0]); int iarg = 1; @@ -381,12 +398,19 @@ int main(int argc, char ** argv) { } else if (arg == "--verbosity") { sparams.verbosity = std::stoi(argv[++iarg]); + } else if (arg == "--no-ppl") { + compute_ppl = false; } else { args.push_back(argv[iarg]); } } if (iarg < argc) { - args.push_back(argv[iarg]); + std::string arg{argv[iarg]}; + if (arg == "--no-ppl") { + compute_ppl = false; + } else { + args.push_back(argv[iarg]); + } } gpt_params params; @@ -448,7 +472,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s\n", get_system_info(params).c_str()); } - bool OK = compute_imatrix(ctx, params); + bool OK = compute_imatrix(ctx, params, compute_ppl); if (!OK) { return 1; }