Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Igoorx committed Jun 19, 2023
1 parent a91c122 commit 6c55fe1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
38 changes: 15 additions & 23 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ struct llama_vocab {

llama_trie special_token_trie;
std::unordered_map<token, id> special_token_to_id;
std::vector<id> special_tokens;
size_t max_special_token_length;
};

Expand Down Expand Up @@ -539,14 +538,13 @@ struct llama_file_loader {

for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) {
uint32_t token_id = file.read_u32();
const auto & token = vocab.id_to_token[token_id].tok;
const auto & word = vocab.id_to_token[token_id].tok;

vocab.special_token_trie.add(token);
vocab.special_tokens.push_back(token_id);
vocab.special_token_to_id[token] = token_id;
vocab.special_token_trie.add(word);
vocab.special_token_to_id[word] = token_id;

if (vocab.max_special_token_length < token.size()) {
vocab.max_special_token_length = token.size();
if (vocab.max_special_token_length < word.size()) {
vocab.max_special_token_length = word.size();
}
}
}
Expand Down Expand Up @@ -641,9 +639,8 @@ struct llama_file_saver {
file.write_raw(token_score.tok.data(), token_score.tok.size());
file.write_raw(&token_score.score, sizeof(token_score.score));
}
uint32_t n_vocab_sp = any_file_loader->hparams.n_vocab_sp;
for (uint32_t i = 0; i < n_vocab; i++) {
file.write_u32(any_file_loader->vocab.special_tokens[i]);
for (const auto & pair : any_file_loader->vocab.special_token_to_id) {
file.write_u32(pair.second);
}
}
void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
Expand Down Expand Up @@ -1964,24 +1961,23 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
return output;
}

auto offsets = vocab.special_token_trie.split(text);
std::vector<int> offsets = vocab.special_token_trie.split(text);
int start = 0;
for (int end : offsets) {
if (start >= end) {
continue;
}

size_t part_length = end - start;
//printf("\"%.*s\"\n", (int) part_length, text.c_str() + start);

if (vocab.max_special_token_length < part_length) {
tokenizer.tokenize(text.c_str() + start, part_length, output);
} else {
auto token_it = vocab.special_token_to_id.find(std::string(text.c_str() + start, part_length));
const char *part = text.c_str() + start;
size_t part_len = end - start;
if (vocab.max_special_token_length < part_len) {
tokenizer.tokenize(part, part_len, output);
} else {
auto token_it = vocab.special_token_to_id.find(std::string(part, part_len));
if (token_it != vocab.special_token_to_id.end()) {
output.push_back(token_it->second);
} else {
tokenizer.tokenize(text.c_str() + start, part_length, output);
tokenizer.tokenize(part, part_len, output);
}
}
start = end;
Expand Down Expand Up @@ -3515,10 +3511,6 @@ llama_token llama_token_nl() {
return 13;
}

bool llama_is_special_token(const struct llama_context *ctx, llama_token token) {
return std::find(ctx->vocab.special_tokens.begin(), ctx->vocab.special_tokens.end(), token) != ctx->vocab.special_tokens.end();
}


void llama_print_timings(struct llama_context * ctx) {
const int64_t t_end_us = ggml_time_us();
Expand Down
2 changes: 0 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); // next-line

LLAMA_API bool llama_is_special_token(const struct llama_context * ctx, llama_token token);

// Sampling functions

/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
Expand Down

0 comments on commit 6c55fe1

Please sign in to comment.