1#include "json-partial.h"
2
3#include "log.h"
4
5#include <nlohmann/json.hpp>
6
7#include <string>
8#include <regex>
9
10using json = nlohmann::ordered_json;
11
12enum common_json_stack_element_type {
13 COMMON_JSON_STACK_ELEMENT_OBJECT,
14 COMMON_JSON_STACK_ELEMENT_KEY,
15 COMMON_JSON_STACK_ELEMENT_ARRAY,
16};
17
18struct common_json_stack_element {
19 common_json_stack_element_type type;
20 std::string key;
21};
22
23bool common_json_parse(
24 const std::string & input,
25 const std::string & healing_marker,
26 common_json & out)
27{
28 std::string::const_iterator it = input.begin();
29 const auto end = input.end();
30 return common_json_parse(it, end, healing_marker, out);
31}
32
33bool common_json_parse(
34 std::string::const_iterator & it,
35 const std::string::const_iterator & end,
36 const std::string & healing_marker,
37 common_json & out)
38{
39 // // https://json.nlohmann.me/features/parsing/sax_interface/
40 struct json_error_locator : public nlohmann::json_sax<json> {
41 std::size_t position;
42 bool found_error;
43 std::string last_token;
44 std::string exception_message;
45 std::vector<common_json_stack_element> stack;
46
47 json_error_locator() : position(0), found_error(false) {}
48
49 bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
50 this->position = position - 1;
51 this->found_error = true;
52 this->last_token = last_token;
53 this->exception_message = ex.what();
54 return false;
55 }
56 void close_value() {
57 if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
58 stack.pop_back();
59 }
60 }
61 bool null() override { // NOLINT
62 close_value();
63 return true;
64 }
65 bool boolean(bool) override { // NOLINT
66 close_value();
67 return true;
68 }
69 bool number_integer(number_integer_t) override { // NOLINT
70 close_value();
71 return true;
72 }
73 bool number_unsigned(number_unsigned_t) override { // NOLINT
74 close_value();
75 return true;
76 }
77 bool number_float(number_float_t, const string_t &) override { // NOLINT
78 close_value();
79 return true;
80 }
81 bool string(string_t &) override { // NOLINT
82 close_value();
83 return true;
84 }
85 bool binary(binary_t &) override { // NOLINT
86 close_value();
87 return true;
88 }
89 bool start_object(std::size_t) override { // NOLINT
90 stack.push_back(x: {.type: COMMON_JSON_STACK_ELEMENT_OBJECT, .key: ""});
91 return true;
92 }
93 bool end_object() override {
94 GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
95 stack.pop_back();
96 close_value();
97 return true;
98 }
99 bool key(string_t & key) override { // NOLINT
100 stack.push_back(x: {.type: COMMON_JSON_STACK_ELEMENT_KEY, .key: key});
101 return true;
102 }
103 bool start_array(std::size_t) override { // NOLINT
104 stack.push_back(x: {.type: COMMON_JSON_STACK_ELEMENT_ARRAY, .key: ""});
105 return true;
106 }
107 bool end_array() override {
108 GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
109 stack.pop_back();
110 close_value();
111 return true;
112 }
113 };
114 json_error_locator err_loc;
115 auto start = it;
116 json::sax_parse(first: it, last: end, sax: &err_loc);
117
118 if (err_loc.found_error) {
119 it = start;
120 auto temptative_end = it + err_loc.position;
121 // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
122
123 auto input = std::string(it, temptative_end);
124 try {
125 out.json = json::parse(i&: input);
126 // out.json = json::parse(it, temptative_end);
127 it = temptative_end;
128 return true;
129 } catch (const std::exception & ex) {
130 // No, needs healing.
131 LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
132 }
133 auto can_parse = [](const std::string & str) {
134 try {
135 auto _ = json::parse(i: str); // NOLINT
136 return true;
137 } catch (const std::exception &) {
138 return false;
139 }
140 };
141 if (!healing_marker.empty() && !err_loc.stack.empty()) {
142 std::string str(it, temptative_end);
143 auto last_non_sp_pos = str.find_last_not_of(s: " \n\r\t");
144 if (last_non_sp_pos == std::string::npos) {
145 throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
146 }
147 auto last_non_sp_char = str[last_non_sp_pos];
148 // Used to detect stops on a number, which may not be complete.
149 auto was_maybe_number = [&]() {
150 if (!str.empty() && std::isspace(str.back())) {
151 return false;
152 }
153 return std::isdigit(last_non_sp_char) ||
154 last_non_sp_char == '.' ||
155 last_non_sp_char == 'e' ||
156 last_non_sp_char == 'E' ||
157 last_non_sp_char == '-';
158 };
159
160 std::string closing;
161 for (size_t i = err_loc.stack.size(); i > 0; i--) {
162 auto & el = err_loc.stack[i - 1];
163 if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
164 closing += "}";
165 } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
166 closing += "]";
167 } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
168 throw std::runtime_error("Unexpected stack element type");
169 }
170 }
171
172 // Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
173 static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
174
175 auto is_high_surrogate = [&](const std::string & s) {
176 // Check if a partial of a high surrogate (U+D800-U+DBFF)
177 return s.length() >= 4 &&
178 s[0] == '\\' && s[1] == 'u' &&
179 std::tolower(c: s[2]) == 'd' &&
180 (s[3] == '8' || s[3] == '9' || std::tolower(c: s[3]) == 'a' || std::tolower(c: s[3]) == 'b');
181 };
182
183 // Initialize the unicode marker to a low surrogate to handle the edge case
184 // where a high surrogate (U+D800-U+DBFF) is immediately followed by a
185 // backslash (\)
186 std::string unicode_marker_padding = "udc00";
187 std::smatch last_unicode_seq;
188
189 if (std::regex_search(s: str, m&: last_unicode_seq, e: partial_unicode_regex)) {
190 std::smatch second_last_seq;
191 std::string prelude = str.substr(pos: 0, n: last_unicode_seq.position());
192
193 // Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
194 unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
195
196 if (is_high_surrogate(last_unicode_seq.str())) {
197 // If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
198 unicode_marker_padding += "\\udc00";
199 } else if (std::regex_search(s: prelude, m&: second_last_seq, e: partial_unicode_regex)) {
200 if (is_high_surrogate(second_last_seq.str())) {
201 // If this follows a high surrogate, pad it to be a low surrogate
202 if (last_unicode_seq.length() == 2) {
203 unicode_marker_padding = "dc00";
204 } else if (last_unicode_seq.length() == 3) {
205 unicode_marker_padding = "c00";
206 } else {
207 // The original unicode_marker_padding is already padded with 0s
208 }
209 }
210 }
211 }
212
213 const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
214
215 if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
216 // We're inside an object value
217 if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
218 // Was about to create an object value
219 str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
220 } else if (can_parse(str + ": 1" + closing)) {
221 str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
222 } else if (last_non_sp_char == '{' && can_parse(str + closing)) {
223 // Was about to create an object
224 str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
225 } else if (can_parse(str + "\"" + closing)) {
226 // Was inside an object value string
227 str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
228 } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
229 // Was inside an object value string after an escape
230 str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
231 } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
232 // Was inside an object value string after a partial unicode escape
233 str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
234 } else {
235 // find last :
236 auto last_pos = str.find_last_of(c: ':');
237 if (last_pos == std::string::npos) {
238 throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
239 }
240 // Cutting back to opening : for object value
241 str = str.substr(pos: 0, n: last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
242 }
243 } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
244 if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
245 // Was about to create an array value
246 str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
247 } else if (can_parse(str + "\"" + closing)) {
248 // Was inside an array value string
249 str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
250 } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
251 // Was inside an array value string after an escape
252 str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
253 } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
254 // Was inside an array value string after a partial unicode escape
255 str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
256 } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
257 // Had just finished a value
258 str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
259 } else {
260 auto last_pos = str.find_last_of(s: "[,");
261 if (last_pos == std::string::npos) {
262 throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
263 }
264 // Cutting back to last [ or , for array value
265 str = str.substr(pos: 0, n: last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
266 }
267 } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
268 if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
269 (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
270 // Was about to create an object key+value
271 str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
272 } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
273 // Was about to create an object key+value
274 str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
275 } else if (can_parse(str + "\": 1" + closing)) {
276 // Was inside an object key string
277 str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
278 } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
279 // Was inside an object key string after an escape
280 str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
281 } else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
282 // Was inside an object key string after a partial unicode escape
283 str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
284 } else {
285 auto last_pos = str.find_last_of(c: ':');
286 if (last_pos == std::string::npos) {
287 throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
288 }
289 // fprintf(stderr, "Cutting back to last : for object key+value\n");
290 str = str.substr(pos: 0, n: last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
291 }
292 } else {
293 throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
294 }
295 // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
296 out.json = json::parse(i&: str);
297 it = temptative_end;
298 return true;
299 }
300 // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
301 // fprintf(stderr, "Closing: TODO\n");
302 return false;
303 }
304 out.json = json::parse(first: it, last: end);
305 it = end;
306 return true;
307}
308