| 1 | #pragma once |
| 2 | |
| 3 | #include "llama.h" |
| 4 | |
| 5 | #include <map> |
| 6 | #include <regex> |
| 7 | #include <string> |
| 8 | #include <vector> |
| 9 | |
| 10 | struct llama_vocab; |
| 11 | |
| 12 | // grammar element type |
| 13 | enum llama_gretype { |
| 14 | // end of rule definition |
| 15 | LLAMA_GRETYPE_END = 0, |
| 16 | |
| 17 | // start of alternate definition for rule |
| 18 | LLAMA_GRETYPE_ALT = 1, |
| 19 | |
| 20 | // non-terminal element: reference to rule |
| 21 | LLAMA_GRETYPE_RULE_REF = 2, |
| 22 | |
| 23 | // terminal element: character (code point) |
| 24 | LLAMA_GRETYPE_CHAR = 3, |
| 25 | |
| 26 | // inverse char(s) ([^a], [^a-b] [^abc]) |
| 27 | LLAMA_GRETYPE_CHAR_NOT = 4, |
| 28 | |
| 29 | // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to |
| 30 | // be an inclusive range ([a-z]) |
| 31 | LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, |
| 32 | |
| 33 | // modifies a preceding LLAMA_GRETYPE_CHAR or |
| 34 | // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) |
| 35 | LLAMA_GRETYPE_CHAR_ALT = 6, |
| 36 | |
| 37 | // any character (.) |
| 38 | LLAMA_GRETYPE_CHAR_ANY = 7, |
| 39 | }; |
| 40 | |
| 41 | typedef struct llama_grammar_element { |
| 42 | enum llama_gretype type; |
| 43 | uint32_t value; // Unicode code point or rule ID |
| 44 | } llama_grammar_element; |
| 45 | |
| 46 | struct llama_partial_utf8 { |
| 47 | uint32_t value; // bit value so far (unshifted) |
| 48 | int n_remain; // num bytes remaining; -1 indicates invalid sequence |
| 49 | }; |
| 50 | |
| 51 | struct llama_grammar_candidate { |
| 52 | size_t index; |
| 53 | const uint32_t * code_points; |
| 54 | llama_partial_utf8 partial_utf8; |
| 55 | }; |
| 56 | |
| 57 | using llama_grammar_rule = std::vector< llama_grammar_element>; |
| 58 | using llama_grammar_stack = std::vector<const llama_grammar_element *>; |
| 59 | |
| 60 | using llama_grammar_rules = std::vector<llama_grammar_rule>; |
| 61 | using llama_grammar_stacks = std::vector<llama_grammar_stack>; |
| 62 | using llama_grammar_candidates = std::vector<llama_grammar_candidate>; |
| 63 | |
| 64 | // TODO: remove, needed for tests atm |
| 65 | const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); |
| 66 | llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); |
| 67 | |
| 68 | // takes a set of possible pushdown stacks on a grammar, which are required to |
| 69 | // be positioned at a character range (see `llama_grammar_advance_stack`), and |
| 70 | // produces the N possible stacks if the given char is accepted at those |
| 71 | // positions |
| 72 | void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); |
| 73 | |
| 74 | std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( |
| 75 | const llama_grammar_rules & rules, |
| 76 | const llama_grammar_stack & stack, |
| 77 | const llama_grammar_candidates & candidates); |
| 78 | |
| 79 | struct llama_grammar_parser { |
| 80 | std::map<std::string, uint32_t> symbol_ids; |
| 81 | |
| 82 | llama_grammar_rules rules; |
| 83 | |
| 84 | llama_grammar_stack c_rules() const; |
| 85 | |
| 86 | uint32_t get_symbol_id(const char * src, size_t len); |
| 87 | uint32_t generate_symbol_id(const std::string & base_name); |
| 88 | |
| 89 | void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); |
| 90 | |
| 91 | const char * parse_alternates( |
| 92 | const char * src, |
| 93 | const std::string & rule_name, |
| 94 | uint32_t rule_id, |
| 95 | bool is_nested); |
| 96 | |
| 97 | const char * parse_sequence( |
| 98 | const char * src, |
| 99 | const std::string & rule_name, |
| 100 | llama_grammar_rule & rule, |
| 101 | bool is_nested); |
| 102 | |
| 103 | const char * parse_rule(const char * src); |
| 104 | |
| 105 | bool parse(const char * src); |
| 106 | void print(FILE * file); |
| 107 | }; |
| 108 | |
| 109 | struct llama_grammar_trigger_pattern { |
| 110 | std::string pattern; |
| 111 | std::regex regex; |
| 112 | }; |
| 113 | |
| 114 | struct llama_grammar { |
| 115 | // note: allow null vocab for testing (not great) |
| 116 | const llama_vocab * vocab; |
| 117 | |
| 118 | const llama_grammar_rules rules; // TODO: shared ptr |
| 119 | llama_grammar_stacks stacks; |
| 120 | |
| 121 | // buffer for partially generated UTF-8 sequence from accepted tokens |
| 122 | llama_partial_utf8 partial_utf8; |
| 123 | |
| 124 | // lazy grammars wait for trigger words or tokens before constraining the sampling. |
| 125 | // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. |
| 126 | // (useful e.g. for tool_choice=required) |
| 127 | bool lazy = false; |
| 128 | bool awaiting_trigger = false; // Initialized to true for lazy grammars only |
| 129 | std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. |
| 130 | std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). |
| 131 | std::vector<llama_grammar_trigger_pattern> |
| 132 | trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated |
| 133 | // string, and the grammar will be given the string from the first match group onwards. |
| 134 | |
| 135 | }; |
| 136 | |
| 137 | // |
| 138 | // internal API |
| 139 | // |
| 140 | |
| 141 | // note: needed for tests (not great) |
| 142 | struct llama_grammar * llama_grammar_init_impl( |
| 143 | const struct llama_vocab * vocab, |
| 144 | const llama_grammar_element ** rules, |
| 145 | size_t n_rules, |
| 146 | size_t start_rule_index); |
| 147 | |
| 148 | struct llama_grammar * llama_grammar_init_impl( |
| 149 | const struct llama_vocab * vocab, |
| 150 | const char * grammar_str, |
| 151 | const char * grammar_root, |
| 152 | bool lazy, |
| 153 | const char ** trigger_patterns, |
| 154 | size_t num_trigger_patterns, |
| 155 | const llama_token * trigger_tokens, |
| 156 | size_t num_trigger_tokens); |
| 157 | |
| 158 | void llama_grammar_free_impl(struct llama_grammar * grammar); |
| 159 | |
| 160 | struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); |
| 161 | |
| 162 | // TODO: move the API below as member functions of llama_grammar |
| 163 | void llama_grammar_apply_impl( |
| 164 | const struct llama_grammar & grammar, |
| 165 | llama_token_data_array * cur_p); |
| 166 | |
| 167 | void llama_grammar_accept_impl( |
| 168 | struct llama_grammar & grammar, |
| 169 | llama_token token); |
| 170 | |
| 171 | void llama_grammar_accept_str( |
| 172 | struct llama_grammar & grammar, |
| 173 | const std::string & piece); |
| 174 | |