1#include "chat.h"
2#include "utils.hpp"
3
4#include "arg.h"
5#include "common.h"
6#include "json-schema-to-grammar.h"
7#include "llama.h"
8#include "log.h"
9#include "sampling.h"
10#include "speculative.h"
11#include "mtmd.h"
12
13// mime type for sending response
14#define MIMETYPE_JSON "application/json; charset=utf-8"
15
16// auto generated files (see README.md for details)
17#include "index.html.gz.hpp"
18#include "loading.html.hpp"
19
20#include <atomic>
21#include <chrono>
22#include <condition_variable>
23#include <cstddef>
24#include <cinttypes>
25#include <deque>
26#include <memory>
27#include <mutex>
28#include <signal.h>
29#include <thread>
30#include <unordered_map>
31#include <unordered_set>
32
33using json = nlohmann::ordered_json;
34
35constexpr int HTTP_POLLING_SECONDS = 1;
36
37enum stop_type {
38 STOP_TYPE_NONE,
39 STOP_TYPE_EOS,
40 STOP_TYPE_WORD,
41 STOP_TYPE_LIMIT,
42};
43
44// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
45enum slot_state {
46 SLOT_STATE_IDLE,
47 SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
48 SLOT_STATE_PROCESSING_PROMPT,
49 SLOT_STATE_DONE_PROMPT,
50 SLOT_STATE_GENERATING,
51};
52
53enum server_state {
54 SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
55 SERVER_STATE_READY, // Server is ready and model is loaded
56};
57
58enum server_task_type {
59 SERVER_TASK_TYPE_COMPLETION,
60 SERVER_TASK_TYPE_EMBEDDING,
61 SERVER_TASK_TYPE_RERANK,
62 SERVER_TASK_TYPE_INFILL,
63 SERVER_TASK_TYPE_CANCEL,
64 SERVER_TASK_TYPE_NEXT_RESPONSE,
65 SERVER_TASK_TYPE_METRICS,
66 SERVER_TASK_TYPE_SLOT_SAVE,
67 SERVER_TASK_TYPE_SLOT_RESTORE,
68 SERVER_TASK_TYPE_SLOT_ERASE,
69 SERVER_TASK_TYPE_SET_LORA,
70};
71
72enum oaicompat_type {
73 OAICOMPAT_TYPE_NONE,
74 OAICOMPAT_TYPE_CHAT,
75 OAICOMPAT_TYPE_COMPLETION,
76 OAICOMPAT_TYPE_EMBEDDING,
77};
78
79// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
80enum error_type {
81 ERROR_TYPE_INVALID_REQUEST,
82 ERROR_TYPE_AUTHENTICATION,
83 ERROR_TYPE_SERVER,
84 ERROR_TYPE_NOT_FOUND,
85 ERROR_TYPE_PERMISSION,
86 ERROR_TYPE_UNAVAILABLE, // custom error
87 ERROR_TYPE_NOT_SUPPORTED, // custom error
88 ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
89};
90
91static bool server_task_type_need_embd(server_task_type task_type) {
92 switch (task_type) {
93 case SERVER_TASK_TYPE_EMBEDDING:
94 case SERVER_TASK_TYPE_RERANK:
95 return true;
96 default:
97 return false;
98 }
99}
100
101static bool server_task_type_need_logits(server_task_type task_type) {
102 switch (task_type) {
103 case SERVER_TASK_TYPE_COMPLETION:
104 case SERVER_TASK_TYPE_INFILL:
105 return true;
106 default:
107 return false;
108 }
109}
110
111struct slot_params {
112 bool stream = true;
113 bool include_usage = false;
114 bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
115 bool return_tokens = false;
116 bool return_progress = false;
117
118 int32_t n_keep = 0; // number of tokens to keep from initial prompt
119 int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
120 int32_t n_predict = -1; // new tokens to predict
121 int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
122
123 int64_t t_max_prompt_ms = -1; // TODO: implement
124 int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
125
126 std::vector<common_adapter_lora_info> lora;
127
128 std::vector<std::string> antiprompt;
129 std::vector<std::string> response_fields;
130 bool timings_per_token = false;
131 bool post_sampling_probs = false;
132
133 struct common_params_sampling sampling;
134 struct common_params_speculative speculative;
135
136 // OAI-compat fields
137 bool verbose = false;
138 oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
139 std::string oaicompat_model;
140 std::string oaicompat_cmpl_id;
141 common_chat_syntax oaicompat_chat_syntax;
142
143 // Embeddings
144 int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
145
146 json to_json(bool only_metrics = false) const {
147 std::vector<std::string> samplers;
148 samplers.reserve(n: sampling.samplers.size());
149 for (const auto & sampler : sampling.samplers) {
150 samplers.emplace_back(args: common_sampler_type_to_str(cnstr: sampler));
151 }
152
153 json lora = json::array();
154 for (size_t i = 0; i < this->lora.size(); ++i) {
155 lora.push_back(init: {{"id", i}, {"scale", this->lora[i].scale}});
156 }
157
158 if (only_metrics) {
159 return json {
160 {"seed", sampling.seed},
161 {"temperature", sampling.temp},
162 {"dynatemp_range", sampling.dynatemp_range},
163 {"dynatemp_exponent", sampling.dynatemp_exponent},
164 {"top_k", sampling.top_k},
165 {"top_p", sampling.top_p},
166 {"min_p", sampling.min_p},
167 {"top_n_sigma", sampling.top_n_sigma},
168 {"xtc_probability", sampling.xtc_probability},
169 {"xtc_threshold", sampling.xtc_threshold},
170 {"typical_p", sampling.typ_p},
171 {"repeat_last_n", sampling.penalty_last_n},
172 {"repeat_penalty", sampling.penalty_repeat},
173 {"presence_penalty", sampling.penalty_present},
174 {"frequency_penalty", sampling.penalty_freq},
175 {"dry_multiplier", sampling.dry_multiplier},
176 {"dry_base", sampling.dry_base},
177 {"dry_allowed_length", sampling.dry_allowed_length},
178 {"dry_penalty_last_n", sampling.dry_penalty_last_n},
179 {"mirostat", sampling.mirostat},
180 {"mirostat_tau", sampling.mirostat_tau},
181 {"mirostat_eta", sampling.mirostat_eta},
182 {"max_tokens", n_predict},
183 {"n_predict", n_predict}, // TODO: deduplicate?
184 {"n_keep", n_keep},
185 {"n_discard", n_discard},
186 {"ignore_eos", sampling.ignore_eos},
187 {"stream", stream},
188 {"n_probs", sampling.n_probs},
189 {"min_keep", sampling.min_keep},
190 {"chat_format", common_chat_format_name(format: oaicompat_chat_syntax.format)},
191 {"reasoning_format", common_reasoning_format_name(format: oaicompat_chat_syntax.reasoning_format)},
192 {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
193 {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
194 {"samplers", samplers},
195 {"speculative.n_max", speculative.n_max},
196 {"speculative.n_min", speculative.n_min},
197 {"speculative.p_min", speculative.p_min},
198 {"timings_per_token", timings_per_token},
199 {"post_sampling_probs", post_sampling_probs},
200 {"lora", lora},
201 };
202 }
203
204 auto grammar_triggers = json::array();
205 for (const auto & trigger : sampling.grammar_triggers) {
206 server_grammar_trigger ct(trigger);
207 grammar_triggers.push_back(val: ct.to_json());
208 }
209
210 return json {
211 {"seed", sampling.seed},
212 {"temperature", sampling.temp},
213 {"dynatemp_range", sampling.dynatemp_range},
214 {"dynatemp_exponent", sampling.dynatemp_exponent},
215 {"top_k", sampling.top_k},
216 {"top_p", sampling.top_p},
217 {"min_p", sampling.min_p},
218 {"top_n_sigma", sampling.top_n_sigma},
219 {"xtc_probability", sampling.xtc_probability},
220 {"xtc_threshold", sampling.xtc_threshold},
221 {"typical_p", sampling.typ_p},
222 {"repeat_last_n", sampling.penalty_last_n},
223 {"repeat_penalty", sampling.penalty_repeat},
224 {"presence_penalty", sampling.penalty_present},
225 {"frequency_penalty", sampling.penalty_freq},
226 {"dry_multiplier", sampling.dry_multiplier},
227 {"dry_base", sampling.dry_base},
228 {"dry_allowed_length", sampling.dry_allowed_length},
229 {"dry_penalty_last_n", sampling.dry_penalty_last_n},
230 {"dry_sequence_breakers", sampling.dry_sequence_breakers},
231 {"mirostat", sampling.mirostat},
232 {"mirostat_tau", sampling.mirostat_tau},
233 {"mirostat_eta", sampling.mirostat_eta},
234 {"stop", antiprompt},
235 {"max_tokens", n_predict},
236 {"n_predict", n_predict}, // TODO: deduplicate?
237 {"n_keep", n_keep},
238 {"n_discard", n_discard},
239 {"ignore_eos", sampling.ignore_eos},
240 {"stream", stream},
241 {"logit_bias", format_logit_bias(logit_bias: sampling.logit_bias)},
242 {"n_probs", sampling.n_probs},
243 {"min_keep", sampling.min_keep},
244 {"grammar", sampling.grammar},
245 {"grammar_lazy", sampling.grammar_lazy},
246 {"grammar_triggers", grammar_triggers},
247 {"preserved_tokens", sampling.preserved_tokens},
248 {"chat_format", common_chat_format_name(format: oaicompat_chat_syntax.format)},
249 {"reasoning_format", common_reasoning_format_name(format: oaicompat_chat_syntax.reasoning_format)},
250 {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
251 {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
252 {"samplers", samplers},
253 {"speculative.n_max", speculative.n_max},
254 {"speculative.n_min", speculative.n_min},
255 {"speculative.p_min", speculative.p_min},
256 {"timings_per_token", timings_per_token},
257 {"post_sampling_probs", post_sampling_probs},
258 {"lora", lora},
259 };
260 }
261};
262
263struct server_task {
264 int id = -1; // to be filled by server_queue
265 int index = -1; // used when there are multiple prompts (batch request)
266
267 // used by SERVER_TASK_TYPE_CANCEL
268 int id_target = -1;
269 int id_slot = -1;
270
271 // used by SERVER_TASK_TYPE_INFERENCE
272 slot_params params;
273 server_tokens tokens;
274
275 server_task_type type;
276
277 // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
278 struct slot_action {
279 int slot_id;
280 std::string filename;
281 std::string filepath;
282 };
283 slot_action slot_action;
284
285 // used by SERVER_TASK_TYPE_METRICS
286 bool metrics_reset_bucket = false;
287
288 // used by SERVER_TASK_TYPE_SET_LORA
289 std::vector<common_adapter_lora_info> set_lora;
290
291 server_task() = default;
292
293 server_task(server_task_type type) : type(type) {}
294
295 int32_t n_tokens() const {
296 return tokens.size();
297 }
298
299 static slot_params params_from_json_cmpl(
300 const llama_context * ctx,
301 const common_params & params_base,
302 const json & data) {
303 const llama_model * model = llama_get_model(ctx);
304 const llama_vocab * vocab = llama_model_get_vocab(model);
305
306 slot_params params;
307
308 // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
309 slot_params defaults;
310 defaults.sampling = params_base.sampling;
311 defaults.speculative = params_base.speculative;
312 defaults.n_keep = params_base.n_keep;
313 defaults.n_predict = params_base.n_predict;
314 defaults.antiprompt = params_base.antiprompt;
315
316 // enabling this will output extra debug information in the HTTP responses from the server
317 params.verbose = params_base.verbosity > 9;
318 params.timings_per_token = json_value(body: data, key: "timings_per_token", default_value: false);
319
320 params.stream = json_value(body: data, key: "stream", default_value: false);
321 auto stream_opt = json_value(body: data, key: "stream_options", default_value: json::object());
322 params.include_usage = json_value(body: stream_opt, key: "include_usage", default_value: false);
323 params.cache_prompt = json_value(body: data, key: "cache_prompt", default_value: true);
324 params.return_tokens = json_value(body: data, key: "return_tokens", default_value: false);
325 params.return_progress = json_value(body: data, key: "return_progress", default_value: false);
326 params.n_predict = json_value(body: data, key: "n_predict", default_value: json_value(body: data, key: "max_tokens", default_value: defaults.n_predict));
327 params.n_indent = json_value(body: data, key: "n_indent", default_value: defaults.n_indent);
328 params.n_keep = json_value(body: data, key: "n_keep", default_value: defaults.n_keep);
329 params.n_discard = json_value(body: data, key: "n_discard", default_value: defaults.n_discard);
330 //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
331 params.t_max_predict_ms = json_value(body: data, key: "t_max_predict_ms", default_value: defaults.t_max_predict_ms);
332 params.response_fields = json_value(body: data, key: "response_fields", default_value: std::vector<std::string>());
333
334 params.sampling.top_k = json_value(body: data, key: "top_k", default_value: defaults.sampling.top_k);
335 params.sampling.top_p = json_value(body: data, key: "top_p", default_value: defaults.sampling.top_p);
336 params.sampling.min_p = json_value(body: data, key: "min_p", default_value: defaults.sampling.min_p);
337 params.sampling.top_n_sigma = json_value(body: data, key: "top_n_sigma", default_value: defaults.sampling.top_n_sigma);
338 params.sampling.xtc_probability = json_value(body: data, key: "xtc_probability", default_value: defaults.sampling.xtc_probability);
339 params.sampling.xtc_threshold = json_value(body: data, key: "xtc_threshold", default_value: defaults.sampling.xtc_threshold);
340 params.sampling.typ_p = json_value(body: data, key: "typical_p", default_value: defaults.sampling.typ_p);
341 params.sampling.temp = json_value(body: data, key: "temperature", default_value: defaults.sampling.temp);
342 params.sampling.dynatemp_range = json_value(body: data, key: "dynatemp_range", default_value: defaults.sampling.dynatemp_range);
343 params.sampling.dynatemp_exponent = json_value(body: data, key: "dynatemp_exponent", default_value: defaults.sampling.dynatemp_exponent);
344 params.sampling.penalty_last_n = json_value(body: data, key: "repeat_last_n", default_value: defaults.sampling.penalty_last_n);
345 params.sampling.penalty_repeat = json_value(body: data, key: "repeat_penalty", default_value: defaults.sampling.penalty_repeat);
346 params.sampling.penalty_freq = json_value(body: data, key: "frequency_penalty", default_value: defaults.sampling.penalty_freq);
347 params.sampling.penalty_present = json_value(body: data, key: "presence_penalty", default_value: defaults.sampling.penalty_present);
348 params.sampling.dry_multiplier = json_value(body: data, key: "dry_multiplier", default_value: defaults.sampling.dry_multiplier);
349 params.sampling.dry_base = json_value(body: data, key: "dry_base", default_value: defaults.sampling.dry_base);
350 params.sampling.dry_allowed_length = json_value(body: data, key: "dry_allowed_length", default_value: defaults.sampling.dry_allowed_length);
351 params.sampling.dry_penalty_last_n = json_value(body: data, key: "dry_penalty_last_n", default_value: defaults.sampling.dry_penalty_last_n);
352 params.sampling.mirostat = json_value(body: data, key: "mirostat", default_value: defaults.sampling.mirostat);
353 params.sampling.mirostat_tau = json_value(body: data, key: "mirostat_tau", default_value: defaults.sampling.mirostat_tau);
354 params.sampling.mirostat_eta = json_value(body: data, key: "mirostat_eta", default_value: defaults.sampling.mirostat_eta);
355 params.sampling.seed = json_value(body: data, key: "seed", default_value: defaults.sampling.seed);
356 params.sampling.n_probs = json_value(body: data, key: "n_probs", default_value: defaults.sampling.n_probs);
357 params.sampling.min_keep = json_value(body: data, key: "min_keep", default_value: defaults.sampling.min_keep);
358 params.post_sampling_probs = json_value(body: data, key: "post_sampling_probs", default_value: defaults.post_sampling_probs);
359
360 params.speculative.n_min = json_value(body: data, key: "speculative.n_min", default_value: defaults.speculative.n_min);
361 params.speculative.n_max = json_value(body: data, key: "speculative.n_max", default_value: defaults.speculative.n_max);
362 params.speculative.p_min = json_value(body: data, key: "speculative.p_min", default_value: defaults.speculative.p_min);
363
364 params.speculative.n_min = std::min(a: params.speculative.n_max, b: params.speculative.n_min);
365 params.speculative.n_min = std::max(a: params.speculative.n_min, b: 0);
366 params.speculative.n_max = std::max(a: params.speculative.n_max, b: 0);
367
368 // Use OpenAI API logprobs only if n_probs wasn't provided
369 if (data.contains(key: "logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
370 params.sampling.n_probs = json_value(body: data, key: "logprobs", default_value: defaults.sampling.n_probs);
371 }
372
373 if (data.contains(key: "lora")) {
374 if (data.at(key: "lora").is_array()) {
375 params.lora = parse_lora_request(lora_base: params_base.lora_adapters, data: data.at(key: "lora"));
376 } else {
377 throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
378 }
379 } else {
380 params.lora = params_base.lora_adapters;
381 }
382
383 // TODO: add more sanity checks for the input parameters
384
385 if (params.sampling.penalty_last_n < -1) {
386 throw std::runtime_error("Error: repeat_last_n must be >= -1");
387 }
388
389 if (params.sampling.dry_penalty_last_n < -1) {
390 throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
391 }
392
393 if (params.sampling.penalty_last_n == -1) {
394 // note: should be the slot's context and not the full context, but it's ok
395 params.sampling.penalty_last_n = llama_n_ctx(ctx);
396 }
397
398 if (params.sampling.dry_penalty_last_n == -1) {
399 params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
400 }
401
402 if (params.sampling.dry_base < 1.0f) {
403 params.sampling.dry_base = defaults.sampling.dry_base;
404 }
405
406 // sequence breakers for DRY
407 {
408 // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
409 // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
410
411 if (data.contains(key: "dry_sequence_breakers")) {
412 params.sampling.dry_sequence_breakers = json_value(body: data, key: "dry_sequence_breakers", default_value: std::vector<std::string>());
413 if (params.sampling.dry_sequence_breakers.empty()) {
414 throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
415 }
416 }
417 }
418
419 // process "json_schema" and "grammar"
420 if (data.contains(key: "json_schema") && !data.contains(key: "grammar")) {
421 try {
422 auto schema = json_value(body: data, key: "json_schema", default_value: json::object());
423 SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
424 params.sampling.grammar = json_schema_to_grammar(schema);
425 SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
426 } catch (const std::exception & e) {
427 throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
428 }
429 } else {
430 params.sampling.grammar = json_value(body: data, key: "grammar", default_value: defaults.sampling.grammar);
431 SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
432 params.sampling.grammar_lazy = json_value(body: data, key: "grammar_lazy", default_value: defaults.sampling.grammar_lazy);
433 SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
434 }
435
436 {
437 auto it = data.find(key: "chat_format");
438 if (it != data.end()) {
439 params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
440 SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
441 } else {
442 params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
443 }
444 common_reasoning_format reasoning_format = params_base.reasoning_format;
445 if (data.contains(key: "reasoning_format")) {
446 reasoning_format = common_reasoning_format_from_name(format: data.at(key: "reasoning_format").get<std::string>());
447 }
448 params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
449 params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
450 params.oaicompat_chat_syntax.thinking_forced_open = json_value(body: data, key: "thinking_forced_open", default_value: false);
451 params.oaicompat_chat_syntax.parse_tool_calls = json_value(body: data, key: "parse_tool_calls", default_value: false);
452 }
453
454 {
455 const auto preserved_tokens = data.find(key: "preserved_tokens");
456 if (preserved_tokens != data.end()) {
457 for (const auto & t : *preserved_tokens) {
458 auto ids = common_tokenize(vocab, text: t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
459 if (ids.size() == 1) {
460 SRV_DBG("Preserved token: %d\n", ids[0]);
461 params.sampling.preserved_tokens.insert(x: ids[0]);
462 } else {
463 // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
464 SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
465 }
466 }
467 }
468 const auto grammar_triggers = data.find(key: "grammar_triggers");
469 if (grammar_triggers != data.end()) {
470 for (const auto & t : *grammar_triggers) {
471 server_grammar_trigger ct(t);
472 if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
473 const auto & word = ct.value.value;
474 auto ids = common_tokenize(vocab, text: word, /* add_special= */ false, /* parse_special= */ true);
475 if (ids.size() == 1) {
476 auto token = ids[0];
477 if (std::find(first: params.sampling.preserved_tokens.begin(), last: params.sampling.preserved_tokens.end(), val: (llama_token) token) == params.sampling.preserved_tokens.end()) {
478 throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
479 }
480 SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
481 common_grammar_trigger trigger;
482 trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
483 trigger.value = word;
484 trigger.token = token;
485 params.sampling.grammar_triggers.push_back(x: std::move(trigger));
486 } else {
487 SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
488 params.sampling.grammar_triggers.push_back(x: {.type: COMMON_GRAMMAR_TRIGGER_TYPE_WORD, .value: word});
489 }
490 } else {
491 if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
492 SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
493 } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
494 SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
495 } else {
496 throw std::runtime_error("Unknown grammar trigger type");
497 }
498 params.sampling.grammar_triggers.emplace_back(args: std::move(ct.value));
499 }
500 }
501 }
502 if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
503 throw std::runtime_error("Error: no triggers set for lazy grammar!");
504 }
505 }
506
507 {
508 params.sampling.logit_bias.clear();
509
510 const auto & logit_bias = data.find(key: "logit_bias");
511 if (logit_bias != data.end() && logit_bias->is_array()) {
512 const int n_vocab = llama_vocab_n_tokens(vocab);
513 for (const auto & el : *logit_bias) {
514 // TODO: we may want to throw errors here, in case "el" is incorrect
515 if (el.is_array() && el.size() == 2) {
516 float bias;
517 if (el[1].is_number()) {
518 bias = el[1].get<float>();
519 } else if (el[1].is_boolean() && !el[1].get<bool>()) {
520 bias = -INFINITY;
521 } else {
522 continue;
523 }
524
525 if (el[0].is_number_integer()) {
526 llama_token tok = el[0].get<llama_token>();
527 if (tok >= 0 && tok < n_vocab) {
528 params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias});
529 }
530 } else if (el[0].is_string()) {
531 auto toks = common_tokenize(vocab, text: el[0].get<std::string>(), add_special: false);
532 for (auto tok : toks) {
533 params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias});
534 }
535 }
536 }
537 }
538 } else if (logit_bias != data.end() && logit_bias->is_object()) {
539 const int n_vocab = llama_vocab_n_tokens(vocab);
540 for (const auto & el : logit_bias->items()) {
541 float bias;
542 const auto & key = el.key();
543 const auto & value = el.value();
544 if (value.is_number()) {
545 bias = value.get<float>();
546 } else if (value.is_boolean() && !value.get<bool>()) {
547 bias = -INFINITY;
548 } else {
549 continue;
550 }
551
552 char *end;
553 llama_token tok = strtol(nptr: key.c_str(), endptr: &end, base: 10);
554 if (*end == 0) {
555 if (tok >= 0 && tok < n_vocab) {
556 params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias});
557 }
558 } else {
559 auto toks = common_tokenize(vocab, text: key, add_special: false);
560 for (auto tok : toks) {
561 params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias});
562 }
563 }
564 }
565 }
566
567 params.sampling.ignore_eos = json_value(body: data, key: "ignore_eos", default_value: params_base.sampling.ignore_eos);
568 if (params.sampling.ignore_eos) {
569 params.sampling.logit_bias.insert(
570 position: params.sampling.logit_bias.end(),
571 first: defaults.sampling.logit_bias_eog.begin(), last: defaults.sampling.logit_bias_eog.end());
572 }
573 }
574
575 {
576 params.antiprompt.clear();
577
578 const auto & stop = data.find(key: "stop");
579 if (stop != data.end() && stop->is_array()) {
580 for (const auto & word : *stop) {
581 if (!word.empty()) {
582 params.antiprompt.push_back(x: word);
583 }
584 }
585 }
586 // set reverse prompt from cli args if not set in the request
587 if (params.antiprompt.empty()) {
588 params.antiprompt = defaults.antiprompt;
589 }
590 }
591
592 {
593 const auto samplers = data.find(key: "samplers");
594 if (samplers != data.end()) {
595 if (samplers->is_array()) {
596 params.sampling.samplers = common_sampler_types_from_names(names: *samplers, allow_alt_names: false);
597 } else if (samplers->is_string()){
598 params.sampling.samplers = common_sampler_types_from_chars(chars: samplers->get<std::string>());
599 }
600 } else {
601 params.sampling.samplers = defaults.sampling.samplers;
602 }
603 }
604
605 std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
606 params.oaicompat_model = json_value(body: data, key: "model", default_value: model_name);
607
608 return params;
609 }
610
611 // utility function
612 static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
613 std::unordered_set<int> ids(tasks.size());
614 for (size_t i = 0; i < tasks.size(); i++) {
615 ids.insert(x: tasks[i].id);
616 }
617 return ids;
618 }
619};
620
621struct result_timings {
622 int32_t cache_n = -1;
623
624 int32_t prompt_n = -1;
625 double prompt_ms;
626 double prompt_per_token_ms;
627 double prompt_per_second;
628
629 int32_t predicted_n = -1;
630 double predicted_ms;
631 double predicted_per_token_ms;
632 double predicted_per_second;
633
634 // Optional speculative metrics - only included when > 0
635 int32_t draft_n = 0;
636 int32_t draft_n_accepted = 0;
637
638 json to_json() const {
639 json base = {
640 {"cache_n", cache_n},
641
642 {"prompt_n", prompt_n},
643 {"prompt_ms", prompt_ms},
644 {"prompt_per_token_ms", prompt_per_token_ms},
645 {"prompt_per_second", prompt_per_second},
646
647 {"predicted_n", predicted_n},
648 {"predicted_ms", predicted_ms},
649 {"predicted_per_token_ms", predicted_per_token_ms},
650 {"predicted_per_second", predicted_per_second},
651 };
652
653 if (draft_n > 0) {
654 base["draft_n"] = draft_n;
655 base["draft_n_accepted"] = draft_n_accepted;
656 }
657
658 return base;
659 }
660};
661
662struct result_prompt_progress {
663 int32_t total = 0;
664 int32_t cache = 0;
665 int32_t processed = 0;
666 int64_t time_ms = 0;
667
668 json to_json() const {
669 return json {
670 {"total", total},
671 {"cache", cache},
672 {"processed", processed},
673 {"time_ms", time_ms},
674 };
675 }
676};
677
678struct server_task_result {
679 int id = -1;
680 int id_slot = -1;
681 virtual bool is_error() {
682 // only used by server_task_result_error
683 return false;
684 }
685 virtual bool is_stop() {
686 // only used by server_task_result_cmpl_*
687 return false;
688 }
689 virtual int get_index() {
690 return -1;
691 }
692 virtual json to_json() = 0;
693 virtual ~server_task_result() = default;
694};
695
696// using shared_ptr for polymorphism of server_task_result
697using server_task_result_ptr = std::unique_ptr<server_task_result>;
698
699static inline std::string stop_type_to_str(stop_type type) {
700 switch (type) {
701 case STOP_TYPE_EOS: return "eos";
702 case STOP_TYPE_WORD: return "word";
703 case STOP_TYPE_LIMIT: return "limit";
704 default: return "none";
705 }
706}
707
708struct completion_token_output {
709 llama_token tok;
710 float prob;
711 std::string text_to_send;
712 struct prob_info {
713 llama_token tok;
714 std::string txt;
715 float prob;
716 };
717 std::vector<prob_info> probs;
718
719 json to_json(bool post_sampling_probs) const {
720 json probs_for_token = json::array();
721 for (const auto & p : probs) {
722 std::string txt(p.txt);
723 txt.resize(n: validate_utf8(text: txt));
724 probs_for_token.push_back(val: json {
725 {"id", p.tok},
726 {"token", txt},
727 {"bytes", str_to_bytes(str: p.txt)},
728 {
729 post_sampling_probs ? "prob" : "logprob",
730 post_sampling_probs ? p.prob : logarithm(x: p.prob)
731 },
732 });
733 }
734 return probs_for_token;
735 }
736
737 static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
738 json out = json::array();
739 for (const auto & p : probs) {
740 std::string txt(p.text_to_send);
741 txt.resize(n: validate_utf8(text: txt));
742 out.push_back(val: json {
743 {"id", p.tok},
744 {"token", txt},
745 {"bytes", str_to_bytes(str: p.text_to_send)},
746 {
747 post_sampling_probs ? "prob" : "logprob",
748 post_sampling_probs ? p.prob : logarithm(x: p.prob)
749 },
750 {
751 post_sampling_probs ? "top_probs" : "top_logprobs",
752 p.to_json(post_sampling_probs)
753 },
754 });
755 }
756 return out;
757 }
758
759 static float logarithm(float x) {
760 // nlohmann::json converts -inf to null, so we need to prevent that
761 return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x: x);
762 }
763
764 static std::vector<unsigned char> str_to_bytes(const std::string & str) {
765 std::vector<unsigned char> bytes;
766 for (unsigned char c : str) {
767 bytes.push_back(x: c);
768 }
769 return bytes;
770 }
771};
772
773struct server_task_result_cmpl_final : server_task_result {
774 int index = 0;
775
776 std::string content;
777 llama_tokens tokens;
778
779 bool stream;
780 bool include_usage;
781 result_timings timings;
782 std::string prompt;
783
784 bool truncated;
785 int32_t n_decoded;
786 int32_t n_prompt_tokens;
787 int32_t n_tokens_cached;
788 bool has_new_line;
789 std::string stopping_word;
790 stop_type stop = STOP_TYPE_NONE;
791
792 bool post_sampling_probs;
793 std::vector<completion_token_output> probs_output;
794 std::vector<std::string> response_fields;
795
796 slot_params generation_params;
797
798 // OAI-compat fields
799 bool verbose = false;
800 oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
801 std::string oaicompat_model;
802 std::string oaicompat_cmpl_id;
803 common_chat_msg oaicompat_msg;
804
805 std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
806
807 virtual int get_index() override {
808 return index;
809 }
810
811 virtual bool is_stop() override {
812 return true; // in stream mode, final responses are considered stop
813 }
814
815 virtual json to_json() override {
816 switch (oaicompat) {
817 case OAICOMPAT_TYPE_NONE:
818 return to_json_non_oaicompat();
819 case OAICOMPAT_TYPE_COMPLETION:
820 return to_json_oaicompat();
821 case OAICOMPAT_TYPE_CHAT:
822 return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
823 default:
824 GGML_ASSERT(false && "Invalid oaicompat_type");
825 }
826 }
827
828 json to_json_non_oaicompat() {
829 json res = json {
830 {"index", index},
831 {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
832 {"tokens", stream ? llama_tokens {} : tokens},
833 {"id_slot", id_slot},
834 {"stop", true},
835 {"model", oaicompat_model},
836 {"tokens_predicted", n_decoded},
837 {"tokens_evaluated", n_prompt_tokens},
838 {"generation_settings", generation_params.to_json()},
839 {"prompt", prompt},
840 {"has_new_line", has_new_line},
841 {"truncated", truncated},
842 {"stop_type", stop_type_to_str(type: stop)},
843 {"stopping_word", stopping_word},
844 {"tokens_cached", n_tokens_cached},
845 {"timings", timings.to_json()},
846 };
847 if (!stream && !probs_output.empty()) {
848 res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs: probs_output, post_sampling_probs);
849 }
850 return response_fields.empty() ? res : json_get_nested_values(paths: response_fields, js: res);
851 }
852
853 json to_json_oaicompat() {
854 std::time_t t = std::time(timer: 0);
855 json logprobs = json(nullptr); // OAI default to null
856 if (!stream && probs_output.size() > 0) {
857 logprobs = json{
858 {"content", completion_token_output::probs_vector_to_json(probs: probs_output, post_sampling_probs)},
859 };
860 }
861 json finish_reason = "length";
862 if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
863 finish_reason = "stop";
864 }
865 json res = json {
866 {"choices", json::array(init: {
867 json{
868 {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
869 {"index", index},
870 {"logprobs", logprobs},
871 {"finish_reason", finish_reason},
872 }
873 })},
874 {"created", t},
875 {"model", oaicompat_model},
876 {"system_fingerprint", build_info},
877 {"object", "text_completion"},
878 {"usage", json {
879 {"completion_tokens", n_decoded},
880 {"prompt_tokens", n_prompt_tokens},
881 {"total_tokens", n_decoded + n_prompt_tokens}
882 }},
883 {"id", oaicompat_cmpl_id}
884 };
885
886 // extra fields for debugging purposes
887 if (verbose) {
888 res["__verbose"] = to_json_non_oaicompat();
889 }
890 if (timings.prompt_n >= 0) {
891 res.push_back(init: {"timings", timings.to_json()});
892 }
893
894 return res;
895 }
896
897 json to_json_oaicompat_chat() {
898 std::string finish_reason = "length";
899 common_chat_msg msg;
900 if (!oaicompat_msg.empty()) {
901 msg = oaicompat_msg;
902 } else {
903 msg.role = "assistant";
904 msg.content = content;
905 }
906 if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
907 finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
908 }
909
910 json choice {
911 {"finish_reason", finish_reason},
912 {"index", 0},
913 {"message", msg.to_json_oaicompat<json>()},
914 };
915
916 if (!stream && probs_output.size() > 0) {
917 choice["logprobs"] = json{
918 {"content", completion_token_output::probs_vector_to_json(probs: probs_output, post_sampling_probs)},
919 };
920 }
921
922 std::time_t t = std::time(timer: 0);
923
924 json res = json {
925 {"choices", json::array(init: {choice})},
926 {"created", t},
927 {"model", oaicompat_model},
928 {"system_fingerprint", build_info},
929 {"object", "chat.completion"},
930 {"usage", json {
931 {"completion_tokens", n_decoded},
932 {"prompt_tokens", n_prompt_tokens},
933 {"total_tokens", n_decoded + n_prompt_tokens}
934 }},
935 {"id", oaicompat_cmpl_id}
936 };
937
938 // extra fields for debugging purposes
939 if (verbose) {
940 res["__verbose"] = to_json_non_oaicompat();
941 }
942 if (timings.prompt_n >= 0) {
943 res.push_back(init: {"timings", timings.to_json()});
944 }
945
946 return res;
947 }
948
949 json to_json_oaicompat_chat_stream() {
950 std::time_t t = std::time(timer: 0);
951 std::string finish_reason = "length";
952 if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
953 finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
954 }
955
956 json deltas = json::array();
957 for (const auto & diff : oaicompat_msg_diffs) {
958 deltas.push_back(init: {
959 {"choices", json::array(init: {
960 json {
961 {"finish_reason", nullptr},
962 {"index", 0},
963 {"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
964 },
965 })},
966 {"created", t},
967 {"id", oaicompat_cmpl_id},
968 {"model", oaicompat_model},
969 {"system_fingerprint", build_info},
970 {"object", "chat.completion.chunk"},
971 });
972 }
973
974 deltas.push_back(init: {
975 {"choices", json::array(init: {
976 json {
977 {"finish_reason", finish_reason},
978 {"index", 0},
979 {"delta", json::object()},
980 },
981 })},
982 {"created", t},
983 {"id", oaicompat_cmpl_id},
984 {"model", oaicompat_model},
985 {"system_fingerprint", build_info},
986 {"object", "chat.completion.chunk"},
987 });
988
989 if (include_usage) {
990 // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
991 // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
992 deltas.push_back(init: {
993 {"choices", json::array()},
994 {"created", t},
995 {"id", oaicompat_cmpl_id},
996 {"model", oaicompat_model},
997 {"system_fingerprint", build_info},
998 {"object", "chat.completion.chunk"},
999 {"usage", json {
1000 {"completion_tokens", n_decoded},
1001 {"prompt_tokens", n_prompt_tokens},
1002 {"total_tokens", n_decoded + n_prompt_tokens},
1003 }},
1004 });
1005 }
1006
1007 if (timings.prompt_n >= 0) {
1008 deltas.back().push_back(init: {"timings", timings.to_json()});
1009 }
1010
1011 // extra fields for debugging purposes
1012 if (verbose && !deltas.empty()) {
1013 deltas.front()["__verbose"] = to_json_non_oaicompat();
1014 }
1015
1016 return deltas;
1017 }
1018};
1019
1020struct server_task_result_cmpl_partial : server_task_result {
1021 int index = 0;
1022
1023 std::string content;
1024 llama_tokens tokens;
1025
1026 int32_t n_decoded;
1027 int32_t n_prompt_tokens;
1028
1029 bool post_sampling_probs;
1030 bool is_progress = false;
1031 completion_token_output prob_output;
1032 result_timings timings;
1033 result_prompt_progress progress;
1034
1035 // OAI-compat fields
1036 bool verbose = false;
1037 oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
1038 std::string oaicompat_model;
1039 std::string oaicompat_cmpl_id;
1040 std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
1041
1042 virtual int get_index() override {
1043 return index;
1044 }
1045
1046 virtual bool is_stop() override {
1047 return false; // in stream mode, partial responses are not considered stop
1048 }
1049
1050 virtual json to_json() override {
1051 switch (oaicompat) {
1052 case OAICOMPAT_TYPE_NONE:
1053 return to_json_non_oaicompat();
1054 case OAICOMPAT_TYPE_COMPLETION:
1055 return to_json_oaicompat();
1056 case OAICOMPAT_TYPE_CHAT:
1057 return to_json_oaicompat_chat();
1058 default:
1059 GGML_ASSERT(false && "Invalid oaicompat_type");
1060 }
1061 }
1062
1063 json to_json_non_oaicompat() {
1064 // non-OAI-compat JSON
1065 json res = json {
1066 {"index", index},
1067 {"content", content},
1068 {"tokens", tokens},
1069 {"stop", false},
1070 {"id_slot", id_slot},
1071 {"tokens_predicted", n_decoded},
1072 {"tokens_evaluated", n_prompt_tokens},
1073 };
1074 // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
1075 if (timings.prompt_n > 0) {
1076 res.push_back(init: {"timings", timings.to_json()});
1077 }
1078 if (is_progress) {
1079 res.push_back(init: {"prompt_progress", progress.to_json()});
1080 }
1081 if (!prob_output.probs.empty()) {
1082 res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs: {prob_output}, post_sampling_probs);
1083 }
1084 return res;
1085 }
1086
1087 json to_json_oaicompat() {
1088 std::time_t t = std::time(timer: 0);
1089 json logprobs = json(nullptr); // OAI default to null
1090 if (prob_output.probs.size() > 0) {
1091 logprobs = json{
1092 {"content", completion_token_output::probs_vector_to_json(probs: {prob_output}, post_sampling_probs)},
1093 };
1094 }
1095 json res = json {
1096 {"choices", json::array(init: {
1097 json{
1098 {"text", content},
1099 {"index", index},
1100 {"logprobs", logprobs},
1101 {"finish_reason", nullptr},
1102 }
1103 })},
1104 {"created", t},
1105 {"model", oaicompat_model},
1106 {"system_fingerprint", build_info},
1107 {"object", "text_completion"},
1108 {"id", oaicompat_cmpl_id}
1109 };
1110
1111 // extra fields for debugging purposes
1112 if (verbose) {
1113 res["__verbose"] = to_json_non_oaicompat();
1114 }
1115 if (timings.prompt_n >= 0) {
1116 res.push_back(init: {"timings", timings.to_json()});
1117 }
1118 if (is_progress) {
1119 res.push_back(init: {"prompt_progress", progress.to_json()});
1120 }
1121
1122 return res;
1123 }
1124
1125 json to_json_oaicompat_chat() {
1126 bool first = n_decoded == 1;
1127 std::time_t t = std::time(timer: 0);
1128 json choices;
1129
1130 std::vector<json> deltas;
1131 auto add_delta = [&](const json & delta) {
1132 deltas.push_back(x: {
1133 {"choices", json::array(init: {
1134 json {
1135 {"finish_reason", nullptr},
1136 {"index", 0},
1137 {"delta", delta},
1138 },
1139 })},
1140 {"created", t},
1141 {"id", oaicompat_cmpl_id},
1142 {"model", oaicompat_model},
1143 {"system_fingerprint", build_info},
1144 {"object", "chat.completion.chunk"},
1145 });
1146 };
1147 // We have to send an initial update to conform to openai behavior
1148 if (first || is_progress) {
1149 add_delta({
1150 {"role", "assistant"},
1151 {"content", nullptr},
1152 });
1153 }
1154
1155 for (const auto & diff : oaicompat_msg_diffs) {
1156 add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
1157 }
1158
1159 if (!deltas.empty()) {
1160 auto & last_json = deltas[deltas.size() - 1];
1161 GGML_ASSERT(last_json.at("choices").size() >= 1);
1162
1163 if (prob_output.probs.size() > 0) {
1164 last_json.at(key: "choices").at(idx: 0)["logprobs"] = json {
1165 {"content", completion_token_output::probs_vector_to_json(probs: {prob_output}, post_sampling_probs)},
1166 };
1167 }
1168
1169 if (timings.prompt_n >= 0) {
1170 last_json.push_back(init: {"timings", timings.to_json()});
1171 }
1172 if (is_progress) {
1173 last_json.push_back(init: {"prompt_progress", progress.to_json()});
1174 }
1175 }
1176
1177 return deltas;
1178 }
1179};
1180
1181struct server_task_result_embd : server_task_result {
1182 int index = 0;
1183 std::vector<std::vector<float>> embedding;
1184
1185 int32_t n_tokens;
1186
1187 // OAI-compat fields
1188 oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
1189
1190 virtual int get_index() override {
1191 return index;
1192 }
1193
1194 virtual json to_json() override {
1195 return oaicompat == OAICOMPAT_TYPE_EMBEDDING
1196 ? to_json_oaicompat()
1197 : to_json_non_oaicompat();
1198 }
1199
1200 json to_json_non_oaicompat() {
1201 return json {
1202 {"index", index},
1203 {"embedding", embedding},
1204 };
1205 }
1206
1207 json to_json_oaicompat() {
1208 return json {
1209 {"index", index},
1210 {"embedding", embedding[0]},
1211 {"tokens_evaluated", n_tokens},
1212 };
1213 }
1214};
1215
1216struct server_task_result_rerank : server_task_result {
1217 int index = 0;
1218 float score = -1e6;
1219
1220 int32_t n_tokens;
1221
1222 virtual int get_index() override {
1223 return index;
1224 }
1225
1226 virtual json to_json() override {
1227 return json {
1228 {"index", index},
1229 {"score", score},
1230 {"tokens_evaluated", n_tokens},
1231 };
1232 }
1233};
1234
1235// this function maybe used outside of server_task_result_error
1236static json format_error_response(const std::string & message, const enum error_type type) {
1237 std::string type_str;
1238 int code = 500;
1239 switch (type) {
1240 case ERROR_TYPE_INVALID_REQUEST:
1241 type_str = "invalid_request_error";
1242 code = 400;
1243 break;
1244 case ERROR_TYPE_AUTHENTICATION:
1245 type_str = "authentication_error";
1246 code = 401;
1247 break;
1248 case ERROR_TYPE_NOT_FOUND:
1249 type_str = "not_found_error";
1250 code = 404;
1251 break;
1252 case ERROR_TYPE_SERVER:
1253 type_str = "server_error";
1254 code = 500;
1255 break;
1256 case ERROR_TYPE_PERMISSION:
1257 type_str = "permission_error";
1258 code = 403;
1259 break;
1260 case ERROR_TYPE_NOT_SUPPORTED:
1261 type_str = "not_supported_error";
1262 code = 501;
1263 break;
1264 case ERROR_TYPE_UNAVAILABLE:
1265 type_str = "unavailable_error";
1266 code = 503;
1267 break;
1268 case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
1269 type_str = "exceed_context_size_error";
1270 code = 400;
1271 break;
1272 }
1273 return json {
1274 {"code", code},
1275 {"message", message},
1276 {"type", type_str},
1277 };
1278}
1279
1280struct server_task_result_error : server_task_result {
1281 int index = 0;
1282 error_type err_type = ERROR_TYPE_SERVER;
1283 std::string err_msg;
1284
1285 // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
1286 int32_t n_prompt_tokens = 0;
1287 int32_t n_ctx = 0;
1288
1289 virtual bool is_error() override {
1290 return true;
1291 }
1292
1293 virtual json to_json() override {
1294 json res = format_error_response(message: err_msg, type: err_type);
1295 if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
1296 res["n_prompt_tokens"] = n_prompt_tokens;
1297 res["n_ctx"] = n_ctx;
1298 }
1299 return res;
1300 }
1301};
1302
1303struct server_task_result_metrics : server_task_result {
1304 int n_idle_slots;
1305 int n_processing_slots;
1306 int n_tasks_deferred;
1307 int64_t t_start;
1308
1309 // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
1310 uint64_t n_prompt_tokens_processed_total = 0;
1311 uint64_t t_prompt_processing_total = 0;
1312 uint64_t n_tokens_predicted_total = 0;
1313 uint64_t t_tokens_generation_total = 0;
1314
1315 uint64_t n_tokens_max = 0;
1316
1317 uint64_t n_prompt_tokens_processed = 0;
1318 uint64_t t_prompt_processing = 0;
1319
1320 uint64_t n_tokens_predicted = 0;
1321 uint64_t t_tokens_generation = 0;
1322
1323 uint64_t n_decode_total = 0;
1324 uint64_t n_busy_slots_total = 0;
1325
1326 // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
1327 // therefore, we use json to temporarily store the slot.to_json() result
1328 json slots_data = json::array();
1329
1330 virtual json to_json() override {
1331 return json {
1332 { "idle", n_idle_slots },
1333 { "processing", n_processing_slots },
1334 { "deferred", n_tasks_deferred },
1335 { "t_start", t_start },
1336
1337 { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
1338 { "t_tokens_generation_total", t_tokens_generation_total },
1339 { "n_tokens_predicted_total", n_tokens_predicted_total },
1340 { "t_prompt_processing_total", t_prompt_processing_total },
1341
1342 { "n_tokens_max", n_tokens_max },
1343
1344 { "n_prompt_tokens_processed", n_prompt_tokens_processed },
1345 { "t_prompt_processing", t_prompt_processing },
1346 { "n_tokens_predicted", n_tokens_predicted },
1347 { "t_tokens_generation", t_tokens_generation },
1348
1349 { "n_decode_total", n_decode_total },
1350 { "n_busy_slots_total", n_busy_slots_total },
1351
1352 { "slots", slots_data },
1353 };
1354 }
1355};
1356
1357struct server_task_result_slot_save_load : server_task_result {
1358 std::string filename;
1359 bool is_save; // true = save, false = load
1360
1361 size_t n_tokens;
1362 size_t n_bytes;
1363 double t_ms;
1364
1365 virtual json to_json() override {
1366 if (is_save) {
1367 return json {
1368 { "id_slot", id_slot },
1369 { "filename", filename },
1370 { "n_saved", n_tokens },
1371 { "n_written", n_bytes },
1372 { "timings", {
1373 { "save_ms", t_ms }
1374 }},
1375 };
1376 }
1377
1378 return json {
1379 { "id_slot", id_slot },
1380 { "filename", filename },
1381 { "n_restored", n_tokens },
1382 { "n_read", n_bytes },
1383 { "timings", {
1384 { "restore_ms", t_ms }
1385 }},
1386 };
1387 }
1388};
1389
1390struct server_task_result_slot_erase : server_task_result {
1391 size_t n_erased;
1392
1393 virtual json to_json() override {
1394 return json {
1395 { "id_slot", id_slot },
1396 { "n_erased", n_erased },
1397 };
1398 }
1399};
1400
1401struct server_task_result_apply_lora : server_task_result {
1402 virtual json to_json() override {
1403 return json {{ "success", true }};
1404 }
1405};
1406
1407struct server_prompt_checkpoint {
1408 llama_pos pos_min;
1409 llama_pos pos_max;
1410
1411 std::vector<uint8_t> data;
1412
1413 size_t size() const {
1414 return data.size();
1415 }
1416};
1417
1418struct server_prompt {
1419 server_tokens tokens;
1420
1421 std::vector<uint8_t> data;
1422
1423 std::list<server_prompt_checkpoint> checkpoints;
1424
1425 size_t size() const {
1426 size_t res = data.size();
1427
1428 for (const auto & checkpoint : checkpoints) {
1429 res += checkpoint.size();
1430 }
1431
1432 return res;
1433 }
1434
1435 int n_tokens() const {
1436 return tokens.size();
1437 }
1438};
1439
1440struct server_prompt_cache {
1441 server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
1442 this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
1443 this->limit_tokens = limit_tokens;
1444 }
1445
1446 std::list<server_prompt> states;
1447
1448 // in bytes, 0 = no limit
1449 size_t limit_size = 0;
1450
1451 // in tokens, 0 = no limit
1452 size_t limit_tokens = 0;
1453
1454 size_t size() const {
1455 size_t res = 0;
1456
1457 for (const auto & state : states) {
1458 res += state.size();
1459 }
1460
1461 return res;
1462 }
1463
1464 size_t n_tokens() const {
1465 size_t res = 0;
1466
1467 for (const auto & state : states) {
1468 res += state.n_tokens();
1469 }
1470
1471 return res;
1472 }
1473
1474 server_prompt * alloc(const server_prompt & prompt, size_t state_size) {
1475 // first check if the current state is contained fully in the cache
1476 for (auto it = states.begin(); it != states.end(); ++it) {
1477 const int cur_lcp_len = it->tokens.get_common_prefix(b: prompt.tokens);
1478
1479 if (cur_lcp_len == (int) prompt.tokens.size()) {
1480 SRV_WRN("%s", " - prompt is already in the cache, skipping\n");
1481 return nullptr;
1482 }
1483 }
1484
1485 // next, remove any cached prompts that are fully contained in the current prompt
1486 for (auto it = states.begin(); it != states.end();) {
1487 const int len = it->tokens.get_common_prefix(b: prompt.tokens);
1488
1489 if (len == (int) it->tokens.size()) {
1490 SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);
1491
1492 it = states.erase(position: it);
1493 } else {
1494 ++it;
1495 }
1496 }
1497
1498 std::vector<uint8_t> state_data;
1499
1500 // check if we can allocate enough memory for the new state
1501 try {
1502 state_data.resize(new_size: state_size);
1503 } catch (const std::bad_alloc & e) {
1504 SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
1505
1506 limit_size = std::max<size_t>(a: 1, b: 0.4*size());
1507
1508 SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0));
1509
1510 update();
1511
1512 return nullptr;
1513 }
1514
1515 // TODO: for some reason we can't copy server_tokens, so we have to do this workaround
1516 auto & cur = states.emplace_back();
1517 cur = {
1518 /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
1519 /*.data =*/ std::move(state_data),
1520 /*.checkpoints =*/ prompt.checkpoints,
1521 };
1522
1523 return &cur;
1524 }
1525
1526 bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
1527 const int lcp_best = prompt.tokens.get_common_prefix(b: tokens_new);
1528
1529 float f_keep_best = float(lcp_best) / prompt.tokens.size();
1530 float sim_best = float(lcp_best) / tokens_new.size();
1531
1532 SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
1533
1534 auto it_best = states.end();
1535
1536 // find the most similar cached prompt, that would also preserve the most context
1537 for (auto it = states.begin(); it != states.end(); ++it) {
1538 const int lcp_cur = it->tokens.get_common_prefix(b: tokens_new);
1539
1540 const float f_keep_cur = float(lcp_cur) / it->tokens.size();
1541 const float sim_cur = float(lcp_cur) / tokens_new.size();
1542
1543 // don't trash large prompts
1544 if (f_keep_cur < 0.25f) {
1545 continue;
1546 }
1547
1548 if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
1549 f_keep_best = f_keep_cur;
1550 sim_best = sim_cur;
1551
1552 it_best = it;
1553 }
1554 }
1555
1556 if (it_best != states.end()) {
1557 SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
1558
1559 const size_t size = it_best->data.size();
1560 const size_t n = llama_state_seq_set_data_ext(ctx, src: it_best->data.data(), size, dest_seq_id: id_slot, flags: 0);
1561 if (n != size) {
1562 SRV_WRN("failed to restore state with size %zu\n", size);
1563
1564 return false;
1565 }
1566
1567 it_best->data.clear();
1568 it_best->data.shrink_to_fit();
1569
1570 prompt = std::move(*it_best);
1571
1572 states.erase(position: it_best);
1573 }
1574
1575 return true;
1576 }
1577
1578 void update() {
1579 if (limit_size > 0) {
1580 // always keep at least one state, regardless of the limits
1581 while (states.size() > 1 && size() > limit_size) {
1582 if (states.empty()) {
1583 break;
1584 }
1585
1586 SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
1587
1588 states.pop_front();
1589 }
1590 }
1591
1592 // average size per token
1593 const float size_per_token = std::max<float>(a: 1.0f, b: float(size()) / (std::max<size_t>(a: 1, b: n_tokens())));
1594
1595 // dynamically increase the token limit if it can fit in the memory limit
1596 const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(a: limit_tokens, b: limit_size/size_per_token) : limit_tokens;
1597
1598 if (limit_tokens > 0) {
1599 while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
1600 if (states.empty()) {
1601 break;
1602 }
1603
1604 SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
1605 limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
1606
1607 states.pop_front();
1608 }
1609 }
1610
1611 SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
1612 states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
1613
1614 for (const auto & state : states) {
1615 SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
1616 (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
1617 }
1618 }
1619};
1620
1621struct server_slot {
1622 int id;
1623
1624 llama_batch batch_spec = {};
1625
1626 // TODO: change to unique_ptrs for consistency:
1627 llama_context * ctx = nullptr;
1628 llama_context * ctx_dft = nullptr;
1629
1630 // multimodal
1631 mtmd_context * mctx = nullptr;
1632
1633 common_speculative * spec = nullptr;
1634
1635 std::unique_ptr<const server_task> task;
1636 std::unique_ptr<const server_task> task_prev; // used for debugging
1637
1638 // used to determine the slot that has been used the longest
1639 int64_t t_last_used = -1;
1640
1641 // generation props
1642 int32_t n_ctx = 0; // context size per slot
1643 int32_t n_keep = 0;
1644 int32_t n_decoded = 0;
1645 int32_t n_remaining = -1;
1646 int32_t i_batch = -1;
1647
1648 int32_t n_prompt_tokens_cache = 0;
1649 int32_t n_prompt_tokens_processed = 0;
1650
1651 size_t last_nl_pos = 0;
1652
1653 std::string generated_text;
1654 llama_tokens generated_tokens;
1655
1656 common_chat_msg chat_msg;
1657
1658 std::vector<completion_token_output> generated_token_probs;
1659
1660 bool has_next_token = true;
1661 bool has_new_line = false;
1662 bool truncated = false;
1663
1664 stop_type stop;
1665
1666 std::string stopping_word;
1667
1668 // state
1669 slot_state state = SLOT_STATE_IDLE;
1670
1671 server_prompt prompt;
1672
1673 void prompt_save(server_prompt_cache & prompt_cache) const {
1674 assert(prompt.data.size() == 0);
1675
1676 const size_t cur_size = llama_state_seq_get_size_ext(ctx, seq_id: id, flags: 0);
1677
1678 SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
1679 (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
1680
1681 auto * cur = prompt_cache.alloc(prompt, state_size: cur_size);
1682 if (cur == nullptr) {
1683 return;
1684 }
1685
1686 llama_state_seq_get_data_ext(ctx, dst: cur->data.data(), size: cur_size, seq_id: id, flags: 0);
1687 }
1688
1689 void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
1690 bool res = prompt_cache.load(prompt, tokens_new: tokens, ctx, id_slot: id);
1691 if (!res) {
1692 SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
1693 }
1694 }
1695
1696 std::vector<common_adapter_lora_info> lora;
1697 int32_t alora_invocation_start = -1;
1698
1699 // sampling
1700 json json_schema;
1701
1702 struct common_sampler * smpl = nullptr;
1703
1704 llama_token sampled;
1705
1706 common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1707 std::vector<std::string> generated_tool_call_ids;
1708
1709 // stats
1710 size_t n_sent_text = 0; // number of sent text character
1711
1712 int64_t t_start_process_prompt;
1713 int64_t t_start_generation;
1714
1715 double t_prompt_processing; // ms
1716 double t_token_generation; // ms
1717
1718 std::function<void(int)> callback_on_release;
1719
1720 // Speculative decoding stats
1721 int32_t n_draft_total = 0; // Total draft tokens generated
1722 int32_t n_draft_accepted = 0; // Draft tokens actually accepted
1723
1724 void reset() {
1725 SLT_DBG(*this, "%s", "\n");
1726
1727 n_prompt_tokens_cache = 0;
1728
1729 last_nl_pos = 0;
1730 generated_text = "";
1731 has_new_line = false;
1732 truncated = false;
1733 stop = STOP_TYPE_NONE;
1734 stopping_word = "";
1735 n_sent_text = 0;
1736 chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1737
1738 generated_tokens.clear();
1739 generated_token_probs.clear();
1740 chat_msg = {};
1741 json_schema = json();
1742 generated_tool_call_ids.clear();
1743
1744 // clear speculative decoding stats
1745 n_draft_total = 0;
1746 n_draft_accepted = 0;
1747
1748 task.reset();
1749 task_prev.reset();
1750
1751 // clear alora start
1752 alora_invocation_start = -1;
1753 }
1754
1755 bool need_embd() const {
1756 GGML_ASSERT(task);
1757
1758 return server_task_type_need_embd(task_type: task->type);
1759 }
1760
1761 bool need_logits() const {
1762 GGML_ASSERT(task);
1763
1764 return server_task_type_need_logits(task_type: task->type);
1765 }
1766
1767 // if the context does not have a memory module then all embeddings have to be computed within a single ubatch
1768 // also we cannot split if the pooling would require any past tokens
1769 bool can_split() const {
1770 return
1771 !need_embd() ||
1772 (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
1773 }
1774
1775 bool can_batch_with(server_slot & other_slot) const {
1776 GGML_ASSERT(task);
1777
1778 return task->type == other_slot.task->type && are_lora_equal(l1: lora, l2: other_slot.lora);
1779 }
1780
1781 bool has_budget(const common_params & global_params) {
1782 GGML_ASSERT(task);
1783
1784 if (task->params.n_predict == -1 && global_params.n_predict == -1) {
1785 return true; // limitless
1786 }
1787
1788 n_remaining = -1;
1789
1790 if (task->params.n_predict != -1) {
1791 n_remaining = task->params.n_predict - n_decoded;
1792 } else if (global_params.n_predict != -1) {
1793 n_remaining = global_params.n_predict - n_decoded;
1794 }
1795
1796 return n_remaining > 0; // no budget
1797 }
1798
1799 bool is_processing() const {
1800 return state != SLOT_STATE_IDLE;
1801 }
1802
1803 bool can_speculate() const {
1804 return ctx_dft;
1805 }
1806
1807 void add_token(const completion_token_output & token) {
1808 if (!is_processing()) {
1809 SLT_WRN(*this, "%s", "slot is not processing\n");
1810 return;
1811 }
1812 generated_token_probs.push_back(x: token);
1813 }
1814
1815 void release() {
1816 if (is_processing()) {
1817 GGML_ASSERT(task);
1818
1819 SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
1820
1821 t_last_used = ggml_time_us();
1822 t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
1823 state = SLOT_STATE_IDLE;
1824
1825 task_prev = std::move(task);
1826 task.reset();
1827
1828 callback_on_release(id);
1829 }
1830 }
1831
1832 result_timings get_timings() const {
1833 result_timings timings;
1834 timings.cache_n = n_prompt_tokens_cache;
1835
1836 timings.prompt_n = n_prompt_tokens_processed;
1837 timings.prompt_ms = t_prompt_processing;
1838 timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
1839 timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
1840
1841 timings.predicted_n = n_decoded;
1842 timings.predicted_ms = t_token_generation;
1843 timings.predicted_per_token_ms = t_token_generation / n_decoded;
1844 timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
1845
1846 // Add speculative metrics
1847 if (n_draft_total > 0) {
1848 timings.draft_n = n_draft_total;
1849 timings.draft_n_accepted = n_draft_accepted;
1850 }
1851
1852 return timings;
1853 }
1854
1855 const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
1856 GGML_ASSERT(task);
1857
1858 auto previous_msg = chat_msg;
1859 SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
1860 auto new_msg = common_chat_parse(
1861 input: generated_text,
1862 /* is_partial= */ stop != STOP_TYPE_EOS,
1863 syntax: task->params.oaicompat_chat_syntax);
1864 if (!new_msg.empty()) {
1865 new_msg.set_tool_call_ids(ids_cache&: generated_tool_call_ids, gen_tool_call_id);
1866 chat_msg = new_msg;
1867 diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg: new_msg.empty() ? previous_msg : new_msg);
1868 }
1869 return chat_msg;
1870 }
1871
1872 size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
1873 GGML_ASSERT(task);
1874
1875 size_t stop_pos = std::string::npos;
1876
1877 for (const std::string & word : task->params.antiprompt) {
1878 size_t pos;
1879
1880 if (is_full_stop) {
1881 const size_t tmp = word.size() + last_token_size;
1882 const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
1883
1884 pos = text.find(str: word, pos: from_pos);
1885 } else {
1886 // otherwise, partial stop
1887 pos = string_find_partial_stop(str: text, stop: word);
1888 }
1889
1890 if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
1891 if (is_full_stop) {
1892 stop = STOP_TYPE_WORD;
1893 stopping_word = word;
1894 has_next_token = false;
1895 }
1896 stop_pos = pos;
1897 }
1898 }
1899
1900 return stop_pos;
1901 }
1902
1903 void print_timings() const {
1904 const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
1905 const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
1906
1907 const double t_gen = t_token_generation / n_decoded;
1908 const double n_gen_second = 1e3 / t_token_generation * n_decoded;
1909
1910 SLT_INF(*this,
1911 "\n"
1912 "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
1913 " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
1914 " total time = %10.2f ms / %5d tokens\n",
1915 t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
1916 t_token_generation, n_decoded, t_gen, n_gen_second,
1917 t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
1918
1919 if (n_draft_total > 0) {
1920 const float draft_ratio = (float) n_draft_accepted / n_draft_total;
1921 SLT_INF(*this,
1922 "\n"
1923 "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
1924 draft_ratio, n_draft_accepted, n_draft_total
1925 );
1926 }
1927 }
1928
1929 json to_json(bool only_metrics = false) const {
1930 json res;
1931
1932 res = {
1933 {"id", id},
1934 {"n_ctx", n_ctx},
1935 {"speculative", can_speculate()},
1936 {"is_processing", is_processing()},
1937 };
1938
1939 const auto & ptask = task ? task : task_prev;
1940
1941 if (ptask) {
1942 res["id_task"] = ptask->id;
1943 res["params"] = ptask->params.to_json(only_metrics);
1944 res["next_token"] = {
1945 {
1946 {"has_next_token", has_next_token},
1947 {"has_new_line", has_new_line},
1948 {"n_remain", n_remaining},
1949 {"n_decoded", n_decoded},
1950 }
1951 };
1952
1953 if (!only_metrics) {
1954 res["prompt"] = ptask->tokens.detokenize(ctx, special: true);
1955 res["generated"] = generated_text;
1956 }
1957 }
1958
1959 return res;
1960 }
1961};
1962
1963struct server_metrics {
1964 int64_t t_start = 0;
1965
1966 uint64_t n_prompt_tokens_processed_total = 0;
1967 uint64_t t_prompt_processing_total = 0;
1968 uint64_t n_tokens_predicted_total = 0;
1969 uint64_t t_tokens_generation_total = 0;
1970
1971 uint64_t n_tokens_max = 0;
1972
1973 uint64_t n_prompt_tokens_processed = 0;
1974 uint64_t t_prompt_processing = 0;
1975
1976 uint64_t n_tokens_predicted = 0;
1977 uint64_t t_tokens_generation = 0;
1978
1979 uint64_t n_decode_total = 0;
1980 uint64_t n_busy_slots_total = 0;
1981
1982 void init() {
1983 t_start = ggml_time_us();
1984 }
1985
1986 void on_prompt_eval(const server_slot & slot) {
1987 n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
1988 n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
1989 t_prompt_processing += slot.t_prompt_processing;
1990 t_prompt_processing_total += slot.t_prompt_processing;
1991
1992 n_tokens_max = std::max(a: n_tokens_max, b: (uint64_t) slot.prompt.n_tokens());
1993 }
1994
1995 void on_prediction(const server_slot & slot) {
1996 n_tokens_predicted_total += slot.n_decoded;
1997 n_tokens_predicted += slot.n_decoded;
1998 t_tokens_generation += slot.t_token_generation;
1999 t_tokens_generation_total += slot.t_token_generation;
2000 }
2001
2002 void on_decoded(const std::vector<server_slot> & slots) {
2003 n_decode_total++;
2004 for (const auto & slot : slots) {
2005 if (slot.is_processing()) {
2006 n_busy_slots_total++;
2007 }
2008 n_tokens_max = std::max(a: n_tokens_max, b: (uint64_t) slot.prompt.n_tokens());
2009 }
2010 }
2011
2012 void reset_bucket() {
2013 n_prompt_tokens_processed = 0;
2014 t_prompt_processing = 0;
2015 n_tokens_predicted = 0;
2016 t_tokens_generation = 0;
2017 }
2018};
2019
2020struct server_queue {
2021 int id = 0;
2022 bool running;
2023
2024 // queues
2025 std::deque<server_task> queue_tasks;
2026 std::deque<server_task> queue_tasks_deferred;
2027
2028 std::mutex mutex_tasks;
2029 std::condition_variable condition_tasks;
2030
2031 // callback functions
2032 std::function<void(server_task &&)> callback_new_task;
2033 std::function<void(void)> callback_update_slots;
2034
2035 // Add a new task to the end of the queue
2036 int post(server_task && task, bool front = false) {
2037 std::unique_lock<std::mutex> lock(mutex_tasks);
2038 GGML_ASSERT(task.id != -1);
2039 // if this is cancel task make sure to clean up pending tasks
2040 if (task.type == SERVER_TASK_TYPE_CANCEL) {
2041 cleanup_pending_task(id_target: task.id_target);
2042 }
2043 const int task_id = task.id;
2044 QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
2045 if (front) {
2046 queue_tasks.push_front(x: std::move(task));
2047 } else {
2048 queue_tasks.push_back(x: std::move(task));
2049 }
2050 condition_tasks.notify_one();
2051 return task_id;
2052 }
2053
2054 // multi-task version of post()
2055 int post(std::vector<server_task> && tasks, bool front = false) {
2056 std::unique_lock<std::mutex> lock(mutex_tasks);
2057 for (auto & task : tasks) {
2058 if (task.id == -1) {
2059 task.id = id++;
2060 }
2061 // if this is cancel task make sure to clean up pending tasks
2062 if (task.type == SERVER_TASK_TYPE_CANCEL) {
2063 cleanup_pending_task(id_target: task.id_target);
2064 }
2065 QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
2066 if (front) {
2067 queue_tasks.push_front(x: std::move(task));
2068 } else {
2069 queue_tasks.push_back(x: std::move(task));
2070 }
2071 }
2072 condition_tasks.notify_one();
2073 return 0;
2074 }
2075
2076 // Add a new task, but defer until one slot is available
2077 void defer(server_task && task) {
2078 std::unique_lock<std::mutex> lock(mutex_tasks);
2079 QUE_DBG("defer task, id = %d\n", task.id);
2080 queue_tasks_deferred.push_back(x: std::move(task));
2081 condition_tasks.notify_one();
2082 }
2083
2084 // Get the next id for creating a new task
2085 int get_new_id() {
2086 std::unique_lock<std::mutex> lock(mutex_tasks);
2087 int new_id = id++;
2088 return new_id;
2089 }
2090
2091 // Register function to process a new task
2092 void on_new_task(std::function<void(server_task &&)> callback) {
2093 callback_new_task = std::move(callback);
2094 }
2095
2096 // Register the function to be called when all slots data is ready to be processed
2097 void on_update_slots(std::function<void(void)> callback) {
2098 callback_update_slots = std::move(callback);
2099 }
2100
2101 // Call when the state of one slot is changed, it will move one task from deferred to main queue
2102 void pop_deferred_task() {
2103 std::unique_lock<std::mutex> lock(mutex_tasks);
2104 if (!queue_tasks_deferred.empty()) {
2105 queue_tasks.emplace_front(args: std::move(queue_tasks_deferred.front()));
2106 queue_tasks_deferred.pop_front();
2107 }
2108 condition_tasks.notify_one();
2109 }
2110
2111 // end the start_loop routine
2112 void terminate() {
2113 std::unique_lock<std::mutex> lock(mutex_tasks);
2114 running = false;
2115 condition_tasks.notify_all();
2116 }
2117
2118 /**
2119 * Main loop consists of these steps:
2120 * - Wait until a new task arrives
2121 * - Process the task (i.e. maybe copy data into slot)
2122 * - Check if multitask is finished
2123 * - Update all slots
2124 */
2125 void start_loop() {
2126 running = true;
2127
2128 while (true) {
2129 QUE_DBG("%s", "processing new tasks\n");
2130
2131 while (true) {
2132 std::unique_lock<std::mutex> lock(mutex_tasks);
2133 if (!running) {
2134 QUE_DBG("%s", "terminate\n");
2135 return;
2136 }
2137 if (queue_tasks.empty()) {
2138 lock.unlock();
2139 break;
2140 }
2141 server_task task = std::move(queue_tasks.front());
2142 queue_tasks.pop_front();
2143 lock.unlock();
2144
2145 QUE_DBG("processing task, id = %d\n", task.id);
2146 callback_new_task(std::move(task));
2147 }
2148
2149 // all tasks in the current loop is processed, slots data is now ready
2150 QUE_DBG("%s", "update slots\n");
2151
2152 callback_update_slots();
2153
2154 QUE_DBG("%s", "waiting for new tasks\n");
2155 {
2156 std::unique_lock<std::mutex> lock(mutex_tasks);
2157 if (!running) {
2158 QUE_DBG("%s", "terminate\n");
2159 return;
2160 }
2161 if (queue_tasks.empty()) {
2162 condition_tasks.wait(lock&: lock, p: [&]{
2163 return (!queue_tasks.empty() || !running);
2164 });
2165 }
2166 }
2167 }
2168 }
2169
2170private:
2171 void cleanup_pending_task(int id_target) {
2172 // no need lock because this is called exclusively by post()
2173 auto rm_func = [id_target](const server_task & task) {
2174 return task.id == id_target;
2175 };
2176 queue_tasks.erase(
2177 first: std::remove_if(first: queue_tasks.begin(), last: queue_tasks.end(), pred: rm_func),
2178 last: queue_tasks.end());
2179 queue_tasks_deferred.erase(
2180 first: std::remove_if(first: queue_tasks_deferred.begin(), last: queue_tasks_deferred.end(), pred: rm_func),
2181 last: queue_tasks_deferred.end());
2182 }
2183};
2184
2185struct server_response {
2186 bool running = true;
2187
2188 // for keeping track of all tasks waiting for the result
2189 std::unordered_set<int> waiting_task_ids;
2190
2191 // the main result queue (using ptr for polymorphism)
2192 std::vector<server_task_result_ptr> queue_results;
2193
2194 std::mutex mutex_results;
2195 std::condition_variable condition_results;
2196
2197 // add the id_task to the list of tasks waiting for response
2198 void add_waiting_task_id(int id_task) {
2199 SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
2200
2201 std::unique_lock<std::mutex> lock(mutex_results);
2202 waiting_task_ids.insert(x: id_task);
2203 }
2204
2205 void add_waiting_tasks(const std::vector<server_task> & tasks) {
2206 std::unique_lock<std::mutex> lock(mutex_results);
2207
2208 for (const auto & task : tasks) {
2209 SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
2210 waiting_task_ids.insert(x: task.id);
2211 }
2212 }
2213
2214 // when the request is finished, we can remove task associated with it
2215 void remove_waiting_task_id(int id_task) {
2216 SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
2217
2218 std::unique_lock<std::mutex> lock(mutex_results);
2219 waiting_task_ids.erase(x: id_task);
2220 // make sure to clean up all pending results
2221 queue_results.erase(
2222 first: std::remove_if(first: queue_results.begin(), last: queue_results.end(), pred: [id_task](const server_task_result_ptr & res) {
2223 return res->id == id_task;
2224 }),
2225 last: queue_results.end());
2226 }
2227
2228 void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
2229 std::unique_lock<std::mutex> lock(mutex_results);
2230
2231 for (const auto & id_task : id_tasks) {
2232 SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
2233 waiting_task_ids.erase(x: id_task);
2234 }
2235 }
2236
2237 // This function blocks the thread until there is a response for one of the id_tasks
2238 server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
2239 while (true) {
2240 std::unique_lock<std::mutex> lock(mutex_results);
2241 condition_results.wait(lock&: lock, p: [&]{
2242 if (!running) {
2243 SRV_DBG("%s : queue result stop\n", __func__);
2244 std::terminate(); // we cannot return here since the caller is HTTP code
2245 }
2246 return !queue_results.empty();
2247 });
2248
2249 for (size_t i = 0; i < queue_results.size(); i++) {
2250 if (id_tasks.find(x: queue_results[i]->id) != id_tasks.end()) {
2251 server_task_result_ptr res = std::move(queue_results[i]);
2252 queue_results.erase(position: queue_results.begin() + i);
2253 return res;
2254 }
2255 }
2256 }
2257
2258 // should never reach here
2259 }
2260
2261 // same as recv(), but have timeout in seconds
2262 // if timeout is reached, nullptr is returned
2263 server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
2264 while (true) {
2265 std::unique_lock<std::mutex> lock(mutex_results);
2266
2267 for (int i = 0; i < (int) queue_results.size(); i++) {
2268 if (id_tasks.find(x: queue_results[i]->id) != id_tasks.end()) {
2269 server_task_result_ptr res = std::move(queue_results[i]);
2270 queue_results.erase(position: queue_results.begin() + i);
2271 return res;
2272 }
2273 }
2274
2275 std::cv_status cr_res = condition_results.wait_for(lock&: lock, rtime: std::chrono::seconds(timeout));
2276 if (!running) {
2277 SRV_DBG("%s : queue result stop\n", __func__);
2278 std::terminate(); // we cannot return here since the caller is HTTP code
2279 }
2280 if (cr_res == std::cv_status::timeout) {
2281 return nullptr;
2282 }
2283 }
2284
2285 // should never reach here
2286 }
2287
2288 // single-task version of recv()
2289 server_task_result_ptr recv(int id_task) {
2290 std::unordered_set<int> id_tasks = {id_task};
2291 return recv(id_tasks);
2292 }
2293
2294 // Send a new result to a waiting id_task
2295 void send(server_task_result_ptr && result) {
2296 SRV_DBG("sending result for task id = %d\n", result->id);
2297
2298 std::unique_lock<std::mutex> lock(mutex_results);
2299 for (const auto & id_task : waiting_task_ids) {
2300 if (result->id == id_task) {
2301 SRV_DBG("task id = %d pushed to result queue\n", result->id);
2302
2303 queue_results.emplace_back(args: std::move(result));
2304 condition_results.notify_all();
2305 return;
2306 }
2307 }
2308 }
2309
2310 // terminate the waiting loop
2311 void terminate() {
2312 running = false;
2313 condition_results.notify_all();
2314 }
2315};
2316
2317struct server_context {
2318 common_params params_base;
2319
2320 // note: keep these alive - they determine the lifetime of the model, context, etc.
2321 common_init_result llama_init;
2322 common_init_result llama_init_dft;
2323
2324 llama_model * model = nullptr;
2325 llama_context * ctx = nullptr;
2326
2327 // multimodal
2328 mtmd_context * mctx = nullptr;
2329
2330 const llama_vocab * vocab = nullptr;
2331 bool vocab_dft_compatible = true;
2332
2333 llama_model * model_dft = nullptr;
2334
2335 llama_context_params cparams_dft;
2336
2337 llama_batch batch {};
2338
2339 bool clean_kv_cache = true;
2340 bool add_bos_token = true;
2341
2342 int32_t n_ctx; // total context for all clients / slots
2343
2344 // slots / clients
2345 std::vector<server_slot> slots;
2346
2347 int slots_debug = 0;
2348
2349 server_queue queue_tasks;
2350 server_response queue_results;
2351
2352 std::unique_ptr<server_prompt_cache> prompt_cache;
2353
2354 server_metrics metrics;
2355
2356 // Necessary similarity of prompt for slot selection
2357 float slot_prompt_similarity = 0.0f;
2358
2359 common_chat_templates_ptr chat_templates;
2360 oaicompat_parser_options oai_parser_opt;
2361
2362 ~server_context() {
2363 mtmd_free(ctx: mctx);
2364
2365 // Clear any sampling context
2366 for (server_slot & slot : slots) {
2367 common_sampler_free(gsmpl: slot.smpl);
2368 slot.smpl = nullptr;
2369
2370 llama_free(ctx: slot.ctx_dft);
2371 slot.ctx_dft = nullptr;
2372
2373 common_speculative_free(spec: slot.spec);
2374 slot.spec = nullptr;
2375
2376 llama_batch_free(batch: slot.batch_spec);
2377 }
2378
2379 llama_batch_free(batch);
2380 }
2381
2382 bool load_model(const common_params & params) {
2383 SRV_INF("loading model '%s'\n", params.model.path.c_str());
2384
2385 params_base = params;
2386
2387 llama_init = common_init_from_params(params&: params_base);
2388
2389 model = llama_init.model.get();
2390 ctx = llama_init.context.get();
2391
2392 if (model == nullptr) {
2393 SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
2394 return false;
2395 }
2396
2397 vocab = llama_model_get_vocab(model);
2398
2399 n_ctx = llama_n_ctx(ctx);
2400
2401 add_bos_token = llama_vocab_get_add_bos(vocab);
2402
2403 if (params_base.has_speculative()) {
2404 SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
2405
2406 auto params_dft = params_base;
2407
2408 params_dft.devices = params_base.speculative.devices;
2409 params_dft.model = params_base.speculative.model;
2410 params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
2411 params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
2412 params_dft.n_parallel = 1;
2413 params_dft.cache_type_k = params_base.speculative.cache_type_k;
2414 params_dft.cache_type_v = params_base.speculative.cache_type_v;
2415
2416 params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
2417 params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
2418 params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
2419
2420 llama_init_dft = common_init_from_params(params&: params_dft);
2421
2422 model_dft = llama_init_dft.model.get();
2423
2424 if (model_dft == nullptr) {
2425 SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
2426 return false;
2427 }
2428
2429 vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt: ctx, ctx_dft: llama_init_dft.context.get());
2430 if (!vocab_dft_compatible) {
2431 SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
2432 }
2433
2434 const int n_ctx_dft = llama_n_ctx(ctx: llama_init_dft.context.get());
2435
2436 cparams_dft = common_context_params_to_llama(params: params_dft);
2437 cparams_dft.n_batch = n_ctx_dft;
2438
2439 // the context is not needed - we will create one for each slot
2440 llama_init_dft.context.reset();
2441 }
2442
2443 chat_templates = common_chat_templates_init(model, chat_template_override: params_base.chat_template);
2444 try {
2445 common_chat_format_example(tmpls: chat_templates.get(), use_jinja: params.use_jinja, chat_template_kwargs: params.default_template_kwargs);
2446 } catch (const std::exception & e) {
2447 SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
2448 SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
2449 chat_templates = common_chat_templates_init(model, chat_template_override: "chatml");
2450 }
2451
2452 std::string & mmproj_path = params_base.mmproj.path;
2453 if (!mmproj_path.empty()) {
2454 mtmd_context_params mparams = mtmd_context_params_default();
2455 mparams.use_gpu = params_base.mmproj_use_gpu;
2456 mparams.print_timings = false;
2457 mparams.n_threads = params_base.cpuparams.n_threads;
2458 mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
2459 mparams.flash_attn_type = params_base.flash_attn_type;
2460 mparams.image_min_tokens = params_base.image_min_tokens;
2461 mparams.image_max_tokens = params_base.image_max_tokens;
2462 mctx = mtmd_init_from_file(mmproj_fname: mmproj_path.c_str(), text_model: model, ctx_params: mparams);
2463 if (mctx == nullptr) {
2464 SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
2465 return false;
2466 }
2467 SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
2468
2469 if (params_base.ctx_shift) {
2470 params_base.ctx_shift = false;
2471 SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
2472 }
2473
2474 if (params_base.n_cache_reuse) {
2475 params_base.n_cache_reuse = 0;
2476 SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
2477 }
2478
2479 if (params_base.has_speculative()) {
2480 SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
2481 return false;
2482 }
2483 }
2484
2485 if (!llama_memory_can_shift(mem: llama_get_memory(ctx))) {
2486 if (params_base.ctx_shift) {
2487 params_base.ctx_shift = false;
2488 SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
2489 }
2490
2491 if (params_base.n_cache_reuse) {
2492 params_base.n_cache_reuse = 0;
2493 SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
2494 }
2495 }
2496
2497 return true;
2498 }
2499
2500 void init() {
2501 SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
2502
2503 const int n_ctx_train = llama_model_n_ctx_train(model);
2504
2505 int n_ctx_slot = llama_n_ctx_seq(ctx);
2506 if (n_ctx_slot > n_ctx_train) {
2507 SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
2508 n_ctx_slot = n_ctx_train;
2509 }
2510
2511 for (int i = 0; i < params_base.n_parallel; i++) {
2512 server_slot slot;
2513
2514 slot.id = i;
2515 slot.ctx = ctx;
2516 slot.n_ctx = n_ctx_slot;
2517 slot.mctx = mctx;
2518 slot.prompt.tokens.has_mtmd = mctx != nullptr;
2519
2520 if (model_dft) {
2521 slot.batch_spec = llama_batch_init(n_tokens: params_base.speculative.n_max + 1, embd: 0, n_seq_max: 1);
2522
2523 // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
2524 slot.ctx_dft = llama_init_from_model(model: model_dft, params: cparams_dft);
2525 if (slot.ctx_dft == nullptr) {
2526 SRV_ERR("%s", "failed to create draft context\n");
2527 return;
2528 }
2529
2530 slot.spec = common_speculative_init(ctx_tgt: slot.ctx, ctx_dft: slot.ctx_dft);
2531 if (slot.spec == nullptr) {
2532 SRV_ERR("%s", "failed to create speculator\n");
2533 return;
2534 }
2535 for (auto & pair : params_base.speculative.replacements) {
2536 common_speculative_add_replacement_tgt_dft(spec: slot.spec, source: pair.first.c_str(), dest: pair.second.c_str());
2537 }
2538 }
2539
2540 SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
2541
2542 slot.callback_on_release = [this](int) {
2543 queue_tasks.pop_deferred_task();
2544 };
2545
2546 slot.reset();
2547
2548 slots.push_back(x: std::move(slot));
2549 }
2550
2551 {
2552 const char * LLAMA_SERVER_SLOTS_DEBUG = getenv(name: "LLAMA_SERVER_SLOTS_DEBUG");
2553 slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(nptr: LLAMA_SERVER_SLOTS_DEBUG) : 0;
2554
2555 if (slots_debug) {
2556 SRV_WRN("slots debug = %d\n", slots_debug);
2557 }
2558 }
2559
2560 // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
2561 // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
2562 {
2563 const int32_t n_batch = llama_n_batch(ctx);
2564 batch = llama_batch_init(n_tokens: std::max(a: n_batch, b: params_base.n_parallel), embd: 0, n_seq_max: 1);
2565 }
2566
2567 metrics.init();
2568
2569 if (params_base.cache_ram_mib != 0) {
2570 if (params_base.cache_ram_mib < 0) {
2571 SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
2572 } else {
2573 SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib);
2574 }
2575 SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n");
2576
2577 prompt_cache = std::make_unique<server_prompt_cache>(args&: params_base.cache_ram_mib, args&: n_ctx);
2578 } else {
2579 SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n");
2580 }
2581 SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
2582
2583 // thinking is enabled if:
2584 // 1. It's not explicitly disabled (reasoning_budget == 0)
2585 // 2. The chat template supports it
2586 const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates: chat_templates.get());
2587 SRV_INF("thinking = %d\n", enable_thinking);
2588
2589 oai_parser_opt = {
2590 /* use_jinja */ params_base.use_jinja,
2591 /* prefill_assistant */ params_base.prefill_assistant,
2592 /* reasoning_format */ params_base.reasoning_format,
2593 /* chat_template_kwargs */ params_base.default_template_kwargs,
2594 /* common_chat_templates */ .tmpls: chat_templates.get(),
2595 /* allow_image */ mctx ? mtmd_support_vision(ctx: mctx) : false,
2596 /* allow_audio */ mctx ? mtmd_support_audio (ctx: mctx) : false,
2597 /* enable_thinking */ enable_thinking,
2598 };
2599 }
2600
2601 server_slot * get_slot_by_id(int id) {
2602 for (server_slot & slot : slots) {
2603 if (slot.id == id) {
2604 return &slot;
2605 }
2606 }
2607
2608 return nullptr;
2609 }
2610
2611 server_slot * get_available_slot(const server_task & task) {
2612 server_slot * ret = nullptr;
2613
2614 bool update_cache = false;
2615
2616 // find the slot that has at least n% prompt similarity
2617 if (ret == nullptr && slot_prompt_similarity != 0.0f) {
2618 float sim_best = 0;
2619
2620 for (server_slot & slot : slots) {
2621 // skip the slot if it is not available
2622 if (slot.is_processing()) {
2623 continue;
2624 }
2625
2626 const auto & tokens = slot.prompt.tokens;
2627
2628 // skip the slot if it does not contains cached tokens
2629 if (tokens.empty()) {
2630 continue;
2631 }
2632
2633 // fraction of the Longest Common Prefix length with respect to the input prompt length
2634 const float sim_cur = float(tokens.get_common_prefix(b: task.tokens)) / task.tokens.size();
2635
2636 // select the current slot if the criteria match
2637 if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) {
2638 sim_best = sim_cur;
2639
2640 ret = &slot;
2641 }
2642 }
2643
2644 if (ret != nullptr) {
2645 const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size();
2646
2647 SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n",
2648 sim_best, slot_prompt_similarity, f_keep);
2649
2650 // if we are about to lose a large portion of the existing context - save it in the prompt cache
2651 if (f_keep < 0.5f) {
2652 update_cache = true;
2653 }
2654 }
2655 }
2656
2657 // find the slot that has been least recently used
2658 if (ret == nullptr) {
2659 int64_t t_last = -1;
2660
2661 for (server_slot & slot : slots) {
2662 // skip the slot if it is not available
2663 if (slot.is_processing()) {
2664 continue;
2665 }
2666
2667 // select the current slot if the criteria match
2668 if (!ret || slot.t_last_used <= t_last) {
2669 t_last = slot.t_last_used;
2670 ret = &slot;
2671 }
2672 }
2673
2674 if (ret != nullptr) {
2675 SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
2676
2677 update_cache = true;
2678 }
2679 }
2680
2681 if (ret) {
2682 const auto & tokens = ret->prompt.tokens;
2683
2684 update_cache = update_cache && prompt_cache;
2685
2686 // cache prompts only for completion tasks
2687 update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
2688
2689 // don't update the cache if the slot's context is empty
2690 update_cache = update_cache && tokens.size() > 0;
2691
2692 // TODO: mtmd does not support prompt cache
2693 update_cache = update_cache && (ret->mctx == nullptr);
2694
2695 if (update_cache) {
2696 SRV_WRN("%s", "updating prompt cache\n");
2697
2698 const int64_t t_start = ggml_time_us();
2699
2700 ret->prompt_save(prompt_cache&: *prompt_cache);
2701 ret->prompt_load(prompt_cache&: *prompt_cache, tokens: task.tokens);
2702
2703 prompt_cache->update();
2704
2705 SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
2706 }
2707 }
2708
2709 return ret;
2710 }
2711
2712 // return true if at least one slot has been purged
2713 // TODO: improve logic
2714 // - smarter decision which slot to purge (LRU or longest prompt?)
2715 // - move slot to level 2 cache instead of removing?
2716 // - instead of purging, try to store and resume later?
2717 bool try_purge_idle_slots() {
2718 bool res = false;
2719
2720 if (!params_base.kv_unified) {
2721 return res;
2722 }
2723
2724 for (auto & slot : slots) {
2725 if (slot.is_processing()) {
2726 continue;
2727 }
2728
2729 if (slot.prompt.n_tokens() > 0) {
2730 SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
2731
2732 llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0: -1, p1: -1);
2733 slot.prompt.tokens.clear();
2734
2735 res = true;
2736
2737 // purge slots one by one
2738 break;
2739 }
2740 }
2741
2742 return res;
2743 }
2744
2745 bool launch_slot_with_task(server_slot & slot, server_task && task) {
2746 slot.reset();
2747
2748 if (!are_lora_equal(l1: task.params.lora, l2: slot.lora)) {
2749 // if lora has changed, check to see if the cache should be cleared
2750 if (lora_should_clear_cache(current: slot.lora, next: task.params.lora)) {
2751 SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size());
2752 slot.prompt.tokens.clear();
2753 } else {
2754 SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size());
2755 }
2756 slot.lora = task.params.lora;
2757 }
2758
2759 // if using alora, make sure it's only a single one requested and active
2760 size_t alora_invocation_start = task.tokens.size();
2761 if (lora_all_alora(loras: slot.lora)) {
2762 const auto & enabled_ids = lora_get_enabled_ids(loras: slot.lora);
2763 // TODO: This will error out if a user requests two aloras, but only
2764 // provides the activation string for one. We could, instead search
2765 // for all requested alora activation strings and then either keep
2766 // only the last one, or reject if multiple are found.
2767 if (enabled_ids.size() != 1) {
2768 send_error(task, error: "Cannot run multiple aLoRAs in a single request", type: ERROR_TYPE_INVALID_REQUEST);
2769 return false;
2770 }
2771 const auto & lora = slot.lora[enabled_ids[0]].ptr;
2772
2773 // get the pointer and count for the invocation tokens
2774 const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(adapter: lora);
2775 const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (adapter: lora);
2776
2777 // scan backwards through the prompt tokens to find the last
2778 // occurrence of the invocation sequence
2779 int match_idx = static_cast<int>(n_invocation_tokens) - 1;
2780 for (int i = task.tokens.size() - 1; i >= 0; --i) {
2781 // the token in this position matches the next token to find in
2782 // the invocation sequence
2783 if (task.tokens[i] == invocation_tokens[match_idx]) {
2784 // if it's a full match, we've found the start
2785 if (match_idx == 0) {
2786 alora_invocation_start = i;
2787 break;
2788 }
2789 // otherwise, check the next token in the sequence
2790 --match_idx;
2791 } else {
2792 // no match in this position, so start looking over again
2793 match_idx = static_cast<int>(n_invocation_tokens) - 1;
2794 }
2795 }
2796
2797 // if the activation string is not found, disable the alora
2798 if (alora_invocation_start == task.tokens.size()) {
2799 SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
2800 slot.lora[enabled_ids[0]].scale = 0.0f;
2801 } else {
2802 SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
2803 slot.alora_invocation_start = alora_invocation_start;
2804 }
2805 }
2806
2807 if (!task.tokens.validate(ctx)) {
2808 send_error(task, error: "Prompt contains invalid tokens", type: ERROR_TYPE_INVALID_REQUEST);
2809 return false;
2810 }
2811
2812 SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
2813
2814 // initialize samplers
2815 {
2816 if (slot.smpl != nullptr) {
2817 common_sampler_free(gsmpl: slot.smpl);
2818 }
2819
2820 slot.smpl = common_sampler_init(model, params: task.params.sampling);
2821 if (slot.smpl == nullptr) {
2822 // for now, the only error that may happen here is invalid grammar
2823 send_error(task, error: "Failed to parse grammar", type: ERROR_TYPE_INVALID_REQUEST);
2824 return false;
2825 }
2826
2827 SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
2828 }
2829
2830 // initialize draft batch
2831 // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
2832 if (slot.ctx_dft) {
2833 llama_batch_free(batch: slot.batch_spec);
2834
2835 slot.batch_spec = llama_batch_init(n_tokens: task.params.speculative.n_max + 1, embd: 0, n_seq_max: 1);
2836 }
2837
2838 slot.task = std::make_unique<const server_task>(args: std::move(task));
2839
2840 slot.state = SLOT_STATE_STARTED;
2841
2842 SLT_INF(slot, "%s", "processing task\n");
2843
2844 return true;
2845 }
2846
2847 void kv_cache_clear() {
2848 SRV_DBG("%s", "clearing KV cache\n");
2849
2850 // clear the entire KV cache
2851 llama_memory_clear(mem: llama_get_memory(ctx), data: true);
2852 clean_kv_cache = false;
2853 }
2854
2855 bool process_token(completion_token_output & result, server_slot & slot) {
2856 // remember which tokens were sampled - used for repetition penalties during sampling
2857 const std::string token_str = result.text_to_send;
2858 slot.sampled = result.tok;
2859
2860 slot.generated_text += token_str;
2861 if (slot.task->params.return_tokens) {
2862 slot.generated_tokens.push_back(x: result.tok);
2863 }
2864 slot.has_next_token = true;
2865
2866 // check if there is incomplete UTF-8 character at the end
2867 bool incomplete = validate_utf8(text: slot.generated_text) < slot.generated_text.size();
2868
2869 // search stop word and delete it
2870 if (!incomplete) {
2871 size_t pos = std::min(a: slot.n_sent_text, b: slot.generated_text.size());
2872
2873 const std::string str_test = slot.generated_text.substr(pos: pos);
2874 bool send_text = true;
2875
2876 size_t stop_pos = slot.find_stopping_strings(text: str_test, last_token_size: token_str.size(), is_full_stop: true);
2877 if (stop_pos != std::string::npos) {
2878 slot.generated_text.erase(
2879 first: slot.generated_text.begin() + pos + stop_pos,
2880 last: slot.generated_text.end());
2881 pos = std::min(a: slot.n_sent_text, b: slot.generated_text.size());
2882 } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, token: result.tok) ) {
2883 stop_pos = slot.find_stopping_strings(text: str_test, last_token_size: token_str.size(), is_full_stop: false);
2884 send_text = stop_pos == std::string::npos;
2885 }
2886
2887 // check if there is any token to predict
2888 if (send_text) {
2889 // no send the stop word in the response
2890 result.text_to_send = slot.generated_text.substr(pos: pos, n: std::string::npos);
2891 slot.n_sent_text += result.text_to_send.size();
2892 // add the token to slot queue and cache
2893 } else {
2894 result.text_to_send = "";
2895 }
2896
2897 slot.add_token(token: result);
2898 if (slot.task->params.stream) {
2899 send_partial_response(slot, tkn: result, is_progress: false);
2900 }
2901 }
2902
2903 if (incomplete) {
2904 slot.has_next_token = true;
2905 }
2906
2907 // if context shifting is disabled, make sure that we don't run out of context
2908 if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
2909 slot.truncated = true;
2910 slot.stop = STOP_TYPE_LIMIT;
2911 slot.has_next_token = false;
2912
2913 SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
2914 slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
2915 }
2916
2917 // check the limits
2918 if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(global_params: params_base)) {
2919 slot.stop = STOP_TYPE_LIMIT;
2920 slot.has_next_token = false;
2921
2922 SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
2923 }
2924
2925 if (slot.has_new_line) {
2926 // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
2927 if (slot.task->params.n_indent > 0) {
2928 // check the current indentation
2929 // TODO: improve by not doing it more than once for each new line
2930 if (slot.last_nl_pos > 0) {
2931 size_t pos = slot.last_nl_pos;
2932
2933 int n_indent = 0;
2934 while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
2935 n_indent++;
2936 pos++;
2937 }
2938
2939 if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) {
2940 slot.stop = STOP_TYPE_LIMIT;
2941 slot.has_next_token = false;
2942
2943 // cut the last line
2944 slot.generated_text.erase(pos: pos, n: std::string::npos);
2945
2946 SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
2947 }
2948 }
2949
2950 // find the next new line
2951 {
2952 const size_t pos = slot.generated_text.find(c: '\n', pos: slot.last_nl_pos);
2953
2954 if (pos != std::string::npos) {
2955 slot.last_nl_pos = pos + 1;
2956 }
2957 }
2958 }
2959 }
2960
2961 // check if there is a new line in the generated text
2962 if (result.text_to_send.find(c: '\n') != std::string::npos) {
2963 slot.has_new_line = true;
2964
2965 // if we have seen a new line, we stop after a certain time limit, but only upon another new line
2966 if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) {
2967 slot.stop = STOP_TYPE_LIMIT;
2968 slot.has_next_token = false;
2969
2970 SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms);
2971 }
2972 }
2973
2974 if (llama_vocab_is_eog(vocab, token: result.tok)) {
2975 slot.stop = STOP_TYPE_EOS;
2976 slot.has_next_token = false;
2977
2978 SLT_DBG(slot, "%s", "stopped by EOS\n");
2979 }
2980
2981 SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
2982
2983 return slot.has_next_token; // continue
2984 }
2985
2986 void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
2987 size_t n_probs = slot.task->params.sampling.n_probs;
2988 size_t n_vocab = llama_vocab_n_tokens(vocab);
2989
2990 if (post_sampling) {
2991 const auto * cur_p = common_sampler_get_candidates(gsmpl: slot.smpl, do_sort: true);
2992 const size_t max_probs = cur_p->size;
2993
2994 // set probability for sampled token
2995 for (size_t i = 0; i < max_probs; i++) {
2996 if (cur_p->data[i].id == result.tok) {
2997 result.prob = cur_p->data[i].p;
2998 break;
2999 }
3000 }
3001
3002 // set probability for top n_probs tokens
3003 result.probs.reserve(n: max_probs);
3004 for (size_t i = 0; i < std::min(a: max_probs, b: n_probs); i++) {
3005 result.probs.push_back(x: {
3006 .tok: cur_p->data[i].id,
3007 .txt: common_token_to_piece(ctx, token: cur_p->data[i].id, special),
3008 .prob: cur_p->data[i].p
3009 });
3010 }
3011 } else {
3012 // TODO: optimize this with min-p optimization
3013 std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
3014
3015 // set probability for sampled token
3016 for (size_t i = 0; i < n_vocab; i++) {
3017 // set probability for sampled token
3018 if (cur[i].id == result.tok) {
3019 result.prob = cur[i].p;
3020 break;
3021 }
3022 }
3023
3024 // set probability for top n_probs tokens
3025 result.probs.reserve(n: n_probs);
3026 for (size_t i = 0; i < std::min(a: n_vocab, b: n_probs); i++) {
3027 result.probs.push_back(x: {
3028 .tok: cur[i].id,
3029 .txt: common_token_to_piece(ctx, token: cur[i].id, special),
3030 .prob: cur[i].p
3031 });
3032 }
3033 }
3034 }
3035
3036 void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
3037 send_error(id_task: task.id, error, type);
3038 }
3039
3040 void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
3041 send_error(id_task: slot.task->id, error, type, n_prompt_tokens: slot.task->n_tokens(), n_ctx: slot.n_ctx);
3042 }
3043
3044 void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
3045 SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
3046
3047 if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
3048 GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
3049 }
3050
3051 auto res = std::make_unique<server_task_result_error>();
3052 res->id = id_task;
3053 res->err_type = type;
3054 res->err_msg = error;
3055 res->n_prompt_tokens = n_prompt_tokens;
3056 res->n_ctx = n_ctx;
3057
3058 queue_results.send(result: std::move(res));
3059 }
3060
3061 // if multimodal is enabled, send an error and return false
3062 bool check_no_mtmd(const int id_task) {
3063 if (mctx) {
3064 send_error(id_task, error: "This feature is not supported by multimodal", type: ERROR_TYPE_NOT_SUPPORTED);
3065 return false;
3066 }
3067 return true;
3068 }
3069
3070 void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
3071 auto res = std::make_unique<server_task_result_cmpl_partial>();
3072
3073 res->id = slot.task->id;
3074 res->index = slot.task->index;
3075
3076 if (is_progress) {
3077 res->is_progress = true;
3078 res->progress.total = slot.task->n_tokens();
3079 res->progress.cache = slot.n_prompt_tokens_cache;
3080 res->progress.processed = slot.prompt.tokens.size();
3081 res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000;
3082 } else {
3083 res->content = tkn.text_to_send;
3084 res->tokens = { tkn.tok };
3085
3086 slot.update_chat_msg(diffs&: res->oaicompat_msg_diffs);
3087 }
3088
3089 res->n_decoded = slot.n_decoded;
3090 res->n_prompt_tokens = slot.task->n_tokens();
3091 res->post_sampling_probs = slot.task->params.post_sampling_probs;
3092
3093 res->verbose = slot.task->params.verbose;
3094 res->oaicompat = slot.task->params.oaicompat;
3095 res->oaicompat_model = slot.task->params.oaicompat_model;
3096 res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
3097
3098 // populate res.probs_output
3099 if (slot.task->params.sampling.n_probs > 0) {
3100 res->prob_output = tkn; // copy the token probs
3101 }
3102
3103 // populate timings if this is final response or timings_per_token is enabled
3104 if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) {
3105 res->timings = slot.get_timings();
3106 }
3107
3108 queue_results.send(result: std::move(res));
3109 }
3110
3111 void send_final_response(server_slot & slot) {
3112 auto res = std::make_unique<server_task_result_cmpl_final>();
3113
3114 res->id = slot.task->id;
3115 res->id_slot = slot.id;
3116
3117 res->index = slot.task->index;
3118 res->content = slot.generated_text;
3119 res->tokens = std::move(slot.generated_tokens);
3120 res->timings = slot.get_timings();
3121 res->prompt = slot.task->tokens.detokenize(ctx, special: true);
3122 res->response_fields = std::move(slot.task->params.response_fields);
3123
3124 res->truncated = slot.truncated;
3125 res->n_decoded = slot.n_decoded;
3126 res->n_prompt_tokens = slot.task->n_tokens();
3127 res->n_tokens_cached = slot.prompt.n_tokens();
3128 res->has_new_line = slot.has_new_line;
3129 res->stopping_word = slot.stopping_word;
3130 res->stop = slot.stop;
3131 res->post_sampling_probs = slot.task->params.post_sampling_probs;
3132
3133 res->verbose = slot.task->params.verbose;
3134 res->stream = slot.task->params.stream;
3135 res->include_usage = slot.task->params.include_usage;
3136 res->oaicompat = slot.task->params.oaicompat;
3137 res->oaicompat_model = slot.task->params.oaicompat_model;
3138 res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
3139 res->oaicompat_msg = slot.update_chat_msg(diffs&: res->oaicompat_msg_diffs);
3140
3141 // populate res.probs_output
3142 if (slot.task->params.sampling.n_probs > 0) {
3143 if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) {
3144 const llama_tokens stop_word_toks = common_tokenize(ctx, text: slot.stopping_word, add_special: false);
3145
3146 size_t safe_offset = std::min(a: slot.generated_token_probs.size(), b: stop_word_toks.size());
3147 res->probs_output = std::vector<completion_token_output>(
3148 slot.generated_token_probs.begin(),
3149 slot.generated_token_probs.end() - safe_offset);
3150 } else {
3151 res->probs_output = std::vector<completion_token_output>(
3152 slot.generated_token_probs.begin(),
3153 slot.generated_token_probs.end());
3154 }
3155 }
3156
3157 res->generation_params = slot.task->params; // copy the parameters
3158
3159 queue_results.send(result: std::move(res));
3160 }
3161
3162 void send_embedding(const server_slot & slot, const llama_batch & batch) {
3163 auto res = std::make_unique<server_task_result_embd>();
3164 res->id = slot.task->id;
3165 res->index = slot.task->index;
3166 res->n_tokens = slot.task->n_tokens();
3167 res->oaicompat = slot.task->params.oaicompat;
3168
3169 const int n_embd = llama_model_n_embd(model);
3170
3171 std::vector<float> embd_res(n_embd, 0.0f);
3172
3173 for (int i = 0; i < batch.n_tokens; ++i) {
3174 if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
3175 continue;
3176 }
3177
3178 const float * embd = nullptr;
3179 if (llama_pooling_type(ctx: slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
3180 embd = llama_get_embeddings_ith(ctx, i);
3181 } else {
3182 embd = llama_get_embeddings_seq(ctx, seq_id: batch.seq_id[i][0]);
3183 }
3184
3185 if (embd == nullptr) {
3186 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
3187
3188 res->embedding.push_back(x: std::vector<float>(n_embd, 0.0f));
3189 continue;
3190 }
3191
3192 // normalize only when there is pooling
3193 if (llama_pooling_type(ctx: slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
3194 common_embd_normalize(inp: embd, out: embd_res.data(), n: n_embd, embd_norm: slot.task->params.embd_normalize);
3195 res->embedding.push_back(x: embd_res);
3196 break;
3197 }
3198
3199 res->embedding.emplace_back(args&: embd, args: embd + n_embd);
3200 }
3201
3202 SLT_DBG(slot, "%s", "sending embeddings\n");
3203
3204 queue_results.send(result: std::move(res));
3205 }
3206
3207 void send_rerank(const server_slot & slot, const llama_batch & batch) {
3208 auto res = std::make_unique<server_task_result_rerank>();
3209 res->id = slot.task->id;
3210 res->index = slot.task->index;
3211 res->n_tokens = slot.task->n_tokens();
3212
3213 for (int i = 0; i < batch.n_tokens; ++i) {
3214 if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
3215 continue;
3216 }
3217
3218 const float * embd = llama_get_embeddings_seq(ctx, seq_id: batch.seq_id[i][0]);
3219 if (embd == NULL) {
3220 embd = llama_get_embeddings_ith(ctx, i);
3221 }
3222
3223 if (embd == NULL) {
3224 SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
3225
3226 res->score = -1e6;
3227 continue;
3228 }
3229
3230 res->score = embd[0];
3231 }
3232
3233 SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
3234
3235 queue_results.send(result: std::move(res));
3236 }
3237
3238 //
3239 // Functions to create new task(s) and receive result(s)
3240 //
3241
3242 void cancel_tasks(const std::unordered_set<int> & id_tasks) {
3243 std::vector<server_task> cancel_tasks;
3244 cancel_tasks.reserve(n: id_tasks.size());
3245 for (const auto & id_task : id_tasks) {
3246 SRV_WRN("cancel task, id_task = %d\n", id_task);
3247
3248 server_task task(SERVER_TASK_TYPE_CANCEL);
3249 task.id_target = id_task;
3250 queue_results.remove_waiting_task_id(id_task);
3251 cancel_tasks.push_back(x: std::move(task));
3252 }
3253 // push to beginning of the queue, so it has highest priority
3254 queue_tasks.post(tasks: std::move(cancel_tasks), front: true);
3255 }
3256
3257 // receive the results from task(s)
3258 void receive_multi_results(
3259 const std::unordered_set<int> & id_tasks,
3260 const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
3261 const std::function<void(json)> & error_handler,
3262 const std::function<bool()> & is_connection_closed) {
3263 std::vector<server_task_result_ptr> results(id_tasks.size());
3264 for (int i = 0; i < (int)id_tasks.size(); i++) {
3265 server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, timeout: HTTP_POLLING_SECONDS);
3266
3267 if (is_connection_closed()) {
3268 cancel_tasks(id_tasks);
3269 return;
3270 }
3271
3272 if (result == nullptr) {
3273 i--; // retry
3274 continue;
3275 }
3276
3277 if (result->is_error()) {
3278 error_handler(result->to_json());
3279 cancel_tasks(id_tasks);
3280 return;
3281 }
3282
3283 GGML_ASSERT(
3284 dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
3285 || dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
3286 || dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
3287 );
3288 const size_t idx = result->get_index();
3289 GGML_ASSERT(idx < results.size() && "index out of range");
3290 results[idx] = std::move(result);
3291 }
3292 result_handler(results);
3293 }
3294
3295 // receive the results from task(s), in stream mode
3296 void receive_cmpl_results_stream(
3297 const std::unordered_set<int> & id_tasks,
3298 const std::function<bool(server_task_result_ptr&)> & result_handler,
3299 const std::function<void(json)> & error_handler,
3300 const std::function<bool()> & is_connection_closed) {
3301 size_t n_finished = 0;
3302 while (true) {
3303 server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, timeout: HTTP_POLLING_SECONDS);
3304
3305 if (is_connection_closed()) {
3306 cancel_tasks(id_tasks);
3307 return;
3308 }
3309
3310 if (result == nullptr) {
3311 continue; // retry
3312 }
3313
3314 if (result->is_error()) {
3315 error_handler(result->to_json());
3316 cancel_tasks(id_tasks);
3317 return;
3318 }
3319
3320 GGML_ASSERT(
3321 dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
3322 || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
3323 );
3324 if (!result_handler(result)) {
3325 cancel_tasks(id_tasks);
3326 break;
3327 }
3328
3329 if (result->is_stop()) {
3330 if (++n_finished == id_tasks.size()) {
3331 break;
3332 }
3333 }
3334 }
3335 }
3336
3337 //
3338 // Functions to process the task
3339 //
3340
3341 void process_single_task(server_task && task) {
3342 switch (task.type) {
3343 case SERVER_TASK_TYPE_COMPLETION:
3344 case SERVER_TASK_TYPE_INFILL:
3345 case SERVER_TASK_TYPE_EMBEDDING:
3346 case SERVER_TASK_TYPE_RERANK:
3347 {
3348 const int id_slot = task.id_slot;
3349
3350 server_slot * slot = id_slot != -1 ? get_slot_by_id(id: id_slot) : get_available_slot(task);
3351
3352 if (slot == nullptr) {
3353 // if no slot is available, we defer this task for processing later
3354 SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
3355 queue_tasks.defer(task: std::move(task));
3356 break;
3357 }
3358
3359 if (slot->is_processing()) {
3360 // if requested slot is unavailable, we defer this task for processing later
3361 SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
3362 queue_tasks.defer(task: std::move(task));
3363 break;
3364 }
3365
3366 if (!launch_slot_with_task(slot&: *slot, task: std::move(task))) {
3367 SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
3368 break;
3369 }
3370 } break;
3371 case SERVER_TASK_TYPE_CANCEL:
3372 {
3373 // release slot linked with the task id
3374 for (auto & slot : slots) {
3375 if (slot.task && slot.task->id == task.id_target) {
3376 slot.release();
3377 break;
3378 }
3379 }
3380 } break;
3381 case SERVER_TASK_TYPE_NEXT_RESPONSE:
3382 {
3383 // do nothing
3384 } break;
3385 case SERVER_TASK_TYPE_METRICS:
3386 {
3387 json slots_data = json::array();
3388
3389 int n_idle_slots = 0;
3390 int n_processing_slots = 0;
3391
3392 for (server_slot & slot : slots) {
3393 json slot_data = slot.to_json(only_metrics: slots_debug == 0);
3394
3395 if (slot.is_processing()) {
3396 n_processing_slots++;
3397 } else {
3398 n_idle_slots++;
3399 }
3400
3401 slots_data.push_back(val: slot_data);
3402 }
3403 SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
3404
3405 auto res = std::make_unique<server_task_result_metrics>();
3406 res->id = task.id;
3407 res->slots_data = std::move(slots_data);
3408 res->n_idle_slots = n_idle_slots;
3409 res->n_processing_slots = n_processing_slots;
3410 res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
3411 res->t_start = metrics.t_start;
3412
3413 res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
3414 res->t_prompt_processing_total = metrics.t_prompt_processing_total;
3415 res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
3416 res->t_tokens_generation_total = metrics.t_tokens_generation_total;
3417
3418 res->n_tokens_max = metrics.n_tokens_max;
3419
3420 res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
3421 res->t_prompt_processing = metrics.t_prompt_processing;
3422 res->n_tokens_predicted = metrics.n_tokens_predicted;
3423 res->t_tokens_generation = metrics.t_tokens_generation;
3424
3425 res->n_decode_total = metrics.n_decode_total;
3426 res->n_busy_slots_total = metrics.n_busy_slots_total;
3427
3428 if (task.metrics_reset_bucket) {
3429 metrics.reset_bucket();
3430 }
3431 queue_results.send(result: std::move(res));
3432 } break;
3433 case SERVER_TASK_TYPE_SLOT_SAVE:
3434 {
3435 if (!check_no_mtmd(id_task: task.id)) {
3436 break;
3437 }
3438
3439 int id_slot = task.slot_action.slot_id;
3440 server_slot * slot = get_slot_by_id(id: id_slot);
3441 if (slot == nullptr) {
3442 send_error(task, error: "Invalid slot ID", type: ERROR_TYPE_INVALID_REQUEST);
3443 break;
3444 }
3445 if (slot->is_processing()) {
3446 // if requested slot is unavailable, we defer this task for processing later
3447 SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
3448 queue_tasks.defer(task: std::move(task));
3449 break;
3450 }
3451
3452 const size_t token_count = slot->prompt.tokens.size();
3453 const int64_t t_start = ggml_time_us();
3454
3455 std::string filename = task.slot_action.filename;
3456 std::string filepath = task.slot_action.filepath;
3457
3458 const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
3459 const size_t nwrite = llama_state_seq_save_file(ctx, filepath: filepath.c_str(), seq_id: slot->id, tokens: tokens.data(), n_token_count: token_count);
3460
3461 const int64_t t_end = ggml_time_us();
3462 const double t_save_ms = (t_end - t_start) / 1000.0;
3463
3464 auto res = std::make_unique<server_task_result_slot_save_load>();
3465 res->id = task.id;
3466 res->id_slot = id_slot;
3467 res->filename = filename;
3468 res->is_save = true;
3469 res->n_tokens = token_count;
3470 res->n_bytes = nwrite;
3471 res->t_ms = t_save_ms;
3472 queue_results.send(result: std::move(res));
3473 } break;
3474 case SERVER_TASK_TYPE_SLOT_RESTORE:
3475 {
3476 if (!check_no_mtmd(id_task: task.id)) break;
3477 int id_slot = task.slot_action.slot_id;
3478 server_slot * slot = get_slot_by_id(id: id_slot);
3479 if (slot == nullptr) {
3480 send_error(task, error: "Invalid slot ID", type: ERROR_TYPE_INVALID_REQUEST);
3481 break;
3482 }
3483 if (slot->is_processing()) {
3484 // if requested slot is unavailable, we defer this task for processing later
3485 SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
3486 queue_tasks.defer(task: std::move(task));
3487 break;
3488 }
3489
3490 const int64_t t_start = ggml_time_us();
3491
3492 std::string filename = task.slot_action.filename;
3493 std::string filepath = task.slot_action.filepath;
3494
3495 llama_tokens tokens;
3496 tokens.resize(new_size: slot->n_ctx);
3497 size_t token_count = 0;
3498 size_t nread = llama_state_seq_load_file(ctx, filepath: filepath.c_str(), dest_seq_id: slot->id, tokens_out: tokens.data(), n_token_capacity: tokens.size(), n_token_count_out: &token_count);
3499 if (nread == 0) {
3500 slot->prompt.tokens.clear(); // KV may already been invalidated?
3501 send_error(task, error: "Unable to restore slot, no available space in KV cache or invalid slot save file", type: ERROR_TYPE_INVALID_REQUEST);
3502 break;
3503 }
3504 tokens.resize(new_size: token_count);
3505 slot->prompt.tokens.clear();
3506 slot->prompt.tokens.insert(inp_tokens: tokens);
3507
3508 const int64_t t_end = ggml_time_us();
3509 const double t_restore_ms = (t_end - t_start) / 1000.0;
3510
3511 auto res = std::make_unique<server_task_result_slot_save_load>();
3512 res->id = task.id;
3513 res->id_slot = id_slot;
3514 res->filename = filename;
3515 res->is_save = false;
3516 res->n_tokens = token_count;
3517 res->n_bytes = nread;
3518 res->t_ms = t_restore_ms;
3519 queue_results.send(result: std::move(res));
3520 } break;
3521 case SERVER_TASK_TYPE_SLOT_ERASE:
3522 {
3523 if (!check_no_mtmd(id_task: task.id)) {
3524 break;
3525 }
3526 int id_slot = task.slot_action.slot_id;
3527 server_slot * slot = get_slot_by_id(id: id_slot);
3528 if (slot == nullptr) {
3529 send_error(task, error: "Invalid slot ID", type: ERROR_TYPE_INVALID_REQUEST);
3530 break;
3531 }
3532 if (slot->is_processing()) {
3533 // if requested slot is unavailable, we defer this task for processing later
3534 SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
3535 queue_tasks.defer(task: std::move(task));
3536 break;
3537 }
3538
3539 // Erase token cache
3540 const size_t n_erased = slot->prompt.tokens.size();
3541 llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot->id, p0: -1, p1: -1);
3542 slot->prompt.tokens.clear();
3543
3544 auto res = std::make_unique<server_task_result_slot_erase>();
3545 res->id = task.id;
3546 res->id_slot = id_slot;
3547 res->n_erased = n_erased;
3548 queue_results.send(result: std::move(res));
3549 } break;
3550 case SERVER_TASK_TYPE_SET_LORA:
3551 {
3552 params_base.lora_adapters = std::move(task.set_lora);
3553 auto res = std::make_unique<server_task_result_apply_lora>();
3554 res->id = task.id;
3555 queue_results.send(result: std::move(res));
3556 } break;
3557
3558 }
3559 }
3560
3561 void update_slots() {
3562 // check if all slots are idle
3563 {
3564 bool all_idle = true;
3565
3566 for (auto & slot : slots) {
3567 if (slot.is_processing()) {
3568 all_idle = false;
3569 break;
3570 }
3571 }
3572
3573 if (all_idle) {
3574 SRV_INF("%s", "all slots are idle\n");
3575 if (clean_kv_cache) {
3576 kv_cache_clear();
3577 }
3578
3579 return;
3580 }
3581 }
3582
3583 {
3584 SRV_DBG("%s", "posting NEXT_RESPONSE\n");
3585
3586 server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
3587 task.id = queue_tasks.get_new_id();
3588 queue_tasks.post(task: std::move(task));
3589 }
3590
3591 // apply context-shift if needed
3592 // TODO: simplify and improve
3593 for (server_slot & slot : slots) {
3594 if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
3595 if (!params_base.ctx_shift) {
3596 // this check is redundant (for good)
3597 // we should never get here, because generation should already stopped in process_token()
3598 send_error(slot, error: "context shift is disabled", type: ERROR_TYPE_SERVER);
3599 slot.release();
3600 continue;
3601 }
3602
3603 if (mctx) {
3604 // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
3605 // we don't support ctx_shift because an image chunk may contains multiple tokens
3606 GGML_ABORT("not supported by multimodal");
3607 }
3608
3609 // Shift context
3610 int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
3611
3612 if (add_bos_token) {
3613 n_keep += 1;
3614 }
3615
3616 n_keep = std::min(a: slot.n_ctx - 4, b: n_keep);
3617
3618 const int n_left = slot.prompt.n_tokens() - n_keep;
3619 const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
3620
3621 SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
3622
3623 llama_memory_seq_rm (mem: llama_get_memory(ctx), seq_id: slot.id, p0: n_keep , p1: n_keep + n_discard);
3624 llama_memory_seq_add(mem: llama_get_memory(ctx), seq_id: slot.id, p0: n_keep + n_discard, p1: slot.prompt.n_tokens(), delta: -n_discard);
3625
3626 // add generated tokens to cache
3627 // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
3628 {
3629 GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
3630
3631 llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
3632 for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
3633 new_tokens[i - n_discard] = new_tokens[i];
3634 }
3635
3636 new_tokens.resize(new_size: slot.prompt.tokens.size() - n_discard);
3637
3638 slot.prompt.tokens.clear();
3639 slot.prompt.tokens.insert(inp_tokens: new_tokens);
3640 }
3641
3642 slot.truncated = true;
3643 }
3644 }
3645
3646 // start populating the batch for this iteration
3647 common_batch_clear(batch);
3648
3649 // track if given slot can be batched with slots already in the batch
3650 server_slot * slot_batched = nullptr;
3651
3652 auto accept_special_token = [&](server_slot & slot, llama_token token) {
3653 return params_base.special ||
3654 slot.task->params.sampling.preserved_tokens.find(x: token) != slot.task->params.sampling.preserved_tokens.end();
3655 };
3656
3657 // first, add sampled tokens from any ongoing sequences
3658 for (auto & slot : slots) {
3659 if (slot.state != SLOT_STATE_GENERATING) {
3660 continue;
3661 }
3662
3663 // check if we can batch this slot with the previous one
3664 if (!slot_batched) {
3665 slot_batched = &slot;
3666 } else if (!slot_batched->can_batch_with(other_slot&: slot)) {
3667 continue;
3668 }
3669
3670 slot.i_batch = batch.n_tokens;
3671
3672 common_batch_add(batch, id: slot.sampled, pos: slot.prompt.tokens.pos_next(), seq_ids: { slot.id }, logits: true);
3673
3674 slot.prompt.tokens.push_back(tok: slot.sampled);
3675
3676 SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
3677 slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
3678 }
3679
3680 // process in chunks of params.n_batch
3681 int32_t n_batch = llama_n_batch(ctx);
3682 int32_t n_ubatch = llama_n_ubatch(ctx);
3683
3684 float alora_scale = -1.0f;
3685 size_t alora_disabled_id = 0;
3686
3687 // next, batch any pending prompts without exceeding n_batch
3688 if (params_base.cont_batching || batch.n_tokens == 0) {
3689 for (auto & slot : slots) {
3690 // check if we can batch this slot with the previous one
3691 if (slot.is_processing()) {
3692 if (!slot_batched) {
3693 slot_batched = &slot;
3694 } else if (!slot_batched->can_batch_with(other_slot&: slot)) {
3695 continue;
3696 }
3697 }
3698
3699 // this slot still has a prompt to be processed
3700 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
3701 const auto & input_tokens = slot.task->tokens;
3702
3703 // TODO: maybe move branch to outside of this loop in the future
3704 if (slot.state == SLOT_STATE_STARTED) {
3705 slot.t_start_process_prompt = ggml_time_us();
3706 slot.t_start_generation = 0;
3707
3708 slot.state = SLOT_STATE_PROCESSING_PROMPT;
3709
3710 SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
3711 slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
3712
3713 // print prompt tokens (for debugging)
3714 /*if (1) {
3715 // first 16 tokens (avoid flooding logs)
3716 for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) {
3717 SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
3718 }
3719 } else {
3720 // all
3721 for (int i = 0; i < (int) input_tokens.size(); i++) {
3722 SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str());
3723 }
3724 }*/
3725
3726 // keep track how many tokens we can reuse from the previous state
3727 int n_past = 0;
3728
3729 // empty prompt passed -> release the slot and send empty response
3730 if (input_tokens.empty()) {
3731 SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
3732
3733 slot.print_timings();
3734 send_final_response(slot);
3735 slot.release();
3736
3737 continue;
3738 }
3739
3740 // TODO: support memory-less logits computation
3741 if (slot.need_logits() && !llama_get_memory(ctx)) {
3742 send_error(slot, error: "the current context does not logits computation. skipping", type: ERROR_TYPE_SERVER);
3743 slot.release();
3744 continue;
3745 }
3746
3747 if (!slot.can_split()) {
3748 if (slot.task->n_tokens() > n_ubatch) {
3749 send_error(slot, error: "input is too large to process. increase the physical batch size", type: ERROR_TYPE_SERVER);
3750 slot.release();
3751 continue;
3752 }
3753
3754 if (slot.task->n_tokens() > slot.n_ctx) {
3755 send_error(slot, error: "input is larger than the max context size. skipping", type: ERROR_TYPE_EXCEED_CONTEXT_SIZE);
3756 slot.release();
3757 continue;
3758 }
3759 } else {
3760 if (slot.task->n_tokens() >= slot.n_ctx) {
3761 send_error(slot, error: "the request exceeds the available context size, try increasing it", type: ERROR_TYPE_EXCEED_CONTEXT_SIZE);
3762 slot.release();
3763 continue;
3764 }
3765
3766 if (slot.task->params.cache_prompt) {
3767 // reuse any previously computed tokens that are common with the new prompt
3768 n_past = slot.prompt.tokens.get_common_prefix(b: input_tokens);
3769
3770 // if there is an alora invoked, don't cache after the invocation start
3771 if (slot.alora_invocation_start > 0) {
3772 SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
3773 n_past = std::min(a: n_past, b: slot.alora_invocation_start - 1);
3774 }
3775
3776 // reuse chunks from the cached prompt by shifting their KV cache in the new position
3777 if (params_base.n_cache_reuse > 0) {
3778 GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
3779
3780 size_t head_c = n_past; // cache
3781 size_t head_p = n_past; // current prompt
3782
3783 if (mctx) {
3784 // we should never reach this
3785 GGML_ABORT("not supported by multimodal");
3786 }
3787
3788 SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
3789
3790 while (head_c < slot.prompt.tokens.size() &&
3791 head_p < input_tokens.size()) {
3792
3793 size_t n_match = 0;
3794 while (head_c + n_match < slot.prompt.tokens.size() &&
3795 head_p + n_match < input_tokens.size() &&
3796 slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
3797
3798 n_match++;
3799 }
3800
3801 if (n_match >= (size_t) params_base.n_cache_reuse) {
3802 SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
3803 //for (size_t i = head_p; i < head_p + n_match; i++) {
3804 // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
3805 //}
3806
3807 const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
3808
3809 llama_memory_seq_rm (mem: llama_get_memory(ctx), seq_id: slot.id, p0: head_p, p1: head_c);
3810 llama_memory_seq_add(mem: llama_get_memory(ctx), seq_id: slot.id, p0: head_c, p1: head_c + n_match, delta: kv_shift);
3811
3812 for (size_t i = 0; i < n_match; i++) {
3813 slot.prompt.tokens.set_token(pos: head_p + i, id: slot.prompt.tokens[head_c + i]);
3814 n_past++;
3815 }
3816
3817 head_c += n_match;
3818 head_p += n_match;
3819 } else {
3820 head_c += 1;
3821 }
3822 }
3823
3824 SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
3825 }
3826 } else {
3827 // if we don't cache the prompt, we have to remove all previous tokens
3828 n_past = 0;
3829 }
3830
3831 // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3832 const auto n_swa = std::max(a: 1, b: llama_model_n_swa(model));
3833
3834 // the largest pos_min required for a checkpoint to be useful
3835 const auto pos_min_thold = std::max(a: 0, b: n_past - n_swa);
3836
3837 // note: disallow with mtmd contexts for now
3838 // https://github.com/ggml-org/llama.cpp/issues/17043
3839 if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
3840 const auto pos_min = llama_memory_seq_pos_min(mem: llama_get_memory(ctx), seq_id: slot.id);
3841 if (pos_min == -1) {
3842 SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
3843 GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
3844 }
3845
3846 // when the prompt prefix does not match, print the tokens around the mismatch
3847 // this is useful for debugging prompt caching
3848 if (slots_debug) {
3849 const int np0 = std::max<int>(a: n_past - 4, b: 0);
3850 const int np1 = std::min<int>(a: n_past + 6, b: std::min(a: slot.prompt.tokens.size(), b: slot.task->tokens.size()));
3851
3852 std::stringstream ss0;
3853 std::stringstream ss1;
3854
3855 std::stringstream st0;
3856 std::stringstream st1;
3857
3858 ss0 << "old: ... ";
3859 ss1 << "new: ... ";
3860
3861 for (int i = np0; i < np1; i++) {
3862 if (i == n_past) {
3863 ss0 << " | ";
3864 ss1 << " | ";
3865 }
3866
3867 {
3868 const auto token = slot.prompt.tokens[i];
3869 const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
3870 ss0 << piece;
3871 st0 << std::setw(8) << token;
3872 }
3873
3874 {
3875 const auto token = slot.task->tokens[i];
3876 const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
3877 ss1 << piece;
3878 st1 << std::setw(8) << token;
3879 }
3880 }
3881
3882 SLT_WRN(slot, "%s\n", ss0.str().c_str());
3883 SLT_WRN(slot, "%s\n", ss1.str().c_str());
3884
3885 SLT_WRN(slot, "%s\n", st0.str().c_str());
3886 SLT_WRN(slot, "%s\n", st1.str().c_str());
3887 }
3888
3889 if (pos_min > pos_min_thold) {
3890 // TODO: support can be added in the future when corresponding vision models get released
3891 GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
3892
3893 SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
3894
3895 // search for a context checkpoint
3896 const auto it = std::find_if(
3897 first: slot.prompt.checkpoints.rbegin(),
3898 last: slot.prompt.checkpoints.rend(),
3899 pred: [&](const auto & cur) {
3900 // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3901 return cur.pos_min < pos_min_thold;
3902 }
3903 );
3904
3905 bool do_reset = it == slot.prompt.checkpoints.rend();
3906
3907 if (!do_reset) {
3908 // restore the context checkpoint
3909 const size_t checkpoint_size = it->data.size();
3910 const size_t n = llama_state_seq_set_data_ext(ctx, src: it->data.data(), size: checkpoint_size, dest_seq_id: slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3911
3912 if (n != checkpoint_size) {
3913 SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
3914 do_reset = true;
3915 //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
3916 } else {
3917 n_past = std::min(a: n_past, b: std::max(a: it->pos_min + 1, b: it->pos_max));
3918 SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
3919 }
3920 }
3921
3922 if (do_reset) {
3923 SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
3924 "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
3925 n_past = 0;
3926 }
3927 }
3928 }
3929
3930 {
3931 // erase any checkpoints with pos_min > pos_min_thold
3932 for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
3933 const auto & cur = *it;
3934 if (cur.pos_min > pos_min_thold) {
3935 SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
3936 it = slot.prompt.checkpoints.erase(position: it);
3937 } else {
3938 ++it;
3939 }
3940 }
3941 }
3942 }
3943
3944 // [TAG_PROMPT_LOGITS]
3945 if (n_past == slot.task->n_tokens() && n_past > 0) {
3946 SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
3947 n_past--;
3948 SLT_WRN(slot, "n_past was set to %d\n", n_past);
3949 }
3950
3951 slot.n_prompt_tokens_cache = n_past;
3952 slot.n_prompt_tokens_processed = 0;
3953
3954 slot.prompt.tokens.keep_first(n: n_past);
3955 }
3956
3957 if (!slot.can_split()) {
3958 // cannot fit the prompt in the current batch - will try next iter
3959 if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
3960 continue;
3961 }
3962 }
3963
3964 // truncate any tokens that are beyond n_past for this slot
3965 const llama_pos p0 = slot.prompt.tokens.pos_next();
3966
3967 SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
3968
3969 if (!llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0, p1: -1)) {
3970 SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
3971 llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0: -1, p1: -1);
3972
3973 // there is no common part left
3974 slot.n_prompt_tokens_cache = 0;
3975
3976 slot.prompt.tokens.clear();
3977 }
3978
3979 // check if we should process the image
3980 if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
3981 // process the image
3982 size_t n_tokens_out = 0;
3983 int32_t res = input_tokens.process_chunk(ctx, mctx, idx: slot.prompt.n_tokens(), pos: slot.prompt.tokens.pos_next(), seq_id: slot.id, n_tokens_out);
3984 if (res != 0) {
3985 SLT_ERR(slot, "failed to process image, res = %d\n", res);
3986 send_error(slot, error: "failed to process image", type: ERROR_TYPE_SERVER);
3987 slot.release();
3988 continue;
3989 }
3990
3991 slot.n_prompt_tokens_processed += n_tokens_out;
3992
3993 // add the image chunk to cache
3994 {
3995 const auto & chunk = input_tokens.find_chunk(idx: slot.prompt.n_tokens());
3996 slot.prompt.tokens.push_back(chunk: chunk.get()); // copy
3997 }
3998 }
3999
4000 // If using an alora, there may be uncached tokens that come
4001 // before the invocation sequence. When this happens, the
4002 // tokens before the invocation sequence need to be
4003 // processed without the adapter in a separate batch, then
4004 // the adapter needs to be enabled for the remaining tokens.
4005 if (lora_all_alora(loras: slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
4006 SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
4007 const auto & enabled_loras = lora_get_enabled_ids(loras: slot.lora);
4008 GGML_ASSERT(enabled_loras.size() == 1);
4009 alora_scale = slot.lora[enabled_loras[0]].scale;
4010 slot.lora[enabled_loras[0]].scale = 0.0f;
4011 alora_disabled_id = enabled_loras[0];
4012 }
4013
4014 bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
4015
4016 // make checkpoints only for completion tasks
4017 do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
4018
4019 // make a checkpoint of the parts of the memory that cannot be rolled back.
4020 // checkpoints are created only if:
4021 // - the model uses SWA and we are not using `swa_full`
4022 // - the model architecture is marked as recurrent or hybrid
4023 //
4024 // TODO: try to make this conditional on the context or the memory module, instead of the model type
4025 do_checkpoint = do_checkpoint && (
4026 llama_model_is_recurrent(model) ||
4027 llama_model_is_hybrid(model) ||
4028 (llama_model_n_swa(model) > 0 && !params_base.swa_full)
4029 );
4030
4031 // add prompt tokens for processing in the current batch
4032 while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
4033 // get next token to process
4034 llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
4035 if (cur_tok == LLAMA_TOKEN_NULL) {
4036 break; // end of text chunk
4037 }
4038
4039 // if this is an alora request with pre-invocation
4040 // tokens that are not cached, we need to stop filling
4041 // this batch at those pre-invocation tokens.
4042 if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
4043 SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
4044 break;
4045 }
4046
4047 // embedding requires all tokens in the batch to be output
4048 common_batch_add(batch,
4049 id: cur_tok,
4050 pos: slot.prompt.tokens.pos_next(),
4051 seq_ids: { slot.id },
4052 logits: slot.need_embd());
4053 slot.prompt.tokens.push_back(tok: cur_tok);
4054
4055 slot.n_prompt_tokens_processed++;
4056
4057 // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
4058 if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
4059 break;
4060 }
4061 }
4062
4063 // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
4064
4065 SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
4066
4067 // entire prompt has been processed
4068 if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
4069 slot.state = SLOT_STATE_DONE_PROMPT;
4070
4071 GGML_ASSERT(batch.n_tokens > 0);
4072
4073 common_sampler_reset(gsmpl: slot.smpl);
4074
4075 // Process all prompt tokens through sampler system
4076 for (int i = 0; i < slot.task->n_tokens(); ++i) {
4077 llama_token id = input_tokens[i];
4078 if (id != LLAMA_TOKEN_NULL) {
4079 common_sampler_accept(gsmpl: slot.smpl, token: id, accept_grammar: false);
4080 }
4081 }
4082
4083 // extract the logits only for the last token
4084 batch.logits[batch.n_tokens - 1] = true;
4085
4086 slot.n_decoded = 0;
4087 slot.i_batch = batch.n_tokens - 1;
4088
4089 SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
4090
4091 const auto pos_min = llama_memory_seq_pos_min(mem: llama_get_memory(ctx), seq_id: slot.id);
4092 const auto pos_max = llama_memory_seq_pos_max(mem: llama_get_memory(ctx), seq_id: slot.id);
4093
4094 // no need for empty or small checkpoints
4095 do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
4096
4097 // no need to create checkpoints that are too close together
4098 do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
4099
4100 if (do_checkpoint) {
4101 while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
4102 // make room for the new checkpoint, if needed
4103 const auto & cur = slot.prompt.checkpoints.front();
4104
4105 SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
4106 cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
4107
4108 slot.prompt.checkpoints.erase(position: slot.prompt.checkpoints.begin());
4109 }
4110
4111 const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, seq_id: slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
4112
4113 auto & cur = slot.prompt.checkpoints.emplace_back(args: server_prompt_checkpoint{
4114 /*.pos_min = */ pos_min,
4115 /*.pos_max = */ pos_max,
4116 /*.data = */ std::vector<uint8_t>(checkpoint_size),
4117 });
4118
4119 llama_state_seq_get_data_ext(ctx, dst: cur.data.data(), size: checkpoint_size, seq_id: slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
4120
4121 SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
4122 (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
4123 }
4124 }
4125 }
4126
4127 if (batch.n_tokens >= n_batch) {
4128 break;
4129 }
4130 }
4131 }
4132
4133 if (batch.n_tokens == 0) {
4134 SRV_WRN("%s", "no tokens to decode\n");
4135 return;
4136 }
4137
4138 SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
4139
4140 if (slot_batched) {
4141 // apply lora, only need to do it once per batch
4142 common_set_adapter_lora(ctx, lora&: slot_batched->lora);
4143
4144 // if the lora is temporarily disabled for an alora, re-enable it
4145 // for next time
4146 if (alora_scale > 0.0f) {
4147 SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
4148 slot_batched->lora[alora_disabled_id].scale = alora_scale;
4149 }
4150
4151 llama_set_embeddings(ctx, embeddings: slot_batched->need_embd());
4152 }
4153
4154 int32_t i_next = 0;
4155
4156 // process the created batch of tokens
4157 for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
4158 const int32_t n_tokens = std::min(a: n_batch, b: batch.n_tokens - i);
4159
4160 llama_batch batch_view = {
4161 .n_tokens: n_tokens,
4162 .token: batch.token + i,
4163 .embd: nullptr,
4164 .pos: batch.pos + i,
4165 .n_seq_id: batch.n_seq_id + i,
4166 .seq_id: batch.seq_id + i,
4167 .logits: batch.logits + i,
4168 };
4169
4170 const int ret = llama_decode(ctx, batch: batch_view);
4171
4172 metrics.on_decoded(slots);
4173
4174 if (ret != 0) {
4175 {
4176 std::string err;
4177
4178 if (n_batch == 1 && ret == 1) {
4179 // TODO: try to terminate only the largest active slot/sequence and continue with the rest
4180 // need to remove the tokens from the current batch too
4181 err = "Context size has been exceeded.";
4182 }
4183
4184 if (ret == -1) {
4185 err = "Invalid input batch.";
4186 }
4187
4188 if (ret < -1) {
4189 // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
4190 err = "Compute error.";
4191 }
4192
4193 // TODO: handle ret == 2 (abort) when we start aborting
4194
4195 if (!err.empty()) {
4196 SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
4197
4198 for (auto & slot : slots) {
4199 if (slot.is_processing()) {
4200 send_error(slot, error: err);
4201 slot.release();
4202 }
4203 }
4204
4205 break;
4206 }
4207 }
4208
4209 // retry with half the batch size to try to find a free slot in the KV cache
4210 if (!try_purge_idle_slots()) {
4211 n_batch /= 2;
4212 }
4213
4214 SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
4215
4216 continue; // continue loop of n_batch
4217 }
4218
4219 // move the head of the batch forward with the number of tokens we just processed
4220 i_next = i + n_tokens;
4221
4222 // on successful decode, restore the original batch size
4223 n_batch = llama_n_batch(ctx);
4224
4225 for (auto & slot : slots) {
4226 // optionally send prompt processing progress
4227 if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
4228 if (slot.task->params.stream && slot.task->params.return_progress) {
4229 send_partial_response(slot, tkn: {}, is_progress: true);
4230 }
4231 }
4232
4233 if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
4234 continue; // continue loop of slots
4235 }
4236
4237 if (slot.state == SLOT_STATE_DONE_PROMPT) {
4238 if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) {
4239 // prompt evaluated for embedding
4240 send_embedding(slot, batch: batch_view);
4241 slot.release();
4242 slot.i_batch = -1;
4243 continue; // continue loop of slots
4244 }
4245
4246 if (slot.task->type == SERVER_TASK_TYPE_RERANK) {
4247 send_rerank(slot, batch: batch_view);
4248 slot.release();
4249 slot.i_batch = -1;
4250 continue; // continue loop of slots
4251 }
4252
4253 // prompt evaluated for next-token prediction
4254 slot.state = SLOT_STATE_GENERATING;
4255 } else if (slot.state != SLOT_STATE_GENERATING) {
4256 continue; // continue loop of slots
4257 }
4258
4259 const int tok_idx = slot.i_batch - i;
4260
4261 llama_token id = common_sampler_sample(gsmpl: slot.smpl, ctx, idx: tok_idx);
4262
4263 slot.i_batch = -1;
4264
4265 common_sampler_accept(gsmpl: slot.smpl, token: id, accept_grammar: true);
4266
4267 slot.n_decoded += 1;
4268
4269 const int64_t t_current = ggml_time_us();
4270
4271 if (slot.n_decoded == 1) {
4272 slot.t_start_generation = t_current;
4273 slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
4274 metrics.on_prompt_eval(slot);
4275 }
4276
4277 slot.t_token_generation = std::max<int64_t>(a: 1, b: t_current - slot.t_start_generation) / 1e3;
4278
4279 completion_token_output result;
4280 result.tok = id;
4281 result.text_to_send = common_token_to_piece(ctx, token: result.tok, special: accept_special_token(slot, result.tok));
4282 result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
4283
4284 if (slot.task->params.sampling.n_probs > 0) {
4285 populate_token_probs(slot, result, post_sampling: slot.task->params.post_sampling_probs, special: params_base.special, idx: tok_idx);
4286 }
4287
4288 if (!process_token(result, slot)) {
4289 // release slot because of stop condition
4290 slot.print_timings();
4291 send_final_response(slot);
4292 metrics.on_prediction(slot);
4293 slot.release();
4294
4295 continue;
4296 }
4297 }
4298
4299 // do speculative decoding
4300 // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
4301 // perform the speculative drafting for all sequences at the same time in a single batch
4302 for (auto & slot : slots) {
4303 if (!slot.is_processing() || !slot.can_speculate()) {
4304 continue;
4305 }
4306
4307 if (slot.state != SLOT_STATE_GENERATING) {
4308 continue;
4309 }
4310
4311 if (mctx) {
4312 // we should never reach this, as speculative is automatically disabled if mmproj is loaded
4313 GGML_ABORT("not supported by multimodal");
4314 }
4315
4316 // determine the max draft that fits the current slot state
4317 int n_draft_max = slot.task->params.speculative.n_max;
4318
4319 // note: slot.prompt is not yet expanded with the `id` token sampled above
4320 // also, need to leave space for 1 extra token to allow context shifts
4321 n_draft_max = std::min(a: n_draft_max, b: slot.n_ctx - slot.prompt.n_tokens() - 2);
4322
4323 if (slot.n_remaining > 0) {
4324 n_draft_max = std::min(a: n_draft_max, b: slot.n_remaining - 1);
4325 }
4326
4327 SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
4328
4329 if (n_draft_max < slot.task->params.speculative.n_min) {
4330 SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);
4331
4332 continue;
4333 }
4334
4335 llama_token id = slot.sampled;
4336
4337 struct common_speculative_params params_spec;
4338 params_spec.n_draft = n_draft_max;
4339 params_spec.n_reuse = llama_n_ctx(ctx: slot.ctx_dft) - slot.task->params.speculative.n_max;
4340 params_spec.p_min = slot.task->params.speculative.p_min;
4341
4342 const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
4343 llama_tokens draft = common_speculative_gen_draft(spec: slot.spec, params: params_spec, prompt: cached_text_tokens, id_last: id);
4344
4345 // ignore small drafts
4346 if (slot.task->params.speculative.n_min > (int) draft.size()) {
4347 SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
4348
4349 continue;
4350 }
4351
4352 // keep track of total number of drafted tokens tested
4353 slot.n_draft_total += draft.size();
4354
4355 // construct the speculation batch
4356 common_batch_clear(batch&: slot.batch_spec);
4357 common_batch_add (batch&: slot.batch_spec, id, pos: slot.prompt.tokens.pos_next(), seq_ids: { slot.id }, logits: true);
4358
4359 for (size_t i = 0; i < draft.size(); ++i) {
4360 common_batch_add(batch&: slot.batch_spec, id: draft[i], pos: slot.prompt.tokens.pos_next() + 1 + i, seq_ids: { slot.id }, logits: true);
4361 }
4362
4363 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
4364
4365 llama_decode(ctx, batch: slot.batch_spec);
4366
4367 // the accepted tokens from the speculation
4368 const auto ids = common_sampler_sample_and_accept_n(gsmpl: slot.smpl, ctx, draft);
4369
4370 slot.n_decoded += ids.size();
4371
4372 // update how many tokens out of those tested were accepted
4373 slot.n_draft_accepted += ids.size() - 1;
4374
4375 slot.prompt.tokens.push_back(tok: id);
4376 slot.prompt.tokens.insert(inp_tokens: {ids.begin(), ids.end() - 1});
4377
4378 llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0: slot.prompt.n_tokens(), p1: -1);
4379
4380 for (size_t i = 0; i < ids.size(); ++i) {
4381 completion_token_output result;
4382
4383 result.tok = ids[i];
4384 result.text_to_send = common_token_to_piece(ctx, token: result.tok, special: accept_special_token(slot, result.tok));
4385 result.prob = 1.0f; // set later
4386
4387 // TODO: set result.probs
4388
4389 if (!process_token(result, slot)) {
4390 slot.print_timings();
4391 send_final_response(slot);
4392 metrics.on_prediction(slot);
4393 slot.release();
4394
4395 break;
4396 }
4397 }
4398
4399 SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
4400 }
4401 }
4402
4403 SRV_DBG("%s", "run slots completed\n");
4404 }
4405
4406 json model_meta() const {
4407 return json {
4408 {"vocab_type", llama_vocab_type (vocab)},
4409 {"n_vocab", llama_vocab_n_tokens (vocab)},
4410 {"n_ctx_train", llama_model_n_ctx_train(model)},
4411 {"n_embd", llama_model_n_embd (model)},
4412 {"n_params", llama_model_n_params (model)},
4413 {"size", llama_model_size (model)},
4414 };
4415 }
4416};
4417
4418static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
4419 // skip GH copilot requests when using default port
4420 if (req.path == "/v1/health") {
4421 return;
4422 }
4423
4424 // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
4425
4426 SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
4427
4428 SRV_DBG("request: %s\n", req.body.c_str());
4429 SRV_DBG("response: %s\n", res.body.c_str());
4430}
4431
4432std::function<void(int)> shutdown_handler;
4433std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
4434
4435inline void signal_handler(int signal) {
4436 if (is_terminating.test_and_set()) {
4437 // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
4438 // this is for better developer experience, we can remove when the server is stable enough
4439 fprintf(stderr, format: "Received second interrupt, terminating immediately.\n");
4440 exit(status: 1);
4441 }
4442
4443 shutdown_handler(signal);
4444}
4445
4446int main(int argc, char ** argv) {
4447 // own arguments required by this example
4448 common_params params;
4449
4450 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_SERVER)) {
4451 return 1;
4452 }
4453
4454 // TODO: should we have a separate n_parallel parameter for the server?
4455 // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
4456 // TODO: this is a common configuration that is suitable for most local use cases
4457 // however, overriding the parameters is a bit confusing - figure out something more intuitive
4458 if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) {
4459 LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__);
4460
4461 params.n_parallel = 4;
4462 params.kv_unified = true;
4463 }
4464
4465 common_init();
4466
4467 // struct that contains llama context and inference
4468 server_context ctx_server;
4469
4470 llama_backend_init();
4471 llama_numa_init(numa: params.numa);
4472
4473 LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
4474 LOG_INF("\n");
4475 LOG_INF("%s\n", common_params_get_system_info(params).c_str());
4476 LOG_INF("\n");
4477
4478 std::unique_ptr<httplib::Server> svr;
4479#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
4480 if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
4481 LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
4482 svr.reset(
4483 new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
4484 );
4485 } else {
4486 LOG_INF("Running without SSL\n");
4487 svr.reset(new httplib::Server());
4488 }
4489#else
4490 if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
4491 LOG_ERR("Server is built without SSL support\n");
4492 return 1;
4493 }
4494 svr.reset(p: new httplib::Server());
4495#endif
4496
4497 std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
4498
4499 svr->set_default_headers({{"Server", "llama.cpp"}});
4500 svr->set_logger(log_server_request);
4501
4502 auto res_error = [](httplib::Response & res, const json & error_data) {
4503 json final_response {{"error", error_data}};
4504 res.set_content(s: safe_json_to_str(data: final_response), MIMETYPE_JSON);
4505 res.status = json_value(body: error_data, key: "code", default_value: 500);
4506 };
4507
4508 auto res_ok = [](httplib::Response & res, const json & data) {
4509 res.set_content(s: safe_json_to_str(data), MIMETYPE_JSON);
4510 res.status = 200;
4511 };
4512
4513 svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
4514 std::string message;
4515 try {
4516 std::rethrow_exception(ep);
4517 } catch (const std::exception & e) {
4518 message = e.what();
4519 } catch (...) {
4520 message = "Unknown Exception";
4521 }
4522
4523 try {
4524 json formatted_error = format_error_response(message, type: ERROR_TYPE_SERVER);
4525 LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
4526 res_error(res, formatted_error);
4527 } catch (const std::exception & e) {
4528 LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
4529 }
4530 });
4531
4532 svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
4533 if (res.status == 404) {
4534 res_error(res, format_error_response(message: "File Not Found", type: ERROR_TYPE_NOT_FOUND));
4535 }
4536 // for other error codes, we skip processing here because it's already done by res_error()
4537 });
4538
4539 // set timeouts and change hostname and port
4540 svr->set_read_timeout (sec: params.timeout_read);
4541 svr->set_write_timeout(sec: params.timeout_write);
4542
4543 std::unordered_map<std::string, std::string> log_data;
4544
4545 log_data["hostname"] = params.hostname;
4546 log_data["port"] = std::to_string(val: params.port);
4547
4548 if (params.api_keys.size() == 1) {
4549 auto key = params.api_keys[0];
4550 log_data["api_key"] = "api_key: ****" + key.substr(pos: std::max(a: (int)(key.length() - 4), b: 0));
4551 } else if (params.api_keys.size() > 1) {
4552 log_data["api_key"] = "api_key: " + std::to_string(val: params.api_keys.size()) + " keys loaded";
4553 }
4554
4555 // Necessary similarity of prompt for slot selection
4556 ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
4557
4558 //
4559 // Middlewares
4560 //
4561
4562 auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
4563 static const std::unordered_set<std::string> public_endpoints = {
4564 "/health",
4565 "/v1/health",
4566 "/models",
4567 "/v1/models",
4568 "/api/tags"
4569 };
4570
4571 // If API key is not set, skip validation
4572 if (params.api_keys.empty()) {
4573 return true;
4574 }
4575
4576 // If path is public or is static file, skip validation
4577 if (public_endpoints.find(x: req.path) != public_endpoints.end() || req.path == "/") {
4578 return true;
4579 }
4580
4581 // Check for API key in the header
4582 auto auth_header = req.get_header_value(key: "Authorization");
4583
4584 std::string prefix = "Bearer ";
4585 if (auth_header.substr(pos: 0, n: prefix.size()) == prefix) {
4586 std::string received_api_key = auth_header.substr(pos: prefix.size());
4587 if (std::find(first: params.api_keys.begin(), last: params.api_keys.end(), val: received_api_key) != params.api_keys.end()) {
4588 return true; // API key is valid
4589 }
4590 }
4591
4592 // API key is invalid or not provided
4593 res_error(res, format_error_response(message: "Invalid API Key", type: ERROR_TYPE_AUTHENTICATION));
4594
4595 LOG_WRN("Unauthorized: Invalid API Key\n");
4596
4597 return false;
4598 };
4599
4600 auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
4601 server_state current_state = state.load();
4602 if (current_state == SERVER_STATE_LOADING_MODEL) {
4603 auto tmp = string_split<std::string>(input: req.path, separator: '.');
4604 if (req.path == "/" || tmp.back() == "html") {
4605 res.set_content(s: reinterpret_cast<const char*>(loading_html), n: loading_html_len, content_type: "text/html; charset=utf-8");
4606 res.status = 503;
4607 } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
4608 // allow the models endpoint to be accessed during loading
4609 return true;
4610 } else {
4611 res_error(res, format_error_response(message: "Loading model", type: ERROR_TYPE_UNAVAILABLE));
4612 }
4613 return false;
4614 }
4615 return true;
4616 };
4617
4618 // register server middlewares
4619 svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
4620 res.set_header(key: "Access-Control-Allow-Origin", val: req.get_header_value(key: "Origin"));
4621 // If this is OPTIONS request, skip validation because browsers don't include Authorization header
4622 if (req.method == "OPTIONS") {
4623 res.set_header(key: "Access-Control-Allow-Credentials", val: "true");
4624 res.set_header(key: "Access-Control-Allow-Methods", val: "GET, POST");
4625 res.set_header(key: "Access-Control-Allow-Headers", val: "*");
4626 res.set_content(s: "", content_type: "text/html"); // blank response, no data
4627 return httplib::Server::HandlerResponse::Handled; // skip further processing
4628 }
4629 if (!middleware_server_state(req, res)) {
4630 return httplib::Server::HandlerResponse::Handled;
4631 }
4632 if (!middleware_validate_api_key(req, res)) {
4633 return httplib::Server::HandlerResponse::Handled;
4634 }
4635 return httplib::Server::HandlerResponse::Unhandled;
4636 });
4637
4638 //
4639 // Route handlers (or controllers)
4640 //
4641
4642 const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
4643 // error and loading states are handled by middleware
4644 json health = {{"status", "ok"}};
4645 res_ok(res, health);
4646 };
4647
4648 const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
4649 if (!params.endpoint_slots) {
4650 res_error(res, format_error_response(message: "This server does not support slots endpoint. Start it with `--slots`", type: ERROR_TYPE_NOT_SUPPORTED));
4651 return;
4652 }
4653
4654 // request slots data using task queue
4655 int task_id = ctx_server.queue_tasks.get_new_id();
4656 {
4657 server_task task(SERVER_TASK_TYPE_METRICS);
4658 task.id = task_id;
4659 ctx_server.queue_results.add_waiting_task_id(id_task: task_id);
4660 ctx_server.queue_tasks.post(task: std::move(task), front: true); // high-priority task
4661 }
4662
4663 // get the result
4664 server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id);
4665 ctx_server.queue_results.remove_waiting_task_id(id_task: task_id);
4666
4667 if (result->is_error()) {
4668 res_error(res, result->to_json());
4669 return;
4670 }
4671
4672 // TODO: get rid of this dynamic_cast
4673 auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
4674 GGML_ASSERT(res_task != nullptr);
4675
4676 // optionally return "fail_on_no_slot" error
4677 if (req.has_param(key: "fail_on_no_slot")) {
4678 if (res_task->n_idle_slots == 0) {
4679 res_error(res, format_error_response(message: "no slot available", type: ERROR_TYPE_UNAVAILABLE));
4680 return;
4681 }
4682 }
4683
4684 res_ok(res, res_task->slots_data);
4685 };
4686
4687 const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
4688 if (!params.endpoint_metrics) {
4689 res_error(res, format_error_response(message: "This server does not support metrics endpoint. Start it with `--metrics`", type: ERROR_TYPE_NOT_SUPPORTED));
4690 return;
4691 }
4692
4693 // request slots data using task queue
4694 int task_id = ctx_server.queue_tasks.get_new_id();
4695 {
4696 server_task task(SERVER_TASK_TYPE_METRICS);
4697 task.id = task_id;
4698 ctx_server.queue_results.add_waiting_task_id(id_task: task_id);
4699 ctx_server.queue_tasks.post(task: std::move(task), front: true); // high-priority task
4700 }
4701
4702 // get the result
4703 server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id);
4704 ctx_server.queue_results.remove_waiting_task_id(id_task: task_id);
4705
4706 if (result->is_error()) {
4707 res_error(res, result->to_json());
4708 return;
4709 }
4710
4711 // TODO: get rid of this dynamic_cast
4712 auto res_task = dynamic_cast<server_task_result_metrics*>(result.get());
4713 GGML_ASSERT(res_task != nullptr);
4714
4715 // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
4716 json all_metrics_def = json {
4717 {"counter", {{
4718 {"name", "prompt_tokens_total"},
4719 {"help", "Number of prompt tokens processed."},
4720 {"value", (uint64_t) res_task->n_prompt_tokens_processed_total}
4721 }, {
4722 {"name", "prompt_seconds_total"},
4723 {"help", "Prompt process time"},
4724 {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3}
4725 }, {
4726 {"name", "tokens_predicted_total"},
4727 {"help", "Number of generation tokens processed."},
4728 {"value", (uint64_t) res_task->n_tokens_predicted_total}
4729 }, {
4730 {"name", "tokens_predicted_seconds_total"},
4731 {"help", "Predict process time"},
4732 {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3}
4733 }, {
4734 {"name", "n_decode_total"},
4735 {"help", "Total number of llama_decode() calls"},
4736 {"value", res_task->n_decode_total}
4737 }, {
4738 {"name", "n_tokens_max"},
4739 {"help", "Largest observed n_tokens."},
4740 {"value", res_task->n_tokens_max}
4741 }, {
4742 {"name", "n_busy_slots_per_decode"},
4743 {"help", "Average number of busy slots per llama_decode() call"},
4744 {"value", (float) res_task->n_busy_slots_total / std::max(a: (float) res_task->n_decode_total, b: 1.f)}
4745 }}},
4746 {"gauge", {{
4747 {"name", "prompt_tokens_seconds"},
4748 {"help", "Average prompt throughput in tokens/s."},
4749 {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.}
4750 },{
4751 {"name", "predicted_tokens_seconds"},
4752 {"help", "Average generation throughput in tokens/s."},
4753 {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.}
4754 },{
4755 {"name", "requests_processing"},
4756 {"help", "Number of requests processing."},
4757 {"value", (uint64_t) res_task->n_processing_slots}
4758 },{
4759 {"name", "requests_deferred"},
4760 {"help", "Number of requests deferred."},
4761 {"value", (uint64_t) res_task->n_tasks_deferred}
4762 }}}
4763 };
4764
4765 std::stringstream prometheus;
4766
4767 for (const auto & el : all_metrics_def.items()) {
4768 const auto & type = el.key();
4769 const auto & metrics_def = el.value();
4770
4771 for (const auto & metric_def : metrics_def) {
4772 const std::string name = metric_def.at(key: "name");
4773 const std::string help = metric_def.at(key: "help");
4774
4775 auto value = json_value(body: metric_def, key: "value", default_value: 0.);
4776 prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
4777 << "# TYPE llamacpp:" << name << " " << type << "\n"
4778 << "llamacpp:" << name << " " << value << "\n";
4779 }
4780 }
4781
4782 res.set_header(key: "Process-Start-Time-Unix", val: std::to_string(val: res_task->t_start));
4783
4784 res.set_content(s: prometheus.str(), content_type: "text/plain; version=0.0.4");
4785 res.status = 200; // HTTP OK
4786 };
4787
4788 const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
4789 json request_data = json::parse(i: req.body);
4790 std::string filename = request_data.at(key: "filename");
4791 if (!fs_validate_filename(filename)) {
4792 res_error(res, format_error_response(message: "Invalid filename", type: ERROR_TYPE_INVALID_REQUEST));
4793 return;
4794 }
4795 std::string filepath = params.slot_save_path + filename;
4796
4797 int task_id = ctx_server.queue_tasks.get_new_id();
4798 {
4799 server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
4800 task.id = task_id;
4801 task.slot_action.slot_id = id_slot;
4802 task.slot_action.filename = filename;
4803 task.slot_action.filepath = filepath;
4804
4805 ctx_server.queue_results.add_waiting_task_id(id_task: task_id);
4806 ctx_server.queue_tasks.post(task: std::move(task));
4807 }
4808
4809 server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id);
4810 ctx_server.queue_results.remove_waiting_task_id(id_task: task_id);
4811
4812 if (result->is_error()) {
4813 res_error(res, result->to_json());
4814 return;
4815 }
4816
4817 res_ok(res, result->to_json());
4818 };
4819
4820 const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
4821 json request_data = json::parse(i: req.body);
4822 std::string filename = request_data.at(key: "filename");
4823 if (!fs_validate_filename(filename)) {
4824 res_error(res, format_error_response(message: "Invalid filename", type: ERROR_TYPE_INVALID_REQUEST));
4825 return;
4826 }
4827 std::string filepath = params.slot_save_path + filename;
4828
4829 int task_id = ctx_server.queue_tasks.get_new_id();
4830 {
4831 server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
4832 task.id = task_id;
4833 task.slot_action.slot_id = id_slot;
4834 task.slot_action.filename = filename;
4835 task.slot_action.filepath = filepath;
4836
4837 ctx_server.queue_results.add_waiting_task_id(id_task: task_id);
4838 ctx_server.queue_tasks.post(task: std::move(task));
4839 }
4840
4841 server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id);
4842 ctx_server.queue_results.remove_waiting_task_id(id_task: task_id);
4843
4844 if (result->is_error()) {
4845 res_error(res, result->to_json());
4846 return;
4847 }
4848
4849 GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
4850 res_ok(res, result->to_json());
4851 };
4852
4853 const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
4854 int task_id = ctx_server.queue_tasks.get_new_id();
4855 {
4856 server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
4857 task.id = task_id;
4858 task.slot_action.slot_id = id_slot;
4859
4860 ctx_server.queue_results.add_waiting_task_id(id_task: task_id);
4861 ctx_server.queue_tasks.post(task: std::move(task));
4862 }
4863
4864 server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id);
4865 ctx_server.queue_results.remove_waiting_task_id(id_task: task_id);
4866
4867 if (result->is_error()) {
4868 res_error(res, result->to_json());
4869 return;
4870 }
4871
4872 GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
4873 res_ok(res, result->to_json());
4874 };
4875
4876 const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
4877 if (params.slot_save_path.empty()) {
4878 res_error(res, format_error_response(message: "This server does not support slots action. Start it with `--slot-save-path`", type: ERROR_TYPE_NOT_SUPPORTED));
4879 return;
4880 }
4881
4882 std::string id_slot_str = req.path_params.at(k: "id_slot");
4883 int id_slot;
4884
4885 try {
4886 id_slot = std::stoi(str: id_slot_str);
4887 } catch (const std::exception &) {
4888 res_error(res, format_error_response(message: "Invalid slot ID", type: ERROR_TYPE_INVALID_REQUEST));
4889 return;
4890 }
4891
4892 std::string action = req.get_param_value(key: "action");
4893
4894 if (action == "save") {
4895 handle_slots_save(req, res, id_slot);
4896 } else if (action == "restore") {
4897 handle_slots_restore(req, res, id_slot);
4898 } else if (action == "erase") {
4899 handle_slots_erase(req, res, id_slot);
4900 } else {
4901 res_error(res, format_error_response(message: "Invalid action", type: ERROR_TYPE_INVALID_REQUEST));
4902 }
4903 };
4904
4905 const auto handle_props = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
4906 json default_generation_settings_for_props;
4907
4908 {
4909 slot_params params;
4910
4911 params.sampling = ctx_server.params_base.sampling;
4912
4913 default_generation_settings_for_props = json {
4914 {"params", params.to_json(only_metrics: true)},
4915 {"n_ctx", ctx_server.slots[0].n_ctx},
4916 };
4917 }
4918
4919 // this endpoint is publicly available, please only return what is safe to be exposed
4920 json data = {
4921 { "default_generation_settings", default_generation_settings_for_props },
4922 { "total_slots", ctx_server.params_base.n_parallel },
4923 { "model_alias", ctx_server.params_base.model_alias },
4924 { "model_path", ctx_server.params_base.model.path },
4925 { "modalities", json {
4926 {"vision", ctx_server.oai_parser_opt.allow_image},
4927 {"audio", ctx_server.oai_parser_opt.allow_audio},
4928 } },
4929 { "endpoint_slots", params.endpoint_slots },
4930 { "endpoint_props", params.endpoint_props },
4931 { "endpoint_metrics", params.endpoint_metrics },
4932 { "webui", params.webui },
4933 { "chat_template", common_chat_templates_source(tmpls: ctx_server.chat_templates.get()) },
4934 { "bos_token", common_token_to_piece(ctx: ctx_server.ctx, token: llama_vocab_bos(vocab: ctx_server.vocab), /* special= */ true)},
4935 { "eos_token", common_token_to_piece(ctx: ctx_server.ctx, token: llama_vocab_eos(vocab: ctx_server.vocab), /* special= */ true)},
4936 { "build_info", build_info },
4937 };
4938 if (ctx_server.params_base.use_jinja) {
4939 if (auto tool_use_src = common_chat_templates_source(tmpls: ctx_server.chat_templates.get(), variant: "tool_use")) {
4940 data["chat_template_tool_use"] = tool_use_src;
4941 }
4942 }
4943
4944 res_ok(res, data);
4945 };
4946
4947 const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
4948 if (!ctx_server.params_base.endpoint_props) {
4949 res_error(res, format_error_response(message: "This server does not support changing global properties. Start it with `--props`", type: ERROR_TYPE_NOT_SUPPORTED));
4950 return;
4951 }
4952
4953 json data = json::parse(i: req.body);
4954
4955 // update any props here
4956
4957 res_ok(res, {{ "success", true }});
4958 };
4959
4960 const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
4961 bool has_mtmd = ctx_server.mctx != nullptr;
4962 json data = {
4963 {
4964 "template", common_chat_templates_source(tmpls: ctx_server.chat_templates.get()),
4965 },
4966 {
4967 "model_info", {
4968 { "llama.context_length", ctx_server.slots.back().n_ctx, },
4969 }
4970 },
4971 {"modelfile", ""},
4972 {"parameters", ""},
4973 {"template", common_chat_templates_source(tmpls: ctx_server.chat_templates.get())},
4974 {"details", {
4975 {"parent_model", ""},
4976 {"format", "gguf"},
4977 {"family", ""},
4978 {"families", {""}},
4979 {"parameter_size", ""},
4980 {"quantization_level", ""}
4981 }},
4982 {"model_info", ""},
4983 {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}
4984 };
4985
4986 res_ok(res, data);
4987 };
4988
4989 // handle completion-like requests (completion, chat, infill)
4990 // we can optionally provide a custom format for partial results and final results
4991 const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
4992 server_task_type type,
4993 json & data,
4994 const std::vector<raw_buffer> & files,
4995 const std::function<bool()> & is_connection_closed,
4996 httplib::Response & res,
4997 oaicompat_type oaicompat) -> void {
4998 GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
4999
5000 auto completion_id = gen_chatcmplid();
5001 std::unordered_set<int> task_ids;
5002 try {
5003 std::vector<server_task> tasks;
5004
5005 const auto & prompt = data.at(key: "prompt");
5006 // TODO: this log can become very long, put it behind a flag or think about a more compact format
5007 //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
5008
5009 // process prompt
5010 std::vector<server_tokens> inputs;
5011
5012 if (oaicompat && ctx_server.mctx != nullptr) {
5013 // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
5014 inputs.push_back(x: process_mtmd_prompt(mctx: ctx_server.mctx, prompt: prompt.get<std::string>(), files));
5015 } else {
5016 // Everything else, including multimodal completions.
5017 inputs = tokenize_input_prompts(vocab: ctx_server.vocab, mctx: ctx_server.mctx, json_prompt: prompt, add_special: true, parse_special: true);
5018 }
5019 const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
5020 tasks.reserve(n: inputs.size());
5021 for (size_t i = 0; i < inputs.size(); i++) {
5022 auto n_prompt_tokens = inputs[i].size();
5023 if (n_prompt_tokens >= n_ctx_slot) {
5024 json error_data = format_error_response(message: "the request exceeds the available context size, try increasing it", type: ERROR_TYPE_EXCEED_CONTEXT_SIZE);
5025 error_data["n_prompt_tokens"] = n_prompt_tokens;
5026 error_data["n_ctx"] = n_ctx_slot;
5027 res_error(res, error_data);
5028 return;
5029 }
5030 server_task task = server_task(type);
5031
5032 task.id = ctx_server.queue_tasks.get_new_id();
5033 task.index = i;
5034
5035 task.tokens = std::move(inputs[i]);
5036 task.params = server_task::params_from_json_cmpl(
5037 ctx: ctx_server.ctx,
5038 params_base: ctx_server.params_base,
5039 data);
5040 task.id_slot = json_value(body: data, key: "id_slot", default_value: -1);
5041
5042 // OAI-compat
5043 task.params.oaicompat = oaicompat;
5044 task.params.oaicompat_cmpl_id = completion_id;
5045 // oaicompat_model is already populated by params_from_json_cmpl
5046
5047 tasks.push_back(x: std::move(task));
5048 }
5049
5050 task_ids = server_task::get_list_id(tasks);
5051 ctx_server.queue_results.add_waiting_tasks(tasks);
5052 ctx_server.queue_tasks.post(tasks: std::move(tasks));
5053 } catch (const std::exception & e) {
5054 res_error(res, format_error_response(message: e.what(), type: ERROR_TYPE_INVALID_REQUEST));
5055 return;
5056 }
5057
5058 bool stream = json_value(body: data, key: "stream", default_value: false);
5059
5060 if (!stream) {
5061 ctx_server.receive_multi_results(id_tasks: task_ids, result_handler: [&](std::vector<server_task_result_ptr> & results) {
5062 if (results.size() == 1) {
5063 // single result
5064 res_ok(res, results[0]->to_json());
5065 } else {
5066 // multiple results (multitask)
5067 json arr = json::array();
5068 for (auto & res : results) {
5069 arr.push_back(val: res->to_json());
5070 }
5071 res_ok(res, arr);
5072 }
5073 }, error_handler: [&](const json & error_data) {
5074 res_error(res, error_data);
5075 }, is_connection_closed);
5076
5077 ctx_server.queue_results.remove_waiting_task_ids(id_tasks: task_ids);
5078 } else {
5079 const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
5080 ctx_server.receive_cmpl_results_stream(id_tasks: task_ids, result_handler: [&](server_task_result_ptr & result) -> bool {
5081 json res_json = result->to_json();
5082 if (res_json.is_array()) {
5083 for (const auto & res : res_json) {
5084 if (!server_sent_event(sink, data: res)) {
5085 // sending failed (HTTP connection closed), cancel the generation
5086 return false;
5087 }
5088 }
5089 return true;
5090 } else {
5091 return server_sent_event(sink, data: res_json);
5092 }
5093 }, error_handler: [&](const json & error_data) {
5094 server_sent_event(sink, data: json{{"error", error_data}});
5095 }, is_connection_closed: [&sink]() {
5096 // note: do not use req.is_connection_closed here because req is already destroyed
5097 return !sink.is_writable();
5098 });
5099 if (oaicompat != OAICOMPAT_TYPE_NONE) {
5100 static const std::string ev_done = "data: [DONE]\n\n";
5101 sink.write(ev_done.data(), ev_done.size());
5102 }
5103 sink.done();
5104 return false;
5105 };
5106
5107 auto on_complete = [task_ids, &ctx_server] (bool) {
5108 ctx_server.queue_results.remove_waiting_task_ids(id_tasks: task_ids);
5109 };
5110
5111 res.set_chunked_content_provider(content_type: "text/event-stream", provider: chunked_content_provider, resource_releaser: on_complete);
5112 }
5113 };
5114
5115 const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
5116 json data = json::parse(i: req.body);
5117 std::vector<raw_buffer> files; // dummy
5118 handle_completions_impl(
5119 SERVER_TASK_TYPE_COMPLETION,
5120 data,
5121 files,
5122 req.is_connection_closed,
5123 res,
5124 OAICOMPAT_TYPE_NONE);
5125 };
5126
5127 const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
5128 json data = oaicompat_completion_params_parse(body: json::parse(i: req.body));
5129 std::vector<raw_buffer> files; // dummy
5130 handle_completions_impl(
5131 SERVER_TASK_TYPE_COMPLETION,
5132 data,
5133 files,
5134 req.is_connection_closed,
5135 res,
5136 OAICOMPAT_TYPE_COMPLETION);
5137 };
5138
5139 const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
5140 // check model compatibility
5141 std::string err;
5142 if (llama_vocab_fim_pre(vocab: ctx_server.vocab) == LLAMA_TOKEN_NULL) {
5143 err += "prefix token is missing. ";
5144 }
5145 if (llama_vocab_fim_suf(vocab: ctx_server.vocab) == LLAMA_TOKEN_NULL) {
5146 err += "suffix token is missing. ";
5147 }
5148 if (llama_vocab_fim_mid(vocab: ctx_server.vocab) == LLAMA_TOKEN_NULL) {
5149 err += "middle token is missing. ";
5150 }
5151 if (!err.empty()) {
5152 res_error(res, format_error_response(message: string_format(fmt: "Infill is not supported by this model: %s", err.c_str()), type: ERROR_TYPE_NOT_SUPPORTED));
5153 return;
5154 }
5155
5156 json data = json::parse(i: req.body);
5157
5158 // validate input
5159 if (data.contains(key: "prompt") && !data.at(key: "prompt").is_string()) {
5160 // prompt is optional
5161 res_error(res, format_error_response(message: "\"prompt\" must be a string", type: ERROR_TYPE_INVALID_REQUEST));
5162 }
5163
5164 if (!data.contains(key: "input_prefix")) {
5165 res_error(res, format_error_response(message: "\"input_prefix\" is required", type: ERROR_TYPE_INVALID_REQUEST));
5166 }
5167
5168 if (!data.contains(key: "input_suffix")) {
5169 res_error(res, format_error_response(message: "\"input_suffix\" is required", type: ERROR_TYPE_INVALID_REQUEST));
5170 }
5171
5172 if (data.contains(key: "input_extra") && !data.at(key: "input_extra").is_array()) {
5173 // input_extra is optional
5174 res_error(res, format_error_response(message: "\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", type: ERROR_TYPE_INVALID_REQUEST));
5175 return;
5176 }
5177
5178 json input_extra = json_value(body: data, key: "input_extra", default_value: json::array());
5179 for (const auto & chunk : input_extra) {
5180 // { "text": string, "filename": string }
5181 if (!chunk.contains(key: "text") || !chunk.at(key: "text").is_string()) {
5182 res_error(res, format_error_response(message: "extra_context chunk must contain a \"text\" field with a string value", type: ERROR_TYPE_INVALID_REQUEST));
5183 return;
5184 }
5185 // filename is optional
5186 if (chunk.contains(key: "filename") && !chunk.at(key: "filename").is_string()) {
5187 res_error(res, format_error_response(message: "extra_context chunk's \"filename\" field must be a string", type: ERROR_TYPE_INVALID_REQUEST));
5188 return;
5189 }
5190 }
5191 data["input_extra"] = input_extra; // default to empty array if it's not exist
5192
5193 std::string prompt = json_value(body: data, key: "prompt", default_value: std::string());
5194 std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(vocab: ctx_server.vocab, mctx: ctx_server.mctx, json_prompt: prompt, add_special: false, parse_special: true);
5195 SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
5196 data["prompt"] = format_infill(
5197 vocab: ctx_server.vocab,
5198 input_prefix: data.at(key: "input_prefix"),
5199 input_suffix: data.at(key: "input_suffix"),
5200 input_extra: data.at(key: "input_extra"),
5201 n_batch: ctx_server.params_base.n_batch,
5202 n_predict: ctx_server.params_base.n_predict,
5203 n_ctx: ctx_server.slots[0].n_ctx, // TODO: there should be a better way
5204 spm_infill: ctx_server.params_base.spm_infill,
5205 tokens_prompt: tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
5206 );
5207
5208 std::vector<raw_buffer> files; // dummy
5209 handle_completions_impl(
5210 SERVER_TASK_TYPE_INFILL,
5211 data,
5212 files,
5213 req.is_connection_closed,
5214 res,
5215 OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
5216 };
5217
5218 const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
5219 LOG_DBG("request: %s\n", req.body.c_str());
5220
5221 auto body = json::parse(i: req.body);
5222 std::vector<raw_buffer> files;
5223 json data = oaicompat_chat_params_parse(
5224 body,
5225 opt: ctx_server.oai_parser_opt,
5226 out_files&: files);
5227
5228 handle_completions_impl(
5229 SERVER_TASK_TYPE_COMPLETION,
5230 data,
5231 files,
5232 req.is_connection_closed,
5233 res,
5234 OAICOMPAT_TYPE_CHAT);
5235 };
5236
5237 // same with handle_chat_completions, but without inference part
5238 const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
5239 auto body = json::parse(i: req.body);
5240 std::vector<raw_buffer> files; // dummy, unused
5241 json data = oaicompat_chat_params_parse(
5242 body,
5243 opt: ctx_server.oai_parser_opt,
5244 out_files&: files);
5245 res_ok(res, {{ "prompt", std::move(data.at(key: "prompt")) }});
5246 };
5247
5248 const auto handle_models = [&params, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) {
5249 server_state current_state = state.load();
5250 json model_meta = nullptr;
5251 if (current_state == SERVER_STATE_READY) {
5252 model_meta = ctx_server.model_meta();
5253 }
5254 bool has_mtmd = ctx_server.mctx != nullptr;
5255 json models = {
5256 {"models", {
5257 {
5258 {"name", params.model_alias.empty() ? params.model.path : params.model_alias},
5259 {"model", params.model_alias.empty() ? params.model.path : params.model_alias},
5260 {"modified_at", ""},
5261 {"size", ""},
5262 {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash
5263 {"type", "model"},
5264 {"description", ""},
5265 {"tags", {""}},
5266 {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})},
5267 {"parameters", ""},
5268 {"details", {
5269 {"parent_model", ""},
5270 {"format", "gguf"},
5271 {"family", ""},
5272 {"families", {""}},
5273 {"parameter_size", ""},
5274 {"quantization_level", ""}
5275 }}
5276 }
5277 }},
5278 {"object", "list"},
5279 {"data", {
5280 {
5281 {"id", params.model_alias.empty() ? params.model.path : params.model_alias},
5282 {"object", "model"},
5283 {"created", std::time(timer: 0)},
5284 {"owned_by", "llamacpp"},
5285 {"meta", model_meta},
5286 },
5287 }}
5288 };
5289
5290 res_ok(res, models);
5291 };
5292
5293 const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
5294 const json body = json::parse(i: req.body);
5295
5296 json tokens_response = json::array();
5297 if (body.count(key: "content") != 0) {
5298 const bool add_special = json_value(body, key: "add_special", default_value: false);
5299 const bool parse_special = json_value(body, key: "parse_special", default_value: true);
5300 const bool with_pieces = json_value(body, key: "with_pieces", default_value: false);
5301
5302 llama_tokens tokens = tokenize_mixed(vocab: ctx_server.vocab, json_prompt: body.at(key: "content"), add_special, parse_special);
5303
5304 if (with_pieces) {
5305 for (const auto& token : tokens) {
5306 std::string piece = common_token_to_piece(ctx: ctx_server.ctx, token);
5307 json piece_json;
5308
5309 // Check if the piece is valid UTF-8
5310 if (is_valid_utf8(str: piece)) {
5311 piece_json = piece;
5312 } else {
5313 // If not valid UTF-8, store as array of byte values
5314 piece_json = json::array();
5315 for (unsigned char c : piece) {
5316 piece_json.push_back(val: static_cast<int>(c));
5317 }
5318 }
5319
5320 tokens_response.push_back(init: {
5321 {"id", token},
5322 {"piece", piece_json}
5323 });
5324 }
5325 } else {
5326 tokens_response = tokens;
5327 }
5328 }
5329
5330 const json data = format_tokenizer_response(tokens: tokens_response);
5331 res_ok(res, data);
5332 };
5333
5334 const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
5335 const json body = json::parse(i: req.body);
5336
5337 std::string content;
5338 if (body.count(key: "tokens") != 0) {
5339 const llama_tokens tokens = body.at(key: "tokens");
5340 content = tokens_to_str(ctx: ctx_server.ctx, begin: tokens.cbegin(), end: tokens.cend());
5341 }
5342
5343 const json data = format_detokenized_response(content);
5344 res_ok(res, data);
5345 };
5346
5347 const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
5348 if (!ctx_server.params_base.embedding) {
5349 res_error(res, format_error_response(message: "This server does not support embeddings. Start it with `--embeddings`", type: ERROR_TYPE_NOT_SUPPORTED));
5350 return;
5351 }
5352
5353 if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx: ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
5354 res_error(res, format_error_response(message: "Pooling type 'none' is not OAI compatible. Please use a different pooling type", type: ERROR_TYPE_INVALID_REQUEST));
5355 return;
5356 }
5357
5358 const json body = json::parse(i: req.body);
5359
5360 // for the shape of input/content, see tokenize_input_prompts()
5361 json prompt;
5362 if (body.count(key: "input") != 0) {
5363 prompt = body.at(key: "input");
5364 } else if (body.contains(key: "content")) {
5365 oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
5366 prompt = body.at(key: "content");
5367 } else {
5368 res_error(res, format_error_response(message: "\"input\" or \"content\" must be provided", type: ERROR_TYPE_INVALID_REQUEST));
5369 return;
5370 }
5371
5372 bool use_base64 = false;
5373 if (body.count(key: "encoding_format") != 0) {
5374 const std::string& format = body.at(key: "encoding_format");
5375 if (format == "base64") {
5376 use_base64 = true;
5377 } else if (format != "float") {
5378 res_error(res, format_error_response(message: "The format to return the embeddings in. Can be either float or base64", type: ERROR_TYPE_INVALID_REQUEST));
5379 return;
5380 }
5381 }
5382
5383 auto tokenized_prompts = tokenize_input_prompts(vocab: ctx_server.vocab, mctx: ctx_server.mctx, json_prompt: prompt, add_special: true, parse_special: true);
5384 for (const auto & tokens : tokenized_prompts) {
5385 // this check is necessary for models that do not add BOS token to the input
5386 if (tokens.empty()) {
5387 res_error(res, format_error_response(message: "Input content cannot be empty", type: ERROR_TYPE_INVALID_REQUEST));
5388 return;
5389 }
5390 }
5391
5392 int embd_normalize = 2; // default to Euclidean/L2 norm
5393 if (body.count(key: "embd_normalize") != 0) {
5394 embd_normalize = body.at(key: "embd_normalize");
5395 if (llama_pooling_type(ctx: ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
5396 SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx));
5397 }
5398 }
5399
5400 // create and queue the task
5401 json responses = json::array();
5402 bool error = false;
5403 std::unordered_set<int> task_ids;
5404 {
5405 std::vector<server_task> tasks;
5406 for (size_t i = 0; i < tokenized_prompts.size(); i++) {
5407 server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
5408
5409 task.id = ctx_server.queue_tasks.get_new_id();
5410 task.index = i;
5411 task.tokens = std::move(tokenized_prompts[i]);
5412
5413 // OAI-compat
5414 task.params.oaicompat = oaicompat;
5415 task.params.embd_normalize = embd_normalize;
5416
5417 tasks.push_back(x: std::move(task));
5418 }
5419
5420 task_ids = server_task::get_list_id(tasks);
5421 ctx_server.queue_results.add_waiting_tasks(tasks);
5422 ctx_server.queue_tasks.post(tasks: std::move(tasks));
5423 }
5424
5425 // get the result
5426 ctx_server.receive_multi_results(id_tasks: task_ids, result_handler: [&](std::vector<server_task_result_ptr> & results) {
5427 for (auto & res : results) {
5428 GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
5429 responses.push_back(val: res->to_json());
5430 }
5431 }, error_handler: [&](const json & error_data) {
5432 res_error(res, error_data);
5433 error = true;
5434 }, is_connection_closed: req.is_connection_closed);
5435
5436 ctx_server.queue_results.remove_waiting_task_ids(id_tasks: task_ids);
5437
5438 if (error) {
5439 return;
5440 }
5441
5442 // write JSON response
5443 json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
5444 ? format_embeddings_response_oaicompat(request: body, embeddings: responses, use_base64)
5445 : json(responses);
5446 res_ok(res, root);
5447 };
5448
5449 const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
5450 handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
5451 };
5452
5453 const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
5454 handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
5455 };
5456
5457 const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
5458 if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
5459 res_error(res, format_error_response(message: "This server does not support reranking. Start it with `--reranking`", type: ERROR_TYPE_NOT_SUPPORTED));
5460 return;
5461 }
5462
5463 const json body = json::parse(i: req.body);
5464
5465 // if true, use TEI API format, otherwise use Jina API format
5466 // Jina: https://jina.ai/reranker/
5467 // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
5468 bool is_tei_format = body.contains(key: "texts");
5469
5470 json query;
5471 if (body.count(key: "query") == 1) {
5472 query = body.at(key: "query");
5473 if (!query.is_string()) {
5474 res_error(res, format_error_response(message: "\"query\" must be a string", type: ERROR_TYPE_INVALID_REQUEST));
5475 return;
5476 }
5477 } else {
5478 res_error(res, format_error_response(message: "\"query\" must be provided", type: ERROR_TYPE_INVALID_REQUEST));
5479 return;
5480 }
5481
5482 std::vector<std::string> documents = json_value(body, key: "documents",
5483 default_value: json_value(body, key: "texts", default_value: std::vector<std::string>()));
5484 if (documents.empty()) {
5485 res_error(res, format_error_response(message: "\"documents\" must be a non-empty string array", type: ERROR_TYPE_INVALID_REQUEST));
5486 return;
5487 }
5488
5489 int top_n = json_value(body, key: "top_n", default_value: (int)documents.size());
5490
5491 // create and queue the task
5492 json responses = json::array();
5493 bool error = false;
5494 std::unordered_set<int> task_ids;
5495 {
5496 std::vector<server_task> tasks;
5497 tasks.reserve(n: documents.size());
5498 for (size_t i = 0; i < documents.size(); i++) {
5499 auto tmp = format_rerank(model: ctx_server.model, vocab: ctx_server.vocab, mctx: ctx_server.mctx, query, doc: documents[i]);
5500 server_task task = server_task(SERVER_TASK_TYPE_RERANK);
5501 task.id = ctx_server.queue_tasks.get_new_id();
5502 task.index = i;
5503 task.tokens = std::move(tmp);
5504 tasks.push_back(x: std::move(task));
5505 }
5506
5507 task_ids = server_task::get_list_id(tasks);
5508 ctx_server.queue_results.add_waiting_tasks(tasks);
5509 ctx_server.queue_tasks.post(tasks: std::move(tasks));
5510 }
5511
5512 ctx_server.receive_multi_results(id_tasks: task_ids, result_handler: [&](std::vector<server_task_result_ptr> & results) {
5513 for (auto & res : results) {
5514 GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
5515 responses.push_back(val: res->to_json());
5516 }
5517 }, error_handler: [&](const json & error_data) {
5518 res_error(res, error_data);
5519 error = true;
5520 }, is_connection_closed: req.is_connection_closed);
5521
5522 if (error) {
5523 return;
5524 }
5525
5526 // write JSON response
5527 json root = format_response_rerank(
5528 request: body,
5529 ranks: responses,
5530 is_tei_format,
5531 texts&: documents,
5532 top_n);
5533
5534 res_ok(res, root);
5535 };
5536
5537 const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
5538 json result = json::array();
5539 const auto & loras = ctx_server.params_base.lora_adapters;
5540 for (size_t i = 0; i < loras.size(); ++i) {
5541 auto & lora = loras[i];
5542 json entry = {
5543 {"id", i},
5544 {"path", lora.path},
5545 {"scale", lora.scale},
5546 {"task_name", lora.task_name},
5547 {"prompt_prefix", lora.prompt_prefix},
5548 };
5549 std::string alora_invocation_string = "";
5550 const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(adapter: lora.ptr);
5551 std::vector<llama_token> alora_invocation_tokens;
5552 if (n_alora_tokens) {
5553 const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(adapter: lora.ptr);
5554 for (uint64_t i = 0; i < n_alora_tokens; ++i) {
5555 alora_invocation_string += common_token_to_piece(ctx: ctx_server.ctx, token: alora_tokens[i]);
5556 alora_invocation_tokens.push_back(x: alora_tokens[i]);
5557 }
5558 entry["alora_invocation_string"] = alora_invocation_string;
5559 entry["alora_invocation_tokens"] = alora_invocation_tokens;
5560 }
5561 result.push_back(val: std::move(entry));
5562 }
5563 res_ok(res, result);
5564 res.status = 200; // HTTP OK
5565 };
5566
5567 const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
5568 const json body = json::parse(i: req.body);
5569 if (!body.is_array()) {
5570 res_error(res, format_error_response(message: "Request body must be an array", type: ERROR_TYPE_INVALID_REQUEST));
5571 return;
5572 }
5573
5574 int task_id = ctx_server.queue_tasks.get_new_id();
5575 {
5576 server_task task(SERVER_TASK_TYPE_SET_LORA);
5577 task.id = task_id;
5578 task.set_lora = parse_lora_request(lora_base: ctx_server.params_base.lora_adapters, data: body);
5579 ctx_server.queue_results.add_waiting_task_id(id_task: task_id);
5580 ctx_server.queue_tasks.post(task: std::move(task));
5581 }
5582
5583 // get the result
5584 server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id);
5585 ctx_server.queue_results.remove_waiting_task_id(id_task: task_id);
5586
5587 if (result->is_error()) {
5588 res_error(res, result->to_json());
5589 return;
5590 }
5591
5592 GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
5593 res_ok(res, result->to_json());
5594 };
5595
5596 //
5597 // Router
5598 //
5599
5600 if (!params.webui) {
5601 LOG_INF("Web UI is disabled\n");
5602 } else {
5603 // register static assets routes
5604 if (!params.public_path.empty()) {
5605 // Set the base directory for serving static files
5606 bool is_found = svr->set_mount_point(mount_point: params.api_prefix + "/", dir: params.public_path);
5607 if (!is_found) {
5608 LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
5609 return 1;
5610 }
5611 } else {
5612 // using embedded static index.html
5613 svr->Get(pattern: params.api_prefix + "/", handler: [](const httplib::Request & req, httplib::Response & res) {
5614 if (req.get_header_value(key: "Accept-Encoding").find(s: "gzip") == std::string::npos) {
5615 res.set_content(s: "Error: gzip is not supported by this browser", content_type: "text/plain");
5616 } else {
5617 res.set_header(key: "Content-Encoding", val: "gzip");
5618 // COEP and COOP headers, required by pyodide (python interpreter)
5619 res.set_header(key: "Cross-Origin-Embedder-Policy", val: "require-corp");
5620 res.set_header(key: "Cross-Origin-Opener-Policy", val: "same-origin");
5621 res.set_content(s: reinterpret_cast<const char*>(index_html_gz), n: index_html_gz_len, content_type: "text/html; charset=utf-8");
5622 }
5623 return false;
5624 });
5625 }
5626 }
5627
5628 // register API routes
5629 svr->Get (pattern: params.api_prefix + "/health", handler: handle_health); // public endpoint (no API key check)
5630 svr->Get (pattern: params.api_prefix + "/v1/health", handler: handle_health); // public endpoint (no API key check)
5631 svr->Get (pattern: params.api_prefix + "/metrics", handler: handle_metrics);
5632 svr->Get (pattern: params.api_prefix + "/props", handler: handle_props);
5633 svr->Post(pattern: params.api_prefix + "/props", handler: handle_props_change);
5634 svr->Post(pattern: params.api_prefix + "/api/show", handler: handle_api_show);
5635 svr->Get (pattern: params.api_prefix + "/models", handler: handle_models); // public endpoint (no API key check)
5636 svr->Get (pattern: params.api_prefix + "/v1/models", handler: handle_models); // public endpoint (no API key check)
5637 svr->Get (pattern: params.api_prefix + "/api/tags", handler: handle_models); // ollama specific endpoint. public endpoint (no API key check)
5638 svr->Post(pattern: params.api_prefix + "/completion", handler: handle_completions); // legacy
5639 svr->Post(pattern: params.api_prefix + "/completions", handler: handle_completions);
5640 svr->Post(pattern: params.api_prefix + "/v1/completions", handler: handle_completions_oai);
5641 svr->Post(pattern: params.api_prefix + "/chat/completions", handler: handle_chat_completions);
5642 svr->Post(pattern: params.api_prefix + "/v1/chat/completions", handler: handle_chat_completions);
5643 svr->Post(pattern: params.api_prefix + "/api/chat", handler: handle_chat_completions); // ollama specific endpoint
5644 svr->Post(pattern: params.api_prefix + "/infill", handler: handle_infill);
5645 svr->Post(pattern: params.api_prefix + "/embedding", handler: handle_embeddings); // legacy
5646 svr->Post(pattern: params.api_prefix + "/embeddings", handler: handle_embeddings);
5647 svr->Post(pattern: params.api_prefix + "/v1/embeddings", handler: handle_embeddings_oai);
5648 svr->Post(pattern: params.api_prefix + "/rerank", handler: handle_rerank);
5649 svr->Post(pattern: params.api_prefix + "/reranking", handler: handle_rerank);
5650 svr->Post(pattern: params.api_prefix + "/v1/rerank", handler: handle_rerank);
5651 svr->Post(pattern: params.api_prefix + "/v1/reranking", handler: handle_rerank);
5652 svr->Post(pattern: params.api_prefix + "/tokenize", handler: handle_tokenize);
5653 svr->Post(pattern: params.api_prefix + "/detokenize", handler: handle_detokenize);
5654 svr->Post(pattern: params.api_prefix + "/apply-template", handler: handle_apply_template);
5655 // LoRA adapters hotswap
5656 svr->Get (pattern: params.api_prefix + "/lora-adapters", handler: handle_lora_adapters_list);
5657 svr->Post(pattern: params.api_prefix + "/lora-adapters", handler: handle_lora_adapters_apply);
5658 // Save & load slots
5659 svr->Get (pattern: params.api_prefix + "/slots", handler: handle_slots);
5660 svr->Post(pattern: params.api_prefix + "/slots/:id_slot", handler: handle_slots_action);
5661
5662 //
5663 // Start the server
5664 //
5665 if (params.n_threads_http < 1) {
5666 // +2 threads for monitoring endpoints
5667 params.n_threads_http = std::max(a: params.n_parallel + 2, b: (int32_t) std::thread::hardware_concurrency() - 1);
5668 }
5669 log_data["n_threads_http"] = std::to_string(val: params.n_threads_http);
5670 svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); };
5671
5672 // clean up function, to be called before exit
5673 auto clean_up = [&svr, &ctx_server]() {
5674 SRV_INF("%s: cleaning up before exit...\n", __func__);
5675 svr->stop();
5676 ctx_server.queue_results.terminate();
5677 llama_backend_free();
5678 };
5679
5680 bool was_bound = false;
5681 bool is_sock = false;
5682 if (string_ends_with(str: std::string(params.hostname), suffix: ".sock")) {
5683 is_sock = true;
5684 LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
5685 svr->set_address_family(AF_UNIX);
5686 // bind_to_port requires a second arg, any value other than 0 should
5687 // simply get ignored
5688 was_bound = svr->bind_to_port(host: params.hostname, port: 8080);
5689 } else {
5690 LOG_INF("%s: binding port with default address family\n", __func__);
5691 // bind HTTP listen port
5692 if (params.port == 0) {
5693 int bound_port = svr->bind_to_any_port(host: params.hostname);
5694 if ((was_bound = (bound_port >= 0))) {
5695 params.port = bound_port;
5696 }
5697 } else {
5698 was_bound = svr->bind_to_port(host: params.hostname, port: params.port);
5699 }
5700 }
5701
5702 if (!was_bound) {
5703 LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
5704 clean_up();
5705 return 1;
5706 }
5707
5708 // run the HTTP server in a thread
5709 std::thread t([&]() { svr->listen_after_bind(); });
5710 svr->wait_until_ready();
5711
5712 LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http);
5713
5714 // load the model
5715 LOG_INF("%s: loading model\n", __func__);
5716
5717 if (!ctx_server.load_model(params)) {
5718 clean_up();
5719 t.join();
5720 LOG_ERR("%s: exiting due to model loading error\n", __func__);
5721 return 1;
5722 }
5723
5724 ctx_server.init();
5725 state.store(i: SERVER_STATE_READY);
5726
5727 LOG_INF("%s: model loaded\n", __func__);
5728
5729 // print sample chat example to make it clear which template is used
5730 LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
5731 common_chat_templates_source(ctx_server.chat_templates.get()),
5732 common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, ctx_server.params_base.default_template_kwargs).c_str());
5733
5734 ctx_server.queue_tasks.on_new_task(callback: [&ctx_server](server_task && task) {
5735 ctx_server.process_single_task(task: std::move(task));
5736 });
5737
5738 ctx_server.queue_tasks.on_update_slots(callback: [&ctx_server]() {
5739 ctx_server.update_slots();
5740 });
5741
5742 shutdown_handler = [&](int) {
5743 // this will unblock start_loop()
5744 ctx_server.queue_tasks.terminate();
5745 };
5746
5747#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
5748 struct sigaction sigint_action;
5749 sigint_action.sa_handler = signal_handler;
5750 sigemptyset (set: &sigint_action.sa_mask);
5751 sigint_action.sa_flags = 0;
5752 sigaction(SIGINT, act: &sigint_action, NULL);
5753 sigaction(SIGTERM, act: &sigint_action, NULL);
5754#elif defined (_WIN32)
5755 auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
5756 return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
5757 };
5758 SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
5759#endif
5760
5761 LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__,
5762 is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() :
5763 string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str());
5764
5765 // this call blocks the main thread until queue_tasks.terminate() is called
5766 ctx_server.queue_tasks.start_loop();
5767
5768 clean_up();
5769 t.join();
5770 llama_memory_breakdown_print(ctx: ctx_server.ctx);
5771
5772 return 0;
5773}
5774