Skip to content

Commit

Permalink
Update tests/test-sampling.cpp
Browse files Browse the repository at this point in the history
Co-authored-by: slaren <[email protected]>
  • Loading branch information
ggerganov and slaren authored Sep 23, 2024
1 parent 3cb33a8 commit a5a11bf
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions tests/test-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,25 +248,25 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
}

#define BENCH(__cnstr, __data, __n_iter) do { \
auto * cnstr = (__cnstr); \
std::vector<llama_token_data> cur((__data).size()); \
std::copy((__data).begin(), (__data).end(), cur.begin()); \
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \
llama_sampler_apply(cnstr, &cur_p); \
llama_sampler_reset(cnstr); \
const int64_t t_start = ggml_time_us(); \
const int n_iter = (__n_iter); \
for (int i = 0; i < n_iter; i++) { \
std::copy((__data).begin(), (__data).end(), cur.begin()); \
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \
llama_sampler_apply(cnstr, &cur_p); \
llama_sampler_reset(cnstr); \
} \
const int64_t t_end = ggml_time_us(); \
llama_sampler_free(cnstr); \
printf("%-42s: %8.3f us/iter\n", #__cnstr, (t_end - t_start) / (float)n_iter); \
} while(0)
static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
std::vector<llama_token_data> cur(data.size());
std::copy(data.begin(), data.end(), cur.begin());
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(cnstr, &cur_p);
llama_sampler_reset(cnstr);
const int64_t t_start = ggml_time_us();
for (int i = 0; i < n_iter; i++) {
std::copy(data.begin(), data.end(), cur.begin());
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(cnstr, &cur_p);
llama_sampler_reset(cnstr);
}
const int64_t t_end = ggml_time_us();
llama_sampler_free(cnstr);
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
}

#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))

static void test_perf() {
const int n_vocab = 1 << 17;
Expand Down

0 comments on commit a5a11bf

Please sign in to comment.