-
Notifications
You must be signed in to change notification settings - Fork 10.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
examples : add GBNF validator program (#5948)
* Revising GBNF validator program to be much simpler. * Changing from streams to using cstdio * Adding final newline character.
- Loading branch information
Showing
5 changed files
with
171 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
set(TARGET gbnf-validator) | ||
add_executable(${TARGET} gbnf-validator.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE common grammar-parser llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#define LLAMA_API_INTERNAL | ||
|
||
#include "grammar-parser.h" | ||
#include "ggml.h" | ||
#include "llama.h" | ||
#include "unicode.h" | ||
|
||
#include <cstdio> | ||
#include <cstdlib> | ||
#include <string> | ||
#include <vector> | ||
|
||
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { | ||
auto decoded = decode_utf8(input_str, {}); | ||
const auto & code_points = decoded.first; | ||
|
||
size_t pos = 0; | ||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { | ||
auto prev_stacks = grammar->stacks; | ||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); | ||
if (grammar->stacks.empty()) { | ||
error_pos = pos; | ||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'"; | ||
grammar->stacks = prev_stacks; | ||
return false; | ||
} | ||
++pos; | ||
} | ||
|
||
for (const auto & stack : grammar->stacks) { | ||
if (stack.empty()) { | ||
return true; | ||
} | ||
} | ||
|
||
error_pos = pos; | ||
error_msg = "Unexpected end of input"; | ||
return false; | ||
} | ||
|
||
static void print_error_message(const std::string & input_str, size_t error_pos, const std::string & error_msg) { | ||
fprintf(stdout, "Input string is invalid according to the grammar.\n"); | ||
fprintf(stdout, "Error: %s at position %zu\n", error_msg.c_str(), error_pos); | ||
fprintf(stdout, "\n"); | ||
fprintf(stdout, "Input string:\n"); | ||
fprintf(stdout, "%s", input_str.substr(0, error_pos).c_str()); | ||
if (error_pos < input_str.size()) { | ||
fprintf(stdout, "\033[1;31m%c", input_str[error_pos]); | ||
if (error_pos+1 < input_str.size()) { | ||
fprintf(stdout, "\033[0;31m%s", input_str.substr(error_pos+1).c_str()); | ||
} | ||
fprintf(stdout, "\033[0m\n"); | ||
} | ||
} | ||
|
||
int main(int argc, char** argv) { | ||
if (argc != 3) { | ||
fprintf(stdout, "Usage: %s <grammar_filename> <input_filename>\n", argv[0]); | ||
return 1; | ||
} | ||
|
||
const std::string grammar_filename = argv[1]; | ||
const std::string input_filename = argv[2]; | ||
|
||
// Read the GBNF grammar file | ||
FILE* grammar_file = fopen(grammar_filename.c_str(), "r"); | ||
if (!grammar_file) { | ||
fprintf(stdout, "Failed to open grammar file: %s\n", grammar_filename.c_str()); | ||
return 1; | ||
} | ||
|
||
fseek(grammar_file, 0, SEEK_END); | ||
size_t grammar_size = ftell(grammar_file); | ||
fseek(grammar_file, 0, SEEK_SET); | ||
|
||
std::string grammar_str(grammar_size, ' '); | ||
fread(&grammar_str[0], 1, grammar_size, grammar_file); | ||
fclose(grammar_file); | ||
|
||
// Parse the GBNF grammar | ||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); | ||
|
||
// will be empty (default) if there are parse errors | ||
if (parsed_grammar.rules.empty()) { | ||
fprintf(stdout, "%s: failed to parse grammar\n", __func__); | ||
return 1; | ||
} | ||
|
||
// Ensure that there is a "root" node. | ||
if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) { | ||
fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__); | ||
return 1; | ||
} | ||
|
||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); | ||
|
||
// Create the LLAMA grammar | ||
auto grammar = llama_grammar_init( | ||
grammar_rules.data(), | ||
grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); | ||
|
||
// Read the input file | ||
FILE* input_file = fopen(input_filename.c_str(), "r"); | ||
if (!input_file) { | ||
fprintf(stdout, "Failed to open input file: %s\n", input_filename.c_str()); | ||
return 1; | ||
} | ||
|
||
fseek(input_file, 0, SEEK_END); | ||
size_t input_size = ftell(input_file); | ||
fseek(input_file, 0, SEEK_SET); | ||
|
||
std::string input_str(input_size, ' '); | ||
fread(&input_str[0], 1, input_size, input_file); | ||
fclose(input_file); | ||
|
||
// Validate the input string against the grammar | ||
size_t error_pos; | ||
std::string error_msg; | ||
bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg); | ||
|
||
if (is_valid) { | ||
fprintf(stdout, "Input string is valid according to the grammar.\n"); | ||
} else { | ||
print_error_message(input_str, error_pos, error_msg); | ||
} | ||
|
||
// Clean up | ||
llama_grammar_free(grammar); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters