1#ifdef NDEBUG
2#undef NDEBUG
3#endif
4
5#include "llama.h"
6
7#include "../src/llama-grammar.h"
8
9#include <cassert>
10#include <stdexcept>
11
12int main()
13{
14 llama_grammar_parser parsed_grammar;
15
16 std::vector<std::pair<std::string, uint32_t>> expected = {
17 {"expr", 2},
18 {"expr_6", 6},
19 {"expr_7", 7},
20 {"ident", 8},
21 {"ident_10", 10},
22 {"num", 9},
23 {"num_11", 11},
24 {"root", 0},
25 {"root_1", 1},
26 {"root_5", 5},
27 {"term", 4},
28 {"ws", 3},
29 {"ws_12", 12},
30 };
31
32 std::vector<std::vector<llama_grammar_element>> expected_rules = {
33 {{.type: LLAMA_GRETYPE_RULE_REF, .value: 5}, {.type: LLAMA_GRETYPE_END, .value: 0}},
34 {
35 {.type: LLAMA_GRETYPE_RULE_REF, .value: 2},
36 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
37 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
38 {.type: LLAMA_GRETYPE_RULE_REF, .value: 4},
39 {.type: LLAMA_GRETYPE_CHAR, .value: 10},
40 {.type: LLAMA_GRETYPE_END, .value: 0},
41 },
42 {{.type: LLAMA_GRETYPE_RULE_REF, .value: 4}, {.type: LLAMA_GRETYPE_RULE_REF, .value: 7}, {.type: LLAMA_GRETYPE_END, .value: 0}},
43 {{.type: LLAMA_GRETYPE_RULE_REF, .value: 12}, {.type: LLAMA_GRETYPE_END, .value: 0}},
44 {
45 {.type: LLAMA_GRETYPE_RULE_REF, .value: 8},
46 {.type: LLAMA_GRETYPE_ALT, .value: 0},
47 {.type: LLAMA_GRETYPE_RULE_REF, .value: 9},
48 {.type: LLAMA_GRETYPE_ALT, .value: 0},
49 {.type: LLAMA_GRETYPE_CHAR, .value: 40},
50 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
51 {.type: LLAMA_GRETYPE_RULE_REF, .value: 2},
52 {.type: LLAMA_GRETYPE_CHAR, .value: 41},
53 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
54 {.type: LLAMA_GRETYPE_END, .value: 0},
55 },
56 {{.type: LLAMA_GRETYPE_RULE_REF, .value: 1}, {.type: LLAMA_GRETYPE_RULE_REF, .value: 5}, {.type: LLAMA_GRETYPE_ALT, .value: 0}, {.type: LLAMA_GRETYPE_RULE_REF, .value: 1}, {.type: LLAMA_GRETYPE_END, .value: 0}},
57 {
58 {.type: LLAMA_GRETYPE_CHAR, .value: 45},
59 {.type: LLAMA_GRETYPE_CHAR_ALT, .value: 43},
60 {.type: LLAMA_GRETYPE_CHAR_ALT, .value: 42},
61 {.type: LLAMA_GRETYPE_CHAR_ALT, .value: 47},
62 {.type: LLAMA_GRETYPE_RULE_REF, .value: 4},
63 {.type: LLAMA_GRETYPE_END, .value: 0},
64 },
65 {{.type: LLAMA_GRETYPE_RULE_REF, .value: 6}, {.type: LLAMA_GRETYPE_RULE_REF, .value: 7}, {.type: LLAMA_GRETYPE_ALT, .value: 0}, {.type: LLAMA_GRETYPE_END, .value: 0}},
66 {
67 {.type: LLAMA_GRETYPE_CHAR, .value: 97},
68 {.type: LLAMA_GRETYPE_CHAR_RNG_UPPER, .value: 122},
69 {.type: LLAMA_GRETYPE_RULE_REF, .value: 10},
70 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
71 {.type: LLAMA_GRETYPE_END, .value: 0},
72 },
73 {{.type: LLAMA_GRETYPE_RULE_REF, .value: 11}, {.type: LLAMA_GRETYPE_RULE_REF, .value: 3}, {.type: LLAMA_GRETYPE_END, .value: 0}},
74 {
75 {.type: LLAMA_GRETYPE_CHAR, .value: 97},
76 {.type: LLAMA_GRETYPE_CHAR_RNG_UPPER, .value: 122},
77 {.type: LLAMA_GRETYPE_CHAR_ALT, .value: 48},
78 {.type: LLAMA_GRETYPE_CHAR_RNG_UPPER, .value: 57},
79 {.type: LLAMA_GRETYPE_CHAR_ALT, .value: 95},
80 {.type: LLAMA_GRETYPE_RULE_REF, .value: 10},
81 {.type: LLAMA_GRETYPE_ALT, .value: 0},
82 {.type: LLAMA_GRETYPE_END, .value: 0},
83 },
84 {
85 {.type: LLAMA_GRETYPE_CHAR, .value: 48},
86 {.type: LLAMA_GRETYPE_CHAR_RNG_UPPER, .value: 57},
87 {.type: LLAMA_GRETYPE_RULE_REF, .value: 11},
88 {.type: LLAMA_GRETYPE_ALT, .value: 0},
89 {.type: LLAMA_GRETYPE_CHAR, .value: 48},
90 {.type: LLAMA_GRETYPE_CHAR_RNG_UPPER, .value: 57},
91 {.type: LLAMA_GRETYPE_END, .value: 0},
92 },
93 {
94 {.type: LLAMA_GRETYPE_CHAR, .value: 32},
95 {.type: LLAMA_GRETYPE_CHAR_ALT, .value: 9},
96 {.type: LLAMA_GRETYPE_CHAR_ALT, .value: 10},
97 {.type: LLAMA_GRETYPE_RULE_REF, .value: 12},
98 {.type: LLAMA_GRETYPE_ALT, .value: 0},
99 {.type: LLAMA_GRETYPE_END, .value: 0},
100 },
101 };
102
103 for (auto pair : expected)
104 {
105 parsed_grammar.symbol_ids[pair.first] = pair.second;
106 }
107
108 for (auto rule : expected_rules)
109 {
110 parsed_grammar.rules.emplace_back();
111 for (auto element : rule)
112 {
113 parsed_grammar.rules.back().push_back(x: element);
114 }
115 }
116
117 std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
118
119 llama_grammar * grammar = llama_grammar_init_impl(vocab: nullptr, rules: grammar_rules.data(), n_rules: grammar_rules.size(), start_rule_index: parsed_grammar.symbol_ids.at(k: "root"));
120 if (grammar == nullptr) {
121 throw std::runtime_error("Failed to initialize llama_grammar");
122 }
123
124 std::vector<std::vector<llama_grammar_element>> expected_stacks = {
125 {
126 {.type: LLAMA_GRETYPE_RULE_REF, .value: 5},
127 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
128 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
129 {.type: LLAMA_GRETYPE_CHAR, .value: 97},
130 },
131 {
132 {.type: LLAMA_GRETYPE_RULE_REF, .value: 5},
133 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
134 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
135 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
136 {.type: LLAMA_GRETYPE_CHAR, .value: 48},
137 },
138 {
139 {.type: LLAMA_GRETYPE_RULE_REF, .value: 5},
140 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
141 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
142 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
143 {.type: LLAMA_GRETYPE_CHAR, .value: 48},
144 },
145 {
146 {.type: LLAMA_GRETYPE_RULE_REF, .value: 5},
147 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
148 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
149 {.type: LLAMA_GRETYPE_CHAR, .value: 40},
150 },
151 {
152 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
153 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
154 {.type: LLAMA_GRETYPE_CHAR, .value: 97},
155 },
156 {
157 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
158 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
159 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
160 {.type: LLAMA_GRETYPE_CHAR, .value: 48},
161 },
162 {
163 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
164 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
165 {.type: LLAMA_GRETYPE_RULE_REF, .value: 3},
166 {.type: LLAMA_GRETYPE_CHAR, .value: 48},
167 },
168 {
169 {.type: LLAMA_GRETYPE_CHAR, .value: 61},
170 {.type: LLAMA_GRETYPE_RULE_REF, .value: 7},
171 {.type: LLAMA_GRETYPE_CHAR, .value: 40},
172 }};
173
174 auto index = 0;
175 for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
176 {
177 // compare stack to expected_stack
178 for (uint32_t i = 0; i < stack.size(); i++)
179 {
180 const llama_grammar_element * element = stack[i];
181 const llama_grammar_element & expected_element = expected_stacks[index][i];
182
183 // pretty print error message before asserting
184 if (expected_element.type != element->type || expected_element.value != element->value)
185 {
186 fprintf(stderr, format: "index: %d\n", index);
187 fprintf(stderr, format: "expected_element: %d, %u\n", expected_element.type, expected_element.value);
188 fprintf(stderr, format: "actual_element: %d, %u\n", element->type, element->value);
189 fprintf(stderr, format: "expected_element != actual_element\n");
190 }
191
192 assert(expected_element.type == element->type && expected_element.value == element->value);
193 }
194 index++;
195 }
196
197 std::vector<llama_grammar_candidate> next_candidates;
198 next_candidates.resize(new_size: 24);
199
200 for (size_t i = 0; i < 24; ++i)
201 {
202 uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
203 cp[0] = 37 + i;
204 cp[1] = 0;
205 next_candidates[i] = {.index: i, .code_points: cp, .partial_utf8: {}};
206 }
207
208 std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
209 {
210 {0, 37},
211 {1, 38},
212 {2, 39},
213 {3, 40},
214 {4, 41},
215 {5, 42},
216 {6, 43},
217 {7, 44},
218 {8, 45},
219 {9, 46},
220 {10, 47},
221 {11, 48},
222 {12, 49},
223 {13, 50},
224 {14, 51},
225 {15, 52},
226 {16, 53},
227 {17, 54},
228 {18, 55},
229 {19, 56},
230 {20, 57},
231 {21, 58},
232 {22, 59},
233 {23, 60},
234 },
235 {
236 {0, 37},
237 {1, 38},
238 {2, 39},
239 {3, 40},
240 {4, 41},
241 {5, 42},
242 {6, 43},
243 {7, 44},
244 {8, 45},
245 {9, 46},
246 {10, 47},
247 {21, 58},
248 {22, 59},
249 {23, 60},
250 },
251 {
252 {0, 37},
253 {1, 38},
254 {2, 39},
255 {3, 40},
256 {4, 41},
257 {5, 42},
258 {6, 43},
259 {7, 44},
260 {8, 45},
261 {9, 46},
262 {10, 47},
263 {21, 58},
264 {22, 59},
265 {23, 60},
266 },
267 {
268 {0, 37},
269 {1, 38},
270 {2, 39},
271 {4, 41},
272 {5, 42},
273 {6, 43},
274 {7, 44},
275 {8, 45},
276 {9, 46},
277 {10, 47},
278 {11, 48},
279 {12, 49},
280 {13, 50},
281 {14, 51},
282 {15, 52},
283 {16, 53},
284 {17, 54},
285 {18, 55},
286 {19, 56},
287 {20, 57},
288 {21, 58},
289 {22, 59},
290 {23, 60},
291 },
292 {
293 {0, 37},
294 {1, 38},
295 {2, 39},
296 {3, 40},
297 {4, 41},
298 {5, 42},
299 {6, 43},
300 {7, 44},
301 {8, 45},
302 {9, 46},
303 {10, 47},
304 {11, 48},
305 {12, 49},
306 {13, 50},
307 {14, 51},
308 {15, 52},
309 {16, 53},
310 {17, 54},
311 {18, 55},
312 {19, 56},
313 {20, 57},
314 {21, 58},
315 {22, 59},
316 {23, 60},
317 },
318 {
319 {0, 37},
320 {1, 38},
321 {2, 39},
322 {3, 40},
323 {4, 41},
324 {5, 42},
325 {6, 43},
326 {7, 44},
327 {8, 45},
328 {9, 46},
329 {10, 47},
330 {21, 58},
331 {22, 59},
332 {23, 60},
333 },
334 {
335 {0, 37},
336 {1, 38},
337 {2, 39},
338 {3, 40},
339 {4, 41},
340 {5, 42},
341 {6, 43},
342 {7, 44},
343 {8, 45},
344 {9, 46},
345 {10, 47},
346 {21, 58},
347 {22, 59},
348 {23, 60},
349 },
350 {
351 {0, 37},
352 {1, 38},
353 {2, 39},
354 {4, 41},
355 {5, 42},
356 {6, 43},
357 {7, 44},
358 {8, 45},
359 {9, 46},
360 {10, 47},
361 {11, 48},
362 {12, 49},
363 {13, 50},
364 {14, 51},
365 {15, 52},
366 {16, 53},
367 {17, 54},
368 {18, 55},
369 {19, 56},
370 {20, 57},
371 {21, 58},
372 {22, 59},
373 {23, 60},
374 },
375 };
376
377 std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(rules: llama_grammar_get_rules(grammar), stack: llama_grammar_get_stacks(grammar)[0], candidates: next_candidates);
378
379 std::vector<std::vector<llama_grammar_candidate>> all_rejects;
380
381 for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
382 {
383 rejects = llama_grammar_reject_candidates_for_stack(rules: llama_grammar_get_rules(grammar), stack: llama_grammar_get_stacks(grammar)[count], candidates: next_candidates);
384 all_rejects.push_back(x: rejects);
385 }
386
387 index = 0;
388 for (auto rej : all_rejects)
389 {
390 for (uint32_t i = 0; i < rej.size(); i++)
391 {
392 auto element = rej[i];
393 auto expected_element = expected_reject[index][i];
394 assert(element.index == expected_element.first && *element.code_points == expected_element.second);
395 }
396 index++;
397 }
398
399 for (auto &candidate : next_candidates)
400 {
401 delete[] candidate.code_points;
402 candidate.code_points = nullptr;
403 }
404
405 llama_grammar_free_impl(grammar);
406
407 return 0;
408}
409