diff --git a/Makefile b/Makefile index 8a903d7ed5914..88234972f81f2 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,7 @@ BUILD_TARGETS = \ # Binaries only useful for tests TEST_TARGETS = \ + tests/test-antiprompts \ tests/test-arg-parser \ tests/test-autorelease \ tests/test-backend-ops \ @@ -1567,6 +1568,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +tests/test-antiprompts: tests/test-antiprompts.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + tests/test-grad0: tests/test-grad0.cpp \ $(OBJ_GGML) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/common/common.h b/common/common.h index cb87c4479ed0a..1a5cfe7b1173b 100644 --- a/common/common.h +++ b/common/common.h @@ -4,9 +4,11 @@ #include "llama.h" +#include #include #include #include +#include #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -134,6 +136,7 @@ struct gpt_sampler_params { }; std::string grammar; // optional BNF-like grammar to constrain sampling + std::vector grammar_trigger_words; // optional trigger words to enable grammar std::vector logit_bias; // logit biases to apply @@ -533,6 +536,201 @@ struct llama_control_vector_load_info { // On error, returns {-1, empty} llama_control_vector_data llama_control_vector_load(const std::vector & load_infos); +// +// Antiprompt utils +// + +class llama_antiprompts { + public: + + struct llama_antiprompt { + std::string value; + bool is_grammar_trigger; + }; + + std::vector stop_words; + std::vector grammar_trigger_words; + +private: + // The Aho–Corasick algorithm allows efficient string matching with multiple patterns. + // See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm + struct TrieNode { + std::unordered_map children; + TrieNode* fail = nullptr; + int output = -1; + size_t depth = 0; + + void clear() { + children.clear(); + fail = nullptr; + output = -1; + depth = 0; + } + }; + + TrieNode root; + std::vector antiprompts; + std::unordered_map stop_tokens; // Single token antiprompts (and their index in antiprompts), if any. + + void build_trie() { + // root = std::unique_ptr(new TrieNode()); + for (size_t i = 0; i < antiprompts.size(); ++i) { + TrieNode* node = &root; + const auto & pattern = antiprompts[i].value; + for (size_t j = 0; j < pattern.length(); ++j) { + char c = pattern[j]; + auto & child = node->children[c]; + if (child.depth == 0) { + child.depth = j + 1; + } + node = &child; + } + node->output = i; + } + } + + void build_failure_and_dict_links() { + std::queue q; + for (auto& child : root.children) { + child.second.fail = &root; + q.push(&child.second); + } + + while (!q.empty()) { + auto node = q.front(); + q.pop(); + + for (auto & pair : node->children) { + auto & c = pair.first; + auto & child = pair.second; + auto f = node->fail; + + while (f != &root && f->children.find(c) == f->children.end()) { + f = f->fail; + } + + child.fail = (f == &root && f->children.find(c) == f->children.end()) + ? &root : &f->children[c]; + + if (child.fail->output != -1) { + child.output = child.fail->output; + } + + q.push(&child); + } + } + } + + public: + + bool empty() const { + return antiprompts.empty() && stop_tokens.empty(); + } + void clear() { + root.clear(); + antiprompts.clear(); + stop_tokens.clear(); + } + + void build(const llama_context * ctx, const std::vector & stop_words, const std::vector & grammar_trigger_words) { + build( + [&](const std::string & text) { + return llama_tokenize(ctx, text, /* special= */ true); + }, + stop_words, + grammar_trigger_words + ); + } + + void build(const std::function(const std::string)> & tokenizer, const std::vector & stop_words, const std::vector & grammar_trigger_words) { + clear(); + this->stop_words = stop_words; + this->grammar_trigger_words = grammar_trigger_words; + + for (const std::string & stop_word : stop_words) { + antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false}); + } + for (const std::string & trigger : grammar_trigger_words) { + antiprompts.push_back({trigger, /* is_grammar_trigger= */ true}); + } + + for (size_t i = 0, n = antiprompts.size(); i < n; i++) { + const auto & antiprompt = antiprompts[i]; + std::vector tokens = tokenizer(antiprompt.value); + if (tokens.size() == 1) { + stop_tokens[tokens[0]] = i; + } + } + + build_trie(); + build_failure_and_dict_links(); + } + + struct MatchResult { + size_t pos; + std::string pattern; + bool is_partial; + size_t matchLength; + bool is_grammar_trigger; + + bool operator==(const MatchResult & other) const { + return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger; + } + operator std::string() const { + return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}"; + } + }; + + MatchResult findSingleTokenMatch(llama_token token) const { + auto it = stop_tokens.find(token); + if (it != stop_tokens.end()) { + const auto & antiprompt = antiprompts[it->second]; + return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger}; + } + return {std::string::npos, "", false, 0, false}; + } + + MatchResult findFirstMatch(const std::string& text, size_t offset = 0) { + TrieNode* current = &root; + MatchResult partialMatch{std::string::npos, "", true, 0, false}; + + for (size_t i = offset; i < text.length(); ++i) { + char c = text[i]; + while (current != &root && current->children.find(c) == current->children.end()) { + current = current->fail; + } + auto it = current->children.find(c); + if (it != current->children.end()) { + current = &it->second; + } + if (current->output != -1) { + const auto & antiprompt = antiprompts[current->output]; + return { + i - antiprompt.value.length() + 1, + antiprompt.value, + false, + antiprompt.value.length(), + antiprompt.is_grammar_trigger, + }; + } + // Update partial match if we're at a deeper node + if (current->depth > partialMatch.matchLength) { + partialMatch.pos = i - current->depth + 1; + partialMatch.pattern = ""; // We don't know which pattern it partially matches + partialMatch.matchLength = current->depth; + partialMatch.is_grammar_trigger = false; + } + } + + // If we've found a partial match and haven't returned a full match, return the partial match + if (partialMatch.pos != std::string::npos) { + return partialMatch; + } + + return {std::string::npos, "", false, 0, false}; + } +}; + // // Split utils // diff --git a/common/sampling.cpp b/common/sampling.cpp index 3dc7f112094e6..ac1f8b174f23b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -139,6 +139,15 @@ std::string gpt_sampler_params::print() const { return std::string(result); } +bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) { + if (gsmpl->grmr) { + return false; + } + gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root"); + llama_sampler_accept_str(gsmpl->grmr, trigger.c_str()); + return true; +} + struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -146,7 +155,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st auto * result = new gpt_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr, /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, @@ -226,7 +235,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st void gpt_sampler_free(struct gpt_sampler * gsmpl) { if (gsmpl) { - llama_sampler_free(gsmpl->grmr); + if (gsmpl->grmr) { + llama_sampler_free(gsmpl->grmr); + } llama_sampler_free(gsmpl->chain); diff --git a/common/sampling.h b/common/sampling.h index d0e1a9203e99a..34c52377d6716 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -79,5 +79,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr); std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr); +bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger); + std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector gpt_sampler_types_from_chars(const std::string & chars); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6bbb1e13ed7ac..068d53b390ca6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -36,7 +36,7 @@ static llama_model ** g_model; static gpt_sampler ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; -static std::ostringstream * g_output_ss; +static std::string * g_output_s; static std::vector * g_output_tokens; static bool is_interacting = false; static bool need_insert_eot = false; @@ -115,7 +115,7 @@ static void sigint_handler(int signo) { console::cleanup(); LOG("\n"); gpt_perf_print(*g_ctx, *g_smpl); - write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); + write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, *g_output_s, *g_output_tokens); // make sure all logs are flushed LOG("Interrupted by user\n"); @@ -507,7 +507,8 @@ int main(int argc, char ** argv) { std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; - std::ostringstream output_ss; g_output_ss = &output_ss; + std::string output_s; g_output_s = &output_s; + size_t last_partial_stop = std::string::npos; std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode // the first thing we will do is to output the prompt, so set color accordingly @@ -516,13 +517,8 @@ int main(int argc, char ** argv) { std::vector embd; - // tokenized antiprompts - std::vector> antiprompt_ids; - - antiprompt_ids.reserve(params.antiprompt.size()); - for (const std::string & antiprompt : params.antiprompt) { - antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); - } + llama_antiprompts antiprompts; + antiprompts.build(ctx, params.antiprompt, {}); if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); @@ -727,7 +723,7 @@ int main(int argc, char ** argv) { } else { // Outgoing Generated Tokens output_tokens.push_back(id); - output_ss << token_str; + output_s.append(token_str); } } } @@ -740,44 +736,34 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { - // check for reverse prompt in the last n_prev tokens - if (!params.antiprompt.empty()) { - const int n_prev = 32; - const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev); - + // check for reverse prompt + if (!antiprompts.empty()) { is_antiprompt = false; - // Check if each of the reverse prompts appears at the end of the output. - // If we're not running interactively, the reverse prompt might be tokenized with some following characters - // so we'll compensate for that by widening the search window a bit. - for (std::string & antiprompt : params.antiprompt) { - size_t extra_padding = params.interactive ? 0 : 2; - size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding) - ? last_output.length() - static_cast(antiprompt.length() + extra_padding) - : 0; - - if (last_output.find(antiprompt, search_start_pos) != std::string::npos) { - if (params.interactive) { - is_interacting = true; - } - is_antiprompt = true; - break; - } - } // check for reverse prompt using special tokens llama_token last_token = gpt_sampler_last(smpl); - for (std::vector ids : antiprompt_ids) { - if (ids.size() == 1 && last_token == ids[0]) { - if (params.interactive) { - is_interacting = true; + auto match = antiprompts.findSingleTokenMatch(last_token); + if (match.pos != std::string::npos) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + } else { + match = antiprompts.findFirstMatch(output_s, last_partial_stop == std::string::npos ? 0 : last_partial_stop); + if (match.pos != std::string::npos) { + if (match.is_partial) { + last_partial_stop = match.pos; + } else { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; } - is_antiprompt = true; - break; } } if (is_antiprompt) { - LOG_DBG("found antiprompt: %s\n", last_output.c_str()); + LOG_DBG("found antiprompt: %s\n", match.pattern.c_str()); } } @@ -786,9 +772,9 @@ int main(int argc, char ** argv) { LOG_DBG("found an EOG token\n"); if (params.interactive) { - if (!params.antiprompt.empty()) { + if (!antiprompts.stop_words.empty()) { // tokenize and inject first reverse prompt - const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true); + const auto first_antiprompt = ::llama_tokenize(ctx, antiprompts.stop_words.front(), false, true); embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); is_antiprompt = true; } @@ -882,7 +868,7 @@ int main(int argc, char ** argv) { for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); - output_ss << llama_token_to_piece(ctx, token); + output_s.append(llama_token_to_piece(ctx, token)); } // reset assistant message @@ -926,7 +912,7 @@ int main(int argc, char ** argv) { LOG("\n\n"); gpt_perf_print(ctx, smpl); - write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + write_logfile(ctx, params, model, input_tokens, output_s, output_tokens); gpt_sampler_free(smpl); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e5275a5149551..9ac064748ead0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -131,8 +131,6 @@ struct slot_params { int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_predict = -1; // new tokens to predict - std::vector antiprompt; - json input_prefix; json input_suffix; }; @@ -183,6 +181,8 @@ struct server_slot { std::string oaicompat_model; std::string stopping_word; + llama_antiprompts antiprompts; + // sampling json json_schema; @@ -281,34 +281,6 @@ struct server_slot { }; } - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) { - size_t stop_pos = std::string::npos; - - for (const std::string & word : params.antiprompt) { - size_t pos; - - if (type == STOP_TYPE_FULL) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - } else { - pos = find_partial_stop_string(word, text); - } - - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == STOP_TYPE_FULL) { - stopped_word = true; - stopping_word = word; - has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - void print_timings() const { const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; @@ -999,16 +971,26 @@ struct server_context { } { - slot.params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - slot.params.antiprompt.push_back(word); + slot.antiprompts.clear(); + + auto copy_string_array = [&](const json & data, const std::string & key, std::vector & vec) { + const auto & arr = data.find(key); + if (arr != data.end() && arr->is_array()) { + for (const auto & word : *arr) { + if (word.is_string()) { + vec.push_back(word); + } } } - } + }; + + std::vector stop_words; + std::vector grammar_trigger_words; + + copy_string_array(data, "stop", stop_words); + copy_string_array(data, "grammar_trigger_words", grammar_trigger_words); + + slot.antiprompts.build(ctx, stop_words, grammar_trigger_words); } { @@ -1110,6 +1092,18 @@ struct server_context { const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); slot.sampled = result.tok; + auto match = slot.antiprompts.findSingleTokenMatch(result.tok); + if (match.pos != std::string::npos && !match.is_partial) { + if (match.is_grammar_trigger) { + gpt_sampler_trigger_grammar(model, slot.smpl, llama_token_to_piece(ctx, result.tok, params.special)); + } else { + slot.stopped_word = true; + slot.stopping_word = match.pattern; + slot.has_next_token = false; + return false; + } + } + // search stop word and delete it slot.generated_text += token_str; slot.has_next_token = true; @@ -1139,23 +1133,33 @@ struct server_context { if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - const std::string str_test = slot.generated_text.substr(pos); + match = slot.antiprompts.findFirstMatch(slot.generated_text, pos); + bool is_stop_full = false; + bool is_grammar_trigger = false; + size_t length = slot.generated_text.size(); + + // If there is a lazy grammar trigger word at stop_pos, enable the lazy grammar + if (match.is_grammar_trigger && gpt_sampler_trigger_grammar(model, slot.smpl, match.pattern)) { + is_grammar_trigger = true; + length = pos + match.pos + match.matchLength; + } else if (!match.is_grammar_trigger && match.pos != std::string::npos && !match.is_partial) { + slot.stopped_word = true; + slot.stopping_word = match.pattern; + slot.has_next_token = false; - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL); - if (stop_pos != std::string::npos) { is_stop_full = true; - slot.generated_text.erase( - slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else { - is_stop_full = false; - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL); + // length = pos + match.pos; + length = match.pos; } + slot.generated_text.erase( + slot.generated_text.begin() + length, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, length); + // check if there is any token to predict - if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { + if (match.pos == std::string::npos || (!slot.has_next_token && !is_grammar_trigger && !is_stop_full && match.pos > 0)) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.n_sent_text += result.text_to_send.size(); @@ -1243,7 +1247,8 @@ struct server_context { {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, + {"stop", slot.antiprompts.stop_words}, + {"grammar_trigger", slot.antiprompts.grammar_trigger_words}, {"max_tokens", slot.params.n_predict}, // User configured n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f093f547ff2c1..8cab665014f8c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -196,20 +196,15 @@ static size_t common_part(const std::string & a, const std::string & b) { return i; } -static bool ends_with(const std::string & str, const std::string & suffix) { - return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); -} - -static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { +static size_t find_partial_stop_string(const std::string & stop, const std::string & text) { if (!text.empty() && !stop.empty()) { - const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { - const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) { - return text.size() - char_index - 1; - } + auto it = std::find(stop.rbegin(), stop.rend(), text.back()); + while (it != stop.rend()) { + size_t length = std::distance(it, stop.rend()); + if (text.length() >= length && 0 == text.compare(text.length() - length, length, stop)) { + return text.length() - length; } + it = std::find(std::next(it), stop.rend(), text.back()); } } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 74e9f64b393b2..b554fa6943c85 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1121,7 +1121,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token } const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); + llama_grammar_accept_str(grammar, piece); +} +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f529ce351e416..4a55ff5dac5c5 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -142,3 +142,7 @@ void llama_grammar_apply_impl( void llama_grammar_accept_impl( struct llama_grammar & grammar, llama_token token); + +void llama_grammar_accept_str( + struct llama_grammar & grammar, + const std::string & piece); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e255a8fc4fd54..0773cd94f00d9 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -193,6 +193,12 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { } } +void llama_sampler_accept_str(struct llama_sampler * smpl, const char * piece) { + if (smpl->iface->accept_str) { + smpl->iface->accept_str(smpl, piece); + } +} + void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { GGML_ASSERT(smpl->iface->apply); smpl->iface->apply(smpl, cur_p); @@ -325,6 +331,7 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_chain_i = { /* .name = */ llama_sampler_chain_name, /* .accept = */ llama_sampler_chain_accept, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_chain_apply, /* .reset = */ llama_sampler_chain_reset, /* .clone = */ llama_sampler_chain_clone, @@ -399,6 +406,7 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to static struct llama_sampler_i llama_sampler_greedy_i = { /* .name = */ llama_sampler_greedy_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_greedy_apply, /* .reset = */ nullptr, /* .clone = */ nullptr, @@ -457,6 +465,7 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_dist_i = { /* .name = */ llama_sampler_dist_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_dist_apply, /* .reset = */ llama_sampler_dist_reset, /* .clone = */ llama_sampler_dist_clone, @@ -488,6 +497,7 @@ static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_t static struct llama_sampler_i llama_sampler_softmax_i = { /* .name = */ llama_sampler_softmax_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_softmax_apply, /* .reset = */ nullptr, /* .clone = */ nullptr, @@ -528,6 +538,7 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_top_k_i = { /* .name = */ llama_sampler_top_k_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_top_k_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_top_k_clone, @@ -594,6 +605,7 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_top_p_i = { /* .name = */ llama_sampler_top_p_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_top_p_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_top_p_clone, @@ -690,6 +702,7 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_min_p_i = { /* .name = */ llama_sampler_min_p_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_min_p_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_min_p_clone, @@ -785,6 +798,7 @@ static void llama_sampler_tail_free_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_tail_free_i = { /* .name = */ llama_sampler_tail_free_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_tail_free_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_tail_free_clone, @@ -884,6 +898,7 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_typical_i = { /* .name = */ llama_sampler_typical_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_typical_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_typical_clone, @@ -929,6 +944,7 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_temp_i = { /* .name = */ llama_sampler_temp_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_temp_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_temp_clone, @@ -1042,6 +1058,7 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .name = */ llama_sampler_temp_ext_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_temp_ext_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_temp_ext_clone, @@ -1145,6 +1162,7 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_mirostat_i = { /* .name = */ llama_sampler_mirostat_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_mirostat_apply, /* .reset = */ llama_sampler_mirostat_reset, /* .clone = */ llama_sampler_mirostat_clone, @@ -1244,6 +1262,7 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_mirostat_v2_i = { /* .name = */ llama_sampler_mirostat_v2_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_mirostat_v2_apply, /* .reset = */ llama_sampler_mirostat_v2_reset, /* .clone = */ llama_sampler_mirostat_v2_clone, @@ -1287,6 +1306,13 @@ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama } } +static void llama_sampler_grammar_accept_str(struct llama_sampler * smpl, const char * piece) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_accept_str(*ctx->grammar, piece); + } +} + static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (ctx->grammar) { @@ -1339,6 +1365,7 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_grammar_i = { /* .name = */ llama_sampler_grammar_name, /* .accept = */ llama_sampler_grammar_accept_impl, + /* .accept_str = */ llama_sampler_grammar_accept_str, /* .apply = */ llama_sampler_grammar_apply, /* .reset = */ llama_sampler_grammar_reset, /* .clone = */ llama_sampler_grammar_clone, @@ -1522,6 +1549,7 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_penalties_i = { /* .name = */ llama_sampler_penalties_name, /* .accept = */ llama_sampler_penalties_accept, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_penalties_apply, /* .reset = */ llama_sampler_penalties_reset, /* .clone = */ llama_sampler_penalties_clone, @@ -1624,6 +1652,7 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { static struct llama_sampler_i llama_sampler_logit_bias_i = { /* .name = */ llama_sampler_logit_bias_name, /* .accept = */ nullptr, + /* .accept_str = */ nullptr, /* .apply = */ llama_sampler_logit_bias_apply, /* .reset = */ nullptr, /* .clone = */ llama_sampler_logit_bias_clone, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 08ad66b49fdd4..25f2489961b90 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -122,6 +122,7 @@ llama_target_and_test(test-grad0.cpp) llama_target_and_test(test-barrier.cpp) # llama_target_and_test(test-opt.cpp) # SLOW llama_target_and_test(test-backend-ops.cpp) +llama_target_and_test(test-antiprompts.cpp) llama_target_and_test(test-rope.cpp) diff --git a/tests/test-antiprompts.cpp b/tests/test-antiprompts.cpp new file mode 100644 index 0000000000000..226c7d24f4f30 --- /dev/null +++ b/tests/test-antiprompts.cpp @@ -0,0 +1,88 @@ +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include "llama.h" +#include "common.h" + +#include + +template +void assert_equal(const T & actual, const T & expected) { + if (expected == actual) return; + printf("Expected: %s, Actual: %s\n", ((std::string)expected).c_str(), ((std::string)actual).c_str()); + assert(expected == actual); +} + +// cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_CURL=1 && cmake --build build -j -t test-jinja -t test-antiprompts && ./build/bin/test-antiprompts +int main() +{ + auto tokenizer = [&](const std::string & text) { + std::vector tokens; + for (size_t i = 0; i < text.length(); ++i) { + tokens.push_back(text[i]); + } + return tokens; + }; + const std::vector stop_words { }; + const std::vector grammar_trigger_words { }; + + printf("Testing antiprompts\n"); + + llama_antiprompts antiprompts; + antiprompts.build(tokenizer, {"abc", "bcd"}, {"bca", "x"}); + + assert_equal(antiprompts.findSingleTokenMatch('x'), { + .pos = 0, + .pattern = "x", + .is_partial = false, + .matchLength = 1, + .is_grammar_trigger = true, + }); + assert_equal(antiprompts.findSingleTokenMatch('a'), { + .pos = std::string::npos, + .pattern = "", + .is_partial = false, + .matchLength = 0, + .is_grammar_trigger = false, + }); + assert_equal(antiprompts.findFirstMatch(" ab", 0), { + .pos = 1, + .pattern = "", + .is_partial = true, + .matchLength = 2, + .is_grammar_trigger = false, + }); + assert_equal(antiprompts.findFirstMatch(" abc", 0), { + .pos = 1, + .pattern = "abc", + .is_partial = false, + .matchLength = 3, + .is_grammar_trigger = false, + }); + assert_equal(antiprompts.findFirstMatch(" bc", 0), { + .pos = 1, + .pattern = "", + .is_partial = true, + .matchLength = 2, + .is_grammar_trigger = false, + }); + assert_equal(antiprompts.findFirstMatch(" bcd", 0), { + .pos = 1, + .pattern = "bcd", + .is_partial = false, + .matchLength = 3, + .is_grammar_trigger = false, + }); + assert_equal(antiprompts.findFirstMatch(" bca", 0), { + .pos = 1, + .pattern = "bca", + .is_partial = false, + .matchLength = 3, + .is_grammar_trigger = true, + }); + printf("OK\n"); + // llama_antiprompts::MatchResult{0, "a", .is_partial = false, . 1, false}); + + return 0; +}