1#include "chat-parser.h"
2#include "common.h"
3#include "log.h"
4#include "regex-partial.h"
5
6#include <algorithm>
7#include <cctype>
8#include <optional>
9#include <stdexcept>
10#include <string>
11#include <string_view>
12#include <vector>
13
14using json = nlohmann::ordered_json;
15
16common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
17 : input_(input), is_partial_(is_partial), syntax_(syntax)
18{
19 result_.role = "assistant";
20
21 while (true) {
22 std::string id = std::to_string(val: std::rand());
23 if (input.find(str: id) == std::string::npos) {
24 healing_marker_ = id;
25 break;
26 }
27 }
28}
29
30std::string common_chat_msg_parser::str(const common_string_range & rng) const {
31 GGML_ASSERT(rng.begin <= rng.end);
32 return input_.substr(pos: rng.begin, n: rng.end - rng.begin);
33}
34
35void common_chat_msg_parser::add_content(const std::string &content) {
36 result_.content += content;
37}
38
39void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
40 result_.reasoning_content += reasoning_content;
41}
42
43bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
44 if (name.empty()) {
45 return false;
46 }
47
48 common_chat_tool_call tool_call;
49 tool_call.name = name;
50 tool_call.arguments = arguments;
51 tool_call.id = id;
52
53 // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
54 result_.tool_calls.emplace_back(args&: tool_call);
55
56 return true;
57}
58bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
59 std::string name = tool_call.contains(key: "name") ? tool_call.at(key: "name") : "";
60 std::string id = tool_call.contains(key: "id") ? tool_call.at(key: "id") : "";
61 std::string arguments = "";
62 if (tool_call.contains(key: "arguments")) {
63 if (tool_call.at(key: "arguments").is_object()) {
64 arguments = tool_call.at(key: "arguments").dump();
65 } else {
66 arguments = tool_call.at(key: "arguments");
67 }
68 }
69
70 return add_tool_call(name, id, arguments);
71}
72
73bool common_chat_msg_parser::add_tool_calls(const json & arr) {
74 for (const auto & item : arr) {
75 if (!add_tool_call(tool_call: item)) {
76 return false;
77 }
78 }
79 return true;
80}
81
82bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) {
83 if (!tool_call.is_object() || tool_call.size() != 1) {
84 return false;
85 }
86
87 // Get the tool name (the single key in the object)
88 auto it = tool_call.begin();
89 std::string name = it.key();
90
91 if (name.empty()) {
92 return false;
93 }
94
95 // Get the arguments (the nested object)
96 const json & args_json = it.value();
97 std::string arguments = "";
98
99 if (args_json.is_object()) {
100 arguments = args_json.dump();
101 } else if (args_json.is_string()) {
102 arguments = args_json;
103 } else if (!args_json.is_null()) {
104 // For other types, convert to string representation
105 arguments = args_json.dump();
106 }
107
108 return add_tool_call(name, id: "", arguments);
109}
110void common_chat_msg_parser::finish() {
111 if (!is_partial_ && pos_ != input_.size()) {
112 throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
113 }
114}
115
116bool common_chat_msg_parser::consume_spaces() {
117 const auto length = input_.size();
118 auto consumed = false;
119 while (pos_ < length && std::isspace(input_[pos_])) {
120 ++pos_;
121 consumed = true;
122 }
123 return consumed;
124}
125
126bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
127 auto pos = pos_;
128 for (auto i = 0u; i < literal.size(); ++i) {
129 if (pos >= input_.size()) {
130 return false;
131 }
132 if (input_[pos] != literal[i]) {
133 return false;
134 }
135 ++pos;
136 }
137 pos_ = pos;
138 return true;
139}
140
141std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
142 auto idx = input_.find(str: literal, pos: pos_);
143 if (idx != std::string::npos) {
144 find_regex_result res;
145 res.prelude = input_.substr(pos: pos_, n: idx - pos_);
146 auto end = idx + literal.size();
147 res.groups.emplace_back(args: common_string_range{idx, end});
148 move_to(pos: end);
149 return res;
150 }
151 if (is_partial_) {
152 idx = string_find_partial_stop(str: input_, stop: literal);
153 if (idx != std::string::npos && idx >= pos_) {
154 find_regex_result res;
155 res.prelude = input_.substr(pos: pos_, n: idx - pos_);
156 auto end = input_.size();
157 res.groups.emplace_back(args: common_string_range{idx, end});
158 move_to(pos: end);
159 return res;
160 }
161 }
162 return std::nullopt;
163}
164
165void common_chat_msg_parser::consume_literal(const std::string & literal) {
166 if (!try_consume_literal(literal)) {
167 throw common_chat_msg_partial_exception(literal);
168 }
169}
170
171bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
172 std::string pending_reasoning_prefix;
173
174 if (syntax_.reasoning_format == COMMON_REASONING_FORMAT_NONE) {
175 return false;
176 }
177
178 auto set_reasoning_prefix = [&](size_t prefix_pos) {
179 if (!syntax_.thinking_forced_open || syntax_.reasoning_in_content) {
180 return;
181 }
182 if (prefix_pos + start_think.size() > input_.size()) {
183 pending_reasoning_prefix.clear();
184 return;
185 }
186 // Capture the exact literal that opened the reasoning section so we can
187 // surface it back to callers. This ensures formats that force the
188 // reasoning tag open (e.g. DeepSeek R1) retain their original prefix
189 // instead of dropping it during parsing.
190 pending_reasoning_prefix = input_.substr(pos: prefix_pos, n: start_think.size());
191 };
192
193 auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
194 auto stripped_reasoning = string_strip(str: reasoning);
195 if (stripped_reasoning.empty()) {
196 return;
197 }
198 if (syntax_.reasoning_in_content) {
199 add_content(content: syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
200 add_content(content: stripped_reasoning);
201 if (closed) {
202 add_content(content: syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
203 }
204 } else {
205 if (!pending_reasoning_prefix.empty()) {
206 add_reasoning_content(reasoning_content: pending_reasoning_prefix);
207 pending_reasoning_prefix.clear();
208 }
209 add_reasoning_content(reasoning_content: stripped_reasoning);
210 }
211 };
212
213 const size_t saved_pos = pos_;
214 const size_t saved_content_size = result_.content.size();
215 const size_t saved_reasoning_size = result_.reasoning_content.size();
216
217 auto restore_state = [&]() {
218 move_to(pos: saved_pos);
219 result_.content.resize(n: saved_content_size);
220 result_.reasoning_content.resize(n: saved_reasoning_size);
221 };
222
223 // Allow leading whitespace to be preserved as content when reasoning is present at the start
224 size_t cursor = pos_;
225 size_t whitespace_end = cursor;
226 while (whitespace_end < input_.size() && std::isspace(static_cast<unsigned char>(input_[whitespace_end]))) {
227 ++whitespace_end;
228 }
229
230 if (whitespace_end >= input_.size()) {
231 restore_state();
232 if (syntax_.thinking_forced_open) {
233 auto rest = input_.substr(pos: saved_pos);
234 if (!rest.empty()) {
235 handle_reasoning(rest, /* closed */ !is_partial());
236 }
237 move_to(pos: input_.size());
238 return true;
239 }
240 return false;
241 }
242
243 cursor = whitespace_end;
244 const size_t remaining = input_.size() - cursor;
245 const size_t start_prefix = std::min(a: start_think.size(), b: remaining);
246 const bool has_start_tag = input_.compare(pos1: cursor, n1: start_prefix, str: start_think, pos2: 0, n2: start_prefix) == 0;
247
248 if (has_start_tag && start_prefix < start_think.size()) {
249 move_to(pos: input_.size());
250 return true;
251 }
252
253 if (has_start_tag) {
254 if (whitespace_end > pos_) {
255 add_content(content: input_.substr(pos: pos_, n: whitespace_end - pos_));
256 }
257 set_reasoning_prefix(cursor);
258 cursor += start_think.size();
259 } else if (syntax_.thinking_forced_open) {
260 cursor = whitespace_end;
261 } else {
262 restore_state();
263 return false;
264 }
265 while (true) {
266 if (cursor >= input_.size()) {
267 move_to(pos: input_.size());
268 return true;
269 }
270
271 size_t end_pos = input_.find(str: end_think, pos: cursor);
272 if (end_pos == std::string::npos) {
273 std::string_view remaining_view(input_.data() + cursor, input_.size() - cursor);
274 size_t partial_off = string_find_partial_stop(str: remaining_view, stop: end_think);
275 size_t reasoning_end = partial_off == std::string::npos ? input_.size() : cursor + partial_off;
276 if (reasoning_end > cursor) {
277 handle_reasoning(input_.substr(pos: cursor, n: reasoning_end - cursor), /* closed */ partial_off == std::string::npos && !is_partial());
278 }
279 move_to(pos: input_.size());
280 return true;
281 }
282
283 if (end_pos > cursor) {
284 handle_reasoning(input_.substr(pos: cursor, n: end_pos - cursor), /* closed */ true);
285 } else {
286 handle_reasoning("", /* closed */ true);
287 }
288
289 cursor = end_pos + end_think.size();
290
291 while (cursor < input_.size() && std::isspace(static_cast<unsigned char>(input_[cursor]))) {
292 ++cursor;
293 }
294
295 const size_t next_remaining = input_.size() - cursor;
296 if (next_remaining == 0) {
297 move_to(pos: cursor);
298 return true;
299 }
300
301 const size_t next_prefix = std::min(a: start_think.size(), b: next_remaining);
302 if (input_.compare(pos1: cursor, n1: next_prefix, str: start_think, pos2: 0, n2: next_prefix) == 0) {
303 if (next_prefix < start_think.size()) {
304 move_to(pos: input_.size());
305 return true;
306 }
307 set_reasoning_prefix(cursor);
308 cursor += start_think.size();
309 continue;
310 }
311
312 move_to(pos: cursor);
313 return true;
314 }
315}
316
317std::string common_chat_msg_parser::consume_rest() {
318 auto rest = input_.substr(pos: pos_);
319 pos_ = input_.size();
320 return rest;
321}
322
323// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
324std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
325 auto m = regex.search(input: input_, pos: from == std::string::npos ? pos_ : from);
326 if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
327 return std::nullopt;
328 }
329 auto prelude = input_.substr(pos: pos_, n: m.groups[0].begin - pos_);
330 pos_ = m.groups[0].end;
331
332 if (add_prelude_to_content) {
333 add_content(content: prelude);
334 }
335 if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
336 if (is_partial()) {
337 throw common_chat_msg_partial_exception(regex.str());
338 }
339 return std::nullopt;
340 }
341 return find_regex_result{.prelude: prelude, .groups: m.groups};
342}
343
344common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
345 if (auto result = try_consume_regex(regex)) {
346 return *result;
347 }
348 throw common_chat_msg_partial_exception(regex.str());
349}
350
351std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
352 auto m = regex.search(input: input_, pos: pos_);
353 if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
354 return std::nullopt;
355 }
356 if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
357 if (is_partial()) {
358 throw common_chat_msg_partial_exception(regex.str());
359 }
360 return std::nullopt;
361 }
362 if (m.groups[0].begin != pos_) {
363 // Didn't match at the current position.
364 return std::nullopt;
365 }
366 pos_ = m.groups[0].end;
367
368 return find_regex_result {
369 /* .prelude = */ "",
370 .groups: m.groups,
371 };
372}
373
374std::optional<common_json> common_chat_msg_parser::try_consume_json() {
375 auto it = input_.cbegin() + pos_;
376 const auto end = input_.cend();
377 common_json result;
378 if (!common_json_parse(it, end, healing_marker: healing_marker_, out&: result)) {
379 return std::nullopt;
380 }
381 pos_ = std::distance(first: input_.cbegin(), last: it);
382 if (result.healing_marker.marker.empty()) {
383 // No healing marker, just return the parsed json
384 return result;
385 }
386 if (!is_partial()) {
387 throw common_chat_msg_partial_exception("JSON");
388 }
389 return result;
390}
391
392common_json common_chat_msg_parser::consume_json() {
393 if (auto result = try_consume_json()) {
394 return *result;
395 }
396 throw common_chat_msg_partial_exception("JSON");
397}
398
399common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
400 const std::vector<std::vector<std::string>> & args_paths,
401 const std::vector<std::vector<std::string>> & content_paths
402) {
403 if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
404 return *result;
405 }
406 throw common_chat_msg_partial_exception("JSON");
407}
408
409std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
410 const std::vector<std::vector<std::string>> & args_paths,
411 const std::vector<std::vector<std::string>> & content_paths
412) {
413 auto partial = try_consume_json();
414 if (!partial) {
415 return std::nullopt;
416 }
417 auto is_arguments_path = [&](const std::vector<std::string> & path) {
418 return std::find(first: args_paths.begin(), last: args_paths.end(), val: path) != args_paths.end();
419 };
420 auto is_content_path = [&](const std::vector<std::string> & path) {
421 return std::find(first: content_paths.begin(), last: content_paths.end(), val: path) != content_paths.end();
422 };
423
424 if (partial->healing_marker.marker.empty()) {
425 if (args_paths.empty()) {
426 // No arguments to dump, and JSON was parsed fully.
427 return consume_json_result {
428 .value: partial->json,
429 /* .is_partial = */ false,
430 };
431 }
432 if (is_arguments_path({})) {
433 // Entire JSON is the arguments and was parsed fully.
434 return consume_json_result {
435 .value: partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
436 /* .is_partial = */ false,
437 };
438 }
439 }
440
441 LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
442
443 auto found_healing_marker = false;
444 std::vector<std::string> path;
445 std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
446 if (is_arguments_path(path)) {
447 auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
448 if (is_partial() && !partial->healing_marker.marker.empty()) {
449 auto idx = arguments.find(str: partial->healing_marker.json_dump_marker);
450 if (idx != std::string::npos) {
451 arguments.resize(n: idx);
452 found_healing_marker = true;
453 }
454 if (arguments == "\"") {
455 // This happens because of completing `:"$magic` after `"arguments"`
456 arguments = "";
457 }
458 }
459 return arguments;
460 }
461 if (is_content_path(path)) {
462 if (!j.is_string()) {
463 throw std::runtime_error("Content path must be a string");
464 }
465 std::string str = j;
466 auto idx = str.find(str: partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
467 if (idx != std::string::npos) {
468 str.resize(n: idx);
469 found_healing_marker = true;
470 }
471 return str;
472 }
473 if (j.is_object()) {
474 auto obj = json::object();
475 for (const auto & p : j.items()) {
476 const auto & key = p.key();
477 const auto & value = p.value();
478 const std::string key_str = key; // NOLINT
479 auto idx = key_str.find(str: healing_marker_);
480 if (idx != std::string::npos) {
481 found_healing_marker = true;
482 break;
483 }
484 path.push_back(x: key_str);
485 if (value.is_string()) {
486 const std::string value_str = value;
487 if (value_str.find(str: healing_marker_) != std::string::npos) {
488 found_healing_marker = true;
489 if (is_content_path(path)) {
490 if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
491 // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
492 obj[key] = remove_unsupported_healings_and_dump_args(value);
493 }
494 }
495 break;
496 }
497 obj[key] = value;
498 } else {
499 obj[key] = remove_unsupported_healings_and_dump_args(value);
500 }
501 path.pop_back();
502 }
503 return obj;
504 }
505 if (j.is_array()) {
506 auto arr = json::array();
507 for (const auto & value : j) {
508 if (value.is_string()) {
509 std::string str = value;
510 auto idx = str.find(str: healing_marker_);
511 if (idx != std::string::npos) {
512 // Don't heal array values that aren't in the arguments.
513 found_healing_marker = true;
514 break;
515 }
516 }
517 arr.push_back(val: remove_unsupported_healings_and_dump_args(value));
518 }
519 return arr;
520 }
521 return j;
522 };
523
524 auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
525 LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
526 return consume_json_result {
527 .value: cleaned,
528 /* .is_partial = */ found_healing_marker,
529 };
530}
531
532void common_chat_msg_parser::clear_tools() {
533 result_.tool_calls.clear();
534}
535