| 1 | #include "../src/unicode.h" |
| 2 | #include "../src/llama-grammar.h" |
| 3 | |
| 4 | #include <cstdio> |
| 5 | #include <cstdlib> |
| 6 | #include <sstream> |
| 7 | #include <fstream> |
| 8 | #include <string> |
| 9 | #include <vector> |
| 10 | |
| 11 | static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { |
| 12 | const auto cpts = unicode_cpts_from_utf8(utf8: input_str); |
| 13 | |
| 14 | auto & stacks_cur = llama_grammar_get_stacks(grammar); |
| 15 | |
| 16 | size_t pos = 0; |
| 17 | for (const auto & cpt : cpts) { |
| 18 | llama_grammar_accept(grammar, chr: cpt); |
| 19 | |
| 20 | if (stacks_cur.empty()) { |
| 21 | error_pos = pos; |
| 22 | error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'" ; |
| 23 | return false; |
| 24 | } |
| 25 | ++pos; |
| 26 | } |
| 27 | |
| 28 | for (const auto & stack : stacks_cur) { |
| 29 | if (stack.empty()) { |
| 30 | return true; |
| 31 | } |
| 32 | } |
| 33 | |
| 34 | error_pos = pos; |
| 35 | error_msg = "Unexpected end of input" ; |
| 36 | return false; |
| 37 | } |
| 38 | |
| 39 | static void print_error_message(const std::string & input_str, size_t error_pos, const std::string & error_msg) { |
| 40 | fprintf(stdout, format: "Input string is invalid according to the grammar.\n" ); |
| 41 | fprintf(stdout, format: "Error: %s at position %zu\n" , error_msg.c_str(), error_pos); |
| 42 | fprintf(stdout, format: "\n" ); |
| 43 | fprintf(stdout, format: "Input string:\n" ); |
| 44 | fprintf(stdout, format: "%s" , input_str.substr(pos: 0, n: error_pos).c_str()); |
| 45 | if (error_pos < input_str.size()) { |
| 46 | fprintf(stdout, format: "\033[1;31m%c" , input_str[error_pos]); |
| 47 | if (error_pos+1 < input_str.size()) { |
| 48 | fprintf(stdout, format: "\033[0;31m%s" , input_str.substr(pos: error_pos+1).c_str()); |
| 49 | } |
| 50 | fprintf(stdout, format: "\033[0m\n" ); |
| 51 | } |
| 52 | } |
| 53 | |
| 54 | int main(int argc, char** argv) { |
| 55 | if (argc != 3) { |
| 56 | fprintf(stdout, format: "Usage: %s <grammar_filename> <input_filename>\n" , argv[0]); |
| 57 | return 1; |
| 58 | } |
| 59 | |
| 60 | const std::string grammar_filename = argv[1]; |
| 61 | const std::string input_filename = argv[2]; |
| 62 | |
| 63 | // Read the GBNF grammar file |
| 64 | FILE* grammar_file = fopen(filename: grammar_filename.c_str(), modes: "r" ); |
| 65 | if (!grammar_file) { |
| 66 | fprintf(stdout, format: "Failed to open grammar file: %s\n" , grammar_filename.c_str()); |
| 67 | return 1; |
| 68 | } |
| 69 | |
| 70 | std::string grammar_str; |
| 71 | { |
| 72 | std::ifstream grammar_file(grammar_filename); |
| 73 | GGML_ASSERT(grammar_file.is_open() && "Failed to open grammar file" ); |
| 74 | std::stringstream buffer; |
| 75 | buffer << grammar_file.rdbuf(); |
| 76 | grammar_str = buffer.str(); |
| 77 | } |
| 78 | |
| 79 | llama_grammar * grammar = llama_grammar_init_impl(vocab: nullptr, grammar_str: grammar_str.c_str(), grammar_root: "root" , lazy: false, trigger_patterns: nullptr, num_trigger_patterns: 0, trigger_tokens: nullptr, num_trigger_tokens: 0); |
| 80 | if (grammar == nullptr) { |
| 81 | fprintf(stdout, format: "Failed to initialize llama_grammar\n" ); |
| 82 | return 1; |
| 83 | } |
| 84 | // Read the input file |
| 85 | std::string input_str; |
| 86 | { |
| 87 | std::ifstream input_file(input_filename); |
| 88 | GGML_ASSERT(input_file.is_open() && "Failed to open input file" ); |
| 89 | std::stringstream buffer; |
| 90 | buffer << input_file.rdbuf(); |
| 91 | input_str = buffer.str(); |
| 92 | } |
| 93 | |
| 94 | // Validate the input string against the grammar |
| 95 | size_t error_pos; |
| 96 | std::string error_msg; |
| 97 | bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg); |
| 98 | |
| 99 | if (is_valid) { |
| 100 | fprintf(stdout, format: "Input string is valid according to the grammar.\n" ); |
| 101 | } else { |
| 102 | print_error_message(input_str, error_pos, error_msg); |
| 103 | } |
| 104 | |
| 105 | // Clean up |
| 106 | llama_grammar_free_impl(grammar); |
| 107 | |
| 108 | return 0; |
| 109 | } |
| 110 | |