From 6c55fe1b960e1fc8e452b231dfe07e8049234314 Mon Sep 17 00:00:00 2001 From: Igor Pissolati Date: Mon, 19 Jun 2023 14:52:57 -0300 Subject: [PATCH] Code cleanup --- llama.cpp | 38 +++++++++++++++----------------------- llama.h | 2 -- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4a6329f3103237..e41ef098499a60 100644 --- a/llama.cpp +++ b/llama.cpp @@ -248,7 +248,6 @@ struct llama_vocab { llama_trie special_token_trie; std::unordered_map special_token_to_id; - std::vector special_tokens; size_t max_special_token_length; }; @@ -539,14 +538,13 @@ struct llama_file_loader { for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { uint32_t token_id = file.read_u32(); - const auto & token = vocab.id_to_token[token_id].tok; + const auto & word = vocab.id_to_token[token_id].tok; - vocab.special_token_trie.add(token); - vocab.special_tokens.push_back(token_id); - vocab.special_token_to_id[token] = token_id; + vocab.special_token_trie.add(word); + vocab.special_token_to_id[word] = token_id; - if (vocab.max_special_token_length < token.size()) { - vocab.max_special_token_length = token.size(); + if (vocab.max_special_token_length < word.size()) { + vocab.max_special_token_length = word.size(); } } } @@ -641,9 +639,8 @@ struct llama_file_saver { file.write_raw(token_score.tok.data(), token_score.tok.size()); file.write_raw(&token_score.score, sizeof(token_score.score)); } - uint32_t n_vocab_sp = any_file_loader->hparams.n_vocab_sp; - for (uint32_t i = 0; i < n_vocab; i++) { - file.write_u32(any_file_loader->vocab.special_tokens[i]); + for (const auto & pair : any_file_loader->vocab.special_token_to_id) { + file.write_u32(pair.second); } } void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { @@ -1964,24 +1961,23 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co return output; } - auto offsets = vocab.special_token_trie.split(text); + std::vector offsets = vocab.special_token_trie.split(text); int start = 0; for (int end : offsets) { if (start >= end) { continue; } - size_t part_length = end - start; - //printf("\"%.*s\"\n", (int) part_length, text.c_str() + start); - - if (vocab.max_special_token_length < part_length) { - tokenizer.tokenize(text.c_str() + start, part_length, output); - } else { - auto token_it = vocab.special_token_to_id.find(std::string(text.c_str() + start, part_length)); + const char *part = text.c_str() + start; + size_t part_len = end - start; + if (vocab.max_special_token_length < part_len) { + tokenizer.tokenize(part, part_len, output); + } else { + auto token_it = vocab.special_token_to_id.find(std::string(part, part_len)); if (token_it != vocab.special_token_to_id.end()) { output.push_back(token_it->second); } else { - tokenizer.tokenize(text.c_str() + start, part_length, output); + tokenizer.tokenize(part, part_len, output); } } start = end; @@ -3515,10 +3511,6 @@ llama_token llama_token_nl() { return 13; } -bool llama_is_special_token(const struct llama_context *ctx, llama_token token) { - return std::find(ctx->vocab.special_tokens.begin(), ctx->vocab.special_tokens.end(), token) != ctx->vocab.special_tokens.end(); -} - void llama_print_timings(struct llama_context * ctx) { const int64_t t_end_us = ggml_time_us(); diff --git a/llama.h b/llama.h index 9d09d8133baef9..26121536cff601 100644 --- a/llama.h +++ b/llama.h @@ -248,8 +248,6 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line - LLAMA_API bool llama_is_special_token(const struct llama_context * ctx, llama_token token); - // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.