| 1 | #include "json-partial.h" |
| 2 | |
| 3 | #include "log.h" |
| 4 | |
| 5 | #include <nlohmann/json.hpp> |
| 6 | |
| 7 | #include <string> |
| 8 | #include <regex> |
| 9 | |
| 10 | using json = nlohmann::ordered_json; |
| 11 | |
| 12 | enum common_json_stack_element_type { |
| 13 | COMMON_JSON_STACK_ELEMENT_OBJECT, |
| 14 | COMMON_JSON_STACK_ELEMENT_KEY, |
| 15 | COMMON_JSON_STACK_ELEMENT_ARRAY, |
| 16 | }; |
| 17 | |
| 18 | struct common_json_stack_element { |
| 19 | common_json_stack_element_type type; |
| 20 | std::string key; |
| 21 | }; |
| 22 | |
| 23 | bool common_json_parse( |
| 24 | const std::string & input, |
| 25 | const std::string & healing_marker, |
| 26 | common_json & out) |
| 27 | { |
| 28 | std::string::const_iterator it = input.begin(); |
| 29 | const auto end = input.end(); |
| 30 | return common_json_parse(it, end, healing_marker, out); |
| 31 | } |
| 32 | |
| 33 | bool common_json_parse( |
| 34 | std::string::const_iterator & it, |
| 35 | const std::string::const_iterator & end, |
| 36 | const std::string & healing_marker, |
| 37 | common_json & out) |
| 38 | { |
| 39 | // // https://json.nlohmann.me/features/parsing/sax_interface/ |
| 40 | struct json_error_locator : public nlohmann::json_sax<json> { |
| 41 | std::size_t position; |
| 42 | bool found_error; |
| 43 | std::string last_token; |
| 44 | std::string exception_message; |
| 45 | std::vector<common_json_stack_element> stack; |
| 46 | |
| 47 | json_error_locator() : position(0), found_error(false) {} |
| 48 | |
| 49 | bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT |
| 50 | this->position = position - 1; |
| 51 | this->found_error = true; |
| 52 | this->last_token = last_token; |
| 53 | this->exception_message = ex.what(); |
| 54 | return false; |
| 55 | } |
| 56 | void close_value() { |
| 57 | if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) { |
| 58 | stack.pop_back(); |
| 59 | } |
| 60 | } |
| 61 | bool null() override { // NOLINT |
| 62 | close_value(); |
| 63 | return true; |
| 64 | } |
| 65 | bool boolean(bool) override { // NOLINT |
| 66 | close_value(); |
| 67 | return true; |
| 68 | } |
| 69 | bool number_integer(number_integer_t) override { // NOLINT |
| 70 | close_value(); |
| 71 | return true; |
| 72 | } |
| 73 | bool number_unsigned(number_unsigned_t) override { // NOLINT |
| 74 | close_value(); |
| 75 | return true; |
| 76 | } |
| 77 | bool number_float(number_float_t, const string_t &) override { // NOLINT |
| 78 | close_value(); |
| 79 | return true; |
| 80 | } |
| 81 | bool string(string_t &) override { // NOLINT |
| 82 | close_value(); |
| 83 | return true; |
| 84 | } |
| 85 | bool binary(binary_t &) override { // NOLINT |
| 86 | close_value(); |
| 87 | return true; |
| 88 | } |
| 89 | bool start_object(std::size_t) override { // NOLINT |
| 90 | stack.push_back(x: {.type: COMMON_JSON_STACK_ELEMENT_OBJECT, .key: "" }); |
| 91 | return true; |
| 92 | } |
| 93 | bool end_object() override { |
| 94 | GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT); |
| 95 | stack.pop_back(); |
| 96 | close_value(); |
| 97 | return true; |
| 98 | } |
| 99 | bool key(string_t & key) override { // NOLINT |
| 100 | stack.push_back(x: {.type: COMMON_JSON_STACK_ELEMENT_KEY, .key: key}); |
| 101 | return true; |
| 102 | } |
| 103 | bool start_array(std::size_t) override { // NOLINT |
| 104 | stack.push_back(x: {.type: COMMON_JSON_STACK_ELEMENT_ARRAY, .key: "" }); |
| 105 | return true; |
| 106 | } |
| 107 | bool end_array() override { |
| 108 | GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY); |
| 109 | stack.pop_back(); |
| 110 | close_value(); |
| 111 | return true; |
| 112 | } |
| 113 | }; |
| 114 | json_error_locator err_loc; |
| 115 | auto start = it; |
| 116 | json::sax_parse(first: it, last: end, sax: &err_loc); |
| 117 | |
| 118 | if (err_loc.found_error) { |
| 119 | it = start; |
| 120 | auto temptative_end = it + err_loc.position; |
| 121 | // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str()); |
| 122 | |
| 123 | auto input = std::string(it, temptative_end); |
| 124 | try { |
| 125 | out.json = json::parse(i&: input); |
| 126 | // out.json = json::parse(it, temptative_end); |
| 127 | it = temptative_end; |
| 128 | return true; |
| 129 | } catch (const std::exception & ex) { |
| 130 | // No, needs healing. |
| 131 | LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n" , ex.what(), std::string(it, temptative_end).c_str()); |
| 132 | } |
| 133 | auto can_parse = [](const std::string & str) { |
| 134 | try { |
| 135 | auto _ = json::parse(i: str); // NOLINT |
| 136 | return true; |
| 137 | } catch (const std::exception &) { |
| 138 | return false; |
| 139 | } |
| 140 | }; |
| 141 | if (!healing_marker.empty() && !err_loc.stack.empty()) { |
| 142 | std::string str(it, temptative_end); |
| 143 | auto last_non_sp_pos = str.find_last_not_of(s: " \n\r\t" ); |
| 144 | if (last_non_sp_pos == std::string::npos) { |
| 145 | throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location" ); |
| 146 | } |
| 147 | auto last_non_sp_char = str[last_non_sp_pos]; |
| 148 | // Used to detect stops on a number, which may not be complete. |
| 149 | auto was_maybe_number = [&]() { |
| 150 | if (!str.empty() && std::isspace(str.back())) { |
| 151 | return false; |
| 152 | } |
| 153 | return std::isdigit(last_non_sp_char) || |
| 154 | last_non_sp_char == '.' || |
| 155 | last_non_sp_char == 'e' || |
| 156 | last_non_sp_char == 'E' || |
| 157 | last_non_sp_char == '-'; |
| 158 | }; |
| 159 | |
| 160 | std::string closing; |
| 161 | for (size_t i = err_loc.stack.size(); i > 0; i--) { |
| 162 | auto & el = err_loc.stack[i - 1]; |
| 163 | if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) { |
| 164 | closing += "}" ; |
| 165 | } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) { |
| 166 | closing += "]" ; |
| 167 | } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) { |
| 168 | throw std::runtime_error("Unexpected stack element type" ); |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | // Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX |
| 173 | static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)" ); |
| 174 | |
| 175 | auto is_high_surrogate = [&](const std::string & s) { |
| 176 | // Check if a partial of a high surrogate (U+D800-U+DBFF) |
| 177 | return s.length() >= 4 && |
| 178 | s[0] == '\\' && s[1] == 'u' && |
| 179 | std::tolower(c: s[2]) == 'd' && |
| 180 | (s[3] == '8' || s[3] == '9' || std::tolower(c: s[3]) == 'a' || std::tolower(c: s[3]) == 'b'); |
| 181 | }; |
| 182 | |
| 183 | // Initialize the unicode marker to a low surrogate to handle the edge case |
| 184 | // where a high surrogate (U+D800-U+DBFF) is immediately followed by a |
| 185 | // backslash (\) |
| 186 | std::string unicode_marker_padding = "udc00" ; |
| 187 | std::smatch last_unicode_seq; |
| 188 | |
| 189 | if (std::regex_search(s: str, m&: last_unicode_seq, e: partial_unicode_regex)) { |
| 190 | std::smatch second_last_seq; |
| 191 | std::string prelude = str.substr(pos: 0, n: last_unicode_seq.position()); |
| 192 | |
| 193 | // Pad the escape sequence with 0s until it forms a complete sequence of 6 characters |
| 194 | unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0'); |
| 195 | |
| 196 | if (is_high_surrogate(last_unicode_seq.str())) { |
| 197 | // If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF) |
| 198 | unicode_marker_padding += "\\udc00" ; |
| 199 | } else if (std::regex_search(s: prelude, m&: second_last_seq, e: partial_unicode_regex)) { |
| 200 | if (is_high_surrogate(second_last_seq.str())) { |
| 201 | // If this follows a high surrogate, pad it to be a low surrogate |
| 202 | if (last_unicode_seq.length() == 2) { |
| 203 | unicode_marker_padding = "dc00" ; |
| 204 | } else if (last_unicode_seq.length() == 3) { |
| 205 | unicode_marker_padding = "c00" ; |
| 206 | } else { |
| 207 | // The original unicode_marker_padding is already padded with 0s |
| 208 | } |
| 209 | } |
| 210 | } |
| 211 | } |
| 212 | |
| 213 | const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$"; |
| 214 | |
| 215 | if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) { |
| 216 | // We're inside an object value |
| 217 | if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) { |
| 218 | // Was about to create an object value |
| 219 | str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 220 | } else if (can_parse(str + ": 1" + closing)) { |
| 221 | str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing; |
| 222 | } else if (last_non_sp_char == '{' && can_parse(str + closing)) { |
| 223 | // Was about to create an object |
| 224 | str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; |
| 225 | } else if (can_parse(str + "\"" + closing)) { |
| 226 | // Was inside an object value string |
| 227 | str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; |
| 228 | } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { |
| 229 | // Was inside an object value string after an escape |
| 230 | str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; |
| 231 | } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) { |
| 232 | // Was inside an object value string after a partial unicode escape |
| 233 | str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing; |
| 234 | } else { |
| 235 | // find last : |
| 236 | auto last_pos = str.find_last_of(c: ':'); |
| 237 | if (last_pos == std::string::npos) { |
| 238 | throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location" ); |
| 239 | } |
| 240 | // Cutting back to opening : for object value |
| 241 | str = str.substr(pos: 0, n: last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 242 | } |
| 243 | } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) { |
| 244 | if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) { |
| 245 | // Was about to create an array value |
| 246 | str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 247 | } else if (can_parse(str + "\"" + closing)) { |
| 248 | // Was inside an array value string |
| 249 | str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; |
| 250 | } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { |
| 251 | // Was inside an array value string after an escape |
| 252 | str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; |
| 253 | } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) { |
| 254 | // Was inside an array value string after a partial unicode escape |
| 255 | str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing; |
| 256 | } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) { |
| 257 | // Had just finished a value |
| 258 | str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing; |
| 259 | } else { |
| 260 | auto last_pos = str.find_last_of(s: "[," ); |
| 261 | if (last_pos == std::string::npos) { |
| 262 | throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location" ); |
| 263 | } |
| 264 | // Cutting back to last [ or , for array value |
| 265 | str = str.substr(pos: 0, n: last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 266 | } |
| 267 | } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) { |
| 268 | if ((last_non_sp_char == '{' && can_parse(str + closing)) || |
| 269 | (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) { |
| 270 | // Was about to create an object key+value |
| 271 | str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; |
| 272 | } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) { |
| 273 | // Was about to create an object key+value |
| 274 | str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing; |
| 275 | } else if (can_parse(str + "\": 1" + closing)) { |
| 276 | // Was inside an object key string |
| 277 | str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing; |
| 278 | } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) { |
| 279 | // Was inside an object key string after an escape |
| 280 | str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing; |
| 281 | } else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) { |
| 282 | // Was inside an object key string after a partial unicode escape |
| 283 | str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing; |
| 284 | } else { |
| 285 | auto last_pos = str.find_last_of(c: ':'); |
| 286 | if (last_pos == std::string::npos) { |
| 287 | throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location" ); |
| 288 | } |
| 289 | // fprintf(stderr, "Cutting back to last : for object key+value\n"); |
| 290 | str = str.substr(pos: 0, n: last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; |
| 291 | } |
| 292 | } else { |
| 293 | throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location" ); |
| 294 | } |
| 295 | // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str()); |
| 296 | out.json = json::parse(i&: str); |
| 297 | it = temptative_end; |
| 298 | return true; |
| 299 | } |
| 300 | // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) |
| 301 | // fprintf(stderr, "Closing: TODO\n"); |
| 302 | return false; |
| 303 | } |
| 304 | out.json = json::parse(first: it, last: end); |
| 305 | it = end; |
| 306 | return true; |
| 307 | } |
| 308 | |