diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 60a7aef5bd355..d693bfb6cdeb4 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -929,6 +929,13 @@ struct test { return ts; } + std::vector get_ttfb() const { + int n_tokens = n_prompt + n_gen; + std::vector ts; + std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts), [n_tokens](uint64_t t) { return t/1e6; }); + return ts; + } + double avg_ts() const { return ::avg(get_ts()); } @@ -937,6 +944,14 @@ struct test { return ::stdev(get_ts()); } + double avg_ttfb() const { + return ::avg(get_ttfb()); + } + + double stdev_ttfb() const { + return ::stdev(get_ttfb()); + } + static std::string get_backend() { if (cuda) { return GGML_CUDA_NAME; @@ -1187,6 +1202,9 @@ struct markdown_printer : public printer { if (field == "model") { return -30; } + if (field == "ttfb") { + return 30; + } if (field == "t/s") { return 20; } @@ -1314,6 +1332,7 @@ struct markdown_printer : public printer { } fields.emplace_back("test"); fields.emplace_back("t/s"); + fields.emplace_back("ttfb"); fprintf(fout, "|"); for (const auto & field : fields) { @@ -1368,6 +1387,9 @@ struct markdown_printer : public printer { } else if (field == "t/s") { snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts()); value = buf; + } else if (field == "ttfb") { + snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ttfb(), t.stdev_ttfb()); + value = buf; } else if (vmap.find(field) != vmap.end()) { value = vmap.at(field); } else { @@ -1376,7 +1398,7 @@ struct markdown_printer : public printer { } int width = get_field_width(field); - if (field == "t/s") { + if (field == "t/s" || field == "ttfb") { // HACK: the utf-8 character is 2 bytes width += 1; }