1// Tests common_regex (esp. its partial final matches support).
2
3#include "common.h"
4#include "regex-partial.h"
5
6#include <sstream>
7#include <iostream>
8#include <optional>
9
10template <class T> static void assert_equals(const T & expected, const T & actual) {
11 if (expected != actual) {
12 std::cerr << "Expected: " << expected << std::endl;
13 std::cerr << " Actual: " << actual << std::endl;
14 std::cerr << std::flush;
15 throw std::runtime_error("Test failed");
16 }
17}
18
19struct test_case {
20 std::string pattern;
21 struct input_output {
22 std::string input;
23 common_regex_match output;
24 };
25 std::vector<input_output> inputs_outputs;
26};
27
28static std::string common_regex_match_type_name(common_regex_match_type type) {
29 switch (type) {
30 case COMMON_REGEX_MATCH_TYPE_NONE:
31 return "COMMON_REGEX_MATCH_TYPE_NONE";
32 case COMMON_REGEX_MATCH_TYPE_PARTIAL:
33 return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
34 case COMMON_REGEX_MATCH_TYPE_FULL:
35 return "COMMON_REGEX_MATCH_TYPE_FULL";
36 }
37 return "?";
38}
39
40static void test_regex() {
41 printf(format: "[%s]\n", __func__);
42 auto test = [](const test_case & test_case) {
43 common_regex cr(test_case.pattern);
44 std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
45 // std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
46 for (const auto & input_output : test_case.inputs_outputs) {
47 std::cout << " Input: " << input_output.input << '\n';
48 auto m = cr.search(input: input_output.input, pos: 0);
49 if (m != input_output.output) {
50 auto match_to_str = [&](const std::optional<common_regex_match> & m) {
51 std::ostringstream ss;
52 if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
53 ss << "<no match>";
54 } else {
55 GGML_ASSERT(!input_output.output.groups.empty());
56 std::vector<std::string> parts;
57 for (const auto & g : m->groups) {
58 parts.push_back(x: "{" + std::to_string(val: g.begin) + ", " + std::to_string(val: g.end) + "}");
59 }
60 ss << "{" << common_regex_match_type_name(type: m->type) << ", {" << string_join(values: parts, separator: ", ") << "}}";
61 }
62 return ss.str();
63 };
64 std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
65 std::cout << " Got: " << match_to_str(m) << '\n';
66 std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(pattern: test_case.pattern) << "/\n";
67
68 throw std::runtime_error("Test failed");
69 }
70 }
71 };
72 test({
73 .pattern: "a",
74 .inputs_outputs: {
75 {.input: "a", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 1}}}},
76 {.input: "b", .output: {.type: COMMON_REGEX_MATCH_TYPE_NONE, .groups: {}}},
77 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 1}}}},
78 {.input: "ba", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{1, 2}}}},
79 }
80 });
81 test({
82 .pattern: "abcd",
83 .inputs_outputs: {
84 {.input: "abcd", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 4}}}},
85 {.input: "abcde", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 4}}}},
86 {.input: "abc", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 3}}}},
87 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 2}}}},
88 {.input: "a", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 1}}}},
89 {.input: "d", .output: {}},
90 {.input: "bcd", .output: {}},
91 {.input: "cde", .output: {}},
92 {.input: "cd", .output: {}},
93 {.input: "yeah ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{5, 7}}}},
94 {.input: "abbie", .output: {}},
95 {.input: "", .output: {}},
96 }
97 });
98 test({
99 .pattern: ".*?ab",
100 .inputs_outputs: {
101 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 2}}}},
102 {.input: "abc", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 2}}}},
103 {.input: "dab", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 3}}}},
104 {.input: "dabc", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 3}}}},
105 {.input: "da", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 2}}}},
106 {.input: "d", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 1}}}},
107 }
108 });
109 test({
110 .pattern: "a.*?b",
111 .inputs_outputs: {
112 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 2}}}},
113 {.input: "abc", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 2}}}},
114 {.input: "a b", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 3}}}},
115 {.input: "a", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 1}}}},
116 {.input: "argh", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 4}}}},
117 {.input: "d", .output: {}},
118 {.input: "b", .output: {}},
119 }
120 });
121 test({
122 .pattern: "ab(?:cd){2,4}ef",
123 .inputs_outputs: {
124 // {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
125 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 2}}}},
126 {.input: "abcd", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 4}}}},
127 {.input: "abcde", .output: {}},
128 {.input: "abcdef", .output: {}},
129 {.input: "abcdcd", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 6}}}},
130 {.input: "abcdcde", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 7}}}},
131 {.input: "abcdcdef", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 8}}}},
132 {.input: "abcdcdcdcdef", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 12}}}},
133 {.input: "abcdcdcdcdcdef", .output: {}},
134 {.input: "abcde", .output: {}},
135 {.input: "yea", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{2, 3}}}},
136 }
137 });
138 test({
139 .pattern: "a(?:rte| pure )fact",
140 .inputs_outputs: {
141 {.input: "a", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 1}}}},
142 {.input: "art", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 3}}}},
143 {.input: "artefa", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 6}}}},
144 {.input: "fact", .output: {}},
145 {.input: "an arte", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{3, 7}}}},
146 {.input: "artefact", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 8}}}},
147 {.input: "an artefact", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{3, 11}}}},
148 {.input: "a pure", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 6}}}},
149 {.input: "a pure fact", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 11}}}},
150 {.input: "it's a pure fact", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{5, 16}}}},
151 {.input: "" , .output: {}},
152 {.input: "pure", .output: {}},
153 {.input: "pure fact", .output: {}},
154 }
155 });
156 test({
157 .pattern: "abc",
158 .inputs_outputs: {
159 {.input: " abcc", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{1, 4}}}},
160 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 2}}}},
161 {.input: "abc", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 3}}}},
162 {.input: " ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{1, 3}}}},
163 {.input: "a", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 1}}}},
164 {.input: "b", .output: {}},
165 {.input: "c", .output: {}},
166 {.input: "", .output: {}},
167 }
168 });
169
170 test({
171 .pattern: "(?:abc)?\\s*def",
172 .inputs_outputs: {
173 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 2}}}},
174 {.input: "abc", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 3}}}},
175 {.input: "abc ", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 4}}}},
176 {.input: "abc d", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 5}}}},
177 {.input: "abc de", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 6}}}},
178 {.input: "abc def", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 7}}}},
179 {.input: "abc defg", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 7}}}},
180 {.input: "abc defgh", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 7}}}},
181 {.input: "abcde", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 5}}}},
182 {.input: "abcdefgh", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 6}}}},
183 {.input: " d", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 2}}}},
184 {.input: "def", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 3}}}},
185 }
186 });
187
188 test({
189 .pattern: "a+b",
190 .inputs_outputs: {
191 {.input: "aaab", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 4}}}},
192 {.input: "aaa", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 3}}}},
193 {.input: "ab", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 2}}}},
194 }
195 });
196
197 test({
198 .pattern: "(?:"
199 "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
200 "(" // match 2 (open_tag)
201 "<tool_call>"
202 "|<function_call>"
203 "|<tool>"
204 "|<tools>"
205 "|<response>"
206 "|<json>"
207 "|<xml>"
208 "|<JSON>"
209 ")?"
210 "(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
211 ")"
212 "|<function=([^>]+)>" // match 4 (function name)
213 "|<function name=\"([^\"]+)\">", // match 5 (function name again)
214 .inputs_outputs: {
215 {.input: "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
216 {.input: "<tool_call> {\"name", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 18}}}},
217 {.input: "<tool_call>{\"name", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 17}}}},
218 {.input: "Let's call something\n<tool_call>{\"name", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{21, 38}}}},
219 {.input: "Ok then<tool_call>{\"name", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{7, 24}}}},
220 {.input: "{\"name", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{0, 6}}}},
221 {.input: "Ok then{\"name", .output: {.type: COMMON_REGEX_MATCH_TYPE_PARTIAL, .groups: {{7, 13}}}},
222 {.input: "<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
223 {.input: "<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
224 {.input: "<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
225 {.input: "<function=all>", .output: {.type: COMMON_REGEX_MATCH_TYPE_FULL, .groups: {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
226
227 }
228 });
229}
230
231static void test_regex_to_reversed_partial_regex() {
232 printf(format: "[%s]\n", __func__);
233
234 assert_equals<std::string>(
235 expected: "((?:(?:c)?b)?a)[\\s\\S]*",
236 actual: regex_to_reversed_partial_regex(pattern: "abc"));
237
238 assert_equals<std::string>(
239 expected: "(a+)[\\s\\S]*",
240 actual: regex_to_reversed_partial_regex(pattern: "a+"));
241
242 assert_equals<std::string>(
243 expected: "(a*)[\\s\\S]*",
244 actual: regex_to_reversed_partial_regex(pattern: "a*"));
245
246 assert_equals<std::string>(
247 expected: "(a?)[\\s\\S]*",
248 actual: regex_to_reversed_partial_regex(pattern: "a?"));
249
250 assert_equals<std::string>(
251 expected: "([a-z])[\\s\\S]*",
252 actual: regex_to_reversed_partial_regex(pattern: "[a-z]"));
253
254 assert_equals<std::string>(
255 expected: "((?:\\w+)?[a-z])[\\s\\S]*",
256 actual: regex_to_reversed_partial_regex(pattern: "[a-z]\\w+"));
257
258 assert_equals<std::string>(
259 expected: "((?:a|b))[\\s\\S]*",
260 actual: regex_to_reversed_partial_regex(pattern: "(?:a|b)"));
261 assert_equals<std::string>(
262 expected: "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
263 actual: regex_to_reversed_partial_regex(pattern: "abcd"));
264 assert_equals<std::string>(
265 expected: "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
266 actual: regex_to_reversed_partial_regex(pattern: "a*b"));
267 assert_equals<std::string>(
268 expected: "((?:(?:b)?a)?.*)[\\s\\S]*",
269 actual: regex_to_reversed_partial_regex(pattern: ".*?ab"));
270 assert_equals<std::string>(
271 expected: "((?:(?:b)?.*)?a)[\\s\\S]*",
272 actual: regex_to_reversed_partial_regex(pattern: "a.*?b"));
273 assert_equals<std::string>(
274 expected: "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
275 actual: regex_to_reversed_partial_regex(pattern: "a(bc)d"));
276 assert_equals<std::string>(
277 expected: "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
278 actual: regex_to_reversed_partial_regex(pattern: "a(bc|de)"));
279 assert_equals<std::string>(
280 expected: "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
281 actual: regex_to_reversed_partial_regex(pattern: "ab{2,4}c"));
282}
283
284int main() {
285 test_regex_to_reversed_partial_regex();
286 test_regex();
287 std::cout << "All tests passed.\n";
288}
289