1#pragma once
2
3#include "llama.h"
4
5#include <map>
6#include <regex>
7#include <string>
8#include <vector>
9
10struct llama_vocab;
11
12// grammar element type
13enum llama_gretype {
14 // end of rule definition
15 LLAMA_GRETYPE_END = 0,
16
17 // start of alternate definition for rule
18 LLAMA_GRETYPE_ALT = 1,
19
20 // non-terminal element: reference to rule
21 LLAMA_GRETYPE_RULE_REF = 2,
22
23 // terminal element: character (code point)
24 LLAMA_GRETYPE_CHAR = 3,
25
26 // inverse char(s) ([^a], [^a-b] [^abc])
27 LLAMA_GRETYPE_CHAR_NOT = 4,
28
29 // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
30 // be an inclusive range ([a-z])
31 LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
32
33 // modifies a preceding LLAMA_GRETYPE_CHAR or
34 // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
35 LLAMA_GRETYPE_CHAR_ALT = 6,
36
37 // any character (.)
38 LLAMA_GRETYPE_CHAR_ANY = 7,
39};
40
41typedef struct llama_grammar_element {
42 enum llama_gretype type;
43 uint32_t value; // Unicode code point or rule ID
44} llama_grammar_element;
45
46struct llama_partial_utf8 {
47 uint32_t value; // bit value so far (unshifted)
48 int n_remain; // num bytes remaining; -1 indicates invalid sequence
49};
50
51struct llama_grammar_candidate {
52 size_t index;
53 const uint32_t * code_points;
54 llama_partial_utf8 partial_utf8;
55};
56
57using llama_grammar_rule = std::vector< llama_grammar_element>;
58using llama_grammar_stack = std::vector<const llama_grammar_element *>;
59
60using llama_grammar_rules = std::vector<llama_grammar_rule>;
61using llama_grammar_stacks = std::vector<llama_grammar_stack>;
62using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
63
64// TODO: remove, needed for tests atm
65const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
66 llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
67
68// takes a set of possible pushdown stacks on a grammar, which are required to
69// be positioned at a character range (see `llama_grammar_advance_stack`), and
70// produces the N possible stacks if the given char is accepted at those
71// positions
72void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
73
74std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
75 const llama_grammar_rules & rules,
76 const llama_grammar_stack & stack,
77 const llama_grammar_candidates & candidates);
78
79struct llama_grammar_parser {
80 std::map<std::string, uint32_t> symbol_ids;
81
82 llama_grammar_rules rules;
83
84 llama_grammar_stack c_rules() const;
85
86 uint32_t get_symbol_id(const char * src, size_t len);
87 uint32_t generate_symbol_id(const std::string & base_name);
88
89 void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
90
91 const char * parse_alternates(
92 const char * src,
93 const std::string & rule_name,
94 uint32_t rule_id,
95 bool is_nested);
96
97 const char * parse_sequence(
98 const char * src,
99 const std::string & rule_name,
100 llama_grammar_rule & rule,
101 bool is_nested);
102
103 const char * parse_rule(const char * src);
104
105 bool parse(const char * src);
106 void print(FILE * file);
107};
108
109struct llama_grammar_trigger_pattern {
110 std::string pattern;
111 std::regex regex;
112};
113
114struct llama_grammar {
115 // note: allow null vocab for testing (not great)
116 const llama_vocab * vocab;
117
118 const llama_grammar_rules rules; // TODO: shared ptr
119 llama_grammar_stacks stacks;
120
121 // buffer for partially generated UTF-8 sequence from accepted tokens
122 llama_partial_utf8 partial_utf8;
123
124 // lazy grammars wait for trigger words or tokens before constraining the sampling.
125 // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
126 // (useful e.g. for tool_choice=required)
127 bool lazy = false;
128 bool awaiting_trigger = false; // Initialized to true for lazy grammars only
129 std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
130 std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
131 std::vector<llama_grammar_trigger_pattern>
132 trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
133 // string, and the grammar will be given the string from the first match group onwards.
134
135};
136
137//
138// internal API
139//
140
141// note: needed for tests (not great)
142struct llama_grammar * llama_grammar_init_impl(
143 const struct llama_vocab * vocab,
144 const llama_grammar_element ** rules,
145 size_t n_rules,
146 size_t start_rule_index);
147
148struct llama_grammar * llama_grammar_init_impl(
149 const struct llama_vocab * vocab,
150 const char * grammar_str,
151 const char * grammar_root,
152 bool lazy,
153 const char ** trigger_patterns,
154 size_t num_trigger_patterns,
155 const llama_token * trigger_tokens,
156 size_t num_trigger_tokens);
157
158void llama_grammar_free_impl(struct llama_grammar * grammar);
159
160struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
161
162// TODO: move the API below as member functions of llama_grammar
163void llama_grammar_apply_impl(
164 const struct llama_grammar & grammar,
165 llama_token_data_array * cur_p);
166
167void llama_grammar_accept_impl(
168 struct llama_grammar & grammar,
169 llama_token token);
170
171void llama_grammar_accept_str(
172 struct llama_grammar & grammar,
173 const std::string & piece);
174