From a5a11bfbc387030d7cc8d80049061245e1e6534b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Sep 2024 17:18:12 +0300 Subject: [PATCH] Update tests/test-sampling.cpp Co-authored-by: slaren --- tests/test-sampling.cpp | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 2c79ec4720d60..6e021c4c70357 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -248,25 +248,25 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); } -#define BENCH(__cnstr, __data, __n_iter) do { \ - auto * cnstr = (__cnstr); \ - std::vector cur((__data).size()); \ - std::copy((__data).begin(), (__data).end(), cur.begin()); \ - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ - llama_sampler_apply(cnstr, &cur_p); \ - llama_sampler_reset(cnstr); \ - const int64_t t_start = ggml_time_us(); \ - const int n_iter = (__n_iter); \ - for (int i = 0; i < n_iter; i++) { \ - std::copy((__data).begin(), (__data).end(), cur.begin()); \ - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ - llama_sampler_apply(cnstr, &cur_p); \ - llama_sampler_reset(cnstr); \ - } \ - const int64_t t_end = ggml_time_us(); \ - llama_sampler_free(cnstr); \ - printf("%-42s: %8.3f us/iter\n", #__cnstr, (t_end - t_start) / (float)n_iter); \ -} while(0) +static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector & data, int n_iter) { + std::vector cur(data.size()); + std::copy(data.begin(), data.end(), cur.begin()); + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + llama_sampler_apply(cnstr, &cur_p); + llama_sampler_reset(cnstr); + const int64_t t_start = ggml_time_us(); + for (int i = 0; i < n_iter; i++) { + std::copy(data.begin(), data.end(), cur.begin()); + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + llama_sampler_apply(cnstr, &cur_p); + llama_sampler_reset(cnstr); + } + const int64_t t_end = ggml_time_us(); + llama_sampler_free(cnstr); + printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter); +} + +#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter)) static void test_perf() { const int n_vocab = 1 << 17;