1#pragma once
2
3#include "common.h"
4#include "log.h"
5#include "llama.h"
6#include "arg.h" // common_remote_get_content
7#include "base64.hpp"
8#include "mtmd.h"
9#include "mtmd-helper.h"
10#include "chat.h"
11
12// increase max payload length to allow use of larger context size
13#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
14// increase backlog size to avoid connection resets for >> 1 slots
15#define CPPHTTPLIB_LISTEN_BACKLOG 512
16// increase max URI length to handle longer prompts in query string
17#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
18// disable Nagle's algorithm
19#define CPPHTTPLIB_TCP_NODELAY true
20#include <cpp-httplib/httplib.h>
21
22#define JSON_ASSERT GGML_ASSERT
23#include <nlohmann/json.hpp>
24
25#include <random>
26#include <sstream>
27#include <string>
28#include <vector>
29#include <memory>
30#include <cinttypes>
31
32#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
33
34using json = nlohmann::ordered_json;
35
36#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
37#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
38#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
39#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
40
41#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
42#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
43#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
44#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
45
46#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
47#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
48#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
49#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
50
51using raw_buffer = std::vector<uint8_t>;
52
53template <typename T>
54static T json_value(const json & body, const std::string & key, const T & default_value) {
55 // Fallback null to default value
56 if (body.contains(key) && !body.at(key).is_null()) {
57 try {
58 return body.at(key);
59 } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) {
60 LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what());
61 return default_value;
62 }
63 } else {
64 return default_value;
65 }
66}
67
68const static std::string build_info("b" + std::to_string(val: LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
69
70// thin wrapper around common_grammar_trigger with (de)serialization functions
71struct server_grammar_trigger {
72 common_grammar_trigger value;
73
74 server_grammar_trigger() = default;
75 server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
76 server_grammar_trigger(const json & in) {
77 value.type = (common_grammar_trigger_type) in.at(key: "type").get<int>();
78 value.value = in.at(key: "value").get<std::string>();
79 if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
80 value.token = (llama_token) in.at(key: "token").get<int>();
81 }
82 }
83
84 json to_json() const {
85 json out {
86 {"type", (int) value.type},
87 {"value", value.value},
88 };
89 if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
90 out["token"] = (int) value.token;
91 }
92 return out;
93 }
94};
95
96//
97// tokenizer and input processing utils
98//
99
100static bool json_is_array_of_numbers(const json & data) {
101 if (data.is_array()) {
102 for (const auto & e : data) {
103 if (!e.is_number_integer()) {
104 return false;
105 }
106 }
107 return true;
108 }
109 return false;
110}
111
112// is array having BOTH numbers & strings?
113static bool json_is_array_of_mixed_numbers_strings(const json & data) {
114 bool seen_string = false;
115 bool seen_number = false;
116 if (data.is_array()) {
117 for (const auto & e : data) {
118 seen_string |= e.is_string();
119 seen_number |= e.is_number_integer();
120 if (seen_number && seen_string) {
121 return true;
122 }
123 }
124 }
125 return false;
126}
127
128// does array have any individual integers/tokens?
129static bool json_is_array_and_contains_numbers(const json & data) {
130 if (data.is_array()) {
131 for (const auto & e : data) {
132 if (e.is_number_integer()) {
133 return true;
134 }
135 }
136 return false;
137 }
138 return false;
139}
140
141// get value by path(key1 / key2)
142static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
143 json result = json::object();
144
145 for (const std::string & path : paths) {
146 json current = js;
147 const auto keys = string_split<std::string>(input: path, /*separator*/ separator: '/');
148 bool valid_path = true;
149 for (const std::string & k : keys) {
150 if (valid_path && current.is_object() && current.contains(key: k)) {
151 current = current[k];
152 } else {
153 valid_path = false;
154 }
155 }
156 if (valid_path) {
157 result[path] = current;
158 }
159 }
160 return result;
161}
162
163/**
164 * this handles 2 cases:
165 * - only string, example: "string"
166 * - mixed string and tokens, example: [12, 34, "string", 56, 78]
167 */
168static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
169 // If `add_bos` is true, we only add BOS, when json_prompt is a string,
170 // or the first element of the json_prompt array is a string.
171 llama_tokens prompt_tokens;
172
173 if (json_prompt.is_array()) {
174 bool first = true;
175 for (const auto & p : json_prompt) {
176 if (p.is_string()) {
177 auto s = p.template get<std::string>();
178
179 llama_tokens p;
180 if (first) {
181 p = common_tokenize(vocab, text: s, add_special, parse_special);
182 first = false;
183 } else {
184 p = common_tokenize(vocab, text: s, add_special: false, parse_special);
185 }
186
187 prompt_tokens.insert(position: prompt_tokens.end(), first: p.begin(), last: p.end());
188 } else {
189 if (first) {
190 first = false;
191 }
192
193 prompt_tokens.push_back(x: p.template get<llama_token>());
194 }
195 }
196 } else {
197 auto s = json_prompt.template get<std::string>();
198 prompt_tokens = common_tokenize(vocab, text: s, add_special, parse_special);
199 }
200
201 return prompt_tokens;
202}
203
204// return the last index of character that can form a valid string
205// if the last character is potentially cut in half, return the index before the cut
206// if validate_utf8(text) == text.size(), then the whole text is valid utf8
207static size_t validate_utf8(const std::string& text) {
208 size_t len = text.size();
209 if (len == 0) return 0;
210
211 // Check the last few bytes to see if a multi-byte character is cut off
212 for (size_t i = 1; i <= 4 && i <= len; ++i) {
213 unsigned char c = text[len - i];
214 // Check for start of a multi-byte sequence from the end
215 if ((c & 0xE0) == 0xC0) {
216 // 2-byte character start: 110xxxxx
217 // Needs at least 2 bytes
218 if (i < 2) return len - i;
219 } else if ((c & 0xF0) == 0xE0) {
220 // 3-byte character start: 1110xxxx
221 // Needs at least 3 bytes
222 if (i < 3) return len - i;
223 } else if ((c & 0xF8) == 0xF0) {
224 // 4-byte character start: 11110xxx
225 // Needs at least 4 bytes
226 if (i < 4) return len - i;
227 }
228 }
229
230 // If no cut-off multi-byte character is found, return full length
231 return len;
232}
233
234//
235// template utils
236//
237
238// format infill task
239static llama_tokens format_infill(
240 const llama_vocab * vocab,
241 const json & input_prefix,
242 const json & input_suffix,
243 const json & input_extra,
244 const int n_batch,
245 const int n_predict,
246 const int n_ctx,
247 const bool spm_infill,
248 const llama_tokens & tokens_prompt
249 ) {
250 // TODO: optimize this block by reducing memory allocations and movement
251
252 // use FIM repo-level pattern:
253 // ref: https://arxiv.org/pdf/2409.12186
254 //
255 // [FIM_REP]myproject
256 // [FIM_SEP]filename0
257 // extra chunk 0
258 // [FIM_SEP]filename1
259 // extra chunk 1
260 // ...
261 // [FIM_SEP]filename
262 // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
263 //
264 llama_tokens extra_tokens;
265 extra_tokens.reserve(n: n_ctx);
266
267 auto tokens_prefix = tokenize_mixed(vocab, json_prompt: input_prefix, add_special: false, parse_special: false);
268 auto tokens_suffix = tokenize_mixed(vocab, json_prompt: input_suffix, add_special: false, parse_special: false);
269
270 if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
271 // TODO: make project name an input
272 static const auto k_fim_repo = common_tokenize(vocab, text: "myproject\n", add_special: false, parse_special: false);
273
274 extra_tokens.push_back(x: llama_vocab_fim_rep(vocab));
275 extra_tokens.insert(position: extra_tokens.end(), first: k_fim_repo.begin(), last: k_fim_repo.end());
276 }
277 for (const auto & chunk : input_extra) {
278 // { "text": string, "filename": string }
279 const std::string text = json_value(body: chunk, key: "text", default_value: std::string());
280 const std::string filename = json_value(body: chunk, key: "filename", default_value: std::string("tmp"));
281
282 if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
283 const auto k_fim_file = common_tokenize(vocab, text: filename + "\n", add_special: false, parse_special: false);
284
285 extra_tokens.insert(position: extra_tokens.end(), x: llama_vocab_fim_sep(vocab));
286 extra_tokens.insert(position: extra_tokens.end(), first: k_fim_file.begin(), last: k_fim_file.end());
287 } else {
288 // chunk separator in binary form to avoid confusing the AI
289 static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
290 static const auto k_chunk_prefix_tokens = common_tokenize(vocab, text: k_chunk_prefix_str, add_special: false, parse_special: false);
291
292 extra_tokens.insert(position: extra_tokens.end(), first: k_chunk_prefix_tokens.begin(), last: k_chunk_prefix_tokens.end());
293 }
294
295 const auto chunk_tokens = common_tokenize(vocab, text, add_special: false, parse_special: false);
296 extra_tokens.insert(position: extra_tokens.end(), first: chunk_tokens.begin(), last: chunk_tokens.end());
297 }
298
299 if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
300 // TODO: current filename
301 static const auto k_fim_file = common_tokenize(vocab, text: "filename\n", add_special: false, parse_special: false);
302
303 extra_tokens.insert(position: extra_tokens.end(), x: llama_vocab_fim_sep(vocab));
304 extra_tokens.insert(position: extra_tokens.end(), first: k_fim_file.begin(), last: k_fim_file.end());
305 }
306
307 // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
308 const int n_prefix_take = std::min<int>(a: tokens_prefix.size(), b: 3*(n_batch/4));
309 const int n_suffix_take = std::min<int>(a: tokens_suffix.size(), b: std::max<int>(a: 0, b: (n_batch/4) - (2 + tokens_prompt.size())));
310
311 SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
312
313 // fill the rest of the context with extra chunks
314 const int n_extra_take = std::min<int>(a: std::max<int>(a: 0, b: n_ctx - (n_batch) - 2*n_predict), b: extra_tokens.size());
315
316 tokens_prefix.erase(first: tokens_prefix.begin(), last: tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
317 tokens_suffix.resize(new_size: n_suffix_take);
318
319 tokens_prefix.insert(position: tokens_prefix.begin(), x: llama_vocab_fim_pre(vocab));
320 tokens_prefix.insert(position: tokens_prefix.end(), first: tokens_prompt.begin(), last: tokens_prompt.end());
321 tokens_suffix.insert(position: tokens_suffix.begin(), x: llama_vocab_fim_suf(vocab));
322
323 auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
324 auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
325
326 if (llama_vocab_get_add_bos(vocab)) {
327 embd_inp.insert(position: embd_inp.begin(), x: llama_vocab_bos(vocab));
328 }
329
330 SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
331
332 // put the extra context before the FIM prefix
333 embd_inp.insert(position: embd_inp.begin(), first: extra_tokens.end() - n_extra_take, last: extra_tokens.end());
334
335 embd_inp.insert(position: embd_inp.end(), first: embd_end.begin(), last: embd_end.end());
336 embd_inp.push_back(x: llama_vocab_fim_mid(vocab));
337
338 return embd_inp;
339}
340
341//
342// base64 utils (TODO: move to common in the future)
343//
344
345static const std::string base64_chars =
346 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
347 "abcdefghijklmnopqrstuvwxyz"
348 "0123456789+/";
349
350static inline bool is_base64(uint8_t c) {
351 return (isalnum(c) || (c == '+') || (c == '/'));
352}
353
354static inline raw_buffer base64_decode(const std::string & encoded_string) {
355 int i = 0;
356 int j = 0;
357 int in_ = 0;
358
359 int in_len = encoded_string.size();
360
361 uint8_t char_array_4[4];
362 uint8_t char_array_3[3];
363
364 raw_buffer ret;
365
366 while (in_len-- && (encoded_string[in_] != '=') && is_base64(c: encoded_string[in_])) {
367 char_array_4[i++] = encoded_string[in_]; in_++;
368 if (i == 4) {
369 for (i = 0; i < 4; i++) {
370 char_array_4[i] = base64_chars.find(c: char_array_4[i]);
371 }
372
373 char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
374 char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
375 char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
376
377 for (i = 0; (i < 3); i++) {
378 ret.push_back(x: char_array_3[i]);
379 }
380
381 i = 0;
382 }
383 }
384
385 if (i) {
386 for (j = i; j < 4; j++) {
387 char_array_4[j] = 0;
388 }
389
390 for (j = 0; j < 4; j++) {
391 char_array_4[j] = base64_chars.find(c: char_array_4[j]);
392 }
393
394 char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
395 char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
396 char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
397
398 for (j = 0; j < i - 1; j++) {
399 ret.push_back(x: char_array_3[j]);
400 }
401 }
402
403 return ret;
404}
405
406//
407// random string / id
408//
409
410static std::string random_string() {
411 static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
412
413 std::random_device rd;
414 std::mt19937 generator(rd());
415
416 std::string result(32, ' ');
417
418 for (int i = 0; i < 32; ++i) {
419 result[i] = str[generator() % str.size()];
420 }
421
422 return result;
423}
424
425static std::string gen_chatcmplid() {
426 return "chatcmpl-" + random_string();
427}
428
429static std::string gen_tool_call_id() {
430 return random_string();
431}
432
433//
434// other common utils
435//
436
437// TODO: reuse llama_detokenize
438template <class Iter>
439static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
440 std::string ret;
441 for (; begin != end; ++begin) {
442 ret += common_token_to_piece(ctx, *begin);
443 }
444
445 return ret;
446}
447
448// format incomplete utf-8 multibyte character for output
449static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
450 std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token);
451
452 // if the size is 1 and first bit is 1, meaning it's a partial character
453 // (size > 1 meaning it's already a known token)
454 if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
455 std::stringstream ss;
456 ss << std::hex << (out[0] & 0xff);
457 std::string res(ss.str());
458 out = "byte: \\x" + res;
459 }
460
461 return out;
462}
463
464static bool server_sent_event(httplib::DataSink & sink, const json & data) {
465 const std::string str =
466 "data: " +
467 data.dump(indent: -1, indent_char: ' ', ensure_ascii: false, error_handler: json::error_handler_t::replace) +
468 "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
469
470 LOG_DBG("data stream, to_send: %s", str.c_str());
471
472 return sink.write(str.c_str(), str.size());
473}
474
475//
476// OAI utils
477//
478
479// used by /completions endpoint
480static json oaicompat_completion_params_parse(const json & body) {
481 json llama_params;
482
483 if (!body.contains(key: "prompt")) {
484 throw std::runtime_error("\"prompt\" is required");
485 }
486
487 // Handle "stop" field
488 if (body.contains(key: "stop") && body.at(key: "stop").is_string()) {
489 llama_params["stop"] = json::array(init: {body.at(key: "stop").get<std::string>()});
490 } else {
491 llama_params["stop"] = json_value(body, key: "stop", default_value: json::array());
492 }
493
494 // Handle "n" field
495 int n_choices = json_value(body, key: "n", default_value: 1);
496 if (n_choices != 1) {
497 throw std::runtime_error("Only one completion choice is allowed");
498 }
499
500 // Handle "echo" field
501 if (json_value(body, key: "echo", default_value: false)) {
502 throw std::runtime_error("Only no echo is supported");
503 }
504
505 // Params supported by OAI but unsupported by llama.cpp
506 static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
507 for (const auto & param : unsupported_params) {
508 if (body.contains(key: param)) {
509 throw std::runtime_error("Unsupported param: " + param);
510 }
511 }
512
513 // Copy remaining properties to llama_params
514 for (const auto & item : body.items()) {
515 // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
516 if (!llama_params.contains(key: item.key()) || item.key() == "n_predict") {
517 llama_params[item.key()] = item.value();
518 }
519 }
520
521 return llama_params;
522}
523
524struct oaicompat_parser_options {
525 bool use_jinja;
526 bool prefill_assistant;
527 common_reasoning_format reasoning_format;
528 std::map<std::string,std::string> chat_template_kwargs;
529 common_chat_templates * tmpls;
530 bool allow_image;
531 bool allow_audio;
532 bool enable_thinking = true;
533};
534
535// used by /chat/completions endpoint
536static json oaicompat_chat_params_parse(
537 json & body, /* openai api json semantics */
538 const oaicompat_parser_options & opt,
539 std::vector<raw_buffer> & out_files)
540{
541 json llama_params;
542
543 auto tools = json_value(body, key: "tools", default_value: json());
544 auto has_tools = tools.is_array() && !tools.empty();
545 auto stream = json_value(body, key: "stream", default_value: false);
546 auto tool_choice = json_value(body, key: "tool_choice", default_value: std::string("auto"));
547
548 if (!opt.use_jinja) {
549 if (has_tools) {
550 throw std::runtime_error("tools param requires --jinja flag");
551 }
552 if (tool_choice != "auto") {
553 throw std::runtime_error("tool_choice param requires --jinja flag");
554 }
555 }
556
557 // Handle "stop" field
558 if (body.contains(key: "stop") && body.at(key: "stop").is_string()) {
559 llama_params["stop"] = json::array(init: {body.at(key: "stop").get<std::string>()});
560 } else {
561 llama_params["stop"] = json_value(body, key: "stop", default_value: json::array());
562 }
563
564 auto json_schema = json_value(body, key: "json_schema", default_value: json());
565 auto grammar = json_value(body, key: "grammar", default_value: std::string());
566 if (!json_schema.is_null() && !grammar.empty()) {
567 throw std::runtime_error("Cannot use both json_schema and grammar");
568 }
569
570 // Handle "response_format" field
571 if (body.contains(key: "response_format")) {
572 json response_format = json_value(body, key: "response_format", default_value: json::object());
573 std::string response_type = json_value(body: response_format, key: "type", default_value: std::string());
574 if (response_type == "json_object") {
575 json_schema = json_value(body: response_format, key: "schema", default_value: json::object());
576 } else if (response_type == "json_schema") {
577 auto schema_wrapper = json_value(body: response_format, key: "json_schema", default_value: json::object());
578 json_schema = json_value(body: schema_wrapper, key: "schema", default_value: json::object());
579 } else if (!response_type.empty() && response_type != "text") {
580 throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
581 }
582 }
583
584 // get input files
585 if (!body.contains(key: "messages")) {
586 throw std::runtime_error("'messages' is required");
587 }
588 json & messages = body.at(key: "messages");
589 if (!messages.is_array()) {
590 throw std::runtime_error("Expected 'messages' to be an array");
591 }
592 for (auto & msg : messages) {
593 std::string role = json_value(body: msg, key: "role", default_value: std::string());
594 if (role != "assistant" && !msg.contains(key: "content")) {
595 throw std::runtime_error("All non-assistant messages must contain 'content'");
596 }
597 if (role == "assistant") {
598 if (!msg.contains(key: "content") && !msg.contains(key: "tool_calls")) {
599 throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
600 }
601 if (!msg.contains(key: "content")) {
602 continue; // avoid errors with no content
603 }
604 }
605 json & content = msg.at(key: "content");
606 if (content.is_string() || content.is_null()) {
607 continue;
608 }
609
610 if (!content.is_array()) {
611 throw std::runtime_error("Expected 'content' to be a string or an array");
612 }
613
614 for (auto & p : content) {
615 std::string type = json_value(body: p, key: "type", default_value: std::string());
616 if (type == "image_url") {
617 if (!opt.allow_image) {
618 throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
619 }
620
621 json image_url = json_value(body: p, key: "image_url", default_value: json::object());
622 std::string url = json_value(body: image_url, key: "url", default_value: std::string());
623 if (string_starts_with(str: url, prefix: "http")) {
624 // download remote image
625 // TODO @ngxson : maybe make these params configurable
626 common_remote_params params;
627 params.headers.push_back(x: "User-Agent: llama.cpp/" + build_info);
628 params.max_size = 1024 * 1024 * 10; // 10MB
629 params.timeout = 10; // seconds
630 SRV_INF("downloading image from '%s'\n", url.c_str());
631 auto res = common_remote_get_content(url, params);
632 if (200 <= res.first && res.first < 300) {
633 SRV_INF("downloaded %ld bytes\n", res.second.size());
634 raw_buffer data;
635 data.insert(position: data.end(), first: res.second.begin(), last: res.second.end());
636 out_files.push_back(x: data);
637 } else {
638 throw std::runtime_error("Failed to download image");
639 }
640
641 } else {
642 // try to decode base64 image
643 std::vector<std::string> parts = string_split<std::string>(input: url, /*separator*/ separator: ',');
644 if (parts.size() != 2) {
645 throw std::runtime_error("Invalid image_url.url value");
646 } else if (!string_starts_with(str: parts[0], prefix: "data:image/")) {
647 throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
648 } else if (!string_ends_with(str: parts[0], suffix: "base64")) {
649 throw std::runtime_error("image_url.url must be base64 encoded");
650 } else {
651 auto base64_data = parts[1];
652 auto decoded_data = base64_decode(encoded_string: base64_data);
653 out_files.push_back(x: decoded_data);
654 }
655 }
656
657 // replace this chunk with a marker
658 p["type"] = "text";
659 p["text"] = mtmd_default_marker();
660 p.erase(key: "image_url");
661
662 } else if (type == "input_audio") {
663 if (!opt.allow_audio) {
664 throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
665 }
666
667 json input_audio = json_value(body: p, key: "input_audio", default_value: json::object());
668 std::string data = json_value(body: input_audio, key: "data", default_value: std::string());
669 std::string format = json_value(body: input_audio, key: "format", default_value: std::string());
670 // while we also support flac, we don't allow it here so we matches the OAI spec
671 if (format != "wav" && format != "mp3") {
672 throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
673 }
674 auto decoded_data = base64_decode(encoded_string: data); // expected to be base64 encoded
675 out_files.push_back(x: decoded_data);
676
677 // replace this chunk with a marker
678 p["type"] = "text";
679 p["text"] = mtmd_default_marker();
680 p.erase(key: "input_audio");
681
682 } else if (type != "text") {
683 throw std::runtime_error("unsupported content[].type");
684 }
685 }
686 }
687
688 common_chat_templates_inputs inputs;
689 inputs.messages = common_chat_msgs_parse_oaicompat(messages);
690 inputs.tools = common_chat_tools_parse_oaicompat(tools);
691 inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
692 inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
693 inputs.grammar = grammar;
694 inputs.use_jinja = opt.use_jinja;
695 inputs.parallel_tool_calls = json_value(body, key: "parallel_tool_calls", default_value: false);
696 inputs.add_generation_prompt = json_value(body, key: "add_generation_prompt", default_value: true);
697 inputs.reasoning_format = opt.reasoning_format;
698 inputs.enable_thinking = opt.enable_thinking;
699 if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
700 if (body.contains(key: "grammar")) {
701 throw std::runtime_error("Cannot use custom grammar constraints with tools.");
702 }
703 llama_params["parse_tool_calls"] = true;
704 }
705
706 // merge the template args provided from command line with the args provided in the user request
707 auto chat_template_kwargs_object = json_value(body, key: "chat_template_kwargs", default_value: json::object());
708 inputs.chat_template_kwargs = opt.chat_template_kwargs;
709 for (const auto & item : chat_template_kwargs_object.items()) {
710 inputs.chat_template_kwargs[item.key()] = item.value().dump();
711 }
712
713 // parse the "enable_thinking" kwarg to override the default value
714 auto enable_thinking_kwarg = json_value(body: inputs.chat_template_kwargs, key: "enable_thinking", default_value: std::string(""));
715 if (enable_thinking_kwarg == "true") {
716 inputs.enable_thinking = true;
717 } else if (enable_thinking_kwarg == "false") {
718 inputs.enable_thinking = false;
719 } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') {
720 throw std::runtime_error("invalid type for \"enable_thinking\" (expected boolean, got string)");
721 }
722
723 // if the assistant message appears at the end of list, we do not add end-of-turn token
724 // for ex. this can be useful to modify the reasoning process in reasoning models
725 bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
726 common_chat_msg last_message;
727 if (prefill_assistant_message) {
728 last_message = inputs.messages.back();
729 inputs.messages.pop_back();
730
731 /* sanity check, max one assistant message at the end of the list */
732 if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
733 throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
734 }
735
736 /* TODO: test this properly */
737 inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
738
739 if ( inputs.enable_thinking ) {
740 throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking.");
741 }
742
743 inputs.add_generation_prompt = true;
744 }
745
746 // Apply chat template to the list of messages
747 auto chat_params = common_chat_templates_apply(tmpls: opt.tmpls, inputs);
748
749 /* Append assistant prefilled message */
750 if (prefill_assistant_message) {
751 if (!last_message.content_parts.empty()) {
752 for (auto & p : last_message.content_parts) {
753 chat_params.prompt += p.text;
754 }
755 } else {
756 chat_params.prompt += last_message.content;
757 }
758 }
759
760 llama_params["chat_format"] = static_cast<int>(chat_params.format);
761 llama_params["prompt"] = chat_params.prompt;
762 if (!chat_params.grammar.empty()) {
763 llama_params["grammar"] = chat_params.grammar;
764 }
765 llama_params["grammar_lazy"] = chat_params.grammar_lazy;
766 auto grammar_triggers = json::array();
767 for (const auto & trigger : chat_params.grammar_triggers) {
768 server_grammar_trigger ct(trigger);
769 grammar_triggers.push_back(val: ct.to_json());
770 }
771 llama_params["grammar_triggers"] = grammar_triggers;
772 llama_params["preserved_tokens"] = chat_params.preserved_tokens;
773 llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
774 for (const auto & stop : chat_params.additional_stops) {
775 llama_params["stop"].push_back(val: stop);
776 }
777
778 // Handle "n" field
779 int n_choices = json_value(body, key: "n", default_value: 1);
780 if (n_choices != 1) {
781 throw std::runtime_error("Only one completion choice is allowed");
782 }
783
784 // Handle "logprobs" field
785 // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
786 if (json_value(body, key: "logprobs", default_value: false)) {
787 if (has_tools && stream) {
788 throw std::runtime_error("logprobs is not supported with tools + stream");
789 }
790 llama_params["n_probs"] = json_value(body, key: "top_logprobs", default_value: 20);
791 } else if (body.contains(key: "top_logprobs") && !body.at(key: "top_logprobs").is_null()) {
792 throw std::runtime_error("top_logprobs requires logprobs to be set to true");
793 }
794
795 // Copy remaining properties to llama_params
796 // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
797 // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
798 for (const auto & item : body.items()) {
799 // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
800 if (!llama_params.contains(key: item.key()) || item.key() == "n_predict") {
801 llama_params[item.key()] = item.value();
802 }
803 }
804
805 return llama_params;
806}
807
808static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
809 json data = json::array();
810 int32_t n_tokens = 0;
811 int i = 0;
812 for (const auto & elem : embeddings) {
813 json embedding_obj;
814
815 if (use_base64) {
816 const auto& vec = json_value(body: elem, key: "embedding", default_value: json::array()).get<std::vector<float>>();
817 const char* data_ptr = reinterpret_cast<const char*>(vec.data());
818 size_t data_size = vec.size() * sizeof(float);
819 embedding_obj = {
820 {"embedding", base64::encode(buffer: data_ptr, size: data_size)},
821 {"index", i++},
822 {"object", "embedding"},
823 {"encoding_format", "base64"}
824 };
825 } else {
826 embedding_obj = {
827 {"embedding", json_value(body: elem, key: "embedding", default_value: json::array())},
828 {"index", i++},
829 {"object", "embedding"}
830 };
831 }
832 data.push_back(val: embedding_obj);
833
834 n_tokens += json_value(body: elem, key: "tokens_evaluated", default_value: 0);
835 }
836
837 json res = json {
838 {"model", json_value(body: request, key: "model", default_value: std::string(DEFAULT_OAICOMPAT_MODEL))},
839 {"object", "list"},
840 {"usage", json {
841 {"prompt_tokens", n_tokens},
842 {"total_tokens", n_tokens}
843 }},
844 {"data", data}
845 };
846
847 return res;
848}
849
850static json format_response_rerank(
851 const json & request,
852 const json & ranks,
853 bool is_tei_format,
854 std::vector<std::string> & texts,
855 int top_n) {
856 int32_t n_tokens = 0;
857 bool return_text = is_tei_format && json_value(body: request, key: "return_text", default_value: false);
858 std::vector<json> elements; // Temporary vector to hold unsorted elements
859 std::string score_label = is_tei_format ? "score" : "relevance_score";
860 for (const auto & rank : ranks) {
861 int index = json_value(body: rank, key: "index", default_value: 0);
862 json elem = json{
863 {"index", index},
864 {score_label, json_value(body: rank, key: "score", default_value: 0.0)},
865 };
866 n_tokens += json_value(body: rank, key: "tokens_evaluated", default_value: 0);
867 if (return_text) {
868 elem["text"] = std::move(texts[index]);
869 }
870 elements.push_back(x: elem);
871 }
872
873 std::sort(first: elements.begin(), last: elements.end(), comp: [score_label](const json& a, const json& b) {
874 return json_value(body: a, key: score_label, default_value: 0.0) > json_value(body: b, key: score_label, default_value: 0.0);
875 });
876
877 elements.resize(new_size: std::min(a: top_n, b: (int)elements.size()));
878 json results = elements;
879
880 if (is_tei_format) return results;
881
882 json res = json{
883 {"model", json_value(body: request, key: "model", default_value: std::string(DEFAULT_OAICOMPAT_MODEL))},
884 {"object", "list"},
885 {"usage", json{
886 {"prompt_tokens", n_tokens},
887 {"total_tokens", n_tokens}
888 }},
889 {"results", results}
890 };
891
892 return res;
893}
894
895static bool is_valid_utf8(const std::string & str) {
896 const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
897 const unsigned char* end = bytes + str.length();
898
899 while (bytes < end) {
900 if (*bytes <= 0x7F) {
901 // 1-byte sequence (0xxxxxxx)
902 bytes++;
903 } else if ((*bytes & 0xE0) == 0xC0) {
904 // 2-byte sequence (110xxxxx 10xxxxxx)
905 if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80)
906 return false;
907 bytes += 2;
908 } else if ((*bytes & 0xF0) == 0xE0) {
909 // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx)
910 if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80)
911 return false;
912 bytes += 3;
913 } else if ((*bytes & 0xF8) == 0xF0) {
914 // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
915 if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 ||
916 (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80)
917 return false;
918 bytes += 4;
919 } else {
920 // Invalid UTF-8 lead byte
921 return false;
922 }
923 }
924
925 return true;
926}
927
928static json format_tokenizer_response(const json & tokens) {
929 return json {
930 {"tokens", tokens}
931 };
932}
933
934static json format_detokenized_response(const std::string & content) {
935 return json {
936 {"content", content}
937 };
938}
939
940static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
941 json data = json::array();
942 for (const auto & lb : logit_bias) {
943 data.push_back(val: json{
944 {"bias", lb.bias},
945 {"token", lb.token},
946 });
947 }
948 return data;
949}
950
951static std::string safe_json_to_str(const json & data) {
952 return data.dump(indent: -1, indent_char: ' ', ensure_ascii: false, error_handler: json::error_handler_t::replace);
953}
954
955static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
956 std::vector<llama_token_data> cur;
957 const auto * logits = llama_get_logits_ith(ctx, i: idx);
958
959 const llama_model * model = llama_get_model(ctx);
960 const llama_vocab * vocab = llama_model_get_vocab(model);
961
962 const int n_vocab = llama_vocab_n_tokens(vocab);
963
964 cur.resize(new_size: n_vocab);
965 for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
966 cur[token_id] = llama_token_data{.id: token_id, .logit: logits[token_id], .p: 0.0f};
967 }
968
969 // sort tokens by logits
970 std::sort(first: cur.begin(), last: cur.end(), comp: [](const llama_token_data & a, const llama_token_data & b) {
971 return a.logit > b.logit;
972 });
973
974 // apply softmax
975 float max_l = cur[0].logit;
976 float cum_sum = 0.0f;
977 for (size_t i = 0; i < cur.size(); ++i) {
978 float p = expf(x: cur[i].logit - max_l);
979 cur[i].p = p;
980 cum_sum += p;
981 }
982 for (size_t i = 0; i < cur.size(); ++i) {
983 cur[i].p /= cum_sum;
984 }
985
986 return cur;
987}
988
989static bool are_lora_equal(
990 const std::vector<common_adapter_lora_info> & l1,
991 const std::vector<common_adapter_lora_info> & l2) {
992 if (l1.size() != l2.size()) {
993 return false;
994 }
995 for (size_t i = 0; i < l1.size(); ++i) {
996 // we don't check lora.path to reduce the time complexity
997 if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
998 return false;
999 }
1000 }
1001 return true;
1002}
1003
1004// get the ids of all enabled loras
1005static std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras) {
1006 std::vector<size_t> enabled_ids;
1007 for (size_t i = 0; i < loras.size(); ++i) {
1008 if (loras[i].scale > 0) {
1009 enabled_ids.push_back(x: i);
1010 }
1011 }
1012 return enabled_ids;
1013}
1014
1015// check whether the given lora set has only aloras activated (empty => false)
1016static bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras) {
1017 bool found_alora = false;
1018 for (const auto & lora : loras) {
1019 if (lora.scale != 0) {
1020 if (llama_adapter_get_alora_n_invocation_tokens(adapter: lora.ptr) == 0) {
1021 return false;
1022 }
1023 found_alora = true;
1024 }
1025 }
1026 return found_alora;
1027}
1028
1029// if the two sets of loras are different, they require a cache clear unless the
1030// change is only from aloras to aloras.
1031static bool lora_should_clear_cache(
1032 const std::vector<common_adapter_lora_info> & current,
1033 const std::vector<common_adapter_lora_info> & next) {
1034
1035 // This should always be called after determining that the two sets are
1036 // _not_ equal. This assert is therefore some slightly wasted work and
1037 // should be safe to remove as long as this method is called correctly.
1038 GGML_ASSERT(!are_lora_equal(current, next));
1039
1040 return (
1041 !(lora_get_enabled_ids(loras: current).empty() || lora_all_alora(loras: current)) ||
1042 !lora_all_alora(loras: next));
1043}
1044
1045// parse lora config from JSON request, returned a copy of lora_base with updated scale
1046static std::vector<common_adapter_lora_info> parse_lora_request(
1047 const std::vector<common_adapter_lora_info> & lora_base,
1048 const json & data) {
1049 std::vector<common_adapter_lora_info> lora(lora_base);
1050 int max_idx = lora.size();
1051
1052 // clear existing value
1053 for (auto & entry : lora) {
1054 entry.scale = 0.0f;
1055 }
1056
1057 // set value
1058 for (const auto & entry : data) {
1059 int id = json_value(body: entry, key: "id", default_value: -1);
1060 float scale = json_value(body: entry, key: "scale", default_value: 0.0f);
1061 if (0 <= id && id < max_idx) {
1062 lora[id].scale = scale;
1063 } else {
1064 throw std::runtime_error("invalid adapter id");
1065 }
1066 }
1067
1068 return lora;
1069}
1070
1071//
1072// utils for interacting with libmtmd
1073// (may need to refactor in near future)
1074//
1075
1076/**
1077 * server_tokens is a helper to manage the input tokens and image for the server.
1078 * it is made this way to simplify the logic of KV cache management.
1079 */
1080struct server_tokens {
1081 bool has_mtmd = false;
1082
1083private: // disallow accessing these members directly, risking out-of-sync
1084
1085 // map a **start** index in tokens to the image chunk
1086 // note: the order need to be in-sync with tokens
1087 std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
1088
1089 // list of tokens
1090 // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
1091 // otherwise, it is a normal text token
1092 // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
1093 // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
1094 llama_tokens tokens;
1095
1096 // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
1097 // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
1098 // idx 0 1 2 3 4 5 6 7 8 9 10
1099 // pos 0 1 2 3 4 5 5 5 7 7 7
1100 // map_idx_to_media will contain: {5, img0}, {8, img1}
1101
1102public:
1103 server_tokens() = default;
1104 ~server_tokens() = default;
1105
1106 // Prevent copying
1107 // TODO: server_tokens should be copyable - remove this:
1108 server_tokens(const server_tokens&) = delete;
1109 server_tokens& operator=(const server_tokens&) = delete;
1110
1111 // Allow moving (usually implicitly generated if members are movable)
1112 server_tokens(server_tokens&&) = default;
1113 server_tokens& operator=(server_tokens&&) = default;
1114
1115 // Allow accessing elements using [] operator
1116 llama_token operator[](size_t index) { return tokens[index]; }
1117 const llama_token& operator[](size_t index) const { return tokens[index]; }
1118
1119 server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
1120 for (size_t i = 0; i < mtmd_chunks.size(); ++i) {
1121 push_back(chunk: mtmd_chunks[i]);
1122 }
1123 }
1124
1125 server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
1126 }
1127
1128 llama_pos pos_next() const {
1129 if (!has_mtmd) {
1130 return tokens.size();
1131 }
1132
1133 llama_pos res = tokens.size();
1134
1135 for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
1136 const auto & chunk = it->second;
1137 res += mtmd_input_chunk_get_n_pos(chunk: chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk: chunk.get());
1138 }
1139
1140 return res;
1141 }
1142
1143 // for debugging
1144 std::string str() const {
1145 std::ostringstream oss;
1146 oss << "tokens: ";
1147 for (size_t idx = 0; idx < tokens.size(); ++idx) {
1148 llama_token t = tokens[idx];
1149 oss << "idx:" << idx << " ";
1150 if (t == LLAMA_TOKEN_NULL) {
1151 oss << "<embd> ";
1152 } else {
1153 oss << t << " ";
1154 }
1155 }
1156 oss << "\n";
1157 oss << "image idx: ";
1158 for (const auto & it : map_idx_to_media) {
1159 oss << it.first << ", ";
1160 }
1161 return oss.str();
1162 }
1163
1164 const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
1165 auto it = map_idx_to_media.find(x: idx);
1166 if (it != map_idx_to_media.end()) {
1167 return it->second;
1168 }
1169 throw std::runtime_error("Chunk not found");
1170 }
1171
1172 void push_back(llama_token tok) {
1173 if (tok == LLAMA_TOKEN_NULL) {
1174 throw std::runtime_error("Invalid token");
1175 }
1176 tokens.emplace_back(args&: tok);
1177 }
1178
1179 // will create a copy of the chunk if it contains non-text data
1180 void push_back(const mtmd_input_chunk * chunk) {
1181 auto type = mtmd_input_chunk_get_type(chunk);
1182 if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
1183 GGML_ASSERT(has_mtmd);
1184 const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
1185 size_t start_idx = tokens.size();
1186 for (size_t i = 0; i < n_tokens; ++i) {
1187 tokens.emplace_back(LLAMA_TOKEN_NULL);
1188 }
1189 mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
1190 map_idx_to_media[start_idx] = std::move(new_chunk);
1191 } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
1192 size_t n_tokens;
1193 const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, n_tokens_output: &n_tokens);
1194 for (size_t i = 0; i < n_tokens; ++i) {
1195 push_back(tok: text_tokens[i]);
1196 }
1197 } else {
1198 GGML_ABORT("Invalid chunk type");
1199 }
1200 }
1201
1202 // appends server tokens, updates the media map. copies media chunks.
1203 void push_back(server_tokens & tokens) {
1204 size_t start_idx = size();
1205 for (size_t i = 0; i < tokens.size(); i++) {
1206 push_back(tok: tokens[i]);
1207 }
1208 if (tokens.has_mtmd) {
1209 // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
1210 // We could also just check, but this will prevent silently dropping MTMD data.
1211 GGML_ASSERT(has_mtmd);
1212 for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
1213 auto * chunk = tokens.map_idx_to_media[it->first].get();
1214 mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
1215 map_idx_to_media[start_idx + it->first] = std::move(new_chunk);
1216 }
1217 }
1218 }
1219
1220 // for compatibility with context shift and prompt truncation
1221 void insert(const llama_tokens & inp_tokens) {
1222 GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
1223 tokens.insert(position: tokens.end(), first: inp_tokens.begin(), last: inp_tokens.end());
1224 }
1225
1226 // for compatibility with speculative decoding, ctx shift, slot save/load
1227 const llama_tokens & get_text_tokens() const {
1228 GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
1229 return tokens;
1230 }
1231
1232 // for compatibility with speculative decoding
1233 void set_token(llama_pos pos, llama_token id) {
1234 GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
1235 tokens[pos] = id;
1236 }
1237
1238 size_t size() const {
1239 return tokens.size();
1240 }
1241
1242 bool empty() const {
1243 return tokens.empty();
1244 }
1245
1246 void clear() {
1247 map_idx_to_media.clear();
1248 tokens.clear();
1249 }
1250
1251 void keep_first(size_t n) {
1252 GGML_ASSERT(n <= tokens.size());
1253 if (has_mtmd) {
1254 if (n == tokens.size()) {
1255 return; // nothing to do
1256 }
1257 // we throw an error if we try to remove a token in the middle of an image
1258 // for ex. with input of 5 text tokens and 2 images:
1259 // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
1260 // n 1 2 3 4 5 6 7 8 9 10
1261 // allowed to resize ^ ^
1262 // disallowed to resize ^ ^ ^
1263 if (n > 0) {
1264 // make sure we never remove tokens in the middle of an image
1265 // note that the case where we keep a full image at the end is allowed:
1266 // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL
1267 if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) {
1268 find_chunk(idx: n - 1); // will throw an error if the token is not begin-of-chunk
1269 }
1270 }
1271 // remove all image chunks that are not used anymore
1272 for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
1273 size_t idx = it->first;
1274 if (idx >= n) {
1275 it = map_idx_to_media.erase(position: it);
1276 } else {
1277 ++it;
1278 }
1279 }
1280 }
1281 tokens.resize(new_size: n);
1282 }
1283
1284 std::string detokenize(const llama_context * ctx, bool special) const {
1285 llama_tokens text_tokens;
1286 text_tokens.reserve(n: tokens.size());
1287 for (const auto & t : tokens) {
1288 if (t != LLAMA_TOKEN_NULL) {
1289 text_tokens.push_back(x: t);
1290 }
1291 }
1292 return common_detokenize(ctx, tokens: text_tokens, special);
1293 }
1294
1295 size_t get_common_prefix(const server_tokens & b) const {
1296 const size_t max_idx = std::min(a: tokens.size(), b: b.tokens.size());
1297
1298 if (!has_mtmd) {
1299 for (size_t i = 0; i < max_idx; ++i) {
1300 if (tokens[i] == b.tokens[i]) {
1301 continue;
1302 }
1303
1304 return i;
1305 }
1306
1307 return max_idx;
1308 }
1309
1310 for (size_t i = 0; i < max_idx; ++i) {
1311 const llama_token ai = tokens[i];
1312 const llama_token bi = b.tokens[i];
1313
1314 if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
1315 const auto & a_chunk = find_chunk(idx: i);
1316 const auto & b_chunk = b.find_chunk(idx: i);
1317
1318 GGML_ASSERT(a_chunk && b_chunk);
1319
1320 const std::string id_ai = mtmd_input_chunk_get_id(chunk: a_chunk.get());
1321 const std::string id_bi = mtmd_input_chunk_get_id(chunk: b_chunk.get());
1322
1323 const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(chunk: a_chunk.get());
1324 const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(chunk: b_chunk.get());
1325
1326 if (id_ai == id_bi && n_tok_a == n_tok_b) {
1327 GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
1328 i += n_tok_a - 1; // will be +1 by the for loop
1329 continue;
1330 }
1331
1332 return i;
1333 }
1334
1335 if (ai == bi) {
1336 continue;
1337 }
1338
1339 return i;
1340 }
1341
1342 return max_idx; // all tokens are equal
1343 }
1344
1345 // make sure all text tokens are within the vocab range
1346 bool validate(const struct llama_context * ctx) const {
1347 const llama_model * model = llama_get_model(ctx);
1348 const llama_vocab * vocab = llama_model_get_vocab(model);
1349 const int32_t n_vocab = llama_vocab_n_tokens(vocab);
1350
1351 for (size_t i = 0; i < tokens.size(); ++i) {
1352 const auto & t = tokens[i];
1353 if (t == LLAMA_TOKEN_NULL) {
1354 try {
1355 const auto & chunk = find_chunk(idx: i);
1356 size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk: chunk.get());
1357 i += n_tokens - 1; // will be +1 by the for loop
1358 } catch (const std::exception & e) {
1359 return false;
1360 }
1361 } else if (t < 0 || t >= n_vocab) {
1362 return false;
1363 }
1364 }
1365 return true;
1366 }
1367
1368 // encode and decode the image chunk
1369 int32_t process_chunk(
1370 llama_context * ctx,
1371 mtmd_context * mctx,
1372 size_t idx,
1373 llama_pos pos,
1374 int32_t seq_id,
1375 size_t & n_tokens_out) const {
1376 const auto & chunk = find_chunk(idx);
1377 const char * name = mtmd_input_chunk_get_type(chunk: chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
1378 ? "image" : "audio";
1379 SRV_INF("processing %s...\n", name);
1380 int32_t n_batch = llama_n_batch(ctx);
1381 int64_t t0 = ggml_time_ms();
1382 llama_pos new_n_past; // unused for now
1383 int32_t result = mtmd_helper_eval_chunk_single(ctx: mctx, lctx: ctx,
1384 chunk: chunk.get(),
1385 n_past: pos,
1386 seq_id,
1387 n_batch,
1388 logits_last: true, // logits last
1389 new_n_past: &new_n_past);
1390 SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
1391 if (result != 0) {
1392 LOG_ERR("mtmd_helper_eval failed with status %d", result);
1393 n_tokens_out = 0;
1394 return result;
1395 }
1396 n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk: chunk.get());
1397 return 0;
1398 }
1399};
1400
1401// Computes FNV-1a hash of the data
1402static std::string fnv_hash(const uint8_t * data, size_t len) {
1403 const uint64_t fnv_prime = 0x100000001b3ULL;
1404 uint64_t hash = 0xcbf29ce484222325ULL;
1405
1406 for (size_t i = 0; i < len; ++i) {
1407 hash ^= data[i];
1408 hash *= fnv_prime;
1409 }
1410 return std::to_string(val: hash);
1411}
1412
1413static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
1414 mtmd::bitmaps bitmaps;
1415 for (auto & file : files) {
1416 mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx: mctx, buf: file.data(), len: file.size()));
1417 if (!bmp.ptr) {
1418 throw std::runtime_error("Failed to load image or audio file");
1419 }
1420 // calculate bitmap hash (for KV caching)
1421 std::string hash = fnv_hash(data: bmp.data(), len: bmp.n_bytes());
1422 bmp.set_id(hash.c_str());
1423 bitmaps.entries.push_back(x: std::move(bmp));
1424 }
1425 // process prompt
1426 std::vector<server_tokens> inputs;
1427 // multimodal
1428 mtmd_input_text inp_txt = {
1429 .text: prompt.c_str(),
1430 /* add_special */ .add_special: true,
1431 /* parse_special */ .parse_special: true,
1432 };
1433 mtmd::input_chunks chunks(mtmd_input_chunks_init());
1434 auto bitmaps_c_ptr = bitmaps.c_ptr();
1435 int32_t tokenized = mtmd_tokenize(ctx: mctx,
1436 output: chunks.ptr.get(),
1437 text: &inp_txt,
1438 bitmaps: bitmaps_c_ptr.data(),
1439 n_bitmaps: bitmaps_c_ptr.size());
1440 if (tokenized != 0) {
1441 throw std::runtime_error("Failed to tokenize prompt");
1442 }
1443 auto result = server_tokens(chunks, true);
1444 return result;
1445}
1446
1447/**
1448 * break the input "prompt" object into multiple prompt if needed, then tokenize them
1449 * use tokenize_input_prompts() if the input could be an array.
1450 * this supports these cases:
1451 * - "prompt": "string"
1452 * - "prompt": [12, 34, 56]
1453 * - "prompt": [12, 34, "string", 56, 78]
1454 * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
1455 */
1456static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
1457 constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string";
1458 constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data";
1459 const bool has_mtmd = mctx != nullptr;
1460 if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(data: json_prompt)) {
1461 // string or mixed
1462 llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special);
1463 return server_tokens(tmp, false);
1464 } else if (json_is_array_of_numbers(data: json_prompt)) {
1465 // array of tokens
1466 llama_tokens tmp = json_prompt.get<llama_tokens>();
1467 return server_tokens(tmp, false);
1468 } else if (json_prompt.contains(key: JSON_STRING_PROMPT_KEY)) {
1469 // JSON object with prompt key.
1470 if (json_prompt.contains(key: JSON_MTMD_DATA_KEY)) {
1471 if (!has_mtmd)
1472 throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests.");
1473
1474 // JSON object with prompt and multimodal key.
1475 std::vector<raw_buffer> files;
1476 for (const auto & entry : json_prompt.at(key: JSON_MTMD_DATA_KEY)) {
1477 files.push_back(x: base64_decode(encoded_string: entry));
1478 }
1479 return process_mtmd_prompt(mctx, prompt: json_prompt.at(key: JSON_STRING_PROMPT_KEY), files);
1480 } else {
1481 // Not multimodal, but contains a subobject.
1482 llama_tokens tmp = tokenize_mixed(vocab, json_prompt: json_prompt.at(key: JSON_STRING_PROMPT_KEY), add_special, parse_special);
1483 return server_tokens(tmp, false);
1484 }
1485 } else {
1486 throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens.");
1487 }
1488}
1489
1490/**
1491 * break the input "prompt" object into multiple prompt if needed, then tokenize them
1492 * this supports these cases:
1493 * - "prompt": "string"
1494 * - "prompt": [12, 34, 56]
1495 * - "prompt": [12, 34, "string", 56, 78]
1496 * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] }
1497 * and multiple prompts (multi-tasks):
1498 * - "prompt": ["string1", "string2"]
1499 * - "prompt": ["string1", [12, 34, 56]]
1500 * - "prompt": [[12, 34, 56], [78, 90, 12]]
1501 * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}]
1502 */
1503static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
1504 std::vector<server_tokens> result;
1505 if (json_prompt.is_array() && !json_is_array_and_contains_numbers(data: json_prompt)) {
1506 result.reserve(n: json_prompt.size());
1507 for (const auto & p : json_prompt) {
1508 result.push_back(x: tokenize_input_subprompt(vocab, mctx, json_prompt: p,add_special, parse_special));
1509 }
1510 } else {
1511 result.push_back(x: tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special));
1512 }
1513 if (result.empty()) {
1514 throw std::runtime_error("\"prompt\" must not be empty");
1515 }
1516 return result;
1517}
1518
1519// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
1520static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) {
1521 server_tokens result = {};
1522
1523 const char * rerank_prompt = llama_model_chat_template(model, name: "rerank");
1524
1525 if (rerank_prompt != nullptr) {
1526 std::string prompt = rerank_prompt;
1527 string_replace_all(s&: prompt, search: "{query}" , replace: query);
1528 string_replace_all(s&: prompt, search: "{document}", replace: doc );
1529 server_tokens tokens = tokenize_input_subprompt(vocab, mctx, json_prompt: prompt, add_special: false, parse_special: true);
1530 result.push_back(tokens);
1531 } else {
1532 // Get EOS token - use SEP token as fallback if EOS is not available
1533 server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, json_prompt: query, add_special: false, parse_special: false);
1534 server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, json_prompt: doc, add_special: false, parse_special: false);
1535 llama_token eos_token = llama_vocab_eos(vocab);
1536 if (eos_token == LLAMA_TOKEN_NULL) {
1537 eos_token = llama_vocab_sep(vocab);
1538 }
1539
1540 if (llama_vocab_get_add_bos(vocab)) {
1541 result.push_back(tok: llama_vocab_bos(vocab));
1542 }
1543 result.push_back(tokens&: query_tokens);
1544 if (llama_vocab_get_add_eos(vocab)) {
1545 result.push_back(tok: eos_token);
1546 }
1547 if (llama_vocab_get_add_sep(vocab)) {
1548 result.push_back(tok: llama_vocab_sep(vocab));
1549 }
1550 result.push_back(tokens&: doc_tokens);
1551 if (llama_vocab_get_add_eos(vocab)) {
1552 result.push_back(tok: eos_token);
1553 }
1554 }
1555
1556 return result;
1557}
1558