Skip to content

Commit

Permalink
Add C API for adding special tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
Igoorx committed Aug 7, 2023
1 parent 099119f commit d9791bb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
24 changes: 15 additions & 9 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,15 @@ struct llama_vocab {
llama_trie special_token_trie;
std::unordered_map<token, id> special_token_to_id;
size_t max_special_token_length = 0;

void add_special_token(const token & word, id token_id) {
special_token_trie.add(word);
special_token_to_id[word] = token_id;

if (max_special_token_length < word.size()) {
max_special_token_length = word.size();
}
}
};

struct llama_model {
Expand Down Expand Up @@ -624,15 +633,8 @@ struct llama_file_loader {
for (uint32_t i = 0; i < vocab_sp; i++) {
llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
const auto & word = vocab.id_to_token[token_id].tok;
if (word.empty()) {
continue;
}

vocab.special_token_trie.add(word);
vocab.special_token_to_id[word] = token_id;

if (vocab.max_special_token_length < word.size()) {
vocab.max_special_token_length = word.size();
if (!word.empty()) {
vocab.add_special_token(word, token_id);
}
}
}
Expand Down Expand Up @@ -4263,6 +4265,10 @@ llama_token llama_token_nl() {
return 13;
}

void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) {
model->vocab.add_special_token(token, token_id);
}

struct llama_timings llama_get_timings(struct llama_context * ctx) {
struct llama_timings result = {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
Expand Down
5 changes: 5 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,11 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); // next-line

LLAMA_API void llama_add_special_token(
struct llama_model * model,
const char * token,
llama_token token_id);

// Grammar
//
LLAMA_API struct llama_grammar * llama_grammar_init(
Expand Down

0 comments on commit d9791bb

Please sign in to comment.