Skip to content

Commit

Permalink
grammar: trigger words + refactor of antiprompts
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Sep 25, 2024
1 parent 70392f1 commit 5b6d504
Show file tree
Hide file tree
Showing 12 changed files with 436 additions and 108 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ BUILD_TARGETS = \

# Binaries only useful for tests
TEST_TARGETS = \
tests/test-antiprompts \
tests/test-arg-parser \
tests/test-autorelease \
tests/test-backend-ops \
Expand Down Expand Up @@ -1567,6 +1568,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-antiprompts: tests/test-antiprompts.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-grad0: tests/test-grad0.cpp \
$(OBJ_GGML)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
Expand Down
198 changes: 198 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#include "llama.h"

#include <queue>
#include <string>
#include <vector>
#include <sstream>
#include <unordered_map>

#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
Expand Down Expand Up @@ -134,6 +136,7 @@ struct gpt_sampler_params {
};

std::string grammar; // optional BNF-like grammar to constrain sampling
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar

std::vector<llama_logit_bias> logit_bias; // logit biases to apply

Expand Down Expand Up @@ -533,6 +536,201 @@ struct llama_control_vector_load_info {
// On error, returns {-1, empty}
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);

//
// Antiprompt utils
//

class llama_antiprompts {
public:

struct llama_antiprompt {
std::string value;
bool is_grammar_trigger;
};

std::vector<std::string> stop_words;
std::vector<std::string> grammar_trigger_words;

private:
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
struct TrieNode {
std::unordered_map<char, TrieNode> children;
TrieNode* fail = nullptr;
int output = -1;
size_t depth = 0;

void clear() {
children.clear();
fail = nullptr;
output = -1;
depth = 0;
}
};

TrieNode root;
std::vector<llama_antiprompt> antiprompts;
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.

void build_trie() {
// root = std::unique_ptr<TrieNode>(new TrieNode());
for (size_t i = 0; i < antiprompts.size(); ++i) {
TrieNode* node = &root;
const auto & pattern = antiprompts[i].value;
for (size_t j = 0; j < pattern.length(); ++j) {
char c = pattern[j];
auto & child = node->children[c];
if (child.depth == 0) {
child.depth = j + 1;
}
node = &child;
}
node->output = i;
}
}

void build_failure_and_dict_links() {
std::queue<TrieNode*> q;
for (auto& child : root.children) {
child.second.fail = &root;
q.push(&child.second);
}

while (!q.empty()) {
auto node = q.front();
q.pop();

for (auto & pair : node->children) {
auto & c = pair.first;
auto & child = pair.second;
auto f = node->fail;

while (f != &root && f->children.find(c) == f->children.end()) {
f = f->fail;
}

child.fail = (f == &root && f->children.find(c) == f->children.end())
? &root : &f->children[c];

if (child.fail->output != -1) {
child.output = child.fail->output;
}

q.push(&child);
}
}
}

public:

bool empty() const {
return antiprompts.empty() && stop_tokens.empty();
}
void clear() {
root.clear();
antiprompts.clear();
stop_tokens.clear();
}

void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
build(
[&](const std::string & text) {
return llama_tokenize(ctx, text, /* special= */ true);
},
stop_words,
grammar_trigger_words
);
}

void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
clear();
this->stop_words = stop_words;
this->grammar_trigger_words = grammar_trigger_words;

for (const std::string & stop_word : stop_words) {
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
}
for (const std::string & trigger : grammar_trigger_words) {
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
}

for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
const auto & antiprompt = antiprompts[i];
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
if (tokens.size() == 1) {
stop_tokens[tokens[0]] = i;
}
}

build_trie();
build_failure_and_dict_links();
}

struct MatchResult {
size_t pos;
std::string pattern;
bool is_partial;
size_t matchLength;
bool is_grammar_trigger;

bool operator==(const MatchResult & other) const {
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
}
operator std::string() const {
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
}
};

MatchResult findSingleTokenMatch(llama_token token) const {
auto it = stop_tokens.find(token);
if (it != stop_tokens.end()) {
const auto & antiprompt = antiprompts[it->second];
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
}
return {std::string::npos, "", false, 0, false};
}

MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
TrieNode* current = &root;
MatchResult partialMatch{std::string::npos, "", true, 0, false};

for (size_t i = offset; i < text.length(); ++i) {
char c = text[i];
while (current != &root && current->children.find(c) == current->children.end()) {
current = current->fail;
}
auto it = current->children.find(c);
if (it != current->children.end()) {
current = &it->second;
}
if (current->output != -1) {
const auto & antiprompt = antiprompts[current->output];
return {
i - antiprompt.value.length() + 1,
antiprompt.value,
false,
antiprompt.value.length(),
antiprompt.is_grammar_trigger,
};
}
// Update partial match if we're at a deeper node
if (current->depth > partialMatch.matchLength) {
partialMatch.pos = i - current->depth + 1;
partialMatch.pattern = ""; // We don't know which pattern it partially matches
partialMatch.matchLength = current->depth;
partialMatch.is_grammar_trigger = false;
}
}

// If we've found a partial match and haven't returned a full match, return the partial match
if (partialMatch.pos != std::string::npos) {
return partialMatch;
}

return {std::string::npos, "", false, 0, false};
}
};

//
// Split utils
//
Expand Down
15 changes: 13 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,23 @@ std::string gpt_sampler_params::print() const {
return std::string(result);
}

bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) {
if (gsmpl->grmr) {
return false;
}
gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root");
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
return true;
}

struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

lparams.no_perf = params.no_perf;

auto * result = new gpt_sampler {
/* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
/* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
Expand Down Expand Up @@ -226,7 +235,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st

void gpt_sampler_free(struct gpt_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
if (gsmpl->grmr) {
llama_sampler_free(gsmpl->grmr);
}

llama_sampler_free(gsmpl->chain);

Expand Down
2 changes: 2 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);

bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger);

std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
Loading

0 comments on commit 5b6d504

Please sign in to comment.