| 1 | #pragma once |
| 2 | |
| 3 | #include "common.h" |
| 4 | #include "log.h" |
| 5 | #include "llama.h" |
| 6 | #include "arg.h" // common_remote_get_content |
| 7 | #include "base64.hpp" |
| 8 | #include "mtmd.h" |
| 9 | #include "mtmd-helper.h" |
| 10 | #include "chat.h" |
| 11 | |
| 12 | // increase max payload length to allow use of larger context size |
| 13 | #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 |
| 14 | // increase backlog size to avoid connection resets for >> 1 slots |
| 15 | #define CPPHTTPLIB_LISTEN_BACKLOG 512 |
| 16 | // increase max URI length to handle longer prompts in query string |
| 17 | #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768 |
| 18 | // disable Nagle's algorithm |
| 19 | #define CPPHTTPLIB_TCP_NODELAY true |
| 20 | #include <cpp-httplib/httplib.h> |
| 21 | |
| 22 | #define JSON_ASSERT GGML_ASSERT |
| 23 | #include <nlohmann/json.hpp> |
| 24 | |
| 25 | #include <random> |
| 26 | #include <sstream> |
| 27 | #include <string> |
| 28 | #include <vector> |
| 29 | #include <memory> |
| 30 | #include <cinttypes> |
| 31 | |
| 32 | #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" |
| 33 | |
| 34 | using json = nlohmann::ordered_json; |
| 35 | |
| 36 | #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) |
| 37 | #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) |
| 38 | #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) |
| 39 | #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) |
| 40 | |
| 41 | #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 42 | #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 43 | #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 44 | #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 45 | |
| 46 | #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 47 | #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 48 | #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 49 | #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) |
| 50 | |
| 51 | using raw_buffer = std::vector<uint8_t>; |
| 52 | |
| 53 | template <typename T> |
| 54 | static T json_value(const json & body, const std::string & key, const T & default_value) { |
| 55 | // Fallback null to default value |
| 56 | if (body.contains(key) && !body.at(key).is_null()) { |
| 57 | try { |
| 58 | return body.at(key); |
| 59 | } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { |
| 60 | LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n" , key.c_str(), json(default_value).type_name(), err.what()); |
| 61 | return default_value; |
| 62 | } |
| 63 | } else { |
| 64 | return default_value; |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | const static std::string build_info("b" + std::to_string(val: LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); |
| 69 | |
| 70 | // thin wrapper around common_grammar_trigger with (de)serialization functions |
| 71 | struct server_grammar_trigger { |
| 72 | common_grammar_trigger value; |
| 73 | |
| 74 | server_grammar_trigger() = default; |
| 75 | server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} |
| 76 | server_grammar_trigger(const json & in) { |
| 77 | value.type = (common_grammar_trigger_type) in.at(key: "type" ).get<int>(); |
| 78 | value.value = in.at(key: "value" ).get<std::string>(); |
| 79 | if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { |
| 80 | value.token = (llama_token) in.at(key: "token" ).get<int>(); |
| 81 | } |
| 82 | } |
| 83 | |
| 84 | json to_json() const { |
| 85 | json out { |
| 86 | {"type" , (int) value.type}, |
| 87 | {"value" , value.value}, |
| 88 | }; |
| 89 | if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { |
| 90 | out["token" ] = (int) value.token; |
| 91 | } |
| 92 | return out; |
| 93 | } |
| 94 | }; |
| 95 | |
| 96 | // |
| 97 | // tokenizer and input processing utils |
| 98 | // |
| 99 | |
| 100 | static bool json_is_array_of_numbers(const json & data) { |
| 101 | if (data.is_array()) { |
| 102 | for (const auto & e : data) { |
| 103 | if (!e.is_number_integer()) { |
| 104 | return false; |
| 105 | } |
| 106 | } |
| 107 | return true; |
| 108 | } |
| 109 | return false; |
| 110 | } |
| 111 | |
| 112 | // is array having BOTH numbers & strings? |
| 113 | static bool json_is_array_of_mixed_numbers_strings(const json & data) { |
| 114 | bool seen_string = false; |
| 115 | bool seen_number = false; |
| 116 | if (data.is_array()) { |
| 117 | for (const auto & e : data) { |
| 118 | seen_string |= e.is_string(); |
| 119 | seen_number |= e.is_number_integer(); |
| 120 | if (seen_number && seen_string) { |
| 121 | return true; |
| 122 | } |
| 123 | } |
| 124 | } |
| 125 | return false; |
| 126 | } |
| 127 | |
| 128 | // does array have any individual integers/tokens? |
| 129 | static bool json_is_array_and_contains_numbers(const json & data) { |
| 130 | if (data.is_array()) { |
| 131 | for (const auto & e : data) { |
| 132 | if (e.is_number_integer()) { |
| 133 | return true; |
| 134 | } |
| 135 | } |
| 136 | return false; |
| 137 | } |
| 138 | return false; |
| 139 | } |
| 140 | |
| 141 | // get value by path(key1 / key2) |
| 142 | static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) { |
| 143 | json result = json::object(); |
| 144 | |
| 145 | for (const std::string & path : paths) { |
| 146 | json current = js; |
| 147 | const auto keys = string_split<std::string>(input: path, /*separator*/ separator: '/'); |
| 148 | bool valid_path = true; |
| 149 | for (const std::string & k : keys) { |
| 150 | if (valid_path && current.is_object() && current.contains(key: k)) { |
| 151 | current = current[k]; |
| 152 | } else { |
| 153 | valid_path = false; |
| 154 | } |
| 155 | } |
| 156 | if (valid_path) { |
| 157 | result[path] = current; |
| 158 | } |
| 159 | } |
| 160 | return result; |
| 161 | } |
| 162 | |
| 163 | /** |
| 164 | * this handles 2 cases: |
| 165 | * - only string, example: "string" |
| 166 | * - mixed string and tokens, example: [12, 34, "string", 56, 78] |
| 167 | */ |
| 168 | static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { |
| 169 | // If `add_bos` is true, we only add BOS, when json_prompt is a string, |
| 170 | // or the first element of the json_prompt array is a string. |
| 171 | llama_tokens prompt_tokens; |
| 172 | |
| 173 | if (json_prompt.is_array()) { |
| 174 | bool first = true; |
| 175 | for (const auto & p : json_prompt) { |
| 176 | if (p.is_string()) { |
| 177 | auto s = p.template get<std::string>(); |
| 178 | |
| 179 | llama_tokens p; |
| 180 | if (first) { |
| 181 | p = common_tokenize(vocab, text: s, add_special, parse_special); |
| 182 | first = false; |
| 183 | } else { |
| 184 | p = common_tokenize(vocab, text: s, add_special: false, parse_special); |
| 185 | } |
| 186 | |
| 187 | prompt_tokens.insert(position: prompt_tokens.end(), first: p.begin(), last: p.end()); |
| 188 | } else { |
| 189 | if (first) { |
| 190 | first = false; |
| 191 | } |
| 192 | |
| 193 | prompt_tokens.push_back(x: p.template get<llama_token>()); |
| 194 | } |
| 195 | } |
| 196 | } else { |
| 197 | auto s = json_prompt.template get<std::string>(); |
| 198 | prompt_tokens = common_tokenize(vocab, text: s, add_special, parse_special); |
| 199 | } |
| 200 | |
| 201 | return prompt_tokens; |
| 202 | } |
| 203 | |
| 204 | // return the last index of character that can form a valid string |
| 205 | // if the last character is potentially cut in half, return the index before the cut |
| 206 | // if validate_utf8(text) == text.size(), then the whole text is valid utf8 |
| 207 | static size_t validate_utf8(const std::string& text) { |
| 208 | size_t len = text.size(); |
| 209 | if (len == 0) return 0; |
| 210 | |
| 211 | // Check the last few bytes to see if a multi-byte character is cut off |
| 212 | for (size_t i = 1; i <= 4 && i <= len; ++i) { |
| 213 | unsigned char c = text[len - i]; |
| 214 | // Check for start of a multi-byte sequence from the end |
| 215 | if ((c & 0xE0) == 0xC0) { |
| 216 | // 2-byte character start: 110xxxxx |
| 217 | // Needs at least 2 bytes |
| 218 | if (i < 2) return len - i; |
| 219 | } else if ((c & 0xF0) == 0xE0) { |
| 220 | // 3-byte character start: 1110xxxx |
| 221 | // Needs at least 3 bytes |
| 222 | if (i < 3) return len - i; |
| 223 | } else if ((c & 0xF8) == 0xF0) { |
| 224 | // 4-byte character start: 11110xxx |
| 225 | // Needs at least 4 bytes |
| 226 | if (i < 4) return len - i; |
| 227 | } |
| 228 | } |
| 229 | |
| 230 | // If no cut-off multi-byte character is found, return full length |
| 231 | return len; |
| 232 | } |
| 233 | |
| 234 | // |
| 235 | // template utils |
| 236 | // |
| 237 | |
| 238 | // format infill task |
| 239 | static llama_tokens format_infill( |
| 240 | const llama_vocab * vocab, |
| 241 | const json & input_prefix, |
| 242 | const json & input_suffix, |
| 243 | const json & , |
| 244 | const int n_batch, |
| 245 | const int n_predict, |
| 246 | const int n_ctx, |
| 247 | const bool spm_infill, |
| 248 | const llama_tokens & tokens_prompt |
| 249 | ) { |
| 250 | // TODO: optimize this block by reducing memory allocations and movement |
| 251 | |
| 252 | // use FIM repo-level pattern: |
| 253 | // ref: https://arxiv.org/pdf/2409.12186 |
| 254 | // |
| 255 | // [FIM_REP]myproject |
| 256 | // [FIM_SEP]filename0 |
| 257 | // extra chunk 0 |
| 258 | // [FIM_SEP]filename1 |
| 259 | // extra chunk 1 |
| 260 | // ... |
| 261 | // [FIM_SEP]filename |
| 262 | // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt |
| 263 | // |
| 264 | llama_tokens ; |
| 265 | extra_tokens.reserve(n: n_ctx); |
| 266 | |
| 267 | auto tokens_prefix = tokenize_mixed(vocab, json_prompt: input_prefix, add_special: false, parse_special: false); |
| 268 | auto tokens_suffix = tokenize_mixed(vocab, json_prompt: input_suffix, add_special: false, parse_special: false); |
| 269 | |
| 270 | if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { |
| 271 | // TODO: make project name an input |
| 272 | static const auto k_fim_repo = common_tokenize(vocab, text: "myproject\n" , add_special: false, parse_special: false); |
| 273 | |
| 274 | extra_tokens.push_back(x: llama_vocab_fim_rep(vocab)); |
| 275 | extra_tokens.insert(position: extra_tokens.end(), first: k_fim_repo.begin(), last: k_fim_repo.end()); |
| 276 | } |
| 277 | for (const auto & chunk : input_extra) { |
| 278 | // { "text": string, "filename": string } |
| 279 | const std::string text = json_value(body: chunk, key: "text" , default_value: std::string()); |
| 280 | const std::string filename = json_value(body: chunk, key: "filename" , default_value: std::string("tmp" )); |
| 281 | |
| 282 | if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { |
| 283 | const auto k_fim_file = common_tokenize(vocab, text: filename + "\n" , add_special: false, parse_special: false); |
| 284 | |
| 285 | extra_tokens.insert(position: extra_tokens.end(), x: llama_vocab_fim_sep(vocab)); |
| 286 | extra_tokens.insert(position: extra_tokens.end(), first: k_fim_file.begin(), last: k_fim_file.end()); |
| 287 | } else { |
| 288 | // chunk separator in binary form to avoid confusing the AI |
| 289 | static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; |
| 290 | static const auto k_chunk_prefix_tokens = common_tokenize(vocab, text: k_chunk_prefix_str, add_special: false, parse_special: false); |
| 291 | |
| 292 | extra_tokens.insert(position: extra_tokens.end(), first: k_chunk_prefix_tokens.begin(), last: k_chunk_prefix_tokens.end()); |
| 293 | } |
| 294 | |
| 295 | const auto chunk_tokens = common_tokenize(vocab, text, add_special: false, parse_special: false); |
| 296 | extra_tokens.insert(position: extra_tokens.end(), first: chunk_tokens.begin(), last: chunk_tokens.end()); |
| 297 | } |
| 298 | |
| 299 | if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { |
| 300 | // TODO: current filename |
| 301 | static const auto k_fim_file = common_tokenize(vocab, text: "filename\n" , add_special: false, parse_special: false); |
| 302 | |
| 303 | extra_tokens.insert(position: extra_tokens.end(), x: llama_vocab_fim_sep(vocab)); |
| 304 | extra_tokens.insert(position: extra_tokens.end(), first: k_fim_file.begin(), last: k_fim_file.end()); |
| 305 | } |
| 306 | |
| 307 | // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) |
| 308 | const int n_prefix_take = std::min<int>(a: tokens_prefix.size(), b: 3*(n_batch/4)); |
| 309 | const int n_suffix_take = std::min<int>(a: tokens_suffix.size(), b: std::max<int>(a: 0, b: (n_batch/4) - (2 + tokens_prompt.size()))); |
| 310 | |
| 311 | SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n" , n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); |
| 312 | |
| 313 | // fill the rest of the context with extra chunks |
| 314 | const int = std::min<int>(a: std::max<int>(a: 0, b: n_ctx - (n_batch) - 2*n_predict), b: extra_tokens.size()); |
| 315 | |
| 316 | tokens_prefix.erase(first: tokens_prefix.begin(), last: tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); |
| 317 | tokens_suffix.resize(new_size: n_suffix_take); |
| 318 | |
| 319 | tokens_prefix.insert(position: tokens_prefix.begin(), x: llama_vocab_fim_pre(vocab)); |
| 320 | tokens_prefix.insert(position: tokens_prefix.end(), first: tokens_prompt.begin(), last: tokens_prompt.end()); |
| 321 | tokens_suffix.insert(position: tokens_suffix.begin(), x: llama_vocab_fim_suf(vocab)); |
| 322 | |
| 323 | auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; |
| 324 | auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; |
| 325 | |
| 326 | if (llama_vocab_get_add_bos(vocab)) { |
| 327 | embd_inp.insert(position: embd_inp.begin(), x: llama_vocab_bos(vocab)); |
| 328 | } |
| 329 | |
| 330 | SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n" , n_ctx, n_extra_take, (int) extra_tokens.size()); |
| 331 | |
| 332 | // put the extra context before the FIM prefix |
| 333 | embd_inp.insert(position: embd_inp.begin(), first: extra_tokens.end() - n_extra_take, last: extra_tokens.end()); |
| 334 | |
| 335 | embd_inp.insert(position: embd_inp.end(), first: embd_end.begin(), last: embd_end.end()); |
| 336 | embd_inp.push_back(x: llama_vocab_fim_mid(vocab)); |
| 337 | |
| 338 | return embd_inp; |
| 339 | } |
| 340 | |
| 341 | // |
| 342 | // base64 utils (TODO: move to common in the future) |
| 343 | // |
| 344 | |
| 345 | static const std::string base64_chars = |
| 346 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
| 347 | "abcdefghijklmnopqrstuvwxyz" |
| 348 | "0123456789+/" ; |
| 349 | |
| 350 | static inline bool is_base64(uint8_t c) { |
| 351 | return (isalnum(c) || (c == '+') || (c == '/')); |
| 352 | } |
| 353 | |
| 354 | static inline raw_buffer base64_decode(const std::string & encoded_string) { |
| 355 | int i = 0; |
| 356 | int j = 0; |
| 357 | int in_ = 0; |
| 358 | |
| 359 | int in_len = encoded_string.size(); |
| 360 | |
| 361 | uint8_t char_array_4[4]; |
| 362 | uint8_t char_array_3[3]; |
| 363 | |
| 364 | raw_buffer ret; |
| 365 | |
| 366 | while (in_len-- && (encoded_string[in_] != '=') && is_base64(c: encoded_string[in_])) { |
| 367 | char_array_4[i++] = encoded_string[in_]; in_++; |
| 368 | if (i == 4) { |
| 369 | for (i = 0; i < 4; i++) { |
| 370 | char_array_4[i] = base64_chars.find(c: char_array_4[i]); |
| 371 | } |
| 372 | |
| 373 | char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); |
| 374 | char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); |
| 375 | char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; |
| 376 | |
| 377 | for (i = 0; (i < 3); i++) { |
| 378 | ret.push_back(x: char_array_3[i]); |
| 379 | } |
| 380 | |
| 381 | i = 0; |
| 382 | } |
| 383 | } |
| 384 | |
| 385 | if (i) { |
| 386 | for (j = i; j < 4; j++) { |
| 387 | char_array_4[j] = 0; |
| 388 | } |
| 389 | |
| 390 | for (j = 0; j < 4; j++) { |
| 391 | char_array_4[j] = base64_chars.find(c: char_array_4[j]); |
| 392 | } |
| 393 | |
| 394 | char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); |
| 395 | char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); |
| 396 | char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; |
| 397 | |
| 398 | for (j = 0; j < i - 1; j++) { |
| 399 | ret.push_back(x: char_array_3[j]); |
| 400 | } |
| 401 | } |
| 402 | |
| 403 | return ret; |
| 404 | } |
| 405 | |
| 406 | // |
| 407 | // random string / id |
| 408 | // |
| 409 | |
| 410 | static std::string random_string() { |
| 411 | static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" ); |
| 412 | |
| 413 | std::random_device rd; |
| 414 | std::mt19937 generator(rd()); |
| 415 | |
| 416 | std::string result(32, ' '); |
| 417 | |
| 418 | for (int i = 0; i < 32; ++i) { |
| 419 | result[i] = str[generator() % str.size()]; |
| 420 | } |
| 421 | |
| 422 | return result; |
| 423 | } |
| 424 | |
| 425 | static std::string gen_chatcmplid() { |
| 426 | return "chatcmpl-" + random_string(); |
| 427 | } |
| 428 | |
| 429 | static std::string gen_tool_call_id() { |
| 430 | return random_string(); |
| 431 | } |
| 432 | |
| 433 | // |
| 434 | // other common utils |
| 435 | // |
| 436 | |
| 437 | // TODO: reuse llama_detokenize |
| 438 | template <class Iter> |
| 439 | static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { |
| 440 | std::string ret; |
| 441 | for (; begin != end; ++begin) { |
| 442 | ret += common_token_to_piece(ctx, *begin); |
| 443 | } |
| 444 | |
| 445 | return ret; |
| 446 | } |
| 447 | |
| 448 | // format incomplete utf-8 multibyte character for output |
| 449 | static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { |
| 450 | std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); |
| 451 | |
| 452 | // if the size is 1 and first bit is 1, meaning it's a partial character |
| 453 | // (size > 1 meaning it's already a known token) |
| 454 | if (out.size() == 1 && (out[0] & 0x80) == 0x80) { |
| 455 | std::stringstream ss; |
| 456 | ss << std::hex << (out[0] & 0xff); |
| 457 | std::string res(ss.str()); |
| 458 | out = "byte: \\x" + res; |
| 459 | } |
| 460 | |
| 461 | return out; |
| 462 | } |
| 463 | |
| 464 | static bool server_sent_event(httplib::DataSink & sink, const json & data) { |
| 465 | const std::string str = |
| 466 | "data: " + |
| 467 | data.dump(indent: -1, indent_char: ' ', ensure_ascii: false, error_handler: json::error_handler_t::replace) + |
| 468 | "\n\n" ; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). |
| 469 | |
| 470 | LOG_DBG("data stream, to_send: %s" , str.c_str()); |
| 471 | |
| 472 | return sink.write(str.c_str(), str.size()); |
| 473 | } |
| 474 | |
| 475 | // |
| 476 | // OAI utils |
| 477 | // |
| 478 | |
| 479 | // used by /completions endpoint |
| 480 | static json oaicompat_completion_params_parse(const json & body) { |
| 481 | json llama_params; |
| 482 | |
| 483 | if (!body.contains(key: "prompt" )) { |
| 484 | throw std::runtime_error("\"prompt\" is required" ); |
| 485 | } |
| 486 | |
| 487 | // Handle "stop" field |
| 488 | if (body.contains(key: "stop" ) && body.at(key: "stop" ).is_string()) { |
| 489 | llama_params["stop" ] = json::array(init: {body.at(key: "stop" ).get<std::string>()}); |
| 490 | } else { |
| 491 | llama_params["stop" ] = json_value(body, key: "stop" , default_value: json::array()); |
| 492 | } |
| 493 | |
| 494 | // Handle "n" field |
| 495 | int n_choices = json_value(body, key: "n" , default_value: 1); |
| 496 | if (n_choices != 1) { |
| 497 | throw std::runtime_error("Only one completion choice is allowed" ); |
| 498 | } |
| 499 | |
| 500 | // Handle "echo" field |
| 501 | if (json_value(body, key: "echo" , default_value: false)) { |
| 502 | throw std::runtime_error("Only no echo is supported" ); |
| 503 | } |
| 504 | |
| 505 | // Params supported by OAI but unsupported by llama.cpp |
| 506 | static const std::vector<std::string> unsupported_params { "best_of" , "suffix" }; |
| 507 | for (const auto & param : unsupported_params) { |
| 508 | if (body.contains(key: param)) { |
| 509 | throw std::runtime_error("Unsupported param: " + param); |
| 510 | } |
| 511 | } |
| 512 | |
| 513 | // Copy remaining properties to llama_params |
| 514 | for (const auto & item : body.items()) { |
| 515 | // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" |
| 516 | if (!llama_params.contains(key: item.key()) || item.key() == "n_predict" ) { |
| 517 | llama_params[item.key()] = item.value(); |
| 518 | } |
| 519 | } |
| 520 | |
| 521 | return llama_params; |
| 522 | } |
| 523 | |
| 524 | struct oaicompat_parser_options { |
| 525 | bool use_jinja; |
| 526 | bool prefill_assistant; |
| 527 | common_reasoning_format reasoning_format; |
| 528 | std::map<std::string,std::string> chat_template_kwargs; |
| 529 | common_chat_templates * tmpls; |
| 530 | bool allow_image; |
| 531 | bool allow_audio; |
| 532 | bool enable_thinking = true; |
| 533 | }; |
| 534 | |
| 535 | // used by /chat/completions endpoint |
| 536 | static json oaicompat_chat_params_parse( |
| 537 | json & body, /* openai api json semantics */ |
| 538 | const oaicompat_parser_options & opt, |
| 539 | std::vector<raw_buffer> & out_files) |
| 540 | { |
| 541 | json llama_params; |
| 542 | |
| 543 | auto tools = json_value(body, key: "tools" , default_value: json()); |
| 544 | auto has_tools = tools.is_array() && !tools.empty(); |
| 545 | auto stream = json_value(body, key: "stream" , default_value: false); |
| 546 | auto tool_choice = json_value(body, key: "tool_choice" , default_value: std::string("auto" )); |
| 547 | |
| 548 | if (!opt.use_jinja) { |
| 549 | if (has_tools) { |
| 550 | throw std::runtime_error("tools param requires --jinja flag" ); |
| 551 | } |
| 552 | if (tool_choice != "auto" ) { |
| 553 | throw std::runtime_error("tool_choice param requires --jinja flag" ); |
| 554 | } |
| 555 | } |
| 556 | |
| 557 | // Handle "stop" field |
| 558 | if (body.contains(key: "stop" ) && body.at(key: "stop" ).is_string()) { |
| 559 | llama_params["stop" ] = json::array(init: {body.at(key: "stop" ).get<std::string>()}); |
| 560 | } else { |
| 561 | llama_params["stop" ] = json_value(body, key: "stop" , default_value: json::array()); |
| 562 | } |
| 563 | |
| 564 | auto json_schema = json_value(body, key: "json_schema" , default_value: json()); |
| 565 | auto grammar = json_value(body, key: "grammar" , default_value: std::string()); |
| 566 | if (!json_schema.is_null() && !grammar.empty()) { |
| 567 | throw std::runtime_error("Cannot use both json_schema and grammar" ); |
| 568 | } |
| 569 | |
| 570 | // Handle "response_format" field |
| 571 | if (body.contains(key: "response_format" )) { |
| 572 | json response_format = json_value(body, key: "response_format" , default_value: json::object()); |
| 573 | std::string response_type = json_value(body: response_format, key: "type" , default_value: std::string()); |
| 574 | if (response_type == "json_object" ) { |
| 575 | json_schema = json_value(body: response_format, key: "schema" , default_value: json::object()); |
| 576 | } else if (response_type == "json_schema" ) { |
| 577 | auto schema_wrapper = json_value(body: response_format, key: "json_schema" , default_value: json::object()); |
| 578 | json_schema = json_value(body: schema_wrapper, key: "schema" , default_value: json::object()); |
| 579 | } else if (!response_type.empty() && response_type != "text" ) { |
| 580 | throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); |
| 581 | } |
| 582 | } |
| 583 | |
| 584 | // get input files |
| 585 | if (!body.contains(key: "messages" )) { |
| 586 | throw std::runtime_error("'messages' is required" ); |
| 587 | } |
| 588 | json & messages = body.at(key: "messages" ); |
| 589 | if (!messages.is_array()) { |
| 590 | throw std::runtime_error("Expected 'messages' to be an array" ); |
| 591 | } |
| 592 | for (auto & msg : messages) { |
| 593 | std::string role = json_value(body: msg, key: "role" , default_value: std::string()); |
| 594 | if (role != "assistant" && !msg.contains(key: "content" )) { |
| 595 | throw std::runtime_error("All non-assistant messages must contain 'content'" ); |
| 596 | } |
| 597 | if (role == "assistant" ) { |
| 598 | if (!msg.contains(key: "content" ) && !msg.contains(key: "tool_calls" )) { |
| 599 | throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!" ); |
| 600 | } |
| 601 | if (!msg.contains(key: "content" )) { |
| 602 | continue; // avoid errors with no content |
| 603 | } |
| 604 | } |
| 605 | json & content = msg.at(key: "content" ); |
| 606 | if (content.is_string() || content.is_null()) { |
| 607 | continue; |
| 608 | } |
| 609 | |
| 610 | if (!content.is_array()) { |
| 611 | throw std::runtime_error("Expected 'content' to be a string or an array" ); |
| 612 | } |
| 613 | |
| 614 | for (auto & p : content) { |
| 615 | std::string type = json_value(body: p, key: "type" , default_value: std::string()); |
| 616 | if (type == "image_url" ) { |
| 617 | if (!opt.allow_image) { |
| 618 | throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj" ); |
| 619 | } |
| 620 | |
| 621 | json image_url = json_value(body: p, key: "image_url" , default_value: json::object()); |
| 622 | std::string url = json_value(body: image_url, key: "url" , default_value: std::string()); |
| 623 | if (string_starts_with(str: url, prefix: "http" )) { |
| 624 | // download remote image |
| 625 | // TODO @ngxson : maybe make these params configurable |
| 626 | common_remote_params params; |
| 627 | params.headers.push_back(x: "User-Agent: llama.cpp/" + build_info); |
| 628 | params.max_size = 1024 * 1024 * 10; // 10MB |
| 629 | params.timeout = 10; // seconds |
| 630 | SRV_INF("downloading image from '%s'\n" , url.c_str()); |
| 631 | auto res = common_remote_get_content(url, params); |
| 632 | if (200 <= res.first && res.first < 300) { |
| 633 | SRV_INF("downloaded %ld bytes\n" , res.second.size()); |
| 634 | raw_buffer data; |
| 635 | data.insert(position: data.end(), first: res.second.begin(), last: res.second.end()); |
| 636 | out_files.push_back(x: data); |
| 637 | } else { |
| 638 | throw std::runtime_error("Failed to download image" ); |
| 639 | } |
| 640 | |
| 641 | } else { |
| 642 | // try to decode base64 image |
| 643 | std::vector<std::string> parts = string_split<std::string>(input: url, /*separator*/ separator: ','); |
| 644 | if (parts.size() != 2) { |
| 645 | throw std::runtime_error("Invalid image_url.url value" ); |
| 646 | } else if (!string_starts_with(str: parts[0], prefix: "data:image/" )) { |
| 647 | throw std::runtime_error("Invalid image_url.url format: " + parts[0]); |
| 648 | } else if (!string_ends_with(str: parts[0], suffix: "base64" )) { |
| 649 | throw std::runtime_error("image_url.url must be base64 encoded" ); |
| 650 | } else { |
| 651 | auto base64_data = parts[1]; |
| 652 | auto decoded_data = base64_decode(encoded_string: base64_data); |
| 653 | out_files.push_back(x: decoded_data); |
| 654 | } |
| 655 | } |
| 656 | |
| 657 | // replace this chunk with a marker |
| 658 | p["type" ] = "text" ; |
| 659 | p["text" ] = mtmd_default_marker(); |
| 660 | p.erase(key: "image_url" ); |
| 661 | |
| 662 | } else if (type == "input_audio" ) { |
| 663 | if (!opt.allow_audio) { |
| 664 | throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj" ); |
| 665 | } |
| 666 | |
| 667 | json input_audio = json_value(body: p, key: "input_audio" , default_value: json::object()); |
| 668 | std::string data = json_value(body: input_audio, key: "data" , default_value: std::string()); |
| 669 | std::string format = json_value(body: input_audio, key: "format" , default_value: std::string()); |
| 670 | // while we also support flac, we don't allow it here so we matches the OAI spec |
| 671 | if (format != "wav" && format != "mp3" ) { |
| 672 | throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'" ); |
| 673 | } |
| 674 | auto decoded_data = base64_decode(encoded_string: data); // expected to be base64 encoded |
| 675 | out_files.push_back(x: decoded_data); |
| 676 | |
| 677 | // replace this chunk with a marker |
| 678 | p["type" ] = "text" ; |
| 679 | p["text" ] = mtmd_default_marker(); |
| 680 | p.erase(key: "input_audio" ); |
| 681 | |
| 682 | } else if (type != "text" ) { |
| 683 | throw std::runtime_error("unsupported content[].type" ); |
| 684 | } |
| 685 | } |
| 686 | } |
| 687 | |
| 688 | common_chat_templates_inputs inputs; |
| 689 | inputs.messages = common_chat_msgs_parse_oaicompat(messages); |
| 690 | inputs.tools = common_chat_tools_parse_oaicompat(tools); |
| 691 | inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); |
| 692 | inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); |
| 693 | inputs.grammar = grammar; |
| 694 | inputs.use_jinja = opt.use_jinja; |
| 695 | inputs.parallel_tool_calls = json_value(body, key: "parallel_tool_calls" , default_value: false); |
| 696 | inputs.add_generation_prompt = json_value(body, key: "add_generation_prompt" , default_value: true); |
| 697 | inputs.reasoning_format = opt.reasoning_format; |
| 698 | inputs.enable_thinking = opt.enable_thinking; |
| 699 | if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { |
| 700 | if (body.contains(key: "grammar" )) { |
| 701 | throw std::runtime_error("Cannot use custom grammar constraints with tools." ); |
| 702 | } |
| 703 | llama_params["parse_tool_calls" ] = true; |
| 704 | } |
| 705 | |
| 706 | // merge the template args provided from command line with the args provided in the user request |
| 707 | auto chat_template_kwargs_object = json_value(body, key: "chat_template_kwargs" , default_value: json::object()); |
| 708 | inputs.chat_template_kwargs = opt.chat_template_kwargs; |
| 709 | for (const auto & item : chat_template_kwargs_object.items()) { |
| 710 | inputs.chat_template_kwargs[item.key()] = item.value().dump(); |
| 711 | } |
| 712 | |
| 713 | // parse the "enable_thinking" kwarg to override the default value |
| 714 | auto enable_thinking_kwarg = json_value(body: inputs.chat_template_kwargs, key: "enable_thinking" , default_value: std::string("" )); |
| 715 | if (enable_thinking_kwarg == "true" ) { |
| 716 | inputs.enable_thinking = true; |
| 717 | } else if (enable_thinking_kwarg == "false" ) { |
| 718 | inputs.enable_thinking = false; |
| 719 | } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') { |
| 720 | throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)" ); |
| 721 | } |
| 722 | |
| 723 | // if the assistant message appears at the end of list, we do not add end-of-turn token |
| 724 | // for ex. this can be useful to modify the reasoning process in reasoning models |
| 725 | bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; |
| 726 | common_chat_msg last_message; |
| 727 | if (prefill_assistant_message) { |
| 728 | last_message = inputs.messages.back(); |
| 729 | inputs.messages.pop_back(); |
| 730 | |
| 731 | /* sanity check, max one assistant message at the end of the list */ |
| 732 | if (!inputs.messages.empty() && inputs.messages.back().role == "assistant" ){ |
| 733 | throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list." ); |
| 734 | } |
| 735 | |
| 736 | /* TODO: test this properly */ |
| 737 | inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; |
| 738 | |
| 739 | if ( inputs.enable_thinking ) { |
| 740 | throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking." ); |
| 741 | } |
| 742 | |
| 743 | inputs.add_generation_prompt = true; |
| 744 | } |
| 745 | |
| 746 | // Apply chat template to the list of messages |
| 747 | auto chat_params = common_chat_templates_apply(tmpls: opt.tmpls, inputs); |
| 748 | |
| 749 | /* Append assistant prefilled message */ |
| 750 | if (prefill_assistant_message) { |
| 751 | if (!last_message.content_parts.empty()) { |
| 752 | for (auto & p : last_message.content_parts) { |
| 753 | chat_params.prompt += p.text; |
| 754 | } |
| 755 | } else { |
| 756 | chat_params.prompt += last_message.content; |
| 757 | } |
| 758 | } |
| 759 | |
| 760 | llama_params["chat_format" ] = static_cast<int>(chat_params.format); |
| 761 | llama_params["prompt" ] = chat_params.prompt; |
| 762 | if (!chat_params.grammar.empty()) { |
| 763 | llama_params["grammar" ] = chat_params.grammar; |
| 764 | } |
| 765 | llama_params["grammar_lazy" ] = chat_params.grammar_lazy; |
| 766 | auto grammar_triggers = json::array(); |
| 767 | for (const auto & trigger : chat_params.grammar_triggers) { |
| 768 | server_grammar_trigger ct(trigger); |
| 769 | grammar_triggers.push_back(val: ct.to_json()); |
| 770 | } |
| 771 | llama_params["grammar_triggers" ] = grammar_triggers; |
| 772 | llama_params["preserved_tokens" ] = chat_params.preserved_tokens; |
| 773 | llama_params["thinking_forced_open" ] = chat_params.thinking_forced_open; |
| 774 | for (const auto & stop : chat_params.additional_stops) { |
| 775 | llama_params["stop" ].push_back(val: stop); |
| 776 | } |
| 777 | |
| 778 | // Handle "n" field |
| 779 | int n_choices = json_value(body, key: "n" , default_value: 1); |
| 780 | if (n_choices != 1) { |
| 781 | throw std::runtime_error("Only one completion choice is allowed" ); |
| 782 | } |
| 783 | |
| 784 | // Handle "logprobs" field |
| 785 | // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future |
| 786 | if (json_value(body, key: "logprobs" , default_value: false)) { |
| 787 | if (has_tools && stream) { |
| 788 | throw std::runtime_error("logprobs is not supported with tools + stream" ); |
| 789 | } |
| 790 | llama_params["n_probs" ] = json_value(body, key: "top_logprobs" , default_value: 20); |
| 791 | } else if (body.contains(key: "top_logprobs" ) && !body.at(key: "top_logprobs" ).is_null()) { |
| 792 | throw std::runtime_error("top_logprobs requires logprobs to be set to true" ); |
| 793 | } |
| 794 | |
| 795 | // Copy remaining properties to llama_params |
| 796 | // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. |
| 797 | // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp |
| 798 | for (const auto & item : body.items()) { |
| 799 | // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" |
| 800 | if (!llama_params.contains(key: item.key()) || item.key() == "n_predict" ) { |
| 801 | llama_params[item.key()] = item.value(); |
| 802 | } |
| 803 | } |
| 804 | |
| 805 | return llama_params; |
| 806 | } |
| 807 | |
| 808 | static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { |
| 809 | json data = json::array(); |
| 810 | int32_t n_tokens = 0; |
| 811 | int i = 0; |
| 812 | for (const auto & elem : embeddings) { |
| 813 | json embedding_obj; |
| 814 | |
| 815 | if (use_base64) { |
| 816 | const auto& vec = json_value(body: elem, key: "embedding" , default_value: json::array()).get<std::vector<float>>(); |
| 817 | const char* data_ptr = reinterpret_cast<const char*>(vec.data()); |
| 818 | size_t data_size = vec.size() * sizeof(float); |
| 819 | embedding_obj = { |
| 820 | {"embedding" , base64::encode(buffer: data_ptr, size: data_size)}, |
| 821 | {"index" , i++}, |
| 822 | {"object" , "embedding" }, |
| 823 | {"encoding_format" , "base64" } |
| 824 | }; |
| 825 | } else { |
| 826 | embedding_obj = { |
| 827 | {"embedding" , json_value(body: elem, key: "embedding" , default_value: json::array())}, |
| 828 | {"index" , i++}, |
| 829 | {"object" , "embedding" } |
| 830 | }; |
| 831 | } |
| 832 | data.push_back(val: embedding_obj); |
| 833 | |
| 834 | n_tokens += json_value(body: elem, key: "tokens_evaluated" , default_value: 0); |
| 835 | } |
| 836 | |
| 837 | json res = json { |
| 838 | {"model" , json_value(body: request, key: "model" , default_value: std::string(DEFAULT_OAICOMPAT_MODEL))}, |
| 839 | {"object" , "list" }, |
| 840 | {"usage" , json { |
| 841 | {"prompt_tokens" , n_tokens}, |
| 842 | {"total_tokens" , n_tokens} |
| 843 | }}, |
| 844 | {"data" , data} |
| 845 | }; |
| 846 | |
| 847 | return res; |
| 848 | } |
| 849 | |
| 850 | static json format_response_rerank( |
| 851 | const json & request, |
| 852 | const json & ranks, |
| 853 | bool is_tei_format, |
| 854 | std::vector<std::string> & texts, |
| 855 | int top_n) { |
| 856 | int32_t n_tokens = 0; |
| 857 | bool return_text = is_tei_format && json_value(body: request, key: "return_text" , default_value: false); |
| 858 | std::vector<json> elements; // Temporary vector to hold unsorted elements |
| 859 | std::string score_label = is_tei_format ? "score" : "relevance_score" ; |
| 860 | for (const auto & rank : ranks) { |
| 861 | int index = json_value(body: rank, key: "index" , default_value: 0); |
| 862 | json elem = json{ |
| 863 | {"index" , index}, |
| 864 | {score_label, json_value(body: rank, key: "score" , default_value: 0.0)}, |
| 865 | }; |
| 866 | n_tokens += json_value(body: rank, key: "tokens_evaluated" , default_value: 0); |
| 867 | if (return_text) { |
| 868 | elem["text" ] = std::move(texts[index]); |
| 869 | } |
| 870 | elements.push_back(x: elem); |
| 871 | } |
| 872 | |
| 873 | std::sort(first: elements.begin(), last: elements.end(), comp: [score_label](const json& a, const json& b) { |
| 874 | return json_value(body: a, key: score_label, default_value: 0.0) > json_value(body: b, key: score_label, default_value: 0.0); |
| 875 | }); |
| 876 | |
| 877 | elements.resize(new_size: std::min(a: top_n, b: (int)elements.size())); |
| 878 | json results = elements; |
| 879 | |
| 880 | if (is_tei_format) return results; |
| 881 | |
| 882 | json res = json{ |
| 883 | {"model" , json_value(body: request, key: "model" , default_value: std::string(DEFAULT_OAICOMPAT_MODEL))}, |
| 884 | {"object" , "list" }, |
| 885 | {"usage" , json{ |
| 886 | {"prompt_tokens" , n_tokens}, |
| 887 | {"total_tokens" , n_tokens} |
| 888 | }}, |
| 889 | {"results" , results} |
| 890 | }; |
| 891 | |
| 892 | return res; |
| 893 | } |
| 894 | |
| 895 | static bool is_valid_utf8(const std::string & str) { |
| 896 | const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data()); |
| 897 | const unsigned char* end = bytes + str.length(); |
| 898 | |
| 899 | while (bytes < end) { |
| 900 | if (*bytes <= 0x7F) { |
| 901 | // 1-byte sequence (0xxxxxxx) |
| 902 | bytes++; |
| 903 | } else if ((*bytes & 0xE0) == 0xC0) { |
| 904 | // 2-byte sequence (110xxxxx 10xxxxxx) |
| 905 | if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) |
| 906 | return false; |
| 907 | bytes += 2; |
| 908 | } else if ((*bytes & 0xF0) == 0xE0) { |
| 909 | // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) |
| 910 | if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) |
| 911 | return false; |
| 912 | bytes += 3; |
| 913 | } else if ((*bytes & 0xF8) == 0xF0) { |
| 914 | // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) |
| 915 | if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || |
| 916 | (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) |
| 917 | return false; |
| 918 | bytes += 4; |
| 919 | } else { |
| 920 | // Invalid UTF-8 lead byte |
| 921 | return false; |
| 922 | } |
| 923 | } |
| 924 | |
| 925 | return true; |
| 926 | } |
| 927 | |
| 928 | static json format_tokenizer_response(const json & tokens) { |
| 929 | return json { |
| 930 | {"tokens" , tokens} |
| 931 | }; |
| 932 | } |
| 933 | |
| 934 | static json format_detokenized_response(const std::string & content) { |
| 935 | return json { |
| 936 | {"content" , content} |
| 937 | }; |
| 938 | } |
| 939 | |
| 940 | static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) { |
| 941 | json data = json::array(); |
| 942 | for (const auto & lb : logit_bias) { |
| 943 | data.push_back(val: json{ |
| 944 | {"bias" , lb.bias}, |
| 945 | {"token" , lb.token}, |
| 946 | }); |
| 947 | } |
| 948 | return data; |
| 949 | } |
| 950 | |
| 951 | static std::string safe_json_to_str(const json & data) { |
| 952 | return data.dump(indent: -1, indent_char: ' ', ensure_ascii: false, error_handler: json::error_handler_t::replace); |
| 953 | } |
| 954 | |
| 955 | static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) { |
| 956 | std::vector<llama_token_data> cur; |
| 957 | const auto * logits = llama_get_logits_ith(ctx, i: idx); |
| 958 | |
| 959 | const llama_model * model = llama_get_model(ctx); |
| 960 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 961 | |
| 962 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 963 | |
| 964 | cur.resize(new_size: n_vocab); |
| 965 | for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
| 966 | cur[token_id] = llama_token_data{.id: token_id, .logit: logits[token_id], .p: 0.0f}; |
| 967 | } |
| 968 | |
| 969 | // sort tokens by logits |
| 970 | std::sort(first: cur.begin(), last: cur.end(), comp: [](const llama_token_data & a, const llama_token_data & b) { |
| 971 | return a.logit > b.logit; |
| 972 | }); |
| 973 | |
| 974 | // apply softmax |
| 975 | float max_l = cur[0].logit; |
| 976 | float cum_sum = 0.0f; |
| 977 | for (size_t i = 0; i < cur.size(); ++i) { |
| 978 | float p = expf(x: cur[i].logit - max_l); |
| 979 | cur[i].p = p; |
| 980 | cum_sum += p; |
| 981 | } |
| 982 | for (size_t i = 0; i < cur.size(); ++i) { |
| 983 | cur[i].p /= cum_sum; |
| 984 | } |
| 985 | |
| 986 | return cur; |
| 987 | } |
| 988 | |
| 989 | static bool are_lora_equal( |
| 990 | const std::vector<common_adapter_lora_info> & l1, |
| 991 | const std::vector<common_adapter_lora_info> & l2) { |
| 992 | if (l1.size() != l2.size()) { |
| 993 | return false; |
| 994 | } |
| 995 | for (size_t i = 0; i < l1.size(); ++i) { |
| 996 | // we don't check lora.path to reduce the time complexity |
| 997 | if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { |
| 998 | return false; |
| 999 | } |
| 1000 | } |
| 1001 | return true; |
| 1002 | } |
| 1003 | |
| 1004 | // get the ids of all enabled loras |
| 1005 | static std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras) { |
| 1006 | std::vector<size_t> enabled_ids; |
| 1007 | for (size_t i = 0; i < loras.size(); ++i) { |
| 1008 | if (loras[i].scale > 0) { |
| 1009 | enabled_ids.push_back(x: i); |
| 1010 | } |
| 1011 | } |
| 1012 | return enabled_ids; |
| 1013 | } |
| 1014 | |
| 1015 | // check whether the given lora set has only aloras activated (empty => false) |
| 1016 | static bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras) { |
| 1017 | bool found_alora = false; |
| 1018 | for (const auto & lora : loras) { |
| 1019 | if (lora.scale != 0) { |
| 1020 | if (llama_adapter_get_alora_n_invocation_tokens(adapter: lora.ptr) == 0) { |
| 1021 | return false; |
| 1022 | } |
| 1023 | found_alora = true; |
| 1024 | } |
| 1025 | } |
| 1026 | return found_alora; |
| 1027 | } |
| 1028 | |
| 1029 | // if the two sets of loras are different, they require a cache clear unless the |
| 1030 | // change is only from aloras to aloras. |
| 1031 | static bool lora_should_clear_cache( |
| 1032 | const std::vector<common_adapter_lora_info> & current, |
| 1033 | const std::vector<common_adapter_lora_info> & next) { |
| 1034 | |
| 1035 | // This should always be called after determining that the two sets are |
| 1036 | // _not_ equal. This assert is therefore some slightly wasted work and |
| 1037 | // should be safe to remove as long as this method is called correctly. |
| 1038 | GGML_ASSERT(!are_lora_equal(current, next)); |
| 1039 | |
| 1040 | return ( |
| 1041 | !(lora_get_enabled_ids(loras: current).empty() || lora_all_alora(loras: current)) || |
| 1042 | !lora_all_alora(loras: next)); |
| 1043 | } |
| 1044 | |
| 1045 | // parse lora config from JSON request, returned a copy of lora_base with updated scale |
| 1046 | static std::vector<common_adapter_lora_info> parse_lora_request( |
| 1047 | const std::vector<common_adapter_lora_info> & lora_base, |
| 1048 | const json & data) { |
| 1049 | std::vector<common_adapter_lora_info> lora(lora_base); |
| 1050 | int max_idx = lora.size(); |
| 1051 | |
| 1052 | // clear existing value |
| 1053 | for (auto & entry : lora) { |
| 1054 | entry.scale = 0.0f; |
| 1055 | } |
| 1056 | |
| 1057 | // set value |
| 1058 | for (const auto & entry : data) { |
| 1059 | int id = json_value(body: entry, key: "id" , default_value: -1); |
| 1060 | float scale = json_value(body: entry, key: "scale" , default_value: 0.0f); |
| 1061 | if (0 <= id && id < max_idx) { |
| 1062 | lora[id].scale = scale; |
| 1063 | } else { |
| 1064 | throw std::runtime_error("invalid adapter id" ); |
| 1065 | } |
| 1066 | } |
| 1067 | |
| 1068 | return lora; |
| 1069 | } |
| 1070 | |
| 1071 | // |
| 1072 | // utils for interacting with libmtmd |
| 1073 | // (may need to refactor in near future) |
| 1074 | // |
| 1075 | |
| 1076 | /** |
| 1077 | * server_tokens is a helper to manage the input tokens and image for the server. |
| 1078 | * it is made this way to simplify the logic of KV cache management. |
| 1079 | */ |
| 1080 | struct server_tokens { |
| 1081 | bool has_mtmd = false; |
| 1082 | |
| 1083 | private: // disallow accessing these members directly, risking out-of-sync |
| 1084 | |
| 1085 | // map a **start** index in tokens to the image chunk |
| 1086 | // note: the order need to be in-sync with tokens |
| 1087 | std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media; |
| 1088 | |
| 1089 | // list of tokens |
| 1090 | // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk |
| 1091 | // otherwise, it is a normal text token |
| 1092 | // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list |
| 1093 | // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos |
| 1094 | llama_tokens tokens; |
| 1095 | |
| 1096 | // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): |
| 1097 | // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] |
| 1098 | // idx 0 1 2 3 4 5 6 7 8 9 10 |
| 1099 | // pos 0 1 2 3 4 5 5 5 7 7 7 |
| 1100 | // map_idx_to_media will contain: {5, img0}, {8, img1} |
| 1101 | |
| 1102 | public: |
| 1103 | server_tokens() = default; |
| 1104 | ~server_tokens() = default; |
| 1105 | |
| 1106 | // Prevent copying |
| 1107 | // TODO: server_tokens should be copyable - remove this: |
| 1108 | server_tokens(const server_tokens&) = delete; |
| 1109 | server_tokens& operator=(const server_tokens&) = delete; |
| 1110 | |
| 1111 | // Allow moving (usually implicitly generated if members are movable) |
| 1112 | server_tokens(server_tokens&&) = default; |
| 1113 | server_tokens& operator=(server_tokens&&) = default; |
| 1114 | |
| 1115 | // Allow accessing elements using [] operator |
| 1116 | llama_token operator[](size_t index) { return tokens[index]; } |
| 1117 | const llama_token& operator[](size_t index) const { return tokens[index]; } |
| 1118 | |
| 1119 | server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { |
| 1120 | for (size_t i = 0; i < mtmd_chunks.size(); ++i) { |
| 1121 | push_back(chunk: mtmd_chunks[i]); |
| 1122 | } |
| 1123 | } |
| 1124 | |
| 1125 | server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { |
| 1126 | } |
| 1127 | |
| 1128 | llama_pos pos_next() const { |
| 1129 | if (!has_mtmd) { |
| 1130 | return tokens.size(); |
| 1131 | } |
| 1132 | |
| 1133 | llama_pos res = tokens.size(); |
| 1134 | |
| 1135 | for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { |
| 1136 | const auto & chunk = it->second; |
| 1137 | res += mtmd_input_chunk_get_n_pos(chunk: chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk: chunk.get()); |
| 1138 | } |
| 1139 | |
| 1140 | return res; |
| 1141 | } |
| 1142 | |
| 1143 | // for debugging |
| 1144 | std::string str() const { |
| 1145 | std::ostringstream oss; |
| 1146 | oss << "tokens: " ; |
| 1147 | for (size_t idx = 0; idx < tokens.size(); ++idx) { |
| 1148 | llama_token t = tokens[idx]; |
| 1149 | oss << "idx:" << idx << " " ; |
| 1150 | if (t == LLAMA_TOKEN_NULL) { |
| 1151 | oss << "<embd> " ; |
| 1152 | } else { |
| 1153 | oss << t << " " ; |
| 1154 | } |
| 1155 | } |
| 1156 | oss << "\n" ; |
| 1157 | oss << "image idx: " ; |
| 1158 | for (const auto & it : map_idx_to_media) { |
| 1159 | oss << it.first << ", " ; |
| 1160 | } |
| 1161 | return oss.str(); |
| 1162 | } |
| 1163 | |
| 1164 | const mtmd::input_chunk_ptr & find_chunk(size_t idx) const { |
| 1165 | auto it = map_idx_to_media.find(x: idx); |
| 1166 | if (it != map_idx_to_media.end()) { |
| 1167 | return it->second; |
| 1168 | } |
| 1169 | throw std::runtime_error("Chunk not found" ); |
| 1170 | } |
| 1171 | |
| 1172 | void push_back(llama_token tok) { |
| 1173 | if (tok == LLAMA_TOKEN_NULL) { |
| 1174 | throw std::runtime_error("Invalid token" ); |
| 1175 | } |
| 1176 | tokens.emplace_back(args&: tok); |
| 1177 | } |
| 1178 | |
| 1179 | // will create a copy of the chunk if it contains non-text data |
| 1180 | void push_back(const mtmd_input_chunk * chunk) { |
| 1181 | auto type = mtmd_input_chunk_get_type(chunk); |
| 1182 | if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { |
| 1183 | GGML_ASSERT(has_mtmd); |
| 1184 | const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); |
| 1185 | size_t start_idx = tokens.size(); |
| 1186 | for (size_t i = 0; i < n_tokens; ++i) { |
| 1187 | tokens.emplace_back(LLAMA_TOKEN_NULL); |
| 1188 | } |
| 1189 | mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); |
| 1190 | map_idx_to_media[start_idx] = std::move(new_chunk); |
| 1191 | } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { |
| 1192 | size_t n_tokens; |
| 1193 | const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, n_tokens_output: &n_tokens); |
| 1194 | for (size_t i = 0; i < n_tokens; ++i) { |
| 1195 | push_back(tok: text_tokens[i]); |
| 1196 | } |
| 1197 | } else { |
| 1198 | GGML_ABORT("Invalid chunk type" ); |
| 1199 | } |
| 1200 | } |
| 1201 | |
| 1202 | // appends server tokens, updates the media map. copies media chunks. |
| 1203 | void push_back(server_tokens & tokens) { |
| 1204 | size_t start_idx = size(); |
| 1205 | for (size_t i = 0; i < tokens.size(); i++) { |
| 1206 | push_back(tok: tokens[i]); |
| 1207 | } |
| 1208 | if (tokens.has_mtmd) { |
| 1209 | // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. |
| 1210 | // We could also just check, but this will prevent silently dropping MTMD data. |
| 1211 | GGML_ASSERT(has_mtmd); |
| 1212 | for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { |
| 1213 | auto * chunk = tokens.map_idx_to_media[it->first].get(); |
| 1214 | mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); |
| 1215 | map_idx_to_media[start_idx + it->first] = std::move(new_chunk); |
| 1216 | } |
| 1217 | } |
| 1218 | } |
| 1219 | |
| 1220 | // for compatibility with context shift and prompt truncation |
| 1221 | void insert(const llama_tokens & inp_tokens) { |
| 1222 | GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled |
| 1223 | tokens.insert(position: tokens.end(), first: inp_tokens.begin(), last: inp_tokens.end()); |
| 1224 | } |
| 1225 | |
| 1226 | // for compatibility with speculative decoding, ctx shift, slot save/load |
| 1227 | const llama_tokens & get_text_tokens() const { |
| 1228 | GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled |
| 1229 | return tokens; |
| 1230 | } |
| 1231 | |
| 1232 | // for compatibility with speculative decoding |
| 1233 | void set_token(llama_pos pos, llama_token id) { |
| 1234 | GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled |
| 1235 | tokens[pos] = id; |
| 1236 | } |
| 1237 | |
| 1238 | size_t size() const { |
| 1239 | return tokens.size(); |
| 1240 | } |
| 1241 | |
| 1242 | bool empty() const { |
| 1243 | return tokens.empty(); |
| 1244 | } |
| 1245 | |
| 1246 | void clear() { |
| 1247 | map_idx_to_media.clear(); |
| 1248 | tokens.clear(); |
| 1249 | } |
| 1250 | |
| 1251 | void keep_first(size_t n) { |
| 1252 | GGML_ASSERT(n <= tokens.size()); |
| 1253 | if (has_mtmd) { |
| 1254 | if (n == tokens.size()) { |
| 1255 | return; // nothing to do |
| 1256 | } |
| 1257 | // we throw an error if we try to remove a token in the middle of an image |
| 1258 | // for ex. with input of 5 text tokens and 2 images: |
| 1259 | // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] |
| 1260 | // n 1 2 3 4 5 6 7 8 9 10 |
| 1261 | // allowed to resize ^ ^ |
| 1262 | // disallowed to resize ^ ^ ^ |
| 1263 | if (n > 0) { |
| 1264 | // make sure we never remove tokens in the middle of an image |
| 1265 | // note that the case where we keep a full image at the end is allowed: |
| 1266 | // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL |
| 1267 | if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { |
| 1268 | find_chunk(idx: n - 1); // will throw an error if the token is not begin-of-chunk |
| 1269 | } |
| 1270 | } |
| 1271 | // remove all image chunks that are not used anymore |
| 1272 | for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { |
| 1273 | size_t idx = it->first; |
| 1274 | if (idx >= n) { |
| 1275 | it = map_idx_to_media.erase(position: it); |
| 1276 | } else { |
| 1277 | ++it; |
| 1278 | } |
| 1279 | } |
| 1280 | } |
| 1281 | tokens.resize(new_size: n); |
| 1282 | } |
| 1283 | |
| 1284 | std::string detokenize(const llama_context * ctx, bool special) const { |
| 1285 | llama_tokens text_tokens; |
| 1286 | text_tokens.reserve(n: tokens.size()); |
| 1287 | for (const auto & t : tokens) { |
| 1288 | if (t != LLAMA_TOKEN_NULL) { |
| 1289 | text_tokens.push_back(x: t); |
| 1290 | } |
| 1291 | } |
| 1292 | return common_detokenize(ctx, tokens: text_tokens, special); |
| 1293 | } |
| 1294 | |
| 1295 | size_t get_common_prefix(const server_tokens & b) const { |
| 1296 | const size_t max_idx = std::min(a: tokens.size(), b: b.tokens.size()); |
| 1297 | |
| 1298 | if (!has_mtmd) { |
| 1299 | for (size_t i = 0; i < max_idx; ++i) { |
| 1300 | if (tokens[i] == b.tokens[i]) { |
| 1301 | continue; |
| 1302 | } |
| 1303 | |
| 1304 | return i; |
| 1305 | } |
| 1306 | |
| 1307 | return max_idx; |
| 1308 | } |
| 1309 | |
| 1310 | for (size_t i = 0; i < max_idx; ++i) { |
| 1311 | const llama_token ai = tokens[i]; |
| 1312 | const llama_token bi = b.tokens[i]; |
| 1313 | |
| 1314 | if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { |
| 1315 | const auto & a_chunk = find_chunk(idx: i); |
| 1316 | const auto & b_chunk = b.find_chunk(idx: i); |
| 1317 | |
| 1318 | GGML_ASSERT(a_chunk && b_chunk); |
| 1319 | |
| 1320 | const std::string id_ai = mtmd_input_chunk_get_id(chunk: a_chunk.get()); |
| 1321 | const std::string id_bi = mtmd_input_chunk_get_id(chunk: b_chunk.get()); |
| 1322 | |
| 1323 | const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(chunk: a_chunk.get()); |
| 1324 | const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(chunk: b_chunk.get()); |
| 1325 | |
| 1326 | if (id_ai == id_bi && n_tok_a == n_tok_b) { |
| 1327 | GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk" ); // should never happen |
| 1328 | i += n_tok_a - 1; // will be +1 by the for loop |
| 1329 | continue; |
| 1330 | } |
| 1331 | |
| 1332 | return i; |
| 1333 | } |
| 1334 | |
| 1335 | if (ai == bi) { |
| 1336 | continue; |
| 1337 | } |
| 1338 | |
| 1339 | return i; |
| 1340 | } |
| 1341 | |
| 1342 | return max_idx; // all tokens are equal |
| 1343 | } |
| 1344 | |
| 1345 | // make sure all text tokens are within the vocab range |
| 1346 | bool validate(const struct llama_context * ctx) const { |
| 1347 | const llama_model * model = llama_get_model(ctx); |
| 1348 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 1349 | const int32_t n_vocab = llama_vocab_n_tokens(vocab); |
| 1350 | |
| 1351 | for (size_t i = 0; i < tokens.size(); ++i) { |
| 1352 | const auto & t = tokens[i]; |
| 1353 | if (t == LLAMA_TOKEN_NULL) { |
| 1354 | try { |
| 1355 | const auto & chunk = find_chunk(idx: i); |
| 1356 | size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk: chunk.get()); |
| 1357 | i += n_tokens - 1; // will be +1 by the for loop |
| 1358 | } catch (const std::exception & e) { |
| 1359 | return false; |
| 1360 | } |
| 1361 | } else if (t < 0 || t >= n_vocab) { |
| 1362 | return false; |
| 1363 | } |
| 1364 | } |
| 1365 | return true; |
| 1366 | } |
| 1367 | |
| 1368 | // encode and decode the image chunk |
| 1369 | int32_t process_chunk( |
| 1370 | llama_context * ctx, |
| 1371 | mtmd_context * mctx, |
| 1372 | size_t idx, |
| 1373 | llama_pos pos, |
| 1374 | int32_t seq_id, |
| 1375 | size_t & n_tokens_out) const { |
| 1376 | const auto & chunk = find_chunk(idx); |
| 1377 | const char * name = mtmd_input_chunk_get_type(chunk: chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE |
| 1378 | ? "image" : "audio" ; |
| 1379 | SRV_INF("processing %s...\n" , name); |
| 1380 | int32_t n_batch = llama_n_batch(ctx); |
| 1381 | int64_t t0 = ggml_time_ms(); |
| 1382 | llama_pos new_n_past; // unused for now |
| 1383 | int32_t result = mtmd_helper_eval_chunk_single(ctx: mctx, lctx: ctx, |
| 1384 | chunk: chunk.get(), |
| 1385 | n_past: pos, |
| 1386 | seq_id, |
| 1387 | n_batch, |
| 1388 | logits_last: true, // logits last |
| 1389 | new_n_past: &new_n_past); |
| 1390 | SRV_INF("%s processed in %" PRId64 " ms\n" , name, ggml_time_ms() - t0); |
| 1391 | if (result != 0) { |
| 1392 | LOG_ERR("mtmd_helper_eval failed with status %d" , result); |
| 1393 | n_tokens_out = 0; |
| 1394 | return result; |
| 1395 | } |
| 1396 | n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk: chunk.get()); |
| 1397 | return 0; |
| 1398 | } |
| 1399 | }; |
| 1400 | |
| 1401 | // Computes FNV-1a hash of the data |
| 1402 | static std::string fnv_hash(const uint8_t * data, size_t len) { |
| 1403 | const uint64_t fnv_prime = 0x100000001b3ULL; |
| 1404 | uint64_t hash = 0xcbf29ce484222325ULL; |
| 1405 | |
| 1406 | for (size_t i = 0; i < len; ++i) { |
| 1407 | hash ^= data[i]; |
| 1408 | hash *= fnv_prime; |
| 1409 | } |
| 1410 | return std::to_string(val: hash); |
| 1411 | } |
| 1412 | |
| 1413 | static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) { |
| 1414 | mtmd::bitmaps bitmaps; |
| 1415 | for (auto & file : files) { |
| 1416 | mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx: mctx, buf: file.data(), len: file.size())); |
| 1417 | if (!bmp.ptr) { |
| 1418 | throw std::runtime_error("Failed to load image or audio file" ); |
| 1419 | } |
| 1420 | // calculate bitmap hash (for KV caching) |
| 1421 | std::string hash = fnv_hash(data: bmp.data(), len: bmp.n_bytes()); |
| 1422 | bmp.set_id(hash.c_str()); |
| 1423 | bitmaps.entries.push_back(x: std::move(bmp)); |
| 1424 | } |
| 1425 | // process prompt |
| 1426 | std::vector<server_tokens> inputs; |
| 1427 | // multimodal |
| 1428 | mtmd_input_text inp_txt = { |
| 1429 | .text: prompt.c_str(), |
| 1430 | /* add_special */ .add_special: true, |
| 1431 | /* parse_special */ .parse_special: true, |
| 1432 | }; |
| 1433 | mtmd::input_chunks chunks(mtmd_input_chunks_init()); |
| 1434 | auto bitmaps_c_ptr = bitmaps.c_ptr(); |
| 1435 | int32_t tokenized = mtmd_tokenize(ctx: mctx, |
| 1436 | output: chunks.ptr.get(), |
| 1437 | text: &inp_txt, |
| 1438 | bitmaps: bitmaps_c_ptr.data(), |
| 1439 | n_bitmaps: bitmaps_c_ptr.size()); |
| 1440 | if (tokenized != 0) { |
| 1441 | throw std::runtime_error("Failed to tokenize prompt" ); |
| 1442 | } |
| 1443 | auto result = server_tokens(chunks, true); |
| 1444 | return result; |
| 1445 | } |
| 1446 | |
| 1447 | /** |
| 1448 | * break the input "prompt" object into multiple prompt if needed, then tokenize them |
| 1449 | * use tokenize_input_prompts() if the input could be an array. |
| 1450 | * this supports these cases: |
| 1451 | * - "prompt": "string" |
| 1452 | * - "prompt": [12, 34, 56] |
| 1453 | * - "prompt": [12, 34, "string", 56, 78] |
| 1454 | * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } |
| 1455 | */ |
| 1456 | static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { |
| 1457 | constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string" ; |
| 1458 | constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data" ; |
| 1459 | const bool has_mtmd = mctx != nullptr; |
| 1460 | if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(data: json_prompt)) { |
| 1461 | // string or mixed |
| 1462 | llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); |
| 1463 | return server_tokens(tmp, false); |
| 1464 | } else if (json_is_array_of_numbers(data: json_prompt)) { |
| 1465 | // array of tokens |
| 1466 | llama_tokens tmp = json_prompt.get<llama_tokens>(); |
| 1467 | return server_tokens(tmp, false); |
| 1468 | } else if (json_prompt.contains(key: JSON_STRING_PROMPT_KEY)) { |
| 1469 | // JSON object with prompt key. |
| 1470 | if (json_prompt.contains(key: JSON_MTMD_DATA_KEY)) { |
| 1471 | if (!has_mtmd) |
| 1472 | throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests." ); |
| 1473 | |
| 1474 | // JSON object with prompt and multimodal key. |
| 1475 | std::vector<raw_buffer> files; |
| 1476 | for (const auto & entry : json_prompt.at(key: JSON_MTMD_DATA_KEY)) { |
| 1477 | files.push_back(x: base64_decode(encoded_string: entry)); |
| 1478 | } |
| 1479 | return process_mtmd_prompt(mctx, prompt: json_prompt.at(key: JSON_STRING_PROMPT_KEY), files); |
| 1480 | } else { |
| 1481 | // Not multimodal, but contains a subobject. |
| 1482 | llama_tokens tmp = tokenize_mixed(vocab, json_prompt: json_prompt.at(key: JSON_STRING_PROMPT_KEY), add_special, parse_special); |
| 1483 | return server_tokens(tmp, false); |
| 1484 | } |
| 1485 | } else { |
| 1486 | throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens." ); |
| 1487 | } |
| 1488 | } |
| 1489 | |
| 1490 | /** |
| 1491 | * break the input "prompt" object into multiple prompt if needed, then tokenize them |
| 1492 | * this supports these cases: |
| 1493 | * - "prompt": "string" |
| 1494 | * - "prompt": [12, 34, 56] |
| 1495 | * - "prompt": [12, 34, "string", 56, 78] |
| 1496 | * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } |
| 1497 | * and multiple prompts (multi-tasks): |
| 1498 | * - "prompt": ["string1", "string2"] |
| 1499 | * - "prompt": ["string1", [12, 34, 56]] |
| 1500 | * - "prompt": [[12, 34, 56], [78, 90, 12]] |
| 1501 | * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] |
| 1502 | */ |
| 1503 | static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { |
| 1504 | std::vector<server_tokens> result; |
| 1505 | if (json_prompt.is_array() && !json_is_array_and_contains_numbers(data: json_prompt)) { |
| 1506 | result.reserve(n: json_prompt.size()); |
| 1507 | for (const auto & p : json_prompt) { |
| 1508 | result.push_back(x: tokenize_input_subprompt(vocab, mctx, json_prompt: p,add_special, parse_special)); |
| 1509 | } |
| 1510 | } else { |
| 1511 | result.push_back(x: tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); |
| 1512 | } |
| 1513 | if (result.empty()) { |
| 1514 | throw std::runtime_error("\"prompt\" must not be empty" ); |
| 1515 | } |
| 1516 | return result; |
| 1517 | } |
| 1518 | |
| 1519 | // format rerank task: [BOS]query[EOS][SEP]doc[EOS]. |
| 1520 | static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) { |
| 1521 | server_tokens result = {}; |
| 1522 | |
| 1523 | const char * rerank_prompt = llama_model_chat_template(model, name: "rerank" ); |
| 1524 | |
| 1525 | if (rerank_prompt != nullptr) { |
| 1526 | std::string prompt = rerank_prompt; |
| 1527 | string_replace_all(s&: prompt, search: "{query}" , replace: query); |
| 1528 | string_replace_all(s&: prompt, search: "{document}" , replace: doc ); |
| 1529 | server_tokens tokens = tokenize_input_subprompt(vocab, mctx, json_prompt: prompt, add_special: false, parse_special: true); |
| 1530 | result.push_back(tokens); |
| 1531 | } else { |
| 1532 | // Get EOS token - use SEP token as fallback if EOS is not available |
| 1533 | server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, json_prompt: query, add_special: false, parse_special: false); |
| 1534 | server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, json_prompt: doc, add_special: false, parse_special: false); |
| 1535 | llama_token eos_token = llama_vocab_eos(vocab); |
| 1536 | if (eos_token == LLAMA_TOKEN_NULL) { |
| 1537 | eos_token = llama_vocab_sep(vocab); |
| 1538 | } |
| 1539 | |
| 1540 | if (llama_vocab_get_add_bos(vocab)) { |
| 1541 | result.push_back(tok: llama_vocab_bos(vocab)); |
| 1542 | } |
| 1543 | result.push_back(tokens&: query_tokens); |
| 1544 | if (llama_vocab_get_add_eos(vocab)) { |
| 1545 | result.push_back(tok: eos_token); |
| 1546 | } |
| 1547 | if (llama_vocab_get_add_sep(vocab)) { |
| 1548 | result.push_back(tok: llama_vocab_sep(vocab)); |
| 1549 | } |
| 1550 | result.push_back(tokens&: doc_tokens); |
| 1551 | if (llama_vocab_get_add_eos(vocab)) { |
| 1552 | result.push_back(tok: eos_token); |
| 1553 | } |
| 1554 | } |
| 1555 | |
| 1556 | return result; |
| 1557 | } |
| 1558 | |