| 1 | #include "chat.h" |
| 2 | #include "utils.hpp" |
| 3 | |
| 4 | #include "arg.h" |
| 5 | #include "common.h" |
| 6 | #include "json-schema-to-grammar.h" |
| 7 | #include "llama.h" |
| 8 | #include "log.h" |
| 9 | #include "sampling.h" |
| 10 | #include "speculative.h" |
| 11 | #include "mtmd.h" |
| 12 | |
| 13 | // mime type for sending response |
| 14 | #define MIMETYPE_JSON "application/json; charset=utf-8" |
| 15 | |
| 16 | // auto generated files (see README.md for details) |
| 17 | #include "index.html.gz.hpp" |
| 18 | #include "loading.html.hpp" |
| 19 | |
| 20 | #include <atomic> |
| 21 | #include <chrono> |
| 22 | #include <condition_variable> |
| 23 | #include <cstddef> |
| 24 | #include <cinttypes> |
| 25 | #include <deque> |
| 26 | #include <memory> |
| 27 | #include <mutex> |
| 28 | #include <signal.h> |
| 29 | #include <thread> |
| 30 | #include <unordered_map> |
| 31 | #include <unordered_set> |
| 32 | |
| 33 | using json = nlohmann::ordered_json; |
| 34 | |
| 35 | constexpr int HTTP_POLLING_SECONDS = 1; |
| 36 | |
| 37 | enum stop_type { |
| 38 | STOP_TYPE_NONE, |
| 39 | STOP_TYPE_EOS, |
| 40 | STOP_TYPE_WORD, |
| 41 | STOP_TYPE_LIMIT, |
| 42 | }; |
| 43 | |
| 44 | // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 |
| 45 | enum slot_state { |
| 46 | SLOT_STATE_IDLE, |
| 47 | SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future |
| 48 | SLOT_STATE_PROCESSING_PROMPT, |
| 49 | SLOT_STATE_DONE_PROMPT, |
| 50 | SLOT_STATE_GENERATING, |
| 51 | }; |
| 52 | |
| 53 | enum server_state { |
| 54 | SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet |
| 55 | SERVER_STATE_READY, // Server is ready and model is loaded |
| 56 | }; |
| 57 | |
| 58 | enum server_task_type { |
| 59 | SERVER_TASK_TYPE_COMPLETION, |
| 60 | SERVER_TASK_TYPE_EMBEDDING, |
| 61 | SERVER_TASK_TYPE_RERANK, |
| 62 | SERVER_TASK_TYPE_INFILL, |
| 63 | SERVER_TASK_TYPE_CANCEL, |
| 64 | SERVER_TASK_TYPE_NEXT_RESPONSE, |
| 65 | SERVER_TASK_TYPE_METRICS, |
| 66 | SERVER_TASK_TYPE_SLOT_SAVE, |
| 67 | SERVER_TASK_TYPE_SLOT_RESTORE, |
| 68 | SERVER_TASK_TYPE_SLOT_ERASE, |
| 69 | SERVER_TASK_TYPE_SET_LORA, |
| 70 | }; |
| 71 | |
| 72 | enum oaicompat_type { |
| 73 | OAICOMPAT_TYPE_NONE, |
| 74 | OAICOMPAT_TYPE_CHAT, |
| 75 | OAICOMPAT_TYPE_COMPLETION, |
| 76 | OAICOMPAT_TYPE_EMBEDDING, |
| 77 | }; |
| 78 | |
| 79 | // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 |
| 80 | enum error_type { |
| 81 | ERROR_TYPE_INVALID_REQUEST, |
| 82 | ERROR_TYPE_AUTHENTICATION, |
| 83 | ERROR_TYPE_SERVER, |
| 84 | ERROR_TYPE_NOT_FOUND, |
| 85 | ERROR_TYPE_PERMISSION, |
| 86 | ERROR_TYPE_UNAVAILABLE, // custom error |
| 87 | ERROR_TYPE_NOT_SUPPORTED, // custom error |
| 88 | ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error |
| 89 | }; |
| 90 | |
| 91 | static bool server_task_type_need_embd(server_task_type task_type) { |
| 92 | switch (task_type) { |
| 93 | case SERVER_TASK_TYPE_EMBEDDING: |
| 94 | case SERVER_TASK_TYPE_RERANK: |
| 95 | return true; |
| 96 | default: |
| 97 | return false; |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | static bool server_task_type_need_logits(server_task_type task_type) { |
| 102 | switch (task_type) { |
| 103 | case SERVER_TASK_TYPE_COMPLETION: |
| 104 | case SERVER_TASK_TYPE_INFILL: |
| 105 | return true; |
| 106 | default: |
| 107 | return false; |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | struct slot_params { |
| 112 | bool stream = true; |
| 113 | bool include_usage = false; |
| 114 | bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt |
| 115 | bool return_tokens = false; |
| 116 | bool return_progress = false; |
| 117 | |
| 118 | int32_t n_keep = 0; // number of tokens to keep from initial prompt |
| 119 | int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half |
| 120 | int32_t n_predict = -1; // new tokens to predict |
| 121 | int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters |
| 122 | |
| 123 | int64_t t_max_prompt_ms = -1; // TODO: implement |
| 124 | int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit |
| 125 | |
| 126 | std::vector<common_adapter_lora_info> lora; |
| 127 | |
| 128 | std::vector<std::string> antiprompt; |
| 129 | std::vector<std::string> response_fields; |
| 130 | bool timings_per_token = false; |
| 131 | bool post_sampling_probs = false; |
| 132 | |
| 133 | struct common_params_sampling sampling; |
| 134 | struct common_params_speculative speculative; |
| 135 | |
| 136 | // OAI-compat fields |
| 137 | bool verbose = false; |
| 138 | oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; |
| 139 | std::string oaicompat_model; |
| 140 | std::string oaicompat_cmpl_id; |
| 141 | common_chat_syntax oaicompat_chat_syntax; |
| 142 | |
| 143 | // Embeddings |
| 144 | int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) |
| 145 | |
| 146 | json to_json(bool only_metrics = false) const { |
| 147 | std::vector<std::string> samplers; |
| 148 | samplers.reserve(n: sampling.samplers.size()); |
| 149 | for (const auto & sampler : sampling.samplers) { |
| 150 | samplers.emplace_back(args: common_sampler_type_to_str(cnstr: sampler)); |
| 151 | } |
| 152 | |
| 153 | json lora = json::array(); |
| 154 | for (size_t i = 0; i < this->lora.size(); ++i) { |
| 155 | lora.push_back(init: {{"id" , i}, {"scale" , this->lora[i].scale}}); |
| 156 | } |
| 157 | |
| 158 | if (only_metrics) { |
| 159 | return json { |
| 160 | {"seed" , sampling.seed}, |
| 161 | {"temperature" , sampling.temp}, |
| 162 | {"dynatemp_range" , sampling.dynatemp_range}, |
| 163 | {"dynatemp_exponent" , sampling.dynatemp_exponent}, |
| 164 | {"top_k" , sampling.top_k}, |
| 165 | {"top_p" , sampling.top_p}, |
| 166 | {"min_p" , sampling.min_p}, |
| 167 | {"top_n_sigma" , sampling.top_n_sigma}, |
| 168 | {"xtc_probability" , sampling.xtc_probability}, |
| 169 | {"xtc_threshold" , sampling.xtc_threshold}, |
| 170 | {"typical_p" , sampling.typ_p}, |
| 171 | {"repeat_last_n" , sampling.penalty_last_n}, |
| 172 | {"repeat_penalty" , sampling.penalty_repeat}, |
| 173 | {"presence_penalty" , sampling.penalty_present}, |
| 174 | {"frequency_penalty" , sampling.penalty_freq}, |
| 175 | {"dry_multiplier" , sampling.dry_multiplier}, |
| 176 | {"dry_base" , sampling.dry_base}, |
| 177 | {"dry_allowed_length" , sampling.dry_allowed_length}, |
| 178 | {"dry_penalty_last_n" , sampling.dry_penalty_last_n}, |
| 179 | {"mirostat" , sampling.mirostat}, |
| 180 | {"mirostat_tau" , sampling.mirostat_tau}, |
| 181 | {"mirostat_eta" , sampling.mirostat_eta}, |
| 182 | {"max_tokens" , n_predict}, |
| 183 | {"n_predict" , n_predict}, // TODO: deduplicate? |
| 184 | {"n_keep" , n_keep}, |
| 185 | {"n_discard" , n_discard}, |
| 186 | {"ignore_eos" , sampling.ignore_eos}, |
| 187 | {"stream" , stream}, |
| 188 | {"n_probs" , sampling.n_probs}, |
| 189 | {"min_keep" , sampling.min_keep}, |
| 190 | {"chat_format" , common_chat_format_name(format: oaicompat_chat_syntax.format)}, |
| 191 | {"reasoning_format" , common_reasoning_format_name(format: oaicompat_chat_syntax.reasoning_format)}, |
| 192 | {"reasoning_in_content" , oaicompat_chat_syntax.reasoning_in_content}, |
| 193 | {"thinking_forced_open" , oaicompat_chat_syntax.thinking_forced_open}, |
| 194 | {"samplers" , samplers}, |
| 195 | {"speculative.n_max" , speculative.n_max}, |
| 196 | {"speculative.n_min" , speculative.n_min}, |
| 197 | {"speculative.p_min" , speculative.p_min}, |
| 198 | {"timings_per_token" , timings_per_token}, |
| 199 | {"post_sampling_probs" , post_sampling_probs}, |
| 200 | {"lora" , lora}, |
| 201 | }; |
| 202 | } |
| 203 | |
| 204 | auto grammar_triggers = json::array(); |
| 205 | for (const auto & trigger : sampling.grammar_triggers) { |
| 206 | server_grammar_trigger ct(trigger); |
| 207 | grammar_triggers.push_back(val: ct.to_json()); |
| 208 | } |
| 209 | |
| 210 | return json { |
| 211 | {"seed" , sampling.seed}, |
| 212 | {"temperature" , sampling.temp}, |
| 213 | {"dynatemp_range" , sampling.dynatemp_range}, |
| 214 | {"dynatemp_exponent" , sampling.dynatemp_exponent}, |
| 215 | {"top_k" , sampling.top_k}, |
| 216 | {"top_p" , sampling.top_p}, |
| 217 | {"min_p" , sampling.min_p}, |
| 218 | {"top_n_sigma" , sampling.top_n_sigma}, |
| 219 | {"xtc_probability" , sampling.xtc_probability}, |
| 220 | {"xtc_threshold" , sampling.xtc_threshold}, |
| 221 | {"typical_p" , sampling.typ_p}, |
| 222 | {"repeat_last_n" , sampling.penalty_last_n}, |
| 223 | {"repeat_penalty" , sampling.penalty_repeat}, |
| 224 | {"presence_penalty" , sampling.penalty_present}, |
| 225 | {"frequency_penalty" , sampling.penalty_freq}, |
| 226 | {"dry_multiplier" , sampling.dry_multiplier}, |
| 227 | {"dry_base" , sampling.dry_base}, |
| 228 | {"dry_allowed_length" , sampling.dry_allowed_length}, |
| 229 | {"dry_penalty_last_n" , sampling.dry_penalty_last_n}, |
| 230 | {"dry_sequence_breakers" , sampling.dry_sequence_breakers}, |
| 231 | {"mirostat" , sampling.mirostat}, |
| 232 | {"mirostat_tau" , sampling.mirostat_tau}, |
| 233 | {"mirostat_eta" , sampling.mirostat_eta}, |
| 234 | {"stop" , antiprompt}, |
| 235 | {"max_tokens" , n_predict}, |
| 236 | {"n_predict" , n_predict}, // TODO: deduplicate? |
| 237 | {"n_keep" , n_keep}, |
| 238 | {"n_discard" , n_discard}, |
| 239 | {"ignore_eos" , sampling.ignore_eos}, |
| 240 | {"stream" , stream}, |
| 241 | {"logit_bias" , format_logit_bias(logit_bias: sampling.logit_bias)}, |
| 242 | {"n_probs" , sampling.n_probs}, |
| 243 | {"min_keep" , sampling.min_keep}, |
| 244 | {"grammar" , sampling.grammar}, |
| 245 | {"grammar_lazy" , sampling.grammar_lazy}, |
| 246 | {"grammar_triggers" , grammar_triggers}, |
| 247 | {"preserved_tokens" , sampling.preserved_tokens}, |
| 248 | {"chat_format" , common_chat_format_name(format: oaicompat_chat_syntax.format)}, |
| 249 | {"reasoning_format" , common_reasoning_format_name(format: oaicompat_chat_syntax.reasoning_format)}, |
| 250 | {"reasoning_in_content" , oaicompat_chat_syntax.reasoning_in_content}, |
| 251 | {"thinking_forced_open" , oaicompat_chat_syntax.thinking_forced_open}, |
| 252 | {"samplers" , samplers}, |
| 253 | {"speculative.n_max" , speculative.n_max}, |
| 254 | {"speculative.n_min" , speculative.n_min}, |
| 255 | {"speculative.p_min" , speculative.p_min}, |
| 256 | {"timings_per_token" , timings_per_token}, |
| 257 | {"post_sampling_probs" , post_sampling_probs}, |
| 258 | {"lora" , lora}, |
| 259 | }; |
| 260 | } |
| 261 | }; |
| 262 | |
| 263 | struct server_task { |
| 264 | int id = -1; // to be filled by server_queue |
| 265 | int index = -1; // used when there are multiple prompts (batch request) |
| 266 | |
| 267 | // used by SERVER_TASK_TYPE_CANCEL |
| 268 | int id_target = -1; |
| 269 | int id_slot = -1; |
| 270 | |
| 271 | // used by SERVER_TASK_TYPE_INFERENCE |
| 272 | slot_params params; |
| 273 | server_tokens tokens; |
| 274 | |
| 275 | server_task_type type; |
| 276 | |
| 277 | // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE |
| 278 | struct slot_action { |
| 279 | int slot_id; |
| 280 | std::string filename; |
| 281 | std::string filepath; |
| 282 | }; |
| 283 | slot_action slot_action; |
| 284 | |
| 285 | // used by SERVER_TASK_TYPE_METRICS |
| 286 | bool metrics_reset_bucket = false; |
| 287 | |
| 288 | // used by SERVER_TASK_TYPE_SET_LORA |
| 289 | std::vector<common_adapter_lora_info> set_lora; |
| 290 | |
| 291 | server_task() = default; |
| 292 | |
| 293 | server_task(server_task_type type) : type(type) {} |
| 294 | |
| 295 | int32_t n_tokens() const { |
| 296 | return tokens.size(); |
| 297 | } |
| 298 | |
| 299 | static slot_params params_from_json_cmpl( |
| 300 | const llama_context * ctx, |
| 301 | const common_params & params_base, |
| 302 | const json & data) { |
| 303 | const llama_model * model = llama_get_model(ctx); |
| 304 | const llama_vocab * vocab = llama_model_get_vocab(model); |
| 305 | |
| 306 | slot_params params; |
| 307 | |
| 308 | // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) |
| 309 | slot_params defaults; |
| 310 | defaults.sampling = params_base.sampling; |
| 311 | defaults.speculative = params_base.speculative; |
| 312 | defaults.n_keep = params_base.n_keep; |
| 313 | defaults.n_predict = params_base.n_predict; |
| 314 | defaults.antiprompt = params_base.antiprompt; |
| 315 | |
| 316 | // enabling this will output extra debug information in the HTTP responses from the server |
| 317 | params.verbose = params_base.verbosity > 9; |
| 318 | params.timings_per_token = json_value(body: data, key: "timings_per_token" , default_value: false); |
| 319 | |
| 320 | params.stream = json_value(body: data, key: "stream" , default_value: false); |
| 321 | auto stream_opt = json_value(body: data, key: "stream_options" , default_value: json::object()); |
| 322 | params.include_usage = json_value(body: stream_opt, key: "include_usage" , default_value: false); |
| 323 | params.cache_prompt = json_value(body: data, key: "cache_prompt" , default_value: true); |
| 324 | params.return_tokens = json_value(body: data, key: "return_tokens" , default_value: false); |
| 325 | params.return_progress = json_value(body: data, key: "return_progress" , default_value: false); |
| 326 | params.n_predict = json_value(body: data, key: "n_predict" , default_value: json_value(body: data, key: "max_tokens" , default_value: defaults.n_predict)); |
| 327 | params.n_indent = json_value(body: data, key: "n_indent" , default_value: defaults.n_indent); |
| 328 | params.n_keep = json_value(body: data, key: "n_keep" , default_value: defaults.n_keep); |
| 329 | params.n_discard = json_value(body: data, key: "n_discard" , default_value: defaults.n_discard); |
| 330 | //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement |
| 331 | params.t_max_predict_ms = json_value(body: data, key: "t_max_predict_ms" , default_value: defaults.t_max_predict_ms); |
| 332 | params.response_fields = json_value(body: data, key: "response_fields" , default_value: std::vector<std::string>()); |
| 333 | |
| 334 | params.sampling.top_k = json_value(body: data, key: "top_k" , default_value: defaults.sampling.top_k); |
| 335 | params.sampling.top_p = json_value(body: data, key: "top_p" , default_value: defaults.sampling.top_p); |
| 336 | params.sampling.min_p = json_value(body: data, key: "min_p" , default_value: defaults.sampling.min_p); |
| 337 | params.sampling.top_n_sigma = json_value(body: data, key: "top_n_sigma" , default_value: defaults.sampling.top_n_sigma); |
| 338 | params.sampling.xtc_probability = json_value(body: data, key: "xtc_probability" , default_value: defaults.sampling.xtc_probability); |
| 339 | params.sampling.xtc_threshold = json_value(body: data, key: "xtc_threshold" , default_value: defaults.sampling.xtc_threshold); |
| 340 | params.sampling.typ_p = json_value(body: data, key: "typical_p" , default_value: defaults.sampling.typ_p); |
| 341 | params.sampling.temp = json_value(body: data, key: "temperature" , default_value: defaults.sampling.temp); |
| 342 | params.sampling.dynatemp_range = json_value(body: data, key: "dynatemp_range" , default_value: defaults.sampling.dynatemp_range); |
| 343 | params.sampling.dynatemp_exponent = json_value(body: data, key: "dynatemp_exponent" , default_value: defaults.sampling.dynatemp_exponent); |
| 344 | params.sampling.penalty_last_n = json_value(body: data, key: "repeat_last_n" , default_value: defaults.sampling.penalty_last_n); |
| 345 | params.sampling.penalty_repeat = json_value(body: data, key: "repeat_penalty" , default_value: defaults.sampling.penalty_repeat); |
| 346 | params.sampling.penalty_freq = json_value(body: data, key: "frequency_penalty" , default_value: defaults.sampling.penalty_freq); |
| 347 | params.sampling.penalty_present = json_value(body: data, key: "presence_penalty" , default_value: defaults.sampling.penalty_present); |
| 348 | params.sampling.dry_multiplier = json_value(body: data, key: "dry_multiplier" , default_value: defaults.sampling.dry_multiplier); |
| 349 | params.sampling.dry_base = json_value(body: data, key: "dry_base" , default_value: defaults.sampling.dry_base); |
| 350 | params.sampling.dry_allowed_length = json_value(body: data, key: "dry_allowed_length" , default_value: defaults.sampling.dry_allowed_length); |
| 351 | params.sampling.dry_penalty_last_n = json_value(body: data, key: "dry_penalty_last_n" , default_value: defaults.sampling.dry_penalty_last_n); |
| 352 | params.sampling.mirostat = json_value(body: data, key: "mirostat" , default_value: defaults.sampling.mirostat); |
| 353 | params.sampling.mirostat_tau = json_value(body: data, key: "mirostat_tau" , default_value: defaults.sampling.mirostat_tau); |
| 354 | params.sampling.mirostat_eta = json_value(body: data, key: "mirostat_eta" , default_value: defaults.sampling.mirostat_eta); |
| 355 | params.sampling.seed = json_value(body: data, key: "seed" , default_value: defaults.sampling.seed); |
| 356 | params.sampling.n_probs = json_value(body: data, key: "n_probs" , default_value: defaults.sampling.n_probs); |
| 357 | params.sampling.min_keep = json_value(body: data, key: "min_keep" , default_value: defaults.sampling.min_keep); |
| 358 | params.post_sampling_probs = json_value(body: data, key: "post_sampling_probs" , default_value: defaults.post_sampling_probs); |
| 359 | |
| 360 | params.speculative.n_min = json_value(body: data, key: "speculative.n_min" , default_value: defaults.speculative.n_min); |
| 361 | params.speculative.n_max = json_value(body: data, key: "speculative.n_max" , default_value: defaults.speculative.n_max); |
| 362 | params.speculative.p_min = json_value(body: data, key: "speculative.p_min" , default_value: defaults.speculative.p_min); |
| 363 | |
| 364 | params.speculative.n_min = std::min(a: params.speculative.n_max, b: params.speculative.n_min); |
| 365 | params.speculative.n_min = std::max(a: params.speculative.n_min, b: 0); |
| 366 | params.speculative.n_max = std::max(a: params.speculative.n_max, b: 0); |
| 367 | |
| 368 | // Use OpenAI API logprobs only if n_probs wasn't provided |
| 369 | if (data.contains(key: "logprobs" ) && params.sampling.n_probs == defaults.sampling.n_probs){ |
| 370 | params.sampling.n_probs = json_value(body: data, key: "logprobs" , default_value: defaults.sampling.n_probs); |
| 371 | } |
| 372 | |
| 373 | if (data.contains(key: "lora" )) { |
| 374 | if (data.at(key: "lora" ).is_array()) { |
| 375 | params.lora = parse_lora_request(lora_base: params_base.lora_adapters, data: data.at(key: "lora" )); |
| 376 | } else { |
| 377 | throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields" ); |
| 378 | } |
| 379 | } else { |
| 380 | params.lora = params_base.lora_adapters; |
| 381 | } |
| 382 | |
| 383 | // TODO: add more sanity checks for the input parameters |
| 384 | |
| 385 | if (params.sampling.penalty_last_n < -1) { |
| 386 | throw std::runtime_error("Error: repeat_last_n must be >= -1" ); |
| 387 | } |
| 388 | |
| 389 | if (params.sampling.dry_penalty_last_n < -1) { |
| 390 | throw std::runtime_error("Error: dry_penalty_last_n must be >= -1" ); |
| 391 | } |
| 392 | |
| 393 | if (params.sampling.penalty_last_n == -1) { |
| 394 | // note: should be the slot's context and not the full context, but it's ok |
| 395 | params.sampling.penalty_last_n = llama_n_ctx(ctx); |
| 396 | } |
| 397 | |
| 398 | if (params.sampling.dry_penalty_last_n == -1) { |
| 399 | params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); |
| 400 | } |
| 401 | |
| 402 | if (params.sampling.dry_base < 1.0f) { |
| 403 | params.sampling.dry_base = defaults.sampling.dry_base; |
| 404 | } |
| 405 | |
| 406 | // sequence breakers for DRY |
| 407 | { |
| 408 | // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format |
| 409 | // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 |
| 410 | |
| 411 | if (data.contains(key: "dry_sequence_breakers" )) { |
| 412 | params.sampling.dry_sequence_breakers = json_value(body: data, key: "dry_sequence_breakers" , default_value: std::vector<std::string>()); |
| 413 | if (params.sampling.dry_sequence_breakers.empty()) { |
| 414 | throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings" ); |
| 415 | } |
| 416 | } |
| 417 | } |
| 418 | |
| 419 | // process "json_schema" and "grammar" |
| 420 | if (data.contains(key: "json_schema" ) && !data.contains(key: "grammar" )) { |
| 421 | try { |
| 422 | auto schema = json_value(body: data, key: "json_schema" , default_value: json::object()); |
| 423 | SRV_DBG("JSON schema: %s\n" , schema.dump(2).c_str()); |
| 424 | params.sampling.grammar = json_schema_to_grammar(schema); |
| 425 | SRV_DBG("Converted grammar: %s\n" , params.sampling.grammar.c_str()); |
| 426 | } catch (const std::exception & e) { |
| 427 | throw std::runtime_error(std::string("\"json_schema\": " ) + e.what()); |
| 428 | } |
| 429 | } else { |
| 430 | params.sampling.grammar = json_value(body: data, key: "grammar" , default_value: defaults.sampling.grammar); |
| 431 | SRV_DBG("Grammar: %s\n" , params.sampling.grammar.c_str()); |
| 432 | params.sampling.grammar_lazy = json_value(body: data, key: "grammar_lazy" , default_value: defaults.sampling.grammar_lazy); |
| 433 | SRV_DBG("Grammar lazy: %s\n" , params.sampling.grammar_lazy ? "true" : "false" ); |
| 434 | } |
| 435 | |
| 436 | { |
| 437 | auto it = data.find(key: "chat_format" ); |
| 438 | if (it != data.end()) { |
| 439 | params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>()); |
| 440 | SRV_INF("Chat format: %s\n" , common_chat_format_name(params.oaicompat_chat_syntax.format)); |
| 441 | } else { |
| 442 | params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; |
| 443 | } |
| 444 | common_reasoning_format reasoning_format = params_base.reasoning_format; |
| 445 | if (data.contains(key: "reasoning_format" )) { |
| 446 | reasoning_format = common_reasoning_format_from_name(format: data.at(key: "reasoning_format" ).get<std::string>()); |
| 447 | } |
| 448 | params.oaicompat_chat_syntax.reasoning_format = reasoning_format; |
| 449 | params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); |
| 450 | params.oaicompat_chat_syntax.thinking_forced_open = json_value(body: data, key: "thinking_forced_open" , default_value: false); |
| 451 | params.oaicompat_chat_syntax.parse_tool_calls = json_value(body: data, key: "parse_tool_calls" , default_value: false); |
| 452 | } |
| 453 | |
| 454 | { |
| 455 | const auto preserved_tokens = data.find(key: "preserved_tokens" ); |
| 456 | if (preserved_tokens != data.end()) { |
| 457 | for (const auto & t : *preserved_tokens) { |
| 458 | auto ids = common_tokenize(vocab, text: t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true); |
| 459 | if (ids.size() == 1) { |
| 460 | SRV_DBG("Preserved token: %d\n" , ids[0]); |
| 461 | params.sampling.preserved_tokens.insert(x: ids[0]); |
| 462 | } else { |
| 463 | // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. |
| 464 | SRV_DBG("Not preserved because more than 1 token: %s\n" , t.get<std::string>().c_str()); |
| 465 | } |
| 466 | } |
| 467 | } |
| 468 | const auto grammar_triggers = data.find(key: "grammar_triggers" ); |
| 469 | if (grammar_triggers != data.end()) { |
| 470 | for (const auto & t : *grammar_triggers) { |
| 471 | server_grammar_trigger ct(t); |
| 472 | if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { |
| 473 | const auto & word = ct.value.value; |
| 474 | auto ids = common_tokenize(vocab, text: word, /* add_special= */ false, /* parse_special= */ true); |
| 475 | if (ids.size() == 1) { |
| 476 | auto token = ids[0]; |
| 477 | if (std::find(first: params.sampling.preserved_tokens.begin(), last: params.sampling.preserved_tokens.end(), val: (llama_token) token) == params.sampling.preserved_tokens.end()) { |
| 478 | throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); |
| 479 | } |
| 480 | SRV_DBG("Grammar trigger token: %d (`%s`)\n" , token, word.c_str()); |
| 481 | common_grammar_trigger trigger; |
| 482 | trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; |
| 483 | trigger.value = word; |
| 484 | trigger.token = token; |
| 485 | params.sampling.grammar_triggers.push_back(x: std::move(trigger)); |
| 486 | } else { |
| 487 | SRV_DBG("Grammar trigger word: `%s`\n" , word.c_str()); |
| 488 | params.sampling.grammar_triggers.push_back(x: {.type: COMMON_GRAMMAR_TRIGGER_TYPE_WORD, .value: word}); |
| 489 | } |
| 490 | } else { |
| 491 | if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { |
| 492 | SRV_DBG("Grammar trigger pattern: `%s`\n" , ct.value.value.c_str()); |
| 493 | } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { |
| 494 | SRV_DBG("Grammar trigger pattern full: `%s`\n" , ct.value.value.c_str()); |
| 495 | } else { |
| 496 | throw std::runtime_error("Unknown grammar trigger type" ); |
| 497 | } |
| 498 | params.sampling.grammar_triggers.emplace_back(args: std::move(ct.value)); |
| 499 | } |
| 500 | } |
| 501 | } |
| 502 | if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { |
| 503 | throw std::runtime_error("Error: no triggers set for lazy grammar!" ); |
| 504 | } |
| 505 | } |
| 506 | |
| 507 | { |
| 508 | params.sampling.logit_bias.clear(); |
| 509 | |
| 510 | const auto & logit_bias = data.find(key: "logit_bias" ); |
| 511 | if (logit_bias != data.end() && logit_bias->is_array()) { |
| 512 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 513 | for (const auto & el : *logit_bias) { |
| 514 | // TODO: we may want to throw errors here, in case "el" is incorrect |
| 515 | if (el.is_array() && el.size() == 2) { |
| 516 | float bias; |
| 517 | if (el[1].is_number()) { |
| 518 | bias = el[1].get<float>(); |
| 519 | } else if (el[1].is_boolean() && !el[1].get<bool>()) { |
| 520 | bias = -INFINITY; |
| 521 | } else { |
| 522 | continue; |
| 523 | } |
| 524 | |
| 525 | if (el[0].is_number_integer()) { |
| 526 | llama_token tok = el[0].get<llama_token>(); |
| 527 | if (tok >= 0 && tok < n_vocab) { |
| 528 | params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias}); |
| 529 | } |
| 530 | } else if (el[0].is_string()) { |
| 531 | auto toks = common_tokenize(vocab, text: el[0].get<std::string>(), add_special: false); |
| 532 | for (auto tok : toks) { |
| 533 | params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias}); |
| 534 | } |
| 535 | } |
| 536 | } |
| 537 | } |
| 538 | } else if (logit_bias != data.end() && logit_bias->is_object()) { |
| 539 | const int n_vocab = llama_vocab_n_tokens(vocab); |
| 540 | for (const auto & el : logit_bias->items()) { |
| 541 | float bias; |
| 542 | const auto & key = el.key(); |
| 543 | const auto & value = el.value(); |
| 544 | if (value.is_number()) { |
| 545 | bias = value.get<float>(); |
| 546 | } else if (value.is_boolean() && !value.get<bool>()) { |
| 547 | bias = -INFINITY; |
| 548 | } else { |
| 549 | continue; |
| 550 | } |
| 551 | |
| 552 | char *end; |
| 553 | llama_token tok = strtol(nptr: key.c_str(), endptr: &end, base: 10); |
| 554 | if (*end == 0) { |
| 555 | if (tok >= 0 && tok < n_vocab) { |
| 556 | params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias}); |
| 557 | } |
| 558 | } else { |
| 559 | auto toks = common_tokenize(vocab, text: key, add_special: false); |
| 560 | for (auto tok : toks) { |
| 561 | params.sampling.logit_bias.push_back(x: {.token: tok, .bias: bias}); |
| 562 | } |
| 563 | } |
| 564 | } |
| 565 | } |
| 566 | |
| 567 | params.sampling.ignore_eos = json_value(body: data, key: "ignore_eos" , default_value: params_base.sampling.ignore_eos); |
| 568 | if (params.sampling.ignore_eos) { |
| 569 | params.sampling.logit_bias.insert( |
| 570 | position: params.sampling.logit_bias.end(), |
| 571 | first: defaults.sampling.logit_bias_eog.begin(), last: defaults.sampling.logit_bias_eog.end()); |
| 572 | } |
| 573 | } |
| 574 | |
| 575 | { |
| 576 | params.antiprompt.clear(); |
| 577 | |
| 578 | const auto & stop = data.find(key: "stop" ); |
| 579 | if (stop != data.end() && stop->is_array()) { |
| 580 | for (const auto & word : *stop) { |
| 581 | if (!word.empty()) { |
| 582 | params.antiprompt.push_back(x: word); |
| 583 | } |
| 584 | } |
| 585 | } |
| 586 | // set reverse prompt from cli args if not set in the request |
| 587 | if (params.antiprompt.empty()) { |
| 588 | params.antiprompt = defaults.antiprompt; |
| 589 | } |
| 590 | } |
| 591 | |
| 592 | { |
| 593 | const auto samplers = data.find(key: "samplers" ); |
| 594 | if (samplers != data.end()) { |
| 595 | if (samplers->is_array()) { |
| 596 | params.sampling.samplers = common_sampler_types_from_names(names: *samplers, allow_alt_names: false); |
| 597 | } else if (samplers->is_string()){ |
| 598 | params.sampling.samplers = common_sampler_types_from_chars(chars: samplers->get<std::string>()); |
| 599 | } |
| 600 | } else { |
| 601 | params.sampling.samplers = defaults.sampling.samplers; |
| 602 | } |
| 603 | } |
| 604 | |
| 605 | std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; |
| 606 | params.oaicompat_model = json_value(body: data, key: "model" , default_value: model_name); |
| 607 | |
| 608 | return params; |
| 609 | } |
| 610 | |
| 611 | // utility function |
| 612 | static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) { |
| 613 | std::unordered_set<int> ids(tasks.size()); |
| 614 | for (size_t i = 0; i < tasks.size(); i++) { |
| 615 | ids.insert(x: tasks[i].id); |
| 616 | } |
| 617 | return ids; |
| 618 | } |
| 619 | }; |
| 620 | |
| 621 | struct result_timings { |
| 622 | int32_t cache_n = -1; |
| 623 | |
| 624 | int32_t prompt_n = -1; |
| 625 | double prompt_ms; |
| 626 | double prompt_per_token_ms; |
| 627 | double prompt_per_second; |
| 628 | |
| 629 | int32_t predicted_n = -1; |
| 630 | double predicted_ms; |
| 631 | double predicted_per_token_ms; |
| 632 | double predicted_per_second; |
| 633 | |
| 634 | // Optional speculative metrics - only included when > 0 |
| 635 | int32_t draft_n = 0; |
| 636 | int32_t draft_n_accepted = 0; |
| 637 | |
| 638 | json to_json() const { |
| 639 | json base = { |
| 640 | {"cache_n" , cache_n}, |
| 641 | |
| 642 | {"prompt_n" , prompt_n}, |
| 643 | {"prompt_ms" , prompt_ms}, |
| 644 | {"prompt_per_token_ms" , prompt_per_token_ms}, |
| 645 | {"prompt_per_second" , prompt_per_second}, |
| 646 | |
| 647 | {"predicted_n" , predicted_n}, |
| 648 | {"predicted_ms" , predicted_ms}, |
| 649 | {"predicted_per_token_ms" , predicted_per_token_ms}, |
| 650 | {"predicted_per_second" , predicted_per_second}, |
| 651 | }; |
| 652 | |
| 653 | if (draft_n > 0) { |
| 654 | base["draft_n" ] = draft_n; |
| 655 | base["draft_n_accepted" ] = draft_n_accepted; |
| 656 | } |
| 657 | |
| 658 | return base; |
| 659 | } |
| 660 | }; |
| 661 | |
| 662 | struct result_prompt_progress { |
| 663 | int32_t total = 0; |
| 664 | int32_t cache = 0; |
| 665 | int32_t processed = 0; |
| 666 | int64_t time_ms = 0; |
| 667 | |
| 668 | json to_json() const { |
| 669 | return json { |
| 670 | {"total" , total}, |
| 671 | {"cache" , cache}, |
| 672 | {"processed" , processed}, |
| 673 | {"time_ms" , time_ms}, |
| 674 | }; |
| 675 | } |
| 676 | }; |
| 677 | |
| 678 | struct server_task_result { |
| 679 | int id = -1; |
| 680 | int id_slot = -1; |
| 681 | virtual bool is_error() { |
| 682 | // only used by server_task_result_error |
| 683 | return false; |
| 684 | } |
| 685 | virtual bool is_stop() { |
| 686 | // only used by server_task_result_cmpl_* |
| 687 | return false; |
| 688 | } |
| 689 | virtual int get_index() { |
| 690 | return -1; |
| 691 | } |
| 692 | virtual json to_json() = 0; |
| 693 | virtual ~server_task_result() = default; |
| 694 | }; |
| 695 | |
| 696 | // using shared_ptr for polymorphism of server_task_result |
| 697 | using server_task_result_ptr = std::unique_ptr<server_task_result>; |
| 698 | |
| 699 | static inline std::string stop_type_to_str(stop_type type) { |
| 700 | switch (type) { |
| 701 | case STOP_TYPE_EOS: return "eos" ; |
| 702 | case STOP_TYPE_WORD: return "word" ; |
| 703 | case STOP_TYPE_LIMIT: return "limit" ; |
| 704 | default: return "none" ; |
| 705 | } |
| 706 | } |
| 707 | |
| 708 | struct completion_token_output { |
| 709 | llama_token tok; |
| 710 | float prob; |
| 711 | std::string text_to_send; |
| 712 | struct prob_info { |
| 713 | llama_token tok; |
| 714 | std::string txt; |
| 715 | float prob; |
| 716 | }; |
| 717 | std::vector<prob_info> probs; |
| 718 | |
| 719 | json to_json(bool post_sampling_probs) const { |
| 720 | json probs_for_token = json::array(); |
| 721 | for (const auto & p : probs) { |
| 722 | std::string txt(p.txt); |
| 723 | txt.resize(n: validate_utf8(text: txt)); |
| 724 | probs_for_token.push_back(val: json { |
| 725 | {"id" , p.tok}, |
| 726 | {"token" , txt}, |
| 727 | {"bytes" , str_to_bytes(str: p.txt)}, |
| 728 | { |
| 729 | post_sampling_probs ? "prob" : "logprob" , |
| 730 | post_sampling_probs ? p.prob : logarithm(x: p.prob) |
| 731 | }, |
| 732 | }); |
| 733 | } |
| 734 | return probs_for_token; |
| 735 | } |
| 736 | |
| 737 | static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) { |
| 738 | json out = json::array(); |
| 739 | for (const auto & p : probs) { |
| 740 | std::string txt(p.text_to_send); |
| 741 | txt.resize(n: validate_utf8(text: txt)); |
| 742 | out.push_back(val: json { |
| 743 | {"id" , p.tok}, |
| 744 | {"token" , txt}, |
| 745 | {"bytes" , str_to_bytes(str: p.text_to_send)}, |
| 746 | { |
| 747 | post_sampling_probs ? "prob" : "logprob" , |
| 748 | post_sampling_probs ? p.prob : logarithm(x: p.prob) |
| 749 | }, |
| 750 | { |
| 751 | post_sampling_probs ? "top_probs" : "top_logprobs" , |
| 752 | p.to_json(post_sampling_probs) |
| 753 | }, |
| 754 | }); |
| 755 | } |
| 756 | return out; |
| 757 | } |
| 758 | |
| 759 | static float logarithm(float x) { |
| 760 | // nlohmann::json converts -inf to null, so we need to prevent that |
| 761 | return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x: x); |
| 762 | } |
| 763 | |
| 764 | static std::vector<unsigned char> str_to_bytes(const std::string & str) { |
| 765 | std::vector<unsigned char> bytes; |
| 766 | for (unsigned char c : str) { |
| 767 | bytes.push_back(x: c); |
| 768 | } |
| 769 | return bytes; |
| 770 | } |
| 771 | }; |
| 772 | |
| 773 | struct server_task_result_cmpl_final : server_task_result { |
| 774 | int index = 0; |
| 775 | |
| 776 | std::string content; |
| 777 | llama_tokens tokens; |
| 778 | |
| 779 | bool stream; |
| 780 | bool include_usage; |
| 781 | result_timings timings; |
| 782 | std::string prompt; |
| 783 | |
| 784 | bool truncated; |
| 785 | int32_t n_decoded; |
| 786 | int32_t n_prompt_tokens; |
| 787 | int32_t n_tokens_cached; |
| 788 | bool has_new_line; |
| 789 | std::string stopping_word; |
| 790 | stop_type stop = STOP_TYPE_NONE; |
| 791 | |
| 792 | bool post_sampling_probs; |
| 793 | std::vector<completion_token_output> probs_output; |
| 794 | std::vector<std::string> response_fields; |
| 795 | |
| 796 | slot_params generation_params; |
| 797 | |
| 798 | // OAI-compat fields |
| 799 | bool verbose = false; |
| 800 | oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; |
| 801 | std::string oaicompat_model; |
| 802 | std::string oaicompat_cmpl_id; |
| 803 | common_chat_msg oaicompat_msg; |
| 804 | |
| 805 | std::vector<common_chat_msg_diff> oaicompat_msg_diffs; |
| 806 | |
| 807 | virtual int get_index() override { |
| 808 | return index; |
| 809 | } |
| 810 | |
| 811 | virtual bool is_stop() override { |
| 812 | return true; // in stream mode, final responses are considered stop |
| 813 | } |
| 814 | |
| 815 | virtual json to_json() override { |
| 816 | switch (oaicompat) { |
| 817 | case OAICOMPAT_TYPE_NONE: |
| 818 | return to_json_non_oaicompat(); |
| 819 | case OAICOMPAT_TYPE_COMPLETION: |
| 820 | return to_json_oaicompat(); |
| 821 | case OAICOMPAT_TYPE_CHAT: |
| 822 | return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); |
| 823 | default: |
| 824 | GGML_ASSERT(false && "Invalid oaicompat_type" ); |
| 825 | } |
| 826 | } |
| 827 | |
| 828 | json to_json_non_oaicompat() { |
| 829 | json res = json { |
| 830 | {"index" , index}, |
| 831 | {"content" , stream ? "" : content}, // in stream mode, content is already in last partial chunk |
| 832 | {"tokens" , stream ? llama_tokens {} : tokens}, |
| 833 | {"id_slot" , id_slot}, |
| 834 | {"stop" , true}, |
| 835 | {"model" , oaicompat_model}, |
| 836 | {"tokens_predicted" , n_decoded}, |
| 837 | {"tokens_evaluated" , n_prompt_tokens}, |
| 838 | {"generation_settings" , generation_params.to_json()}, |
| 839 | {"prompt" , prompt}, |
| 840 | {"has_new_line" , has_new_line}, |
| 841 | {"truncated" , truncated}, |
| 842 | {"stop_type" , stop_type_to_str(type: stop)}, |
| 843 | {"stopping_word" , stopping_word}, |
| 844 | {"tokens_cached" , n_tokens_cached}, |
| 845 | {"timings" , timings.to_json()}, |
| 846 | }; |
| 847 | if (!stream && !probs_output.empty()) { |
| 848 | res["completion_probabilities" ] = completion_token_output::probs_vector_to_json(probs: probs_output, post_sampling_probs); |
| 849 | } |
| 850 | return response_fields.empty() ? res : json_get_nested_values(paths: response_fields, js: res); |
| 851 | } |
| 852 | |
| 853 | json to_json_oaicompat() { |
| 854 | std::time_t t = std::time(timer: 0); |
| 855 | json logprobs = json(nullptr); // OAI default to null |
| 856 | if (!stream && probs_output.size() > 0) { |
| 857 | logprobs = json{ |
| 858 | {"content" , completion_token_output::probs_vector_to_json(probs: probs_output, post_sampling_probs)}, |
| 859 | }; |
| 860 | } |
| 861 | json finish_reason = "length" ; |
| 862 | if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { |
| 863 | finish_reason = "stop" ; |
| 864 | } |
| 865 | json res = json { |
| 866 | {"choices" , json::array(init: { |
| 867 | json{ |
| 868 | {"text" , stream ? "" : content}, // in stream mode, content is already in last partial chunk |
| 869 | {"index" , index}, |
| 870 | {"logprobs" , logprobs}, |
| 871 | {"finish_reason" , finish_reason}, |
| 872 | } |
| 873 | })}, |
| 874 | {"created" , t}, |
| 875 | {"model" , oaicompat_model}, |
| 876 | {"system_fingerprint" , build_info}, |
| 877 | {"object" , "text_completion" }, |
| 878 | {"usage" , json { |
| 879 | {"completion_tokens" , n_decoded}, |
| 880 | {"prompt_tokens" , n_prompt_tokens}, |
| 881 | {"total_tokens" , n_decoded + n_prompt_tokens} |
| 882 | }}, |
| 883 | {"id" , oaicompat_cmpl_id} |
| 884 | }; |
| 885 | |
| 886 | // extra fields for debugging purposes |
| 887 | if (verbose) { |
| 888 | res["__verbose" ] = to_json_non_oaicompat(); |
| 889 | } |
| 890 | if (timings.prompt_n >= 0) { |
| 891 | res.push_back(init: {"timings" , timings.to_json()}); |
| 892 | } |
| 893 | |
| 894 | return res; |
| 895 | } |
| 896 | |
| 897 | json to_json_oaicompat_chat() { |
| 898 | std::string finish_reason = "length" ; |
| 899 | common_chat_msg msg; |
| 900 | if (!oaicompat_msg.empty()) { |
| 901 | msg = oaicompat_msg; |
| 902 | } else { |
| 903 | msg.role = "assistant" ; |
| 904 | msg.content = content; |
| 905 | } |
| 906 | if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { |
| 907 | finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls" ; |
| 908 | } |
| 909 | |
| 910 | json choice { |
| 911 | {"finish_reason" , finish_reason}, |
| 912 | {"index" , 0}, |
| 913 | {"message" , msg.to_json_oaicompat<json>()}, |
| 914 | }; |
| 915 | |
| 916 | if (!stream && probs_output.size() > 0) { |
| 917 | choice["logprobs" ] = json{ |
| 918 | {"content" , completion_token_output::probs_vector_to_json(probs: probs_output, post_sampling_probs)}, |
| 919 | }; |
| 920 | } |
| 921 | |
| 922 | std::time_t t = std::time(timer: 0); |
| 923 | |
| 924 | json res = json { |
| 925 | {"choices" , json::array(init: {choice})}, |
| 926 | {"created" , t}, |
| 927 | {"model" , oaicompat_model}, |
| 928 | {"system_fingerprint" , build_info}, |
| 929 | {"object" , "chat.completion" }, |
| 930 | {"usage" , json { |
| 931 | {"completion_tokens" , n_decoded}, |
| 932 | {"prompt_tokens" , n_prompt_tokens}, |
| 933 | {"total_tokens" , n_decoded + n_prompt_tokens} |
| 934 | }}, |
| 935 | {"id" , oaicompat_cmpl_id} |
| 936 | }; |
| 937 | |
| 938 | // extra fields for debugging purposes |
| 939 | if (verbose) { |
| 940 | res["__verbose" ] = to_json_non_oaicompat(); |
| 941 | } |
| 942 | if (timings.prompt_n >= 0) { |
| 943 | res.push_back(init: {"timings" , timings.to_json()}); |
| 944 | } |
| 945 | |
| 946 | return res; |
| 947 | } |
| 948 | |
| 949 | json to_json_oaicompat_chat_stream() { |
| 950 | std::time_t t = std::time(timer: 0); |
| 951 | std::string finish_reason = "length" ; |
| 952 | if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { |
| 953 | finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls" ; |
| 954 | } |
| 955 | |
| 956 | json deltas = json::array(); |
| 957 | for (const auto & diff : oaicompat_msg_diffs) { |
| 958 | deltas.push_back(init: { |
| 959 | {"choices" , json::array(init: { |
| 960 | json { |
| 961 | {"finish_reason" , nullptr}, |
| 962 | {"index" , 0}, |
| 963 | {"delta" , common_chat_msg_diff_to_json_oaicompat<json>(diff)}, |
| 964 | }, |
| 965 | })}, |
| 966 | {"created" , t}, |
| 967 | {"id" , oaicompat_cmpl_id}, |
| 968 | {"model" , oaicompat_model}, |
| 969 | {"system_fingerprint" , build_info}, |
| 970 | {"object" , "chat.completion.chunk" }, |
| 971 | }); |
| 972 | } |
| 973 | |
| 974 | deltas.push_back(init: { |
| 975 | {"choices" , json::array(init: { |
| 976 | json { |
| 977 | {"finish_reason" , finish_reason}, |
| 978 | {"index" , 0}, |
| 979 | {"delta" , json::object()}, |
| 980 | }, |
| 981 | })}, |
| 982 | {"created" , t}, |
| 983 | {"id" , oaicompat_cmpl_id}, |
| 984 | {"model" , oaicompat_model}, |
| 985 | {"system_fingerprint" , build_info}, |
| 986 | {"object" , "chat.completion.chunk" }, |
| 987 | }); |
| 988 | |
| 989 | if (include_usage) { |
| 990 | // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage |
| 991 | // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices |
| 992 | deltas.push_back(init: { |
| 993 | {"choices" , json::array()}, |
| 994 | {"created" , t}, |
| 995 | {"id" , oaicompat_cmpl_id}, |
| 996 | {"model" , oaicompat_model}, |
| 997 | {"system_fingerprint" , build_info}, |
| 998 | {"object" , "chat.completion.chunk" }, |
| 999 | {"usage" , json { |
| 1000 | {"completion_tokens" , n_decoded}, |
| 1001 | {"prompt_tokens" , n_prompt_tokens}, |
| 1002 | {"total_tokens" , n_decoded + n_prompt_tokens}, |
| 1003 | }}, |
| 1004 | }); |
| 1005 | } |
| 1006 | |
| 1007 | if (timings.prompt_n >= 0) { |
| 1008 | deltas.back().push_back(init: {"timings" , timings.to_json()}); |
| 1009 | } |
| 1010 | |
| 1011 | // extra fields for debugging purposes |
| 1012 | if (verbose && !deltas.empty()) { |
| 1013 | deltas.front()["__verbose" ] = to_json_non_oaicompat(); |
| 1014 | } |
| 1015 | |
| 1016 | return deltas; |
| 1017 | } |
| 1018 | }; |
| 1019 | |
| 1020 | struct server_task_result_cmpl_partial : server_task_result { |
| 1021 | int index = 0; |
| 1022 | |
| 1023 | std::string content; |
| 1024 | llama_tokens tokens; |
| 1025 | |
| 1026 | int32_t n_decoded; |
| 1027 | int32_t n_prompt_tokens; |
| 1028 | |
| 1029 | bool post_sampling_probs; |
| 1030 | bool is_progress = false; |
| 1031 | completion_token_output prob_output; |
| 1032 | result_timings timings; |
| 1033 | result_prompt_progress progress; |
| 1034 | |
| 1035 | // OAI-compat fields |
| 1036 | bool verbose = false; |
| 1037 | oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; |
| 1038 | std::string oaicompat_model; |
| 1039 | std::string oaicompat_cmpl_id; |
| 1040 | std::vector<common_chat_msg_diff> oaicompat_msg_diffs; |
| 1041 | |
| 1042 | virtual int get_index() override { |
| 1043 | return index; |
| 1044 | } |
| 1045 | |
| 1046 | virtual bool is_stop() override { |
| 1047 | return false; // in stream mode, partial responses are not considered stop |
| 1048 | } |
| 1049 | |
| 1050 | virtual json to_json() override { |
| 1051 | switch (oaicompat) { |
| 1052 | case OAICOMPAT_TYPE_NONE: |
| 1053 | return to_json_non_oaicompat(); |
| 1054 | case OAICOMPAT_TYPE_COMPLETION: |
| 1055 | return to_json_oaicompat(); |
| 1056 | case OAICOMPAT_TYPE_CHAT: |
| 1057 | return to_json_oaicompat_chat(); |
| 1058 | default: |
| 1059 | GGML_ASSERT(false && "Invalid oaicompat_type" ); |
| 1060 | } |
| 1061 | } |
| 1062 | |
| 1063 | json to_json_non_oaicompat() { |
| 1064 | // non-OAI-compat JSON |
| 1065 | json res = json { |
| 1066 | {"index" , index}, |
| 1067 | {"content" , content}, |
| 1068 | {"tokens" , tokens}, |
| 1069 | {"stop" , false}, |
| 1070 | {"id_slot" , id_slot}, |
| 1071 | {"tokens_predicted" , n_decoded}, |
| 1072 | {"tokens_evaluated" , n_prompt_tokens}, |
| 1073 | }; |
| 1074 | // populate the timings object when needed (usually for the last response or with timings_per_token enabled) |
| 1075 | if (timings.prompt_n > 0) { |
| 1076 | res.push_back(init: {"timings" , timings.to_json()}); |
| 1077 | } |
| 1078 | if (is_progress) { |
| 1079 | res.push_back(init: {"prompt_progress" , progress.to_json()}); |
| 1080 | } |
| 1081 | if (!prob_output.probs.empty()) { |
| 1082 | res["completion_probabilities" ] = completion_token_output::probs_vector_to_json(probs: {prob_output}, post_sampling_probs); |
| 1083 | } |
| 1084 | return res; |
| 1085 | } |
| 1086 | |
| 1087 | json to_json_oaicompat() { |
| 1088 | std::time_t t = std::time(timer: 0); |
| 1089 | json logprobs = json(nullptr); // OAI default to null |
| 1090 | if (prob_output.probs.size() > 0) { |
| 1091 | logprobs = json{ |
| 1092 | {"content" , completion_token_output::probs_vector_to_json(probs: {prob_output}, post_sampling_probs)}, |
| 1093 | }; |
| 1094 | } |
| 1095 | json res = json { |
| 1096 | {"choices" , json::array(init: { |
| 1097 | json{ |
| 1098 | {"text" , content}, |
| 1099 | {"index" , index}, |
| 1100 | {"logprobs" , logprobs}, |
| 1101 | {"finish_reason" , nullptr}, |
| 1102 | } |
| 1103 | })}, |
| 1104 | {"created" , t}, |
| 1105 | {"model" , oaicompat_model}, |
| 1106 | {"system_fingerprint" , build_info}, |
| 1107 | {"object" , "text_completion" }, |
| 1108 | {"id" , oaicompat_cmpl_id} |
| 1109 | }; |
| 1110 | |
| 1111 | // extra fields for debugging purposes |
| 1112 | if (verbose) { |
| 1113 | res["__verbose" ] = to_json_non_oaicompat(); |
| 1114 | } |
| 1115 | if (timings.prompt_n >= 0) { |
| 1116 | res.push_back(init: {"timings" , timings.to_json()}); |
| 1117 | } |
| 1118 | if (is_progress) { |
| 1119 | res.push_back(init: {"prompt_progress" , progress.to_json()}); |
| 1120 | } |
| 1121 | |
| 1122 | return res; |
| 1123 | } |
| 1124 | |
| 1125 | json to_json_oaicompat_chat() { |
| 1126 | bool first = n_decoded == 1; |
| 1127 | std::time_t t = std::time(timer: 0); |
| 1128 | json choices; |
| 1129 | |
| 1130 | std::vector<json> deltas; |
| 1131 | auto add_delta = [&](const json & delta) { |
| 1132 | deltas.push_back(x: { |
| 1133 | {"choices" , json::array(init: { |
| 1134 | json { |
| 1135 | {"finish_reason" , nullptr}, |
| 1136 | {"index" , 0}, |
| 1137 | {"delta" , delta}, |
| 1138 | }, |
| 1139 | })}, |
| 1140 | {"created" , t}, |
| 1141 | {"id" , oaicompat_cmpl_id}, |
| 1142 | {"model" , oaicompat_model}, |
| 1143 | {"system_fingerprint" , build_info}, |
| 1144 | {"object" , "chat.completion.chunk" }, |
| 1145 | }); |
| 1146 | }; |
| 1147 | // We have to send an initial update to conform to openai behavior |
| 1148 | if (first || is_progress) { |
| 1149 | add_delta({ |
| 1150 | {"role" , "assistant" }, |
| 1151 | {"content" , nullptr}, |
| 1152 | }); |
| 1153 | } |
| 1154 | |
| 1155 | for (const auto & diff : oaicompat_msg_diffs) { |
| 1156 | add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff)); |
| 1157 | } |
| 1158 | |
| 1159 | if (!deltas.empty()) { |
| 1160 | auto & last_json = deltas[deltas.size() - 1]; |
| 1161 | GGML_ASSERT(last_json.at("choices" ).size() >= 1); |
| 1162 | |
| 1163 | if (prob_output.probs.size() > 0) { |
| 1164 | last_json.at(key: "choices" ).at(idx: 0)["logprobs" ] = json { |
| 1165 | {"content" , completion_token_output::probs_vector_to_json(probs: {prob_output}, post_sampling_probs)}, |
| 1166 | }; |
| 1167 | } |
| 1168 | |
| 1169 | if (timings.prompt_n >= 0) { |
| 1170 | last_json.push_back(init: {"timings" , timings.to_json()}); |
| 1171 | } |
| 1172 | if (is_progress) { |
| 1173 | last_json.push_back(init: {"prompt_progress" , progress.to_json()}); |
| 1174 | } |
| 1175 | } |
| 1176 | |
| 1177 | return deltas; |
| 1178 | } |
| 1179 | }; |
| 1180 | |
| 1181 | struct server_task_result_embd : server_task_result { |
| 1182 | int index = 0; |
| 1183 | std::vector<std::vector<float>> embedding; |
| 1184 | |
| 1185 | int32_t n_tokens; |
| 1186 | |
| 1187 | // OAI-compat fields |
| 1188 | oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; |
| 1189 | |
| 1190 | virtual int get_index() override { |
| 1191 | return index; |
| 1192 | } |
| 1193 | |
| 1194 | virtual json to_json() override { |
| 1195 | return oaicompat == OAICOMPAT_TYPE_EMBEDDING |
| 1196 | ? to_json_oaicompat() |
| 1197 | : to_json_non_oaicompat(); |
| 1198 | } |
| 1199 | |
| 1200 | json to_json_non_oaicompat() { |
| 1201 | return json { |
| 1202 | {"index" , index}, |
| 1203 | {"embedding" , embedding}, |
| 1204 | }; |
| 1205 | } |
| 1206 | |
| 1207 | json to_json_oaicompat() { |
| 1208 | return json { |
| 1209 | {"index" , index}, |
| 1210 | {"embedding" , embedding[0]}, |
| 1211 | {"tokens_evaluated" , n_tokens}, |
| 1212 | }; |
| 1213 | } |
| 1214 | }; |
| 1215 | |
| 1216 | struct server_task_result_rerank : server_task_result { |
| 1217 | int index = 0; |
| 1218 | float score = -1e6; |
| 1219 | |
| 1220 | int32_t n_tokens; |
| 1221 | |
| 1222 | virtual int get_index() override { |
| 1223 | return index; |
| 1224 | } |
| 1225 | |
| 1226 | virtual json to_json() override { |
| 1227 | return json { |
| 1228 | {"index" , index}, |
| 1229 | {"score" , score}, |
| 1230 | {"tokens_evaluated" , n_tokens}, |
| 1231 | }; |
| 1232 | } |
| 1233 | }; |
| 1234 | |
| 1235 | // this function maybe used outside of server_task_result_error |
| 1236 | static json format_error_response(const std::string & message, const enum error_type type) { |
| 1237 | std::string type_str; |
| 1238 | int code = 500; |
| 1239 | switch (type) { |
| 1240 | case ERROR_TYPE_INVALID_REQUEST: |
| 1241 | type_str = "invalid_request_error" ; |
| 1242 | code = 400; |
| 1243 | break; |
| 1244 | case ERROR_TYPE_AUTHENTICATION: |
| 1245 | type_str = "authentication_error" ; |
| 1246 | code = 401; |
| 1247 | break; |
| 1248 | case ERROR_TYPE_NOT_FOUND: |
| 1249 | type_str = "not_found_error" ; |
| 1250 | code = 404; |
| 1251 | break; |
| 1252 | case ERROR_TYPE_SERVER: |
| 1253 | type_str = "server_error" ; |
| 1254 | code = 500; |
| 1255 | break; |
| 1256 | case ERROR_TYPE_PERMISSION: |
| 1257 | type_str = "permission_error" ; |
| 1258 | code = 403; |
| 1259 | break; |
| 1260 | case ERROR_TYPE_NOT_SUPPORTED: |
| 1261 | type_str = "not_supported_error" ; |
| 1262 | code = 501; |
| 1263 | break; |
| 1264 | case ERROR_TYPE_UNAVAILABLE: |
| 1265 | type_str = "unavailable_error" ; |
| 1266 | code = 503; |
| 1267 | break; |
| 1268 | case ERROR_TYPE_EXCEED_CONTEXT_SIZE: |
| 1269 | type_str = "exceed_context_size_error" ; |
| 1270 | code = 400; |
| 1271 | break; |
| 1272 | } |
| 1273 | return json { |
| 1274 | {"code" , code}, |
| 1275 | {"message" , message}, |
| 1276 | {"type" , type_str}, |
| 1277 | }; |
| 1278 | } |
| 1279 | |
| 1280 | struct server_task_result_error : server_task_result { |
| 1281 | int index = 0; |
| 1282 | error_type err_type = ERROR_TYPE_SERVER; |
| 1283 | std::string err_msg; |
| 1284 | |
| 1285 | // for ERROR_TYPE_EXCEED_CONTEXT_SIZE |
| 1286 | int32_t n_prompt_tokens = 0; |
| 1287 | int32_t n_ctx = 0; |
| 1288 | |
| 1289 | virtual bool is_error() override { |
| 1290 | return true; |
| 1291 | } |
| 1292 | |
| 1293 | virtual json to_json() override { |
| 1294 | json res = format_error_response(message: err_msg, type: err_type); |
| 1295 | if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { |
| 1296 | res["n_prompt_tokens" ] = n_prompt_tokens; |
| 1297 | res["n_ctx" ] = n_ctx; |
| 1298 | } |
| 1299 | return res; |
| 1300 | } |
| 1301 | }; |
| 1302 | |
| 1303 | struct server_task_result_metrics : server_task_result { |
| 1304 | int n_idle_slots; |
| 1305 | int n_processing_slots; |
| 1306 | int n_tasks_deferred; |
| 1307 | int64_t t_start; |
| 1308 | |
| 1309 | // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields |
| 1310 | uint64_t n_prompt_tokens_processed_total = 0; |
| 1311 | uint64_t t_prompt_processing_total = 0; |
| 1312 | uint64_t n_tokens_predicted_total = 0; |
| 1313 | uint64_t t_tokens_generation_total = 0; |
| 1314 | |
| 1315 | uint64_t n_tokens_max = 0; |
| 1316 | |
| 1317 | uint64_t n_prompt_tokens_processed = 0; |
| 1318 | uint64_t t_prompt_processing = 0; |
| 1319 | |
| 1320 | uint64_t n_tokens_predicted = 0; |
| 1321 | uint64_t t_tokens_generation = 0; |
| 1322 | |
| 1323 | uint64_t n_decode_total = 0; |
| 1324 | uint64_t n_busy_slots_total = 0; |
| 1325 | |
| 1326 | // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy |
| 1327 | // therefore, we use json to temporarily store the slot.to_json() result |
| 1328 | json slots_data = json::array(); |
| 1329 | |
| 1330 | virtual json to_json() override { |
| 1331 | return json { |
| 1332 | { "idle" , n_idle_slots }, |
| 1333 | { "processing" , n_processing_slots }, |
| 1334 | { "deferred" , n_tasks_deferred }, |
| 1335 | { "t_start" , t_start }, |
| 1336 | |
| 1337 | { "n_prompt_tokens_processed_total" , n_prompt_tokens_processed_total }, |
| 1338 | { "t_tokens_generation_total" , t_tokens_generation_total }, |
| 1339 | { "n_tokens_predicted_total" , n_tokens_predicted_total }, |
| 1340 | { "t_prompt_processing_total" , t_prompt_processing_total }, |
| 1341 | |
| 1342 | { "n_tokens_max" , n_tokens_max }, |
| 1343 | |
| 1344 | { "n_prompt_tokens_processed" , n_prompt_tokens_processed }, |
| 1345 | { "t_prompt_processing" , t_prompt_processing }, |
| 1346 | { "n_tokens_predicted" , n_tokens_predicted }, |
| 1347 | { "t_tokens_generation" , t_tokens_generation }, |
| 1348 | |
| 1349 | { "n_decode_total" , n_decode_total }, |
| 1350 | { "n_busy_slots_total" , n_busy_slots_total }, |
| 1351 | |
| 1352 | { "slots" , slots_data }, |
| 1353 | }; |
| 1354 | } |
| 1355 | }; |
| 1356 | |
| 1357 | struct server_task_result_slot_save_load : server_task_result { |
| 1358 | std::string filename; |
| 1359 | bool is_save; // true = save, false = load |
| 1360 | |
| 1361 | size_t n_tokens; |
| 1362 | size_t n_bytes; |
| 1363 | double t_ms; |
| 1364 | |
| 1365 | virtual json to_json() override { |
| 1366 | if (is_save) { |
| 1367 | return json { |
| 1368 | { "id_slot" , id_slot }, |
| 1369 | { "filename" , filename }, |
| 1370 | { "n_saved" , n_tokens }, |
| 1371 | { "n_written" , n_bytes }, |
| 1372 | { "timings" , { |
| 1373 | { "save_ms" , t_ms } |
| 1374 | }}, |
| 1375 | }; |
| 1376 | } |
| 1377 | |
| 1378 | return json { |
| 1379 | { "id_slot" , id_slot }, |
| 1380 | { "filename" , filename }, |
| 1381 | { "n_restored" , n_tokens }, |
| 1382 | { "n_read" , n_bytes }, |
| 1383 | { "timings" , { |
| 1384 | { "restore_ms" , t_ms } |
| 1385 | }}, |
| 1386 | }; |
| 1387 | } |
| 1388 | }; |
| 1389 | |
| 1390 | struct server_task_result_slot_erase : server_task_result { |
| 1391 | size_t n_erased; |
| 1392 | |
| 1393 | virtual json to_json() override { |
| 1394 | return json { |
| 1395 | { "id_slot" , id_slot }, |
| 1396 | { "n_erased" , n_erased }, |
| 1397 | }; |
| 1398 | } |
| 1399 | }; |
| 1400 | |
| 1401 | struct server_task_result_apply_lora : server_task_result { |
| 1402 | virtual json to_json() override { |
| 1403 | return json {{ "success" , true }}; |
| 1404 | } |
| 1405 | }; |
| 1406 | |
| 1407 | struct server_prompt_checkpoint { |
| 1408 | llama_pos pos_min; |
| 1409 | llama_pos pos_max; |
| 1410 | |
| 1411 | std::vector<uint8_t> data; |
| 1412 | |
| 1413 | size_t size() const { |
| 1414 | return data.size(); |
| 1415 | } |
| 1416 | }; |
| 1417 | |
| 1418 | struct server_prompt { |
| 1419 | server_tokens tokens; |
| 1420 | |
| 1421 | std::vector<uint8_t> data; |
| 1422 | |
| 1423 | std::list<server_prompt_checkpoint> checkpoints; |
| 1424 | |
| 1425 | size_t size() const { |
| 1426 | size_t res = data.size(); |
| 1427 | |
| 1428 | for (const auto & checkpoint : checkpoints) { |
| 1429 | res += checkpoint.size(); |
| 1430 | } |
| 1431 | |
| 1432 | return res; |
| 1433 | } |
| 1434 | |
| 1435 | int n_tokens() const { |
| 1436 | return tokens.size(); |
| 1437 | } |
| 1438 | }; |
| 1439 | |
| 1440 | struct server_prompt_cache { |
| 1441 | server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { |
| 1442 | this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); |
| 1443 | this->limit_tokens = limit_tokens; |
| 1444 | } |
| 1445 | |
| 1446 | std::list<server_prompt> states; |
| 1447 | |
| 1448 | // in bytes, 0 = no limit |
| 1449 | size_t limit_size = 0; |
| 1450 | |
| 1451 | // in tokens, 0 = no limit |
| 1452 | size_t limit_tokens = 0; |
| 1453 | |
| 1454 | size_t size() const { |
| 1455 | size_t res = 0; |
| 1456 | |
| 1457 | for (const auto & state : states) { |
| 1458 | res += state.size(); |
| 1459 | } |
| 1460 | |
| 1461 | return res; |
| 1462 | } |
| 1463 | |
| 1464 | size_t n_tokens() const { |
| 1465 | size_t res = 0; |
| 1466 | |
| 1467 | for (const auto & state : states) { |
| 1468 | res += state.n_tokens(); |
| 1469 | } |
| 1470 | |
| 1471 | return res; |
| 1472 | } |
| 1473 | |
| 1474 | server_prompt * alloc(const server_prompt & prompt, size_t state_size) { |
| 1475 | // first check if the current state is contained fully in the cache |
| 1476 | for (auto it = states.begin(); it != states.end(); ++it) { |
| 1477 | const int cur_lcp_len = it->tokens.get_common_prefix(b: prompt.tokens); |
| 1478 | |
| 1479 | if (cur_lcp_len == (int) prompt.tokens.size()) { |
| 1480 | SRV_WRN("%s" , " - prompt is already in the cache, skipping\n" ); |
| 1481 | return nullptr; |
| 1482 | } |
| 1483 | } |
| 1484 | |
| 1485 | // next, remove any cached prompts that are fully contained in the current prompt |
| 1486 | for (auto it = states.begin(); it != states.end();) { |
| 1487 | const int len = it->tokens.get_common_prefix(b: prompt.tokens); |
| 1488 | |
| 1489 | if (len == (int) it->tokens.size()) { |
| 1490 | SRV_WRN(" - removing obsolete cached prompt with length %d\n" , len); |
| 1491 | |
| 1492 | it = states.erase(position: it); |
| 1493 | } else { |
| 1494 | ++it; |
| 1495 | } |
| 1496 | } |
| 1497 | |
| 1498 | std::vector<uint8_t> state_data; |
| 1499 | |
| 1500 | // check if we can allocate enough memory for the new state |
| 1501 | try { |
| 1502 | state_data.resize(new_size: state_size); |
| 1503 | } catch (const std::bad_alloc & e) { |
| 1504 | SRV_ERR("failed to allocate memory for prompt cache state: %s\n" , e.what()); |
| 1505 | |
| 1506 | limit_size = std::max<size_t>(a: 1, b: 0.4*size()); |
| 1507 | |
| 1508 | SRV_WRN(" - cache size limit reduced to %.3f MiB\n" , limit_size / (1024.0 * 1024.0)); |
| 1509 | |
| 1510 | update(); |
| 1511 | |
| 1512 | return nullptr; |
| 1513 | } |
| 1514 | |
| 1515 | // TODO: for some reason we can't copy server_tokens, so we have to do this workaround |
| 1516 | auto & cur = states.emplace_back(); |
| 1517 | cur = { |
| 1518 | /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), |
| 1519 | /*.data =*/ std::move(state_data), |
| 1520 | /*.checkpoints =*/ prompt.checkpoints, |
| 1521 | }; |
| 1522 | |
| 1523 | return &cur; |
| 1524 | } |
| 1525 | |
| 1526 | bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { |
| 1527 | const int lcp_best = prompt.tokens.get_common_prefix(b: tokens_new); |
| 1528 | |
| 1529 | float f_keep_best = float(lcp_best) / prompt.tokens.size(); |
| 1530 | float sim_best = float(lcp_best) / tokens_new.size(); |
| 1531 | |
| 1532 | SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n" , f_keep_best, sim_best); |
| 1533 | |
| 1534 | auto it_best = states.end(); |
| 1535 | |
| 1536 | // find the most similar cached prompt, that would also preserve the most context |
| 1537 | for (auto it = states.begin(); it != states.end(); ++it) { |
| 1538 | const int lcp_cur = it->tokens.get_common_prefix(b: tokens_new); |
| 1539 | |
| 1540 | const float f_keep_cur = float(lcp_cur) / it->tokens.size(); |
| 1541 | const float sim_cur = float(lcp_cur) / tokens_new.size(); |
| 1542 | |
| 1543 | // don't trash large prompts |
| 1544 | if (f_keep_cur < 0.25f) { |
| 1545 | continue; |
| 1546 | } |
| 1547 | |
| 1548 | if (f_keep_best < f_keep_cur && sim_best < sim_cur) { |
| 1549 | f_keep_best = f_keep_cur; |
| 1550 | sim_best = sim_cur; |
| 1551 | |
| 1552 | it_best = it; |
| 1553 | } |
| 1554 | } |
| 1555 | |
| 1556 | if (it_best != states.end()) { |
| 1557 | SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n" , f_keep_best, sim_best); |
| 1558 | |
| 1559 | const size_t size = it_best->data.size(); |
| 1560 | const size_t n = llama_state_seq_set_data_ext(ctx, src: it_best->data.data(), size, dest_seq_id: id_slot, flags: 0); |
| 1561 | if (n != size) { |
| 1562 | SRV_WRN("failed to restore state with size %zu\n" , size); |
| 1563 | |
| 1564 | return false; |
| 1565 | } |
| 1566 | |
| 1567 | it_best->data.clear(); |
| 1568 | it_best->data.shrink_to_fit(); |
| 1569 | |
| 1570 | prompt = std::move(*it_best); |
| 1571 | |
| 1572 | states.erase(position: it_best); |
| 1573 | } |
| 1574 | |
| 1575 | return true; |
| 1576 | } |
| 1577 | |
| 1578 | void update() { |
| 1579 | if (limit_size > 0) { |
| 1580 | // always keep at least one state, regardless of the limits |
| 1581 | while (states.size() > 1 && size() > limit_size) { |
| 1582 | if (states.empty()) { |
| 1583 | break; |
| 1584 | } |
| 1585 | |
| 1586 | SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n" , states.front().size() / (1024.0 * 1024.0)); |
| 1587 | |
| 1588 | states.pop_front(); |
| 1589 | } |
| 1590 | } |
| 1591 | |
| 1592 | // average size per token |
| 1593 | const float size_per_token = std::max<float>(a: 1.0f, b: float(size()) / (std::max<size_t>(a: 1, b: n_tokens()))); |
| 1594 | |
| 1595 | // dynamically increase the token limit if it can fit in the memory limit |
| 1596 | const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(a: limit_tokens, b: limit_size/size_per_token) : limit_tokens; |
| 1597 | |
| 1598 | if (limit_tokens > 0) { |
| 1599 | while (states.size() > 1 && n_tokens() > limit_tokens_cur) { |
| 1600 | if (states.empty()) { |
| 1601 | break; |
| 1602 | } |
| 1603 | |
| 1604 | SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n" , |
| 1605 | limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); |
| 1606 | |
| 1607 | states.pop_front(); |
| 1608 | } |
| 1609 | } |
| 1610 | |
| 1611 | SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n" , |
| 1612 | states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); |
| 1613 | |
| 1614 | for (const auto & state : states) { |
| 1615 | SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n" , |
| 1616 | (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); |
| 1617 | } |
| 1618 | } |
| 1619 | }; |
| 1620 | |
| 1621 | struct server_slot { |
| 1622 | int id; |
| 1623 | |
| 1624 | llama_batch batch_spec = {}; |
| 1625 | |
| 1626 | // TODO: change to unique_ptrs for consistency: |
| 1627 | llama_context * ctx = nullptr; |
| 1628 | llama_context * ctx_dft = nullptr; |
| 1629 | |
| 1630 | // multimodal |
| 1631 | mtmd_context * mctx = nullptr; |
| 1632 | |
| 1633 | common_speculative * spec = nullptr; |
| 1634 | |
| 1635 | std::unique_ptr<const server_task> task; |
| 1636 | std::unique_ptr<const server_task> task_prev; // used for debugging |
| 1637 | |
| 1638 | // used to determine the slot that has been used the longest |
| 1639 | int64_t t_last_used = -1; |
| 1640 | |
| 1641 | // generation props |
| 1642 | int32_t n_ctx = 0; // context size per slot |
| 1643 | int32_t n_keep = 0; |
| 1644 | int32_t n_decoded = 0; |
| 1645 | int32_t n_remaining = -1; |
| 1646 | int32_t i_batch = -1; |
| 1647 | |
| 1648 | int32_t n_prompt_tokens_cache = 0; |
| 1649 | int32_t n_prompt_tokens_processed = 0; |
| 1650 | |
| 1651 | size_t last_nl_pos = 0; |
| 1652 | |
| 1653 | std::string generated_text; |
| 1654 | llama_tokens generated_tokens; |
| 1655 | |
| 1656 | common_chat_msg chat_msg; |
| 1657 | |
| 1658 | std::vector<completion_token_output> generated_token_probs; |
| 1659 | |
| 1660 | bool has_next_token = true; |
| 1661 | bool has_new_line = false; |
| 1662 | bool truncated = false; |
| 1663 | |
| 1664 | stop_type stop; |
| 1665 | |
| 1666 | std::string stopping_word; |
| 1667 | |
| 1668 | // state |
| 1669 | slot_state state = SLOT_STATE_IDLE; |
| 1670 | |
| 1671 | server_prompt prompt; |
| 1672 | |
| 1673 | void prompt_save(server_prompt_cache & prompt_cache) const { |
| 1674 | assert(prompt.data.size() == 0); |
| 1675 | |
| 1676 | const size_t cur_size = llama_state_seq_get_size_ext(ctx, seq_id: id, flags: 0); |
| 1677 | |
| 1678 | SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n" , |
| 1679 | (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); |
| 1680 | |
| 1681 | auto * cur = prompt_cache.alloc(prompt, state_size: cur_size); |
| 1682 | if (cur == nullptr) { |
| 1683 | return; |
| 1684 | } |
| 1685 | |
| 1686 | llama_state_seq_get_data_ext(ctx, dst: cur->data.data(), size: cur_size, seq_id: id, flags: 0); |
| 1687 | } |
| 1688 | |
| 1689 | void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { |
| 1690 | bool res = prompt_cache.load(prompt, tokens_new: tokens, ctx, id_slot: id); |
| 1691 | if (!res) { |
| 1692 | SLT_WRN(*this, "%s" , "failed to load prompt from cache\n" ); |
| 1693 | } |
| 1694 | } |
| 1695 | |
| 1696 | std::vector<common_adapter_lora_info> lora; |
| 1697 | int32_t alora_invocation_start = -1; |
| 1698 | |
| 1699 | // sampling |
| 1700 | json json_schema; |
| 1701 | |
| 1702 | struct common_sampler * smpl = nullptr; |
| 1703 | |
| 1704 | llama_token sampled; |
| 1705 | |
| 1706 | common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
| 1707 | std::vector<std::string> generated_tool_call_ids; |
| 1708 | |
| 1709 | // stats |
| 1710 | size_t n_sent_text = 0; // number of sent text character |
| 1711 | |
| 1712 | int64_t t_start_process_prompt; |
| 1713 | int64_t t_start_generation; |
| 1714 | |
| 1715 | double t_prompt_processing; // ms |
| 1716 | double t_token_generation; // ms |
| 1717 | |
| 1718 | std::function<void(int)> callback_on_release; |
| 1719 | |
| 1720 | // Speculative decoding stats |
| 1721 | int32_t n_draft_total = 0; // Total draft tokens generated |
| 1722 | int32_t n_draft_accepted = 0; // Draft tokens actually accepted |
| 1723 | |
| 1724 | void reset() { |
| 1725 | SLT_DBG(*this, "%s" , "\n" ); |
| 1726 | |
| 1727 | n_prompt_tokens_cache = 0; |
| 1728 | |
| 1729 | last_nl_pos = 0; |
| 1730 | generated_text = "" ; |
| 1731 | has_new_line = false; |
| 1732 | truncated = false; |
| 1733 | stop = STOP_TYPE_NONE; |
| 1734 | stopping_word = "" ; |
| 1735 | n_sent_text = 0; |
| 1736 | chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
| 1737 | |
| 1738 | generated_tokens.clear(); |
| 1739 | generated_token_probs.clear(); |
| 1740 | chat_msg = {}; |
| 1741 | json_schema = json(); |
| 1742 | generated_tool_call_ids.clear(); |
| 1743 | |
| 1744 | // clear speculative decoding stats |
| 1745 | n_draft_total = 0; |
| 1746 | n_draft_accepted = 0; |
| 1747 | |
| 1748 | task.reset(); |
| 1749 | task_prev.reset(); |
| 1750 | |
| 1751 | // clear alora start |
| 1752 | alora_invocation_start = -1; |
| 1753 | } |
| 1754 | |
| 1755 | bool need_embd() const { |
| 1756 | GGML_ASSERT(task); |
| 1757 | |
| 1758 | return server_task_type_need_embd(task_type: task->type); |
| 1759 | } |
| 1760 | |
| 1761 | bool need_logits() const { |
| 1762 | GGML_ASSERT(task); |
| 1763 | |
| 1764 | return server_task_type_need_logits(task_type: task->type); |
| 1765 | } |
| 1766 | |
| 1767 | // if the context does not have a memory module then all embeddings have to be computed within a single ubatch |
| 1768 | // also we cannot split if the pooling would require any past tokens |
| 1769 | bool can_split() const { |
| 1770 | return |
| 1771 | !need_embd() || |
| 1772 | (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); |
| 1773 | } |
| 1774 | |
| 1775 | bool can_batch_with(server_slot & other_slot) const { |
| 1776 | GGML_ASSERT(task); |
| 1777 | |
| 1778 | return task->type == other_slot.task->type && are_lora_equal(l1: lora, l2: other_slot.lora); |
| 1779 | } |
| 1780 | |
| 1781 | bool has_budget(const common_params & global_params) { |
| 1782 | GGML_ASSERT(task); |
| 1783 | |
| 1784 | if (task->params.n_predict == -1 && global_params.n_predict == -1) { |
| 1785 | return true; // limitless |
| 1786 | } |
| 1787 | |
| 1788 | n_remaining = -1; |
| 1789 | |
| 1790 | if (task->params.n_predict != -1) { |
| 1791 | n_remaining = task->params.n_predict - n_decoded; |
| 1792 | } else if (global_params.n_predict != -1) { |
| 1793 | n_remaining = global_params.n_predict - n_decoded; |
| 1794 | } |
| 1795 | |
| 1796 | return n_remaining > 0; // no budget |
| 1797 | } |
| 1798 | |
| 1799 | bool is_processing() const { |
| 1800 | return state != SLOT_STATE_IDLE; |
| 1801 | } |
| 1802 | |
| 1803 | bool can_speculate() const { |
| 1804 | return ctx_dft; |
| 1805 | } |
| 1806 | |
| 1807 | void add_token(const completion_token_output & token) { |
| 1808 | if (!is_processing()) { |
| 1809 | SLT_WRN(*this, "%s" , "slot is not processing\n" ); |
| 1810 | return; |
| 1811 | } |
| 1812 | generated_token_probs.push_back(x: token); |
| 1813 | } |
| 1814 | |
| 1815 | void release() { |
| 1816 | if (is_processing()) { |
| 1817 | GGML_ASSERT(task); |
| 1818 | |
| 1819 | SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n" , prompt.n_tokens(), truncated); |
| 1820 | |
| 1821 | t_last_used = ggml_time_us(); |
| 1822 | t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; |
| 1823 | state = SLOT_STATE_IDLE; |
| 1824 | |
| 1825 | task_prev = std::move(task); |
| 1826 | task.reset(); |
| 1827 | |
| 1828 | callback_on_release(id); |
| 1829 | } |
| 1830 | } |
| 1831 | |
| 1832 | result_timings get_timings() const { |
| 1833 | result_timings timings; |
| 1834 | timings.cache_n = n_prompt_tokens_cache; |
| 1835 | |
| 1836 | timings.prompt_n = n_prompt_tokens_processed; |
| 1837 | timings.prompt_ms = t_prompt_processing; |
| 1838 | timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; |
| 1839 | timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; |
| 1840 | |
| 1841 | timings.predicted_n = n_decoded; |
| 1842 | timings.predicted_ms = t_token_generation; |
| 1843 | timings.predicted_per_token_ms = t_token_generation / n_decoded; |
| 1844 | timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; |
| 1845 | |
| 1846 | // Add speculative metrics |
| 1847 | if (n_draft_total > 0) { |
| 1848 | timings.draft_n = n_draft_total; |
| 1849 | timings.draft_n_accepted = n_draft_accepted; |
| 1850 | } |
| 1851 | |
| 1852 | return timings; |
| 1853 | } |
| 1854 | |
| 1855 | const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) { |
| 1856 | GGML_ASSERT(task); |
| 1857 | |
| 1858 | auto previous_msg = chat_msg; |
| 1859 | SRV_DBG("Parsing chat message: %s\n" , generated_text.c_str()); |
| 1860 | auto new_msg = common_chat_parse( |
| 1861 | input: generated_text, |
| 1862 | /* is_partial= */ stop != STOP_TYPE_EOS, |
| 1863 | syntax: task->params.oaicompat_chat_syntax); |
| 1864 | if (!new_msg.empty()) { |
| 1865 | new_msg.set_tool_call_ids(ids_cache&: generated_tool_call_ids, gen_tool_call_id); |
| 1866 | chat_msg = new_msg; |
| 1867 | diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg: new_msg.empty() ? previous_msg : new_msg); |
| 1868 | } |
| 1869 | return chat_msg; |
| 1870 | } |
| 1871 | |
| 1872 | size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { |
| 1873 | GGML_ASSERT(task); |
| 1874 | |
| 1875 | size_t stop_pos = std::string::npos; |
| 1876 | |
| 1877 | for (const std::string & word : task->params.antiprompt) { |
| 1878 | size_t pos; |
| 1879 | |
| 1880 | if (is_full_stop) { |
| 1881 | const size_t tmp = word.size() + last_token_size; |
| 1882 | const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; |
| 1883 | |
| 1884 | pos = text.find(str: word, pos: from_pos); |
| 1885 | } else { |
| 1886 | // otherwise, partial stop |
| 1887 | pos = string_find_partial_stop(str: text, stop: word); |
| 1888 | } |
| 1889 | |
| 1890 | if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { |
| 1891 | if (is_full_stop) { |
| 1892 | stop = STOP_TYPE_WORD; |
| 1893 | stopping_word = word; |
| 1894 | has_next_token = false; |
| 1895 | } |
| 1896 | stop_pos = pos; |
| 1897 | } |
| 1898 | } |
| 1899 | |
| 1900 | return stop_pos; |
| 1901 | } |
| 1902 | |
| 1903 | void print_timings() const { |
| 1904 | const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; |
| 1905 | const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; |
| 1906 | |
| 1907 | const double t_gen = t_token_generation / n_decoded; |
| 1908 | const double n_gen_second = 1e3 / t_token_generation * n_decoded; |
| 1909 | |
| 1910 | SLT_INF(*this, |
| 1911 | "\n" |
| 1912 | "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" |
| 1913 | " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" |
| 1914 | " total time = %10.2f ms / %5d tokens\n" , |
| 1915 | t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, |
| 1916 | t_token_generation, n_decoded, t_gen, n_gen_second, |
| 1917 | t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); |
| 1918 | |
| 1919 | if (n_draft_total > 0) { |
| 1920 | const float draft_ratio = (float) n_draft_accepted / n_draft_total; |
| 1921 | SLT_INF(*this, |
| 1922 | "\n" |
| 1923 | "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n" , |
| 1924 | draft_ratio, n_draft_accepted, n_draft_total |
| 1925 | ); |
| 1926 | } |
| 1927 | } |
| 1928 | |
| 1929 | json to_json(bool only_metrics = false) const { |
| 1930 | json res; |
| 1931 | |
| 1932 | res = { |
| 1933 | {"id" , id}, |
| 1934 | {"n_ctx" , n_ctx}, |
| 1935 | {"speculative" , can_speculate()}, |
| 1936 | {"is_processing" , is_processing()}, |
| 1937 | }; |
| 1938 | |
| 1939 | const auto & ptask = task ? task : task_prev; |
| 1940 | |
| 1941 | if (ptask) { |
| 1942 | res["id_task" ] = ptask->id; |
| 1943 | res["params" ] = ptask->params.to_json(only_metrics); |
| 1944 | res["next_token" ] = { |
| 1945 | { |
| 1946 | {"has_next_token" , has_next_token}, |
| 1947 | {"has_new_line" , has_new_line}, |
| 1948 | {"n_remain" , n_remaining}, |
| 1949 | {"n_decoded" , n_decoded}, |
| 1950 | } |
| 1951 | }; |
| 1952 | |
| 1953 | if (!only_metrics) { |
| 1954 | res["prompt" ] = ptask->tokens.detokenize(ctx, special: true); |
| 1955 | res["generated" ] = generated_text; |
| 1956 | } |
| 1957 | } |
| 1958 | |
| 1959 | return res; |
| 1960 | } |
| 1961 | }; |
| 1962 | |
| 1963 | struct server_metrics { |
| 1964 | int64_t t_start = 0; |
| 1965 | |
| 1966 | uint64_t n_prompt_tokens_processed_total = 0; |
| 1967 | uint64_t t_prompt_processing_total = 0; |
| 1968 | uint64_t n_tokens_predicted_total = 0; |
| 1969 | uint64_t t_tokens_generation_total = 0; |
| 1970 | |
| 1971 | uint64_t n_tokens_max = 0; |
| 1972 | |
| 1973 | uint64_t n_prompt_tokens_processed = 0; |
| 1974 | uint64_t t_prompt_processing = 0; |
| 1975 | |
| 1976 | uint64_t n_tokens_predicted = 0; |
| 1977 | uint64_t t_tokens_generation = 0; |
| 1978 | |
| 1979 | uint64_t n_decode_total = 0; |
| 1980 | uint64_t n_busy_slots_total = 0; |
| 1981 | |
| 1982 | void init() { |
| 1983 | t_start = ggml_time_us(); |
| 1984 | } |
| 1985 | |
| 1986 | void on_prompt_eval(const server_slot & slot) { |
| 1987 | n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; |
| 1988 | n_prompt_tokens_processed += slot.n_prompt_tokens_processed; |
| 1989 | t_prompt_processing += slot.t_prompt_processing; |
| 1990 | t_prompt_processing_total += slot.t_prompt_processing; |
| 1991 | |
| 1992 | n_tokens_max = std::max(a: n_tokens_max, b: (uint64_t) slot.prompt.n_tokens()); |
| 1993 | } |
| 1994 | |
| 1995 | void on_prediction(const server_slot & slot) { |
| 1996 | n_tokens_predicted_total += slot.n_decoded; |
| 1997 | n_tokens_predicted += slot.n_decoded; |
| 1998 | t_tokens_generation += slot.t_token_generation; |
| 1999 | t_tokens_generation_total += slot.t_token_generation; |
| 2000 | } |
| 2001 | |
| 2002 | void on_decoded(const std::vector<server_slot> & slots) { |
| 2003 | n_decode_total++; |
| 2004 | for (const auto & slot : slots) { |
| 2005 | if (slot.is_processing()) { |
| 2006 | n_busy_slots_total++; |
| 2007 | } |
| 2008 | n_tokens_max = std::max(a: n_tokens_max, b: (uint64_t) slot.prompt.n_tokens()); |
| 2009 | } |
| 2010 | } |
| 2011 | |
| 2012 | void reset_bucket() { |
| 2013 | n_prompt_tokens_processed = 0; |
| 2014 | t_prompt_processing = 0; |
| 2015 | n_tokens_predicted = 0; |
| 2016 | t_tokens_generation = 0; |
| 2017 | } |
| 2018 | }; |
| 2019 | |
| 2020 | struct server_queue { |
| 2021 | int id = 0; |
| 2022 | bool running; |
| 2023 | |
| 2024 | // queues |
| 2025 | std::deque<server_task> queue_tasks; |
| 2026 | std::deque<server_task> queue_tasks_deferred; |
| 2027 | |
| 2028 | std::mutex mutex_tasks; |
| 2029 | std::condition_variable condition_tasks; |
| 2030 | |
| 2031 | // callback functions |
| 2032 | std::function<void(server_task &&)> callback_new_task; |
| 2033 | std::function<void(void)> callback_update_slots; |
| 2034 | |
| 2035 | // Add a new task to the end of the queue |
| 2036 | int post(server_task && task, bool front = false) { |
| 2037 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2038 | GGML_ASSERT(task.id != -1); |
| 2039 | // if this is cancel task make sure to clean up pending tasks |
| 2040 | if (task.type == SERVER_TASK_TYPE_CANCEL) { |
| 2041 | cleanup_pending_task(id_target: task.id_target); |
| 2042 | } |
| 2043 | const int task_id = task.id; |
| 2044 | QUE_DBG("new task, id = %d, front = %d\n" , task_id, front); |
| 2045 | if (front) { |
| 2046 | queue_tasks.push_front(x: std::move(task)); |
| 2047 | } else { |
| 2048 | queue_tasks.push_back(x: std::move(task)); |
| 2049 | } |
| 2050 | condition_tasks.notify_one(); |
| 2051 | return task_id; |
| 2052 | } |
| 2053 | |
| 2054 | // multi-task version of post() |
| 2055 | int post(std::vector<server_task> && tasks, bool front = false) { |
| 2056 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2057 | for (auto & task : tasks) { |
| 2058 | if (task.id == -1) { |
| 2059 | task.id = id++; |
| 2060 | } |
| 2061 | // if this is cancel task make sure to clean up pending tasks |
| 2062 | if (task.type == SERVER_TASK_TYPE_CANCEL) { |
| 2063 | cleanup_pending_task(id_target: task.id_target); |
| 2064 | } |
| 2065 | QUE_DBG("new task, id = %d/%d, front = %d\n" , task.id, (int) tasks.size(), front); |
| 2066 | if (front) { |
| 2067 | queue_tasks.push_front(x: std::move(task)); |
| 2068 | } else { |
| 2069 | queue_tasks.push_back(x: std::move(task)); |
| 2070 | } |
| 2071 | } |
| 2072 | condition_tasks.notify_one(); |
| 2073 | return 0; |
| 2074 | } |
| 2075 | |
| 2076 | // Add a new task, but defer until one slot is available |
| 2077 | void defer(server_task && task) { |
| 2078 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2079 | QUE_DBG("defer task, id = %d\n" , task.id); |
| 2080 | queue_tasks_deferred.push_back(x: std::move(task)); |
| 2081 | condition_tasks.notify_one(); |
| 2082 | } |
| 2083 | |
| 2084 | // Get the next id for creating a new task |
| 2085 | int get_new_id() { |
| 2086 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2087 | int new_id = id++; |
| 2088 | return new_id; |
| 2089 | } |
| 2090 | |
| 2091 | // Register function to process a new task |
| 2092 | void on_new_task(std::function<void(server_task &&)> callback) { |
| 2093 | callback_new_task = std::move(callback); |
| 2094 | } |
| 2095 | |
| 2096 | // Register the function to be called when all slots data is ready to be processed |
| 2097 | void on_update_slots(std::function<void(void)> callback) { |
| 2098 | callback_update_slots = std::move(callback); |
| 2099 | } |
| 2100 | |
| 2101 | // Call when the state of one slot is changed, it will move one task from deferred to main queue |
| 2102 | void pop_deferred_task() { |
| 2103 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2104 | if (!queue_tasks_deferred.empty()) { |
| 2105 | queue_tasks.emplace_front(args: std::move(queue_tasks_deferred.front())); |
| 2106 | queue_tasks_deferred.pop_front(); |
| 2107 | } |
| 2108 | condition_tasks.notify_one(); |
| 2109 | } |
| 2110 | |
| 2111 | // end the start_loop routine |
| 2112 | void terminate() { |
| 2113 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2114 | running = false; |
| 2115 | condition_tasks.notify_all(); |
| 2116 | } |
| 2117 | |
| 2118 | /** |
| 2119 | * Main loop consists of these steps: |
| 2120 | * - Wait until a new task arrives |
| 2121 | * - Process the task (i.e. maybe copy data into slot) |
| 2122 | * - Check if multitask is finished |
| 2123 | * - Update all slots |
| 2124 | */ |
| 2125 | void start_loop() { |
| 2126 | running = true; |
| 2127 | |
| 2128 | while (true) { |
| 2129 | QUE_DBG("%s" , "processing new tasks\n" ); |
| 2130 | |
| 2131 | while (true) { |
| 2132 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2133 | if (!running) { |
| 2134 | QUE_DBG("%s" , "terminate\n" ); |
| 2135 | return; |
| 2136 | } |
| 2137 | if (queue_tasks.empty()) { |
| 2138 | lock.unlock(); |
| 2139 | break; |
| 2140 | } |
| 2141 | server_task task = std::move(queue_tasks.front()); |
| 2142 | queue_tasks.pop_front(); |
| 2143 | lock.unlock(); |
| 2144 | |
| 2145 | QUE_DBG("processing task, id = %d\n" , task.id); |
| 2146 | callback_new_task(std::move(task)); |
| 2147 | } |
| 2148 | |
| 2149 | // all tasks in the current loop is processed, slots data is now ready |
| 2150 | QUE_DBG("%s" , "update slots\n" ); |
| 2151 | |
| 2152 | callback_update_slots(); |
| 2153 | |
| 2154 | QUE_DBG("%s" , "waiting for new tasks\n" ); |
| 2155 | { |
| 2156 | std::unique_lock<std::mutex> lock(mutex_tasks); |
| 2157 | if (!running) { |
| 2158 | QUE_DBG("%s" , "terminate\n" ); |
| 2159 | return; |
| 2160 | } |
| 2161 | if (queue_tasks.empty()) { |
| 2162 | condition_tasks.wait(lock&: lock, p: [&]{ |
| 2163 | return (!queue_tasks.empty() || !running); |
| 2164 | }); |
| 2165 | } |
| 2166 | } |
| 2167 | } |
| 2168 | } |
| 2169 | |
| 2170 | private: |
| 2171 | void cleanup_pending_task(int id_target) { |
| 2172 | // no need lock because this is called exclusively by post() |
| 2173 | auto rm_func = [id_target](const server_task & task) { |
| 2174 | return task.id == id_target; |
| 2175 | }; |
| 2176 | queue_tasks.erase( |
| 2177 | first: std::remove_if(first: queue_tasks.begin(), last: queue_tasks.end(), pred: rm_func), |
| 2178 | last: queue_tasks.end()); |
| 2179 | queue_tasks_deferred.erase( |
| 2180 | first: std::remove_if(first: queue_tasks_deferred.begin(), last: queue_tasks_deferred.end(), pred: rm_func), |
| 2181 | last: queue_tasks_deferred.end()); |
| 2182 | } |
| 2183 | }; |
| 2184 | |
| 2185 | struct server_response { |
| 2186 | bool running = true; |
| 2187 | |
| 2188 | // for keeping track of all tasks waiting for the result |
| 2189 | std::unordered_set<int> waiting_task_ids; |
| 2190 | |
| 2191 | // the main result queue (using ptr for polymorphism) |
| 2192 | std::vector<server_task_result_ptr> queue_results; |
| 2193 | |
| 2194 | std::mutex mutex_results; |
| 2195 | std::condition_variable condition_results; |
| 2196 | |
| 2197 | // add the id_task to the list of tasks waiting for response |
| 2198 | void add_waiting_task_id(int id_task) { |
| 2199 | SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n" , id_task, (int) waiting_task_ids.size()); |
| 2200 | |
| 2201 | std::unique_lock<std::mutex> lock(mutex_results); |
| 2202 | waiting_task_ids.insert(x: id_task); |
| 2203 | } |
| 2204 | |
| 2205 | void add_waiting_tasks(const std::vector<server_task> & tasks) { |
| 2206 | std::unique_lock<std::mutex> lock(mutex_results); |
| 2207 | |
| 2208 | for (const auto & task : tasks) { |
| 2209 | SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n" , task.id, (int) waiting_task_ids.size()); |
| 2210 | waiting_task_ids.insert(x: task.id); |
| 2211 | } |
| 2212 | } |
| 2213 | |
| 2214 | // when the request is finished, we can remove task associated with it |
| 2215 | void remove_waiting_task_id(int id_task) { |
| 2216 | SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n" , id_task, (int) waiting_task_ids.size()); |
| 2217 | |
| 2218 | std::unique_lock<std::mutex> lock(mutex_results); |
| 2219 | waiting_task_ids.erase(x: id_task); |
| 2220 | // make sure to clean up all pending results |
| 2221 | queue_results.erase( |
| 2222 | first: std::remove_if(first: queue_results.begin(), last: queue_results.end(), pred: [id_task](const server_task_result_ptr & res) { |
| 2223 | return res->id == id_task; |
| 2224 | }), |
| 2225 | last: queue_results.end()); |
| 2226 | } |
| 2227 | |
| 2228 | void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) { |
| 2229 | std::unique_lock<std::mutex> lock(mutex_results); |
| 2230 | |
| 2231 | for (const auto & id_task : id_tasks) { |
| 2232 | SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n" , id_task, (int) waiting_task_ids.size()); |
| 2233 | waiting_task_ids.erase(x: id_task); |
| 2234 | } |
| 2235 | } |
| 2236 | |
| 2237 | // This function blocks the thread until there is a response for one of the id_tasks |
| 2238 | server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) { |
| 2239 | while (true) { |
| 2240 | std::unique_lock<std::mutex> lock(mutex_results); |
| 2241 | condition_results.wait(lock&: lock, p: [&]{ |
| 2242 | if (!running) { |
| 2243 | SRV_DBG("%s : queue result stop\n" , __func__); |
| 2244 | std::terminate(); // we cannot return here since the caller is HTTP code |
| 2245 | } |
| 2246 | return !queue_results.empty(); |
| 2247 | }); |
| 2248 | |
| 2249 | for (size_t i = 0; i < queue_results.size(); i++) { |
| 2250 | if (id_tasks.find(x: queue_results[i]->id) != id_tasks.end()) { |
| 2251 | server_task_result_ptr res = std::move(queue_results[i]); |
| 2252 | queue_results.erase(position: queue_results.begin() + i); |
| 2253 | return res; |
| 2254 | } |
| 2255 | } |
| 2256 | } |
| 2257 | |
| 2258 | // should never reach here |
| 2259 | } |
| 2260 | |
| 2261 | // same as recv(), but have timeout in seconds |
| 2262 | // if timeout is reached, nullptr is returned |
| 2263 | server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) { |
| 2264 | while (true) { |
| 2265 | std::unique_lock<std::mutex> lock(mutex_results); |
| 2266 | |
| 2267 | for (int i = 0; i < (int) queue_results.size(); i++) { |
| 2268 | if (id_tasks.find(x: queue_results[i]->id) != id_tasks.end()) { |
| 2269 | server_task_result_ptr res = std::move(queue_results[i]); |
| 2270 | queue_results.erase(position: queue_results.begin() + i); |
| 2271 | return res; |
| 2272 | } |
| 2273 | } |
| 2274 | |
| 2275 | std::cv_status cr_res = condition_results.wait_for(lock&: lock, rtime: std::chrono::seconds(timeout)); |
| 2276 | if (!running) { |
| 2277 | SRV_DBG("%s : queue result stop\n" , __func__); |
| 2278 | std::terminate(); // we cannot return here since the caller is HTTP code |
| 2279 | } |
| 2280 | if (cr_res == std::cv_status::timeout) { |
| 2281 | return nullptr; |
| 2282 | } |
| 2283 | } |
| 2284 | |
| 2285 | // should never reach here |
| 2286 | } |
| 2287 | |
| 2288 | // single-task version of recv() |
| 2289 | server_task_result_ptr recv(int id_task) { |
| 2290 | std::unordered_set<int> id_tasks = {id_task}; |
| 2291 | return recv(id_tasks); |
| 2292 | } |
| 2293 | |
| 2294 | // Send a new result to a waiting id_task |
| 2295 | void send(server_task_result_ptr && result) { |
| 2296 | SRV_DBG("sending result for task id = %d\n" , result->id); |
| 2297 | |
| 2298 | std::unique_lock<std::mutex> lock(mutex_results); |
| 2299 | for (const auto & id_task : waiting_task_ids) { |
| 2300 | if (result->id == id_task) { |
| 2301 | SRV_DBG("task id = %d pushed to result queue\n" , result->id); |
| 2302 | |
| 2303 | queue_results.emplace_back(args: std::move(result)); |
| 2304 | condition_results.notify_all(); |
| 2305 | return; |
| 2306 | } |
| 2307 | } |
| 2308 | } |
| 2309 | |
| 2310 | // terminate the waiting loop |
| 2311 | void terminate() { |
| 2312 | running = false; |
| 2313 | condition_results.notify_all(); |
| 2314 | } |
| 2315 | }; |
| 2316 | |
| 2317 | struct server_context { |
| 2318 | common_params params_base; |
| 2319 | |
| 2320 | // note: keep these alive - they determine the lifetime of the model, context, etc. |
| 2321 | common_init_result llama_init; |
| 2322 | common_init_result llama_init_dft; |
| 2323 | |
| 2324 | llama_model * model = nullptr; |
| 2325 | llama_context * ctx = nullptr; |
| 2326 | |
| 2327 | // multimodal |
| 2328 | mtmd_context * mctx = nullptr; |
| 2329 | |
| 2330 | const llama_vocab * vocab = nullptr; |
| 2331 | bool vocab_dft_compatible = true; |
| 2332 | |
| 2333 | llama_model * model_dft = nullptr; |
| 2334 | |
| 2335 | llama_context_params cparams_dft; |
| 2336 | |
| 2337 | llama_batch batch {}; |
| 2338 | |
| 2339 | bool clean_kv_cache = true; |
| 2340 | bool add_bos_token = true; |
| 2341 | |
| 2342 | int32_t n_ctx; // total context for all clients / slots |
| 2343 | |
| 2344 | // slots / clients |
| 2345 | std::vector<server_slot> slots; |
| 2346 | |
| 2347 | int slots_debug = 0; |
| 2348 | |
| 2349 | server_queue queue_tasks; |
| 2350 | server_response queue_results; |
| 2351 | |
| 2352 | std::unique_ptr<server_prompt_cache> prompt_cache; |
| 2353 | |
| 2354 | server_metrics metrics; |
| 2355 | |
| 2356 | // Necessary similarity of prompt for slot selection |
| 2357 | float slot_prompt_similarity = 0.0f; |
| 2358 | |
| 2359 | common_chat_templates_ptr chat_templates; |
| 2360 | oaicompat_parser_options oai_parser_opt; |
| 2361 | |
| 2362 | ~server_context() { |
| 2363 | mtmd_free(ctx: mctx); |
| 2364 | |
| 2365 | // Clear any sampling context |
| 2366 | for (server_slot & slot : slots) { |
| 2367 | common_sampler_free(gsmpl: slot.smpl); |
| 2368 | slot.smpl = nullptr; |
| 2369 | |
| 2370 | llama_free(ctx: slot.ctx_dft); |
| 2371 | slot.ctx_dft = nullptr; |
| 2372 | |
| 2373 | common_speculative_free(spec: slot.spec); |
| 2374 | slot.spec = nullptr; |
| 2375 | |
| 2376 | llama_batch_free(batch: slot.batch_spec); |
| 2377 | } |
| 2378 | |
| 2379 | llama_batch_free(batch); |
| 2380 | } |
| 2381 | |
| 2382 | bool load_model(const common_params & params) { |
| 2383 | SRV_INF("loading model '%s'\n" , params.model.path.c_str()); |
| 2384 | |
| 2385 | params_base = params; |
| 2386 | |
| 2387 | llama_init = common_init_from_params(params&: params_base); |
| 2388 | |
| 2389 | model = llama_init.model.get(); |
| 2390 | ctx = llama_init.context.get(); |
| 2391 | |
| 2392 | if (model == nullptr) { |
| 2393 | SRV_ERR("failed to load model, '%s'\n" , params_base.model.path.c_str()); |
| 2394 | return false; |
| 2395 | } |
| 2396 | |
| 2397 | vocab = llama_model_get_vocab(model); |
| 2398 | |
| 2399 | n_ctx = llama_n_ctx(ctx); |
| 2400 | |
| 2401 | add_bos_token = llama_vocab_get_add_bos(vocab); |
| 2402 | |
| 2403 | if (params_base.has_speculative()) { |
| 2404 | SRV_INF("loading draft model '%s'\n" , params_base.speculative.model.path.c_str()); |
| 2405 | |
| 2406 | auto params_dft = params_base; |
| 2407 | |
| 2408 | params_dft.devices = params_base.speculative.devices; |
| 2409 | params_dft.model = params_base.speculative.model; |
| 2410 | params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; |
| 2411 | params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; |
| 2412 | params_dft.n_parallel = 1; |
| 2413 | params_dft.cache_type_k = params_base.speculative.cache_type_k; |
| 2414 | params_dft.cache_type_v = params_base.speculative.cache_type_v; |
| 2415 | |
| 2416 | params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; |
| 2417 | params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; |
| 2418 | params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; |
| 2419 | |
| 2420 | llama_init_dft = common_init_from_params(params&: params_dft); |
| 2421 | |
| 2422 | model_dft = llama_init_dft.model.get(); |
| 2423 | |
| 2424 | if (model_dft == nullptr) { |
| 2425 | SRV_ERR("failed to load draft model, '%s'\n" , params_base.speculative.model.path.c_str()); |
| 2426 | return false; |
| 2427 | } |
| 2428 | |
| 2429 | vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt: ctx, ctx_dft: llama_init_dft.context.get()); |
| 2430 | if (!vocab_dft_compatible) { |
| 2431 | SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n" , params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); |
| 2432 | } |
| 2433 | |
| 2434 | const int n_ctx_dft = llama_n_ctx(ctx: llama_init_dft.context.get()); |
| 2435 | |
| 2436 | cparams_dft = common_context_params_to_llama(params: params_dft); |
| 2437 | cparams_dft.n_batch = n_ctx_dft; |
| 2438 | |
| 2439 | // the context is not needed - we will create one for each slot |
| 2440 | llama_init_dft.context.reset(); |
| 2441 | } |
| 2442 | |
| 2443 | chat_templates = common_chat_templates_init(model, chat_template_override: params_base.chat_template); |
| 2444 | try { |
| 2445 | common_chat_format_example(tmpls: chat_templates.get(), use_jinja: params.use_jinja, chat_template_kwargs: params.default_template_kwargs); |
| 2446 | } catch (const std::exception & e) { |
| 2447 | SRV_WRN("%s: Chat template parsing error: %s\n" , __func__, e.what()); |
| 2448 | SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n" , __func__); |
| 2449 | chat_templates = common_chat_templates_init(model, chat_template_override: "chatml" ); |
| 2450 | } |
| 2451 | |
| 2452 | std::string & mmproj_path = params_base.mmproj.path; |
| 2453 | if (!mmproj_path.empty()) { |
| 2454 | mtmd_context_params mparams = mtmd_context_params_default(); |
| 2455 | mparams.use_gpu = params_base.mmproj_use_gpu; |
| 2456 | mparams.print_timings = false; |
| 2457 | mparams.n_threads = params_base.cpuparams.n_threads; |
| 2458 | mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; |
| 2459 | mparams.flash_attn_type = params_base.flash_attn_type; |
| 2460 | mparams.image_min_tokens = params_base.image_min_tokens; |
| 2461 | mparams.image_max_tokens = params_base.image_max_tokens; |
| 2462 | mctx = mtmd_init_from_file(mmproj_fname: mmproj_path.c_str(), text_model: model, ctx_params: mparams); |
| 2463 | if (mctx == nullptr) { |
| 2464 | SRV_ERR("failed to load multimodal model, '%s'\n" , mmproj_path.c_str()); |
| 2465 | return false; |
| 2466 | } |
| 2467 | SRV_INF("loaded multimodal model, '%s'\n" , mmproj_path.c_str()); |
| 2468 | |
| 2469 | if (params_base.ctx_shift) { |
| 2470 | params_base.ctx_shift = false; |
| 2471 | SRV_WRN("%s\n" , "ctx_shift is not supported by multimodal, it will be disabled" ); |
| 2472 | } |
| 2473 | |
| 2474 | if (params_base.n_cache_reuse) { |
| 2475 | params_base.n_cache_reuse = 0; |
| 2476 | SRV_WRN("%s\n" , "cache_reuse is not supported by multimodal, it will be disabled" ); |
| 2477 | } |
| 2478 | |
| 2479 | if (params_base.has_speculative()) { |
| 2480 | SRV_ERR("%s\n" , "err: speculative decode is not supported by multimodal" ); |
| 2481 | return false; |
| 2482 | } |
| 2483 | } |
| 2484 | |
| 2485 | if (!llama_memory_can_shift(mem: llama_get_memory(ctx))) { |
| 2486 | if (params_base.ctx_shift) { |
| 2487 | params_base.ctx_shift = false; |
| 2488 | SRV_WRN("%s\n" , "ctx_shift is not supported by this context, it will be disabled" ); |
| 2489 | } |
| 2490 | |
| 2491 | if (params_base.n_cache_reuse) { |
| 2492 | params_base.n_cache_reuse = 0; |
| 2493 | SRV_WRN("%s\n" , "cache_reuse is not supported by this context, it will be disabled" ); |
| 2494 | } |
| 2495 | } |
| 2496 | |
| 2497 | return true; |
| 2498 | } |
| 2499 | |
| 2500 | void init() { |
| 2501 | SRV_INF("initializing slots, n_slots = %d\n" , params_base.n_parallel); |
| 2502 | |
| 2503 | const int n_ctx_train = llama_model_n_ctx_train(model); |
| 2504 | |
| 2505 | int n_ctx_slot = llama_n_ctx_seq(ctx); |
| 2506 | if (n_ctx_slot > n_ctx_train) { |
| 2507 | SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n" , n_ctx_slot, n_ctx_train); |
| 2508 | n_ctx_slot = n_ctx_train; |
| 2509 | } |
| 2510 | |
| 2511 | for (int i = 0; i < params_base.n_parallel; i++) { |
| 2512 | server_slot slot; |
| 2513 | |
| 2514 | slot.id = i; |
| 2515 | slot.ctx = ctx; |
| 2516 | slot.n_ctx = n_ctx_slot; |
| 2517 | slot.mctx = mctx; |
| 2518 | slot.prompt.tokens.has_mtmd = mctx != nullptr; |
| 2519 | |
| 2520 | if (model_dft) { |
| 2521 | slot.batch_spec = llama_batch_init(n_tokens: params_base.speculative.n_max + 1, embd: 0, n_seq_max: 1); |
| 2522 | |
| 2523 | // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] |
| 2524 | slot.ctx_dft = llama_init_from_model(model: model_dft, params: cparams_dft); |
| 2525 | if (slot.ctx_dft == nullptr) { |
| 2526 | SRV_ERR("%s" , "failed to create draft context\n" ); |
| 2527 | return; |
| 2528 | } |
| 2529 | |
| 2530 | slot.spec = common_speculative_init(ctx_tgt: slot.ctx, ctx_dft: slot.ctx_dft); |
| 2531 | if (slot.spec == nullptr) { |
| 2532 | SRV_ERR("%s" , "failed to create speculator\n" ); |
| 2533 | return; |
| 2534 | } |
| 2535 | for (auto & pair : params_base.speculative.replacements) { |
| 2536 | common_speculative_add_replacement_tgt_dft(spec: slot.spec, source: pair.first.c_str(), dest: pair.second.c_str()); |
| 2537 | } |
| 2538 | } |
| 2539 | |
| 2540 | SLT_INF(slot, "new slot, n_ctx = %d\n" , slot.n_ctx); |
| 2541 | |
| 2542 | slot.callback_on_release = [this](int) { |
| 2543 | queue_tasks.pop_deferred_task(); |
| 2544 | }; |
| 2545 | |
| 2546 | slot.reset(); |
| 2547 | |
| 2548 | slots.push_back(x: std::move(slot)); |
| 2549 | } |
| 2550 | |
| 2551 | { |
| 2552 | const char * LLAMA_SERVER_SLOTS_DEBUG = getenv(name: "LLAMA_SERVER_SLOTS_DEBUG" ); |
| 2553 | slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(nptr: LLAMA_SERVER_SLOTS_DEBUG) : 0; |
| 2554 | |
| 2555 | if (slots_debug) { |
| 2556 | SRV_WRN("slots debug = %d\n" , slots_debug); |
| 2557 | } |
| 2558 | } |
| 2559 | |
| 2560 | // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens |
| 2561 | // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) |
| 2562 | { |
| 2563 | const int32_t n_batch = llama_n_batch(ctx); |
| 2564 | batch = llama_batch_init(n_tokens: std::max(a: n_batch, b: params_base.n_parallel), embd: 0, n_seq_max: 1); |
| 2565 | } |
| 2566 | |
| 2567 | metrics.init(); |
| 2568 | |
| 2569 | if (params_base.cache_ram_mib != 0) { |
| 2570 | if (params_base.cache_ram_mib < 0) { |
| 2571 | SRV_WRN("prompt cache is enabled, size limit: %s\n" , "no limit" ); |
| 2572 | } else { |
| 2573 | SRV_WRN("prompt cache is enabled, size limit: %d MiB\n" , params_base.cache_ram_mib); |
| 2574 | } |
| 2575 | SRV_WRN("%s" , "use `--cache-ram 0` to disable the prompt cache\n" ); |
| 2576 | |
| 2577 | prompt_cache = std::make_unique<server_prompt_cache>(args&: params_base.cache_ram_mib, args&: n_ctx); |
| 2578 | } else { |
| 2579 | SRV_WRN("%s" , "prompt cache is disabled - use `--cache-ram N` to enable it\n" ); |
| 2580 | } |
| 2581 | SRV_WRN("%s" , "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n" ); |
| 2582 | |
| 2583 | // thinking is enabled if: |
| 2584 | // 1. It's not explicitly disabled (reasoning_budget == 0) |
| 2585 | // 2. The chat template supports it |
| 2586 | const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates: chat_templates.get()); |
| 2587 | SRV_INF("thinking = %d\n" , enable_thinking); |
| 2588 | |
| 2589 | oai_parser_opt = { |
| 2590 | /* use_jinja */ params_base.use_jinja, |
| 2591 | /* prefill_assistant */ params_base.prefill_assistant, |
| 2592 | /* reasoning_format */ params_base.reasoning_format, |
| 2593 | /* chat_template_kwargs */ params_base.default_template_kwargs, |
| 2594 | /* common_chat_templates */ .tmpls: chat_templates.get(), |
| 2595 | /* allow_image */ mctx ? mtmd_support_vision(ctx: mctx) : false, |
| 2596 | /* allow_audio */ mctx ? mtmd_support_audio (ctx: mctx) : false, |
| 2597 | /* enable_thinking */ enable_thinking, |
| 2598 | }; |
| 2599 | } |
| 2600 | |
| 2601 | server_slot * get_slot_by_id(int id) { |
| 2602 | for (server_slot & slot : slots) { |
| 2603 | if (slot.id == id) { |
| 2604 | return &slot; |
| 2605 | } |
| 2606 | } |
| 2607 | |
| 2608 | return nullptr; |
| 2609 | } |
| 2610 | |
| 2611 | server_slot * get_available_slot(const server_task & task) { |
| 2612 | server_slot * ret = nullptr; |
| 2613 | |
| 2614 | bool update_cache = false; |
| 2615 | |
| 2616 | // find the slot that has at least n% prompt similarity |
| 2617 | if (ret == nullptr && slot_prompt_similarity != 0.0f) { |
| 2618 | float sim_best = 0; |
| 2619 | |
| 2620 | for (server_slot & slot : slots) { |
| 2621 | // skip the slot if it is not available |
| 2622 | if (slot.is_processing()) { |
| 2623 | continue; |
| 2624 | } |
| 2625 | |
| 2626 | const auto & tokens = slot.prompt.tokens; |
| 2627 | |
| 2628 | // skip the slot if it does not contains cached tokens |
| 2629 | if (tokens.empty()) { |
| 2630 | continue; |
| 2631 | } |
| 2632 | |
| 2633 | // fraction of the Longest Common Prefix length with respect to the input prompt length |
| 2634 | const float sim_cur = float(tokens.get_common_prefix(b: task.tokens)) / task.tokens.size(); |
| 2635 | |
| 2636 | // select the current slot if the criteria match |
| 2637 | if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { |
| 2638 | sim_best = sim_cur; |
| 2639 | |
| 2640 | ret = &slot; |
| 2641 | } |
| 2642 | } |
| 2643 | |
| 2644 | if (ret != nullptr) { |
| 2645 | const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); |
| 2646 | |
| 2647 | SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n" , |
| 2648 | sim_best, slot_prompt_similarity, f_keep); |
| 2649 | |
| 2650 | // if we are about to lose a large portion of the existing context - save it in the prompt cache |
| 2651 | if (f_keep < 0.5f) { |
| 2652 | update_cache = true; |
| 2653 | } |
| 2654 | } |
| 2655 | } |
| 2656 | |
| 2657 | // find the slot that has been least recently used |
| 2658 | if (ret == nullptr) { |
| 2659 | int64_t t_last = -1; |
| 2660 | |
| 2661 | for (server_slot & slot : slots) { |
| 2662 | // skip the slot if it is not available |
| 2663 | if (slot.is_processing()) { |
| 2664 | continue; |
| 2665 | } |
| 2666 | |
| 2667 | // select the current slot if the criteria match |
| 2668 | if (!ret || slot.t_last_used <= t_last) { |
| 2669 | t_last = slot.t_last_used; |
| 2670 | ret = &slot; |
| 2671 | } |
| 2672 | } |
| 2673 | |
| 2674 | if (ret != nullptr) { |
| 2675 | SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n" , t_last); |
| 2676 | |
| 2677 | update_cache = true; |
| 2678 | } |
| 2679 | } |
| 2680 | |
| 2681 | if (ret) { |
| 2682 | const auto & tokens = ret->prompt.tokens; |
| 2683 | |
| 2684 | update_cache = update_cache && prompt_cache; |
| 2685 | |
| 2686 | // cache prompts only for completion tasks |
| 2687 | update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; |
| 2688 | |
| 2689 | // don't update the cache if the slot's context is empty |
| 2690 | update_cache = update_cache && tokens.size() > 0; |
| 2691 | |
| 2692 | // TODO: mtmd does not support prompt cache |
| 2693 | update_cache = update_cache && (ret->mctx == nullptr); |
| 2694 | |
| 2695 | if (update_cache) { |
| 2696 | SRV_WRN("%s" , "updating prompt cache\n" ); |
| 2697 | |
| 2698 | const int64_t t_start = ggml_time_us(); |
| 2699 | |
| 2700 | ret->prompt_save(prompt_cache&: *prompt_cache); |
| 2701 | ret->prompt_load(prompt_cache&: *prompt_cache, tokens: task.tokens); |
| 2702 | |
| 2703 | prompt_cache->update(); |
| 2704 | |
| 2705 | SRV_WRN("prompt cache update took %.2f ms\n" , (ggml_time_us() - t_start) / 1000.0); |
| 2706 | } |
| 2707 | } |
| 2708 | |
| 2709 | return ret; |
| 2710 | } |
| 2711 | |
| 2712 | // return true if at least one slot has been purged |
| 2713 | // TODO: improve logic |
| 2714 | // - smarter decision which slot to purge (LRU or longest prompt?) |
| 2715 | // - move slot to level 2 cache instead of removing? |
| 2716 | // - instead of purging, try to store and resume later? |
| 2717 | bool try_purge_idle_slots() { |
| 2718 | bool res = false; |
| 2719 | |
| 2720 | if (!params_base.kv_unified) { |
| 2721 | return res; |
| 2722 | } |
| 2723 | |
| 2724 | for (auto & slot : slots) { |
| 2725 | if (slot.is_processing()) { |
| 2726 | continue; |
| 2727 | } |
| 2728 | |
| 2729 | if (slot.prompt.n_tokens() > 0) { |
| 2730 | SRV_WRN("purging slot %d with %zu tokens\n" , slot.id, slot.prompt.tokens.size()); |
| 2731 | |
| 2732 | llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0: -1, p1: -1); |
| 2733 | slot.prompt.tokens.clear(); |
| 2734 | |
| 2735 | res = true; |
| 2736 | |
| 2737 | // purge slots one by one |
| 2738 | break; |
| 2739 | } |
| 2740 | } |
| 2741 | |
| 2742 | return res; |
| 2743 | } |
| 2744 | |
| 2745 | bool launch_slot_with_task(server_slot & slot, server_task && task) { |
| 2746 | slot.reset(); |
| 2747 | |
| 2748 | if (!are_lora_equal(l1: task.params.lora, l2: slot.lora)) { |
| 2749 | // if lora has changed, check to see if the cache should be cleared |
| 2750 | if (lora_should_clear_cache(current: slot.lora, next: task.params.lora)) { |
| 2751 | SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n" , slot.lora.size(), task.params.lora.size()); |
| 2752 | slot.prompt.tokens.clear(); |
| 2753 | } else { |
| 2754 | SLT_INF(slot, "keeping cache for alora. %zu target loras\n" , task.params.lora.size()); |
| 2755 | } |
| 2756 | slot.lora = task.params.lora; |
| 2757 | } |
| 2758 | |
| 2759 | // if using alora, make sure it's only a single one requested and active |
| 2760 | size_t alora_invocation_start = task.tokens.size(); |
| 2761 | if (lora_all_alora(loras: slot.lora)) { |
| 2762 | const auto & enabled_ids = lora_get_enabled_ids(loras: slot.lora); |
| 2763 | // TODO: This will error out if a user requests two aloras, but only |
| 2764 | // provides the activation string for one. We could, instead search |
| 2765 | // for all requested alora activation strings and then either keep |
| 2766 | // only the last one, or reject if multiple are found. |
| 2767 | if (enabled_ids.size() != 1) { |
| 2768 | send_error(task, error: "Cannot run multiple aLoRAs in a single request" , type: ERROR_TYPE_INVALID_REQUEST); |
| 2769 | return false; |
| 2770 | } |
| 2771 | const auto & lora = slot.lora[enabled_ids[0]].ptr; |
| 2772 | |
| 2773 | // get the pointer and count for the invocation tokens |
| 2774 | const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(adapter: lora); |
| 2775 | const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (adapter: lora); |
| 2776 | |
| 2777 | // scan backwards through the prompt tokens to find the last |
| 2778 | // occurrence of the invocation sequence |
| 2779 | int match_idx = static_cast<int>(n_invocation_tokens) - 1; |
| 2780 | for (int i = task.tokens.size() - 1; i >= 0; --i) { |
| 2781 | // the token in this position matches the next token to find in |
| 2782 | // the invocation sequence |
| 2783 | if (task.tokens[i] == invocation_tokens[match_idx]) { |
| 2784 | // if it's a full match, we've found the start |
| 2785 | if (match_idx == 0) { |
| 2786 | alora_invocation_start = i; |
| 2787 | break; |
| 2788 | } |
| 2789 | // otherwise, check the next token in the sequence |
| 2790 | --match_idx; |
| 2791 | } else { |
| 2792 | // no match in this position, so start looking over again |
| 2793 | match_idx = static_cast<int>(n_invocation_tokens) - 1; |
| 2794 | } |
| 2795 | } |
| 2796 | |
| 2797 | // if the activation string is not found, disable the alora |
| 2798 | if (alora_invocation_start == task.tokens.size()) { |
| 2799 | SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n" , enabled_ids[0]); |
| 2800 | slot.lora[enabled_ids[0]].scale = 0.0f; |
| 2801 | } else { |
| 2802 | SLT_DBG(slot, "alora %zu activated starting at %zu\n" , enabled_ids[0], alora_invocation_start); |
| 2803 | slot.alora_invocation_start = alora_invocation_start; |
| 2804 | } |
| 2805 | } |
| 2806 | |
| 2807 | if (!task.tokens.validate(ctx)) { |
| 2808 | send_error(task, error: "Prompt contains invalid tokens" , type: ERROR_TYPE_INVALID_REQUEST); |
| 2809 | return false; |
| 2810 | } |
| 2811 | |
| 2812 | SLT_DBG(slot, "launching slot : %s\n" , safe_json_to_str(slot.to_json()).c_str()); |
| 2813 | |
| 2814 | // initialize samplers |
| 2815 | { |
| 2816 | if (slot.smpl != nullptr) { |
| 2817 | common_sampler_free(gsmpl: slot.smpl); |
| 2818 | } |
| 2819 | |
| 2820 | slot.smpl = common_sampler_init(model, params: task.params.sampling); |
| 2821 | if (slot.smpl == nullptr) { |
| 2822 | // for now, the only error that may happen here is invalid grammar |
| 2823 | send_error(task, error: "Failed to parse grammar" , type: ERROR_TYPE_INVALID_REQUEST); |
| 2824 | return false; |
| 2825 | } |
| 2826 | |
| 2827 | SLT_INF(slot, "sampler chain: %s\n" , common_sampler_print(slot.smpl).c_str()); |
| 2828 | } |
| 2829 | |
| 2830 | // initialize draft batch |
| 2831 | // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] |
| 2832 | if (slot.ctx_dft) { |
| 2833 | llama_batch_free(batch: slot.batch_spec); |
| 2834 | |
| 2835 | slot.batch_spec = llama_batch_init(n_tokens: task.params.speculative.n_max + 1, embd: 0, n_seq_max: 1); |
| 2836 | } |
| 2837 | |
| 2838 | slot.task = std::make_unique<const server_task>(args: std::move(task)); |
| 2839 | |
| 2840 | slot.state = SLOT_STATE_STARTED; |
| 2841 | |
| 2842 | SLT_INF(slot, "%s" , "processing task\n" ); |
| 2843 | |
| 2844 | return true; |
| 2845 | } |
| 2846 | |
| 2847 | void kv_cache_clear() { |
| 2848 | SRV_DBG("%s" , "clearing KV cache\n" ); |
| 2849 | |
| 2850 | // clear the entire KV cache |
| 2851 | llama_memory_clear(mem: llama_get_memory(ctx), data: true); |
| 2852 | clean_kv_cache = false; |
| 2853 | } |
| 2854 | |
| 2855 | bool process_token(completion_token_output & result, server_slot & slot) { |
| 2856 | // remember which tokens were sampled - used for repetition penalties during sampling |
| 2857 | const std::string token_str = result.text_to_send; |
| 2858 | slot.sampled = result.tok; |
| 2859 | |
| 2860 | slot.generated_text += token_str; |
| 2861 | if (slot.task->params.return_tokens) { |
| 2862 | slot.generated_tokens.push_back(x: result.tok); |
| 2863 | } |
| 2864 | slot.has_next_token = true; |
| 2865 | |
| 2866 | // check if there is incomplete UTF-8 character at the end |
| 2867 | bool incomplete = validate_utf8(text: slot.generated_text) < slot.generated_text.size(); |
| 2868 | |
| 2869 | // search stop word and delete it |
| 2870 | if (!incomplete) { |
| 2871 | size_t pos = std::min(a: slot.n_sent_text, b: slot.generated_text.size()); |
| 2872 | |
| 2873 | const std::string str_test = slot.generated_text.substr(pos: pos); |
| 2874 | bool send_text = true; |
| 2875 | |
| 2876 | size_t stop_pos = slot.find_stopping_strings(text: str_test, last_token_size: token_str.size(), is_full_stop: true); |
| 2877 | if (stop_pos != std::string::npos) { |
| 2878 | slot.generated_text.erase( |
| 2879 | first: slot.generated_text.begin() + pos + stop_pos, |
| 2880 | last: slot.generated_text.end()); |
| 2881 | pos = std::min(a: slot.n_sent_text, b: slot.generated_text.size()); |
| 2882 | } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, token: result.tok) ) { |
| 2883 | stop_pos = slot.find_stopping_strings(text: str_test, last_token_size: token_str.size(), is_full_stop: false); |
| 2884 | send_text = stop_pos == std::string::npos; |
| 2885 | } |
| 2886 | |
| 2887 | // check if there is any token to predict |
| 2888 | if (send_text) { |
| 2889 | // no send the stop word in the response |
| 2890 | result.text_to_send = slot.generated_text.substr(pos: pos, n: std::string::npos); |
| 2891 | slot.n_sent_text += result.text_to_send.size(); |
| 2892 | // add the token to slot queue and cache |
| 2893 | } else { |
| 2894 | result.text_to_send = "" ; |
| 2895 | } |
| 2896 | |
| 2897 | slot.add_token(token: result); |
| 2898 | if (slot.task->params.stream) { |
| 2899 | send_partial_response(slot, tkn: result, is_progress: false); |
| 2900 | } |
| 2901 | } |
| 2902 | |
| 2903 | if (incomplete) { |
| 2904 | slot.has_next_token = true; |
| 2905 | } |
| 2906 | |
| 2907 | // if context shifting is disabled, make sure that we don't run out of context |
| 2908 | if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { |
| 2909 | slot.truncated = true; |
| 2910 | slot.stop = STOP_TYPE_LIMIT; |
| 2911 | slot.has_next_token = false; |
| 2912 | |
| 2913 | SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n" , |
| 2914 | slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); |
| 2915 | } |
| 2916 | |
| 2917 | // check the limits |
| 2918 | if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(global_params: params_base)) { |
| 2919 | slot.stop = STOP_TYPE_LIMIT; |
| 2920 | slot.has_next_token = false; |
| 2921 | |
| 2922 | SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n" , slot.n_decoded, slot.task->params.n_predict); |
| 2923 | } |
| 2924 | |
| 2925 | if (slot.has_new_line) { |
| 2926 | // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent |
| 2927 | if (slot.task->params.n_indent > 0) { |
| 2928 | // check the current indentation |
| 2929 | // TODO: improve by not doing it more than once for each new line |
| 2930 | if (slot.last_nl_pos > 0) { |
| 2931 | size_t pos = slot.last_nl_pos; |
| 2932 | |
| 2933 | int n_indent = 0; |
| 2934 | while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { |
| 2935 | n_indent++; |
| 2936 | pos++; |
| 2937 | } |
| 2938 | |
| 2939 | if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { |
| 2940 | slot.stop = STOP_TYPE_LIMIT; |
| 2941 | slot.has_next_token = false; |
| 2942 | |
| 2943 | // cut the last line |
| 2944 | slot.generated_text.erase(pos: pos, n: std::string::npos); |
| 2945 | |
| 2946 | SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n" , slot.n_decoded, n_indent); |
| 2947 | } |
| 2948 | } |
| 2949 | |
| 2950 | // find the next new line |
| 2951 | { |
| 2952 | const size_t pos = slot.generated_text.find(c: '\n', pos: slot.last_nl_pos); |
| 2953 | |
| 2954 | if (pos != std::string::npos) { |
| 2955 | slot.last_nl_pos = pos + 1; |
| 2956 | } |
| 2957 | } |
| 2958 | } |
| 2959 | } |
| 2960 | |
| 2961 | // check if there is a new line in the generated text |
| 2962 | if (result.text_to_send.find(c: '\n') != std::string::npos) { |
| 2963 | slot.has_new_line = true; |
| 2964 | |
| 2965 | // if we have seen a new line, we stop after a certain time limit, but only upon another new line |
| 2966 | if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { |
| 2967 | slot.stop = STOP_TYPE_LIMIT; |
| 2968 | slot.has_next_token = false; |
| 2969 | |
| 2970 | SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n" , slot.n_decoded, (int) slot.task->params.t_max_predict_ms); |
| 2971 | } |
| 2972 | } |
| 2973 | |
| 2974 | if (llama_vocab_is_eog(vocab, token: result.tok)) { |
| 2975 | slot.stop = STOP_TYPE_EOS; |
| 2976 | slot.has_next_token = false; |
| 2977 | |
| 2978 | SLT_DBG(slot, "%s" , "stopped by EOS\n" ); |
| 2979 | } |
| 2980 | |
| 2981 | SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n" , slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); |
| 2982 | |
| 2983 | return slot.has_next_token; // continue |
| 2984 | } |
| 2985 | |
| 2986 | void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { |
| 2987 | size_t n_probs = slot.task->params.sampling.n_probs; |
| 2988 | size_t n_vocab = llama_vocab_n_tokens(vocab); |
| 2989 | |
| 2990 | if (post_sampling) { |
| 2991 | const auto * cur_p = common_sampler_get_candidates(gsmpl: slot.smpl, do_sort: true); |
| 2992 | const size_t max_probs = cur_p->size; |
| 2993 | |
| 2994 | // set probability for sampled token |
| 2995 | for (size_t i = 0; i < max_probs; i++) { |
| 2996 | if (cur_p->data[i].id == result.tok) { |
| 2997 | result.prob = cur_p->data[i].p; |
| 2998 | break; |
| 2999 | } |
| 3000 | } |
| 3001 | |
| 3002 | // set probability for top n_probs tokens |
| 3003 | result.probs.reserve(n: max_probs); |
| 3004 | for (size_t i = 0; i < std::min(a: max_probs, b: n_probs); i++) { |
| 3005 | result.probs.push_back(x: { |
| 3006 | .tok: cur_p->data[i].id, |
| 3007 | .txt: common_token_to_piece(ctx, token: cur_p->data[i].id, special), |
| 3008 | .prob: cur_p->data[i].p |
| 3009 | }); |
| 3010 | } |
| 3011 | } else { |
| 3012 | // TODO: optimize this with min-p optimization |
| 3013 | std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx); |
| 3014 | |
| 3015 | // set probability for sampled token |
| 3016 | for (size_t i = 0; i < n_vocab; i++) { |
| 3017 | // set probability for sampled token |
| 3018 | if (cur[i].id == result.tok) { |
| 3019 | result.prob = cur[i].p; |
| 3020 | break; |
| 3021 | } |
| 3022 | } |
| 3023 | |
| 3024 | // set probability for top n_probs tokens |
| 3025 | result.probs.reserve(n: n_probs); |
| 3026 | for (size_t i = 0; i < std::min(a: n_vocab, b: n_probs); i++) { |
| 3027 | result.probs.push_back(x: { |
| 3028 | .tok: cur[i].id, |
| 3029 | .txt: common_token_to_piece(ctx, token: cur[i].id, special), |
| 3030 | .prob: cur[i].p |
| 3031 | }); |
| 3032 | } |
| 3033 | } |
| 3034 | } |
| 3035 | |
| 3036 | void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { |
| 3037 | send_error(id_task: task.id, error, type); |
| 3038 | } |
| 3039 | |
| 3040 | void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { |
| 3041 | send_error(id_task: slot.task->id, error, type, n_prompt_tokens: slot.task->n_tokens(), n_ctx: slot.n_ctx); |
| 3042 | } |
| 3043 | |
| 3044 | void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { |
| 3045 | SRV_ERR("task id = %d, error: %s\n" , id_task, error.c_str()); |
| 3046 | |
| 3047 | if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { |
| 3048 | GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); |
| 3049 | } |
| 3050 | |
| 3051 | auto res = std::make_unique<server_task_result_error>(); |
| 3052 | res->id = id_task; |
| 3053 | res->err_type = type; |
| 3054 | res->err_msg = error; |
| 3055 | res->n_prompt_tokens = n_prompt_tokens; |
| 3056 | res->n_ctx = n_ctx; |
| 3057 | |
| 3058 | queue_results.send(result: std::move(res)); |
| 3059 | } |
| 3060 | |
| 3061 | // if multimodal is enabled, send an error and return false |
| 3062 | bool check_no_mtmd(const int id_task) { |
| 3063 | if (mctx) { |
| 3064 | send_error(id_task, error: "This feature is not supported by multimodal" , type: ERROR_TYPE_NOT_SUPPORTED); |
| 3065 | return false; |
| 3066 | } |
| 3067 | return true; |
| 3068 | } |
| 3069 | |
| 3070 | void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { |
| 3071 | auto res = std::make_unique<server_task_result_cmpl_partial>(); |
| 3072 | |
| 3073 | res->id = slot.task->id; |
| 3074 | res->index = slot.task->index; |
| 3075 | |
| 3076 | if (is_progress) { |
| 3077 | res->is_progress = true; |
| 3078 | res->progress.total = slot.task->n_tokens(); |
| 3079 | res->progress.cache = slot.n_prompt_tokens_cache; |
| 3080 | res->progress.processed = slot.prompt.tokens.size(); |
| 3081 | res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; |
| 3082 | } else { |
| 3083 | res->content = tkn.text_to_send; |
| 3084 | res->tokens = { tkn.tok }; |
| 3085 | |
| 3086 | slot.update_chat_msg(diffs&: res->oaicompat_msg_diffs); |
| 3087 | } |
| 3088 | |
| 3089 | res->n_decoded = slot.n_decoded; |
| 3090 | res->n_prompt_tokens = slot.task->n_tokens(); |
| 3091 | res->post_sampling_probs = slot.task->params.post_sampling_probs; |
| 3092 | |
| 3093 | res->verbose = slot.task->params.verbose; |
| 3094 | res->oaicompat = slot.task->params.oaicompat; |
| 3095 | res->oaicompat_model = slot.task->params.oaicompat_model; |
| 3096 | res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; |
| 3097 | |
| 3098 | // populate res.probs_output |
| 3099 | if (slot.task->params.sampling.n_probs > 0) { |
| 3100 | res->prob_output = tkn; // copy the token probs |
| 3101 | } |
| 3102 | |
| 3103 | // populate timings if this is final response or timings_per_token is enabled |
| 3104 | if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { |
| 3105 | res->timings = slot.get_timings(); |
| 3106 | } |
| 3107 | |
| 3108 | queue_results.send(result: std::move(res)); |
| 3109 | } |
| 3110 | |
| 3111 | void send_final_response(server_slot & slot) { |
| 3112 | auto res = std::make_unique<server_task_result_cmpl_final>(); |
| 3113 | |
| 3114 | res->id = slot.task->id; |
| 3115 | res->id_slot = slot.id; |
| 3116 | |
| 3117 | res->index = slot.task->index; |
| 3118 | res->content = slot.generated_text; |
| 3119 | res->tokens = std::move(slot.generated_tokens); |
| 3120 | res->timings = slot.get_timings(); |
| 3121 | res->prompt = slot.task->tokens.detokenize(ctx, special: true); |
| 3122 | res->response_fields = std::move(slot.task->params.response_fields); |
| 3123 | |
| 3124 | res->truncated = slot.truncated; |
| 3125 | res->n_decoded = slot.n_decoded; |
| 3126 | res->n_prompt_tokens = slot.task->n_tokens(); |
| 3127 | res->n_tokens_cached = slot.prompt.n_tokens(); |
| 3128 | res->has_new_line = slot.has_new_line; |
| 3129 | res->stopping_word = slot.stopping_word; |
| 3130 | res->stop = slot.stop; |
| 3131 | res->post_sampling_probs = slot.task->params.post_sampling_probs; |
| 3132 | |
| 3133 | res->verbose = slot.task->params.verbose; |
| 3134 | res->stream = slot.task->params.stream; |
| 3135 | res->include_usage = slot.task->params.include_usage; |
| 3136 | res->oaicompat = slot.task->params.oaicompat; |
| 3137 | res->oaicompat_model = slot.task->params.oaicompat_model; |
| 3138 | res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; |
| 3139 | res->oaicompat_msg = slot.update_chat_msg(diffs&: res->oaicompat_msg_diffs); |
| 3140 | |
| 3141 | // populate res.probs_output |
| 3142 | if (slot.task->params.sampling.n_probs > 0) { |
| 3143 | if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { |
| 3144 | const llama_tokens stop_word_toks = common_tokenize(ctx, text: slot.stopping_word, add_special: false); |
| 3145 | |
| 3146 | size_t safe_offset = std::min(a: slot.generated_token_probs.size(), b: stop_word_toks.size()); |
| 3147 | res->probs_output = std::vector<completion_token_output>( |
| 3148 | slot.generated_token_probs.begin(), |
| 3149 | slot.generated_token_probs.end() - safe_offset); |
| 3150 | } else { |
| 3151 | res->probs_output = std::vector<completion_token_output>( |
| 3152 | slot.generated_token_probs.begin(), |
| 3153 | slot.generated_token_probs.end()); |
| 3154 | } |
| 3155 | } |
| 3156 | |
| 3157 | res->generation_params = slot.task->params; // copy the parameters |
| 3158 | |
| 3159 | queue_results.send(result: std::move(res)); |
| 3160 | } |
| 3161 | |
| 3162 | void send_embedding(const server_slot & slot, const llama_batch & batch) { |
| 3163 | auto res = std::make_unique<server_task_result_embd>(); |
| 3164 | res->id = slot.task->id; |
| 3165 | res->index = slot.task->index; |
| 3166 | res->n_tokens = slot.task->n_tokens(); |
| 3167 | res->oaicompat = slot.task->params.oaicompat; |
| 3168 | |
| 3169 | const int n_embd = llama_model_n_embd(model); |
| 3170 | |
| 3171 | std::vector<float> embd_res(n_embd, 0.0f); |
| 3172 | |
| 3173 | for (int i = 0; i < batch.n_tokens; ++i) { |
| 3174 | if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { |
| 3175 | continue; |
| 3176 | } |
| 3177 | |
| 3178 | const float * embd = nullptr; |
| 3179 | if (llama_pooling_type(ctx: slot.ctx) == LLAMA_POOLING_TYPE_NONE) { |
| 3180 | embd = llama_get_embeddings_ith(ctx, i); |
| 3181 | } else { |
| 3182 | embd = llama_get_embeddings_seq(ctx, seq_id: batch.seq_id[i][0]); |
| 3183 | } |
| 3184 | |
| 3185 | if (embd == nullptr) { |
| 3186 | SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n" , batch.token[i], batch.seq_id[i][0]); |
| 3187 | |
| 3188 | res->embedding.push_back(x: std::vector<float>(n_embd, 0.0f)); |
| 3189 | continue; |
| 3190 | } |
| 3191 | |
| 3192 | // normalize only when there is pooling |
| 3193 | if (llama_pooling_type(ctx: slot.ctx) != LLAMA_POOLING_TYPE_NONE) { |
| 3194 | common_embd_normalize(inp: embd, out: embd_res.data(), n: n_embd, embd_norm: slot.task->params.embd_normalize); |
| 3195 | res->embedding.push_back(x: embd_res); |
| 3196 | break; |
| 3197 | } |
| 3198 | |
| 3199 | res->embedding.emplace_back(args&: embd, args: embd + n_embd); |
| 3200 | } |
| 3201 | |
| 3202 | SLT_DBG(slot, "%s" , "sending embeddings\n" ); |
| 3203 | |
| 3204 | queue_results.send(result: std::move(res)); |
| 3205 | } |
| 3206 | |
| 3207 | void send_rerank(const server_slot & slot, const llama_batch & batch) { |
| 3208 | auto res = std::make_unique<server_task_result_rerank>(); |
| 3209 | res->id = slot.task->id; |
| 3210 | res->index = slot.task->index; |
| 3211 | res->n_tokens = slot.task->n_tokens(); |
| 3212 | |
| 3213 | for (int i = 0; i < batch.n_tokens; ++i) { |
| 3214 | if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { |
| 3215 | continue; |
| 3216 | } |
| 3217 | |
| 3218 | const float * embd = llama_get_embeddings_seq(ctx, seq_id: batch.seq_id[i][0]); |
| 3219 | if (embd == NULL) { |
| 3220 | embd = llama_get_embeddings_ith(ctx, i); |
| 3221 | } |
| 3222 | |
| 3223 | if (embd == NULL) { |
| 3224 | SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n" , batch.token[i], batch.seq_id[i][0]); |
| 3225 | |
| 3226 | res->score = -1e6; |
| 3227 | continue; |
| 3228 | } |
| 3229 | |
| 3230 | res->score = embd[0]; |
| 3231 | } |
| 3232 | |
| 3233 | SLT_DBG(slot, "sending rerank result, res.score = %f\n" , res->score); |
| 3234 | |
| 3235 | queue_results.send(result: std::move(res)); |
| 3236 | } |
| 3237 | |
| 3238 | // |
| 3239 | // Functions to create new task(s) and receive result(s) |
| 3240 | // |
| 3241 | |
| 3242 | void cancel_tasks(const std::unordered_set<int> & id_tasks) { |
| 3243 | std::vector<server_task> cancel_tasks; |
| 3244 | cancel_tasks.reserve(n: id_tasks.size()); |
| 3245 | for (const auto & id_task : id_tasks) { |
| 3246 | SRV_WRN("cancel task, id_task = %d\n" , id_task); |
| 3247 | |
| 3248 | server_task task(SERVER_TASK_TYPE_CANCEL); |
| 3249 | task.id_target = id_task; |
| 3250 | queue_results.remove_waiting_task_id(id_task); |
| 3251 | cancel_tasks.push_back(x: std::move(task)); |
| 3252 | } |
| 3253 | // push to beginning of the queue, so it has highest priority |
| 3254 | queue_tasks.post(tasks: std::move(cancel_tasks), front: true); |
| 3255 | } |
| 3256 | |
| 3257 | // receive the results from task(s) |
| 3258 | void receive_multi_results( |
| 3259 | const std::unordered_set<int> & id_tasks, |
| 3260 | const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler, |
| 3261 | const std::function<void(json)> & error_handler, |
| 3262 | const std::function<bool()> & is_connection_closed) { |
| 3263 | std::vector<server_task_result_ptr> results(id_tasks.size()); |
| 3264 | for (int i = 0; i < (int)id_tasks.size(); i++) { |
| 3265 | server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, timeout: HTTP_POLLING_SECONDS); |
| 3266 | |
| 3267 | if (is_connection_closed()) { |
| 3268 | cancel_tasks(id_tasks); |
| 3269 | return; |
| 3270 | } |
| 3271 | |
| 3272 | if (result == nullptr) { |
| 3273 | i--; // retry |
| 3274 | continue; |
| 3275 | } |
| 3276 | |
| 3277 | if (result->is_error()) { |
| 3278 | error_handler(result->to_json()); |
| 3279 | cancel_tasks(id_tasks); |
| 3280 | return; |
| 3281 | } |
| 3282 | |
| 3283 | GGML_ASSERT( |
| 3284 | dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr |
| 3285 | || dynamic_cast<server_task_result_embd*>(result.get()) != nullptr |
| 3286 | || dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr |
| 3287 | ); |
| 3288 | const size_t idx = result->get_index(); |
| 3289 | GGML_ASSERT(idx < results.size() && "index out of range" ); |
| 3290 | results[idx] = std::move(result); |
| 3291 | } |
| 3292 | result_handler(results); |
| 3293 | } |
| 3294 | |
| 3295 | // receive the results from task(s), in stream mode |
| 3296 | void receive_cmpl_results_stream( |
| 3297 | const std::unordered_set<int> & id_tasks, |
| 3298 | const std::function<bool(server_task_result_ptr&)> & result_handler, |
| 3299 | const std::function<void(json)> & error_handler, |
| 3300 | const std::function<bool()> & is_connection_closed) { |
| 3301 | size_t n_finished = 0; |
| 3302 | while (true) { |
| 3303 | server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, timeout: HTTP_POLLING_SECONDS); |
| 3304 | |
| 3305 | if (is_connection_closed()) { |
| 3306 | cancel_tasks(id_tasks); |
| 3307 | return; |
| 3308 | } |
| 3309 | |
| 3310 | if (result == nullptr) { |
| 3311 | continue; // retry |
| 3312 | } |
| 3313 | |
| 3314 | if (result->is_error()) { |
| 3315 | error_handler(result->to_json()); |
| 3316 | cancel_tasks(id_tasks); |
| 3317 | return; |
| 3318 | } |
| 3319 | |
| 3320 | GGML_ASSERT( |
| 3321 | dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr |
| 3322 | || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr |
| 3323 | ); |
| 3324 | if (!result_handler(result)) { |
| 3325 | cancel_tasks(id_tasks); |
| 3326 | break; |
| 3327 | } |
| 3328 | |
| 3329 | if (result->is_stop()) { |
| 3330 | if (++n_finished == id_tasks.size()) { |
| 3331 | break; |
| 3332 | } |
| 3333 | } |
| 3334 | } |
| 3335 | } |
| 3336 | |
| 3337 | // |
| 3338 | // Functions to process the task |
| 3339 | // |
| 3340 | |
| 3341 | void process_single_task(server_task && task) { |
| 3342 | switch (task.type) { |
| 3343 | case SERVER_TASK_TYPE_COMPLETION: |
| 3344 | case SERVER_TASK_TYPE_INFILL: |
| 3345 | case SERVER_TASK_TYPE_EMBEDDING: |
| 3346 | case SERVER_TASK_TYPE_RERANK: |
| 3347 | { |
| 3348 | const int id_slot = task.id_slot; |
| 3349 | |
| 3350 | server_slot * slot = id_slot != -1 ? get_slot_by_id(id: id_slot) : get_available_slot(task); |
| 3351 | |
| 3352 | if (slot == nullptr) { |
| 3353 | // if no slot is available, we defer this task for processing later |
| 3354 | SRV_DBG("no slot is available, defer task, id_task = %d\n" , task.id); |
| 3355 | queue_tasks.defer(task: std::move(task)); |
| 3356 | break; |
| 3357 | } |
| 3358 | |
| 3359 | if (slot->is_processing()) { |
| 3360 | // if requested slot is unavailable, we defer this task for processing later |
| 3361 | SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n" , task.id); |
| 3362 | queue_tasks.defer(task: std::move(task)); |
| 3363 | break; |
| 3364 | } |
| 3365 | |
| 3366 | if (!launch_slot_with_task(slot&: *slot, task: std::move(task))) { |
| 3367 | SRV_ERR("failed to launch slot with task, id_task = %d\n" , task.id); |
| 3368 | break; |
| 3369 | } |
| 3370 | } break; |
| 3371 | case SERVER_TASK_TYPE_CANCEL: |
| 3372 | { |
| 3373 | // release slot linked with the task id |
| 3374 | for (auto & slot : slots) { |
| 3375 | if (slot.task && slot.task->id == task.id_target) { |
| 3376 | slot.release(); |
| 3377 | break; |
| 3378 | } |
| 3379 | } |
| 3380 | } break; |
| 3381 | case SERVER_TASK_TYPE_NEXT_RESPONSE: |
| 3382 | { |
| 3383 | // do nothing |
| 3384 | } break; |
| 3385 | case SERVER_TASK_TYPE_METRICS: |
| 3386 | { |
| 3387 | json slots_data = json::array(); |
| 3388 | |
| 3389 | int n_idle_slots = 0; |
| 3390 | int n_processing_slots = 0; |
| 3391 | |
| 3392 | for (server_slot & slot : slots) { |
| 3393 | json slot_data = slot.to_json(only_metrics: slots_debug == 0); |
| 3394 | |
| 3395 | if (slot.is_processing()) { |
| 3396 | n_processing_slots++; |
| 3397 | } else { |
| 3398 | n_idle_slots++; |
| 3399 | } |
| 3400 | |
| 3401 | slots_data.push_back(val: slot_data); |
| 3402 | } |
| 3403 | SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n" , n_idle_slots, n_processing_slots); |
| 3404 | |
| 3405 | auto res = std::make_unique<server_task_result_metrics>(); |
| 3406 | res->id = task.id; |
| 3407 | res->slots_data = std::move(slots_data); |
| 3408 | res->n_idle_slots = n_idle_slots; |
| 3409 | res->n_processing_slots = n_processing_slots; |
| 3410 | res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); |
| 3411 | res->t_start = metrics.t_start; |
| 3412 | |
| 3413 | res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; |
| 3414 | res->t_prompt_processing_total = metrics.t_prompt_processing_total; |
| 3415 | res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; |
| 3416 | res->t_tokens_generation_total = metrics.t_tokens_generation_total; |
| 3417 | |
| 3418 | res->n_tokens_max = metrics.n_tokens_max; |
| 3419 | |
| 3420 | res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; |
| 3421 | res->t_prompt_processing = metrics.t_prompt_processing; |
| 3422 | res->n_tokens_predicted = metrics.n_tokens_predicted; |
| 3423 | res->t_tokens_generation = metrics.t_tokens_generation; |
| 3424 | |
| 3425 | res->n_decode_total = metrics.n_decode_total; |
| 3426 | res->n_busy_slots_total = metrics.n_busy_slots_total; |
| 3427 | |
| 3428 | if (task.metrics_reset_bucket) { |
| 3429 | metrics.reset_bucket(); |
| 3430 | } |
| 3431 | queue_results.send(result: std::move(res)); |
| 3432 | } break; |
| 3433 | case SERVER_TASK_TYPE_SLOT_SAVE: |
| 3434 | { |
| 3435 | if (!check_no_mtmd(id_task: task.id)) { |
| 3436 | break; |
| 3437 | } |
| 3438 | |
| 3439 | int id_slot = task.slot_action.slot_id; |
| 3440 | server_slot * slot = get_slot_by_id(id: id_slot); |
| 3441 | if (slot == nullptr) { |
| 3442 | send_error(task, error: "Invalid slot ID" , type: ERROR_TYPE_INVALID_REQUEST); |
| 3443 | break; |
| 3444 | } |
| 3445 | if (slot->is_processing()) { |
| 3446 | // if requested slot is unavailable, we defer this task for processing later |
| 3447 | SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n" , task.id); |
| 3448 | queue_tasks.defer(task: std::move(task)); |
| 3449 | break; |
| 3450 | } |
| 3451 | |
| 3452 | const size_t token_count = slot->prompt.tokens.size(); |
| 3453 | const int64_t t_start = ggml_time_us(); |
| 3454 | |
| 3455 | std::string filename = task.slot_action.filename; |
| 3456 | std::string filepath = task.slot_action.filepath; |
| 3457 | |
| 3458 | const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); |
| 3459 | const size_t nwrite = llama_state_seq_save_file(ctx, filepath: filepath.c_str(), seq_id: slot->id, tokens: tokens.data(), n_token_count: token_count); |
| 3460 | |
| 3461 | const int64_t t_end = ggml_time_us(); |
| 3462 | const double t_save_ms = (t_end - t_start) / 1000.0; |
| 3463 | |
| 3464 | auto res = std::make_unique<server_task_result_slot_save_load>(); |
| 3465 | res->id = task.id; |
| 3466 | res->id_slot = id_slot; |
| 3467 | res->filename = filename; |
| 3468 | res->is_save = true; |
| 3469 | res->n_tokens = token_count; |
| 3470 | res->n_bytes = nwrite; |
| 3471 | res->t_ms = t_save_ms; |
| 3472 | queue_results.send(result: std::move(res)); |
| 3473 | } break; |
| 3474 | case SERVER_TASK_TYPE_SLOT_RESTORE: |
| 3475 | { |
| 3476 | if (!check_no_mtmd(id_task: task.id)) break; |
| 3477 | int id_slot = task.slot_action.slot_id; |
| 3478 | server_slot * slot = get_slot_by_id(id: id_slot); |
| 3479 | if (slot == nullptr) { |
| 3480 | send_error(task, error: "Invalid slot ID" , type: ERROR_TYPE_INVALID_REQUEST); |
| 3481 | break; |
| 3482 | } |
| 3483 | if (slot->is_processing()) { |
| 3484 | // if requested slot is unavailable, we defer this task for processing later |
| 3485 | SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n" , task.id); |
| 3486 | queue_tasks.defer(task: std::move(task)); |
| 3487 | break; |
| 3488 | } |
| 3489 | |
| 3490 | const int64_t t_start = ggml_time_us(); |
| 3491 | |
| 3492 | std::string filename = task.slot_action.filename; |
| 3493 | std::string filepath = task.slot_action.filepath; |
| 3494 | |
| 3495 | llama_tokens tokens; |
| 3496 | tokens.resize(new_size: slot->n_ctx); |
| 3497 | size_t token_count = 0; |
| 3498 | size_t nread = llama_state_seq_load_file(ctx, filepath: filepath.c_str(), dest_seq_id: slot->id, tokens_out: tokens.data(), n_token_capacity: tokens.size(), n_token_count_out: &token_count); |
| 3499 | if (nread == 0) { |
| 3500 | slot->prompt.tokens.clear(); // KV may already been invalidated? |
| 3501 | send_error(task, error: "Unable to restore slot, no available space in KV cache or invalid slot save file" , type: ERROR_TYPE_INVALID_REQUEST); |
| 3502 | break; |
| 3503 | } |
| 3504 | tokens.resize(new_size: token_count); |
| 3505 | slot->prompt.tokens.clear(); |
| 3506 | slot->prompt.tokens.insert(inp_tokens: tokens); |
| 3507 | |
| 3508 | const int64_t t_end = ggml_time_us(); |
| 3509 | const double t_restore_ms = (t_end - t_start) / 1000.0; |
| 3510 | |
| 3511 | auto res = std::make_unique<server_task_result_slot_save_load>(); |
| 3512 | res->id = task.id; |
| 3513 | res->id_slot = id_slot; |
| 3514 | res->filename = filename; |
| 3515 | res->is_save = false; |
| 3516 | res->n_tokens = token_count; |
| 3517 | res->n_bytes = nread; |
| 3518 | res->t_ms = t_restore_ms; |
| 3519 | queue_results.send(result: std::move(res)); |
| 3520 | } break; |
| 3521 | case SERVER_TASK_TYPE_SLOT_ERASE: |
| 3522 | { |
| 3523 | if (!check_no_mtmd(id_task: task.id)) { |
| 3524 | break; |
| 3525 | } |
| 3526 | int id_slot = task.slot_action.slot_id; |
| 3527 | server_slot * slot = get_slot_by_id(id: id_slot); |
| 3528 | if (slot == nullptr) { |
| 3529 | send_error(task, error: "Invalid slot ID" , type: ERROR_TYPE_INVALID_REQUEST); |
| 3530 | break; |
| 3531 | } |
| 3532 | if (slot->is_processing()) { |
| 3533 | // if requested slot is unavailable, we defer this task for processing later |
| 3534 | SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n" , task.id); |
| 3535 | queue_tasks.defer(task: std::move(task)); |
| 3536 | break; |
| 3537 | } |
| 3538 | |
| 3539 | // Erase token cache |
| 3540 | const size_t n_erased = slot->prompt.tokens.size(); |
| 3541 | llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot->id, p0: -1, p1: -1); |
| 3542 | slot->prompt.tokens.clear(); |
| 3543 | |
| 3544 | auto res = std::make_unique<server_task_result_slot_erase>(); |
| 3545 | res->id = task.id; |
| 3546 | res->id_slot = id_slot; |
| 3547 | res->n_erased = n_erased; |
| 3548 | queue_results.send(result: std::move(res)); |
| 3549 | } break; |
| 3550 | case SERVER_TASK_TYPE_SET_LORA: |
| 3551 | { |
| 3552 | params_base.lora_adapters = std::move(task.set_lora); |
| 3553 | auto res = std::make_unique<server_task_result_apply_lora>(); |
| 3554 | res->id = task.id; |
| 3555 | queue_results.send(result: std::move(res)); |
| 3556 | } break; |
| 3557 | |
| 3558 | } |
| 3559 | } |
| 3560 | |
| 3561 | void update_slots() { |
| 3562 | // check if all slots are idle |
| 3563 | { |
| 3564 | bool all_idle = true; |
| 3565 | |
| 3566 | for (auto & slot : slots) { |
| 3567 | if (slot.is_processing()) { |
| 3568 | all_idle = false; |
| 3569 | break; |
| 3570 | } |
| 3571 | } |
| 3572 | |
| 3573 | if (all_idle) { |
| 3574 | SRV_INF("%s" , "all slots are idle\n" ); |
| 3575 | if (clean_kv_cache) { |
| 3576 | kv_cache_clear(); |
| 3577 | } |
| 3578 | |
| 3579 | return; |
| 3580 | } |
| 3581 | } |
| 3582 | |
| 3583 | { |
| 3584 | SRV_DBG("%s" , "posting NEXT_RESPONSE\n" ); |
| 3585 | |
| 3586 | server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); |
| 3587 | task.id = queue_tasks.get_new_id(); |
| 3588 | queue_tasks.post(task: std::move(task)); |
| 3589 | } |
| 3590 | |
| 3591 | // apply context-shift if needed |
| 3592 | // TODO: simplify and improve |
| 3593 | for (server_slot & slot : slots) { |
| 3594 | if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { |
| 3595 | if (!params_base.ctx_shift) { |
| 3596 | // this check is redundant (for good) |
| 3597 | // we should never get here, because generation should already stopped in process_token() |
| 3598 | send_error(slot, error: "context shift is disabled" , type: ERROR_TYPE_SERVER); |
| 3599 | slot.release(); |
| 3600 | continue; |
| 3601 | } |
| 3602 | |
| 3603 | if (mctx) { |
| 3604 | // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded |
| 3605 | // we don't support ctx_shift because an image chunk may contains multiple tokens |
| 3606 | GGML_ABORT("not supported by multimodal" ); |
| 3607 | } |
| 3608 | |
| 3609 | // Shift context |
| 3610 | int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; |
| 3611 | |
| 3612 | if (add_bos_token) { |
| 3613 | n_keep += 1; |
| 3614 | } |
| 3615 | |
| 3616 | n_keep = std::min(a: slot.n_ctx - 4, b: n_keep); |
| 3617 | |
| 3618 | const int n_left = slot.prompt.n_tokens() - n_keep; |
| 3619 | const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); |
| 3620 | |
| 3621 | SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n" , n_keep, n_left, n_discard); |
| 3622 | |
| 3623 | llama_memory_seq_rm (mem: llama_get_memory(ctx), seq_id: slot.id, p0: n_keep , p1: n_keep + n_discard); |
| 3624 | llama_memory_seq_add(mem: llama_get_memory(ctx), seq_id: slot.id, p0: n_keep + n_discard, p1: slot.prompt.n_tokens(), delta: -n_discard); |
| 3625 | |
| 3626 | // add generated tokens to cache |
| 3627 | // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 |
| 3628 | { |
| 3629 | GGML_ASSERT(!slot.prompt.tokens.has_mtmd); |
| 3630 | |
| 3631 | llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy |
| 3632 | for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { |
| 3633 | new_tokens[i - n_discard] = new_tokens[i]; |
| 3634 | } |
| 3635 | |
| 3636 | new_tokens.resize(new_size: slot.prompt.tokens.size() - n_discard); |
| 3637 | |
| 3638 | slot.prompt.tokens.clear(); |
| 3639 | slot.prompt.tokens.insert(inp_tokens: new_tokens); |
| 3640 | } |
| 3641 | |
| 3642 | slot.truncated = true; |
| 3643 | } |
| 3644 | } |
| 3645 | |
| 3646 | // start populating the batch for this iteration |
| 3647 | common_batch_clear(batch); |
| 3648 | |
| 3649 | // track if given slot can be batched with slots already in the batch |
| 3650 | server_slot * slot_batched = nullptr; |
| 3651 | |
| 3652 | auto accept_special_token = [&](server_slot & slot, llama_token token) { |
| 3653 | return params_base.special || |
| 3654 | slot.task->params.sampling.preserved_tokens.find(x: token) != slot.task->params.sampling.preserved_tokens.end(); |
| 3655 | }; |
| 3656 | |
| 3657 | // first, add sampled tokens from any ongoing sequences |
| 3658 | for (auto & slot : slots) { |
| 3659 | if (slot.state != SLOT_STATE_GENERATING) { |
| 3660 | continue; |
| 3661 | } |
| 3662 | |
| 3663 | // check if we can batch this slot with the previous one |
| 3664 | if (!slot_batched) { |
| 3665 | slot_batched = &slot; |
| 3666 | } else if (!slot_batched->can_batch_with(other_slot&: slot)) { |
| 3667 | continue; |
| 3668 | } |
| 3669 | |
| 3670 | slot.i_batch = batch.n_tokens; |
| 3671 | |
| 3672 | common_batch_add(batch, id: slot.sampled, pos: slot.prompt.tokens.pos_next(), seq_ids: { slot.id }, logits: true); |
| 3673 | |
| 3674 | slot.prompt.tokens.push_back(tok: slot.sampled); |
| 3675 | |
| 3676 | SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n" , |
| 3677 | slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); |
| 3678 | } |
| 3679 | |
| 3680 | // process in chunks of params.n_batch |
| 3681 | int32_t n_batch = llama_n_batch(ctx); |
| 3682 | int32_t n_ubatch = llama_n_ubatch(ctx); |
| 3683 | |
| 3684 | float alora_scale = -1.0f; |
| 3685 | size_t alora_disabled_id = 0; |
| 3686 | |
| 3687 | // next, batch any pending prompts without exceeding n_batch |
| 3688 | if (params_base.cont_batching || batch.n_tokens == 0) { |
| 3689 | for (auto & slot : slots) { |
| 3690 | // check if we can batch this slot with the previous one |
| 3691 | if (slot.is_processing()) { |
| 3692 | if (!slot_batched) { |
| 3693 | slot_batched = &slot; |
| 3694 | } else if (!slot_batched->can_batch_with(other_slot&: slot)) { |
| 3695 | continue; |
| 3696 | } |
| 3697 | } |
| 3698 | |
| 3699 | // this slot still has a prompt to be processed |
| 3700 | if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { |
| 3701 | const auto & input_tokens = slot.task->tokens; |
| 3702 | |
| 3703 | // TODO: maybe move branch to outside of this loop in the future |
| 3704 | if (slot.state == SLOT_STATE_STARTED) { |
| 3705 | slot.t_start_process_prompt = ggml_time_us(); |
| 3706 | slot.t_start_generation = 0; |
| 3707 | |
| 3708 | slot.state = SLOT_STATE_PROCESSING_PROMPT; |
| 3709 | |
| 3710 | SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n" , |
| 3711 | slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); |
| 3712 | |
| 3713 | // print prompt tokens (for debugging) |
| 3714 | /*if (1) { |
| 3715 | // first 16 tokens (avoid flooding logs) |
| 3716 | for (int i = 0; i < std::min<int>(16, input_tokens.size()); i++) { |
| 3717 | SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); |
| 3718 | } |
| 3719 | } else { |
| 3720 | // all |
| 3721 | for (int i = 0; i < (int) input_tokens.size(); i++) { |
| 3722 | SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); |
| 3723 | } |
| 3724 | }*/ |
| 3725 | |
| 3726 | // keep track how many tokens we can reuse from the previous state |
| 3727 | int n_past = 0; |
| 3728 | |
| 3729 | // empty prompt passed -> release the slot and send empty response |
| 3730 | if (input_tokens.empty()) { |
| 3731 | SLT_WRN(slot, "%s" , "empty prompt - releasing slot\n" ); |
| 3732 | |
| 3733 | slot.print_timings(); |
| 3734 | send_final_response(slot); |
| 3735 | slot.release(); |
| 3736 | |
| 3737 | continue; |
| 3738 | } |
| 3739 | |
| 3740 | // TODO: support memory-less logits computation |
| 3741 | if (slot.need_logits() && !llama_get_memory(ctx)) { |
| 3742 | send_error(slot, error: "the current context does not logits computation. skipping" , type: ERROR_TYPE_SERVER); |
| 3743 | slot.release(); |
| 3744 | continue; |
| 3745 | } |
| 3746 | |
| 3747 | if (!slot.can_split()) { |
| 3748 | if (slot.task->n_tokens() > n_ubatch) { |
| 3749 | send_error(slot, error: "input is too large to process. increase the physical batch size" , type: ERROR_TYPE_SERVER); |
| 3750 | slot.release(); |
| 3751 | continue; |
| 3752 | } |
| 3753 | |
| 3754 | if (slot.task->n_tokens() > slot.n_ctx) { |
| 3755 | send_error(slot, error: "input is larger than the max context size. skipping" , type: ERROR_TYPE_EXCEED_CONTEXT_SIZE); |
| 3756 | slot.release(); |
| 3757 | continue; |
| 3758 | } |
| 3759 | } else { |
| 3760 | if (slot.task->n_tokens() >= slot.n_ctx) { |
| 3761 | send_error(slot, error: "the request exceeds the available context size, try increasing it" , type: ERROR_TYPE_EXCEED_CONTEXT_SIZE); |
| 3762 | slot.release(); |
| 3763 | continue; |
| 3764 | } |
| 3765 | |
| 3766 | if (slot.task->params.cache_prompt) { |
| 3767 | // reuse any previously computed tokens that are common with the new prompt |
| 3768 | n_past = slot.prompt.tokens.get_common_prefix(b: input_tokens); |
| 3769 | |
| 3770 | // if there is an alora invoked, don't cache after the invocation start |
| 3771 | if (slot.alora_invocation_start > 0) { |
| 3772 | SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n" , n_past, slot.alora_invocation_start); |
| 3773 | n_past = std::min(a: n_past, b: slot.alora_invocation_start - 1); |
| 3774 | } |
| 3775 | |
| 3776 | // reuse chunks from the cached prompt by shifting their KV cache in the new position |
| 3777 | if (params_base.n_cache_reuse > 0) { |
| 3778 | GGML_ASSERT(!slot.prompt.tokens.has_mtmd); |
| 3779 | |
| 3780 | size_t head_c = n_past; // cache |
| 3781 | size_t head_p = n_past; // current prompt |
| 3782 | |
| 3783 | if (mctx) { |
| 3784 | // we should never reach this |
| 3785 | GGML_ABORT("not supported by multimodal" ); |
| 3786 | } |
| 3787 | |
| 3788 | SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n" , params_base.n_cache_reuse, n_past); |
| 3789 | |
| 3790 | while (head_c < slot.prompt.tokens.size() && |
| 3791 | head_p < input_tokens.size()) { |
| 3792 | |
| 3793 | size_t n_match = 0; |
| 3794 | while (head_c + n_match < slot.prompt.tokens.size() && |
| 3795 | head_p + n_match < input_tokens.size() && |
| 3796 | slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { |
| 3797 | |
| 3798 | n_match++; |
| 3799 | } |
| 3800 | |
| 3801 | if (n_match >= (size_t) params_base.n_cache_reuse) { |
| 3802 | SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n" , n_match, head_c, head_c + n_match, head_p, head_p + n_match); |
| 3803 | //for (size_t i = head_p; i < head_p + n_match; i++) { |
| 3804 | // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); |
| 3805 | //} |
| 3806 | |
| 3807 | const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; |
| 3808 | |
| 3809 | llama_memory_seq_rm (mem: llama_get_memory(ctx), seq_id: slot.id, p0: head_p, p1: head_c); |
| 3810 | llama_memory_seq_add(mem: llama_get_memory(ctx), seq_id: slot.id, p0: head_c, p1: head_c + n_match, delta: kv_shift); |
| 3811 | |
| 3812 | for (size_t i = 0; i < n_match; i++) { |
| 3813 | slot.prompt.tokens.set_token(pos: head_p + i, id: slot.prompt.tokens[head_c + i]); |
| 3814 | n_past++; |
| 3815 | } |
| 3816 | |
| 3817 | head_c += n_match; |
| 3818 | head_p += n_match; |
| 3819 | } else { |
| 3820 | head_c += 1; |
| 3821 | } |
| 3822 | } |
| 3823 | |
| 3824 | SLT_DBG(slot, "after context reuse, new n_past = %d\n" , n_past); |
| 3825 | } |
| 3826 | } else { |
| 3827 | // if we don't cache the prompt, we have to remove all previous tokens |
| 3828 | n_past = 0; |
| 3829 | } |
| 3830 | |
| 3831 | // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 |
| 3832 | const auto n_swa = std::max(a: 1, b: llama_model_n_swa(model)); |
| 3833 | |
| 3834 | // the largest pos_min required for a checkpoint to be useful |
| 3835 | const auto pos_min_thold = std::max(a: 0, b: n_past - n_swa); |
| 3836 | |
| 3837 | // note: disallow with mtmd contexts for now |
| 3838 | // https://github.com/ggml-org/llama.cpp/issues/17043 |
| 3839 | if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { |
| 3840 | const auto pos_min = llama_memory_seq_pos_min(mem: llama_get_memory(ctx), seq_id: slot.id); |
| 3841 | if (pos_min == -1) { |
| 3842 | SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n" , n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); |
| 3843 | GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237" ); |
| 3844 | } |
| 3845 | |
| 3846 | // when the prompt prefix does not match, print the tokens around the mismatch |
| 3847 | // this is useful for debugging prompt caching |
| 3848 | if (slots_debug) { |
| 3849 | const int np0 = std::max<int>(a: n_past - 4, b: 0); |
| 3850 | const int np1 = std::min<int>(a: n_past + 6, b: std::min(a: slot.prompt.tokens.size(), b: slot.task->tokens.size())); |
| 3851 | |
| 3852 | std::stringstream ss0; |
| 3853 | std::stringstream ss1; |
| 3854 | |
| 3855 | std::stringstream st0; |
| 3856 | std::stringstream st1; |
| 3857 | |
| 3858 | ss0 << "old: ... " ; |
| 3859 | ss1 << "new: ... " ; |
| 3860 | |
| 3861 | for (int i = np0; i < np1; i++) { |
| 3862 | if (i == n_past) { |
| 3863 | ss0 << " | " ; |
| 3864 | ss1 << " | " ; |
| 3865 | } |
| 3866 | |
| 3867 | { |
| 3868 | const auto token = slot.prompt.tokens[i]; |
| 3869 | const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]" ; |
| 3870 | ss0 << piece; |
| 3871 | st0 << std::setw(8) << token; |
| 3872 | } |
| 3873 | |
| 3874 | { |
| 3875 | const auto token = slot.task->tokens[i]; |
| 3876 | const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]" ; |
| 3877 | ss1 << piece; |
| 3878 | st1 << std::setw(8) << token; |
| 3879 | } |
| 3880 | } |
| 3881 | |
| 3882 | SLT_WRN(slot, "%s\n" , ss0.str().c_str()); |
| 3883 | SLT_WRN(slot, "%s\n" , ss1.str().c_str()); |
| 3884 | |
| 3885 | SLT_WRN(slot, "%s\n" , st0.str().c_str()); |
| 3886 | SLT_WRN(slot, "%s\n" , st1.str().c_str()); |
| 3887 | } |
| 3888 | |
| 3889 | if (pos_min > pos_min_thold) { |
| 3890 | // TODO: support can be added in the future when corresponding vision models get released |
| 3891 | GGML_ASSERT(!slot.prompt.tokens.has_mtmd); |
| 3892 | |
| 3893 | SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n" , n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); |
| 3894 | |
| 3895 | // search for a context checkpoint |
| 3896 | const auto it = std::find_if( |
| 3897 | first: slot.prompt.checkpoints.rbegin(), |
| 3898 | last: slot.prompt.checkpoints.rend(), |
| 3899 | pred: [&](const auto & cur) { |
| 3900 | // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] |
| 3901 | return cur.pos_min < pos_min_thold; |
| 3902 | } |
| 3903 | ); |
| 3904 | |
| 3905 | bool do_reset = it == slot.prompt.checkpoints.rend(); |
| 3906 | |
| 3907 | if (!do_reset) { |
| 3908 | // restore the context checkpoint |
| 3909 | const size_t checkpoint_size = it->data.size(); |
| 3910 | const size_t n = llama_state_seq_set_data_ext(ctx, src: it->data.data(), size: checkpoint_size, dest_seq_id: slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); |
| 3911 | |
| 3912 | if (n != checkpoint_size) { |
| 3913 | SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n" , it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); |
| 3914 | do_reset = true; |
| 3915 | //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); |
| 3916 | } else { |
| 3917 | n_past = std::min(a: n_past, b: std::max(a: it->pos_min + 1, b: it->pos_max)); |
| 3918 | SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n" , it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); |
| 3919 | } |
| 3920 | } |
| 3921 | |
| 3922 | if (do_reset) { |
| 3923 | SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n" , |
| 3924 | "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" ); |
| 3925 | n_past = 0; |
| 3926 | } |
| 3927 | } |
| 3928 | } |
| 3929 | |
| 3930 | { |
| 3931 | // erase any checkpoints with pos_min > pos_min_thold |
| 3932 | for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { |
| 3933 | const auto & cur = *it; |
| 3934 | if (cur.pos_min > pos_min_thold) { |
| 3935 | SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n" , cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); |
| 3936 | it = slot.prompt.checkpoints.erase(position: it); |
| 3937 | } else { |
| 3938 | ++it; |
| 3939 | } |
| 3940 | } |
| 3941 | } |
| 3942 | } |
| 3943 | |
| 3944 | // [TAG_PROMPT_LOGITS] |
| 3945 | if (n_past == slot.task->n_tokens() && n_past > 0) { |
| 3946 | SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n" , n_past, slot.task->n_tokens()); |
| 3947 | n_past--; |
| 3948 | SLT_WRN(slot, "n_past was set to %d\n" , n_past); |
| 3949 | } |
| 3950 | |
| 3951 | slot.n_prompt_tokens_cache = n_past; |
| 3952 | slot.n_prompt_tokens_processed = 0; |
| 3953 | |
| 3954 | slot.prompt.tokens.keep_first(n: n_past); |
| 3955 | } |
| 3956 | |
| 3957 | if (!slot.can_split()) { |
| 3958 | // cannot fit the prompt in the current batch - will try next iter |
| 3959 | if (batch.n_tokens + slot.task->n_tokens() > n_batch) { |
| 3960 | continue; |
| 3961 | } |
| 3962 | } |
| 3963 | |
| 3964 | // truncate any tokens that are beyond n_past for this slot |
| 3965 | const llama_pos p0 = slot.prompt.tokens.pos_next(); |
| 3966 | |
| 3967 | SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n" , slot.prompt.n_tokens(), p0); |
| 3968 | |
| 3969 | if (!llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0, p1: -1)) { |
| 3970 | SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n" , p0); |
| 3971 | llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0: -1, p1: -1); |
| 3972 | |
| 3973 | // there is no common part left |
| 3974 | slot.n_prompt_tokens_cache = 0; |
| 3975 | |
| 3976 | slot.prompt.tokens.clear(); |
| 3977 | } |
| 3978 | |
| 3979 | // check if we should process the image |
| 3980 | if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { |
| 3981 | // process the image |
| 3982 | size_t n_tokens_out = 0; |
| 3983 | int32_t res = input_tokens.process_chunk(ctx, mctx, idx: slot.prompt.n_tokens(), pos: slot.prompt.tokens.pos_next(), seq_id: slot.id, n_tokens_out); |
| 3984 | if (res != 0) { |
| 3985 | SLT_ERR(slot, "failed to process image, res = %d\n" , res); |
| 3986 | send_error(slot, error: "failed to process image" , type: ERROR_TYPE_SERVER); |
| 3987 | slot.release(); |
| 3988 | continue; |
| 3989 | } |
| 3990 | |
| 3991 | slot.n_prompt_tokens_processed += n_tokens_out; |
| 3992 | |
| 3993 | // add the image chunk to cache |
| 3994 | { |
| 3995 | const auto & chunk = input_tokens.find_chunk(idx: slot.prompt.n_tokens()); |
| 3996 | slot.prompt.tokens.push_back(chunk: chunk.get()); // copy |
| 3997 | } |
| 3998 | } |
| 3999 | |
| 4000 | // If using an alora, there may be uncached tokens that come |
| 4001 | // before the invocation sequence. When this happens, the |
| 4002 | // tokens before the invocation sequence need to be |
| 4003 | // processed without the adapter in a separate batch, then |
| 4004 | // the adapter needs to be enabled for the remaining tokens. |
| 4005 | if (lora_all_alora(loras: slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { |
| 4006 | SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n" , slot.prompt.n_tokens(), slot.alora_invocation_start); |
| 4007 | const auto & enabled_loras = lora_get_enabled_ids(loras: slot.lora); |
| 4008 | GGML_ASSERT(enabled_loras.size() == 1); |
| 4009 | alora_scale = slot.lora[enabled_loras[0]].scale; |
| 4010 | slot.lora[enabled_loras[0]].scale = 0.0f; |
| 4011 | alora_disabled_id = enabled_loras[0]; |
| 4012 | } |
| 4013 | |
| 4014 | bool do_checkpoint = params_base.n_ctx_checkpoints > 0; |
| 4015 | |
| 4016 | // make checkpoints only for completion tasks |
| 4017 | do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; |
| 4018 | |
| 4019 | // make a checkpoint of the parts of the memory that cannot be rolled back. |
| 4020 | // checkpoints are created only if: |
| 4021 | // - the model uses SWA and we are not using `swa_full` |
| 4022 | // - the model architecture is marked as recurrent or hybrid |
| 4023 | // |
| 4024 | // TODO: try to make this conditional on the context or the memory module, instead of the model type |
| 4025 | do_checkpoint = do_checkpoint && ( |
| 4026 | llama_model_is_recurrent(model) || |
| 4027 | llama_model_is_hybrid(model) || |
| 4028 | (llama_model_n_swa(model) > 0 && !params_base.swa_full) |
| 4029 | ); |
| 4030 | |
| 4031 | // add prompt tokens for processing in the current batch |
| 4032 | while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { |
| 4033 | // get next token to process |
| 4034 | llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; |
| 4035 | if (cur_tok == LLAMA_TOKEN_NULL) { |
| 4036 | break; // end of text chunk |
| 4037 | } |
| 4038 | |
| 4039 | // if this is an alora request with pre-invocation |
| 4040 | // tokens that are not cached, we need to stop filling |
| 4041 | // this batch at those pre-invocation tokens. |
| 4042 | if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { |
| 4043 | SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n" , slot.prompt.n_tokens(), slot.alora_invocation_start); |
| 4044 | break; |
| 4045 | } |
| 4046 | |
| 4047 | // embedding requires all tokens in the batch to be output |
| 4048 | common_batch_add(batch, |
| 4049 | id: cur_tok, |
| 4050 | pos: slot.prompt.tokens.pos_next(), |
| 4051 | seq_ids: { slot.id }, |
| 4052 | logits: slot.need_embd()); |
| 4053 | slot.prompt.tokens.push_back(tok: cur_tok); |
| 4054 | |
| 4055 | slot.n_prompt_tokens_processed++; |
| 4056 | |
| 4057 | // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. |
| 4058 | if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { |
| 4059 | break; |
| 4060 | } |
| 4061 | } |
| 4062 | |
| 4063 | // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); |
| 4064 | |
| 4065 | SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n" , slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); |
| 4066 | |
| 4067 | // entire prompt has been processed |
| 4068 | if (slot.prompt.n_tokens() == slot.task->n_tokens()) { |
| 4069 | slot.state = SLOT_STATE_DONE_PROMPT; |
| 4070 | |
| 4071 | GGML_ASSERT(batch.n_tokens > 0); |
| 4072 | |
| 4073 | common_sampler_reset(gsmpl: slot.smpl); |
| 4074 | |
| 4075 | // Process all prompt tokens through sampler system |
| 4076 | for (int i = 0; i < slot.task->n_tokens(); ++i) { |
| 4077 | llama_token id = input_tokens[i]; |
| 4078 | if (id != LLAMA_TOKEN_NULL) { |
| 4079 | common_sampler_accept(gsmpl: slot.smpl, token: id, accept_grammar: false); |
| 4080 | } |
| 4081 | } |
| 4082 | |
| 4083 | // extract the logits only for the last token |
| 4084 | batch.logits[batch.n_tokens - 1] = true; |
| 4085 | |
| 4086 | slot.n_decoded = 0; |
| 4087 | slot.i_batch = batch.n_tokens - 1; |
| 4088 | |
| 4089 | SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n" , slot.prompt.n_tokens(), batch.n_tokens); |
| 4090 | |
| 4091 | const auto pos_min = llama_memory_seq_pos_min(mem: llama_get_memory(ctx), seq_id: slot.id); |
| 4092 | const auto pos_max = llama_memory_seq_pos_max(mem: llama_get_memory(ctx), seq_id: slot.id); |
| 4093 | |
| 4094 | // no need for empty or small checkpoints |
| 4095 | do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); |
| 4096 | |
| 4097 | // no need to create checkpoints that are too close together |
| 4098 | do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); |
| 4099 | |
| 4100 | if (do_checkpoint) { |
| 4101 | while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { |
| 4102 | // make room for the new checkpoint, if needed |
| 4103 | const auto & cur = slot.prompt.checkpoints.front(); |
| 4104 | |
| 4105 | SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n" , |
| 4106 | cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); |
| 4107 | |
| 4108 | slot.prompt.checkpoints.erase(position: slot.prompt.checkpoints.begin()); |
| 4109 | } |
| 4110 | |
| 4111 | const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, seq_id: slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); |
| 4112 | |
| 4113 | auto & cur = slot.prompt.checkpoints.emplace_back(args: server_prompt_checkpoint{ |
| 4114 | /*.pos_min = */ pos_min, |
| 4115 | /*.pos_max = */ pos_max, |
| 4116 | /*.data = */ std::vector<uint8_t>(checkpoint_size), |
| 4117 | }); |
| 4118 | |
| 4119 | llama_state_seq_get_data_ext(ctx, dst: cur.data.data(), size: checkpoint_size, seq_id: slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); |
| 4120 | |
| 4121 | SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n" , |
| 4122 | (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); |
| 4123 | } |
| 4124 | } |
| 4125 | } |
| 4126 | |
| 4127 | if (batch.n_tokens >= n_batch) { |
| 4128 | break; |
| 4129 | } |
| 4130 | } |
| 4131 | } |
| 4132 | |
| 4133 | if (batch.n_tokens == 0) { |
| 4134 | SRV_WRN("%s" , "no tokens to decode\n" ); |
| 4135 | return; |
| 4136 | } |
| 4137 | |
| 4138 | SRV_DBG("decoding batch, n_tokens = %d\n" , batch.n_tokens); |
| 4139 | |
| 4140 | if (slot_batched) { |
| 4141 | // apply lora, only need to do it once per batch |
| 4142 | common_set_adapter_lora(ctx, lora&: slot_batched->lora); |
| 4143 | |
| 4144 | // if the lora is temporarily disabled for an alora, re-enable it |
| 4145 | // for next time |
| 4146 | if (alora_scale > 0.0f) { |
| 4147 | SRV_DBG("re-enabling alora with scale %f\n" , alora_scale); |
| 4148 | slot_batched->lora[alora_disabled_id].scale = alora_scale; |
| 4149 | } |
| 4150 | |
| 4151 | llama_set_embeddings(ctx, embeddings: slot_batched->need_embd()); |
| 4152 | } |
| 4153 | |
| 4154 | int32_t i_next = 0; |
| 4155 | |
| 4156 | // process the created batch of tokens |
| 4157 | for (int32_t i = 0; i < batch.n_tokens; i = i_next) { |
| 4158 | const int32_t n_tokens = std::min(a: n_batch, b: batch.n_tokens - i); |
| 4159 | |
| 4160 | llama_batch batch_view = { |
| 4161 | .n_tokens: n_tokens, |
| 4162 | .token: batch.token + i, |
| 4163 | .embd: nullptr, |
| 4164 | .pos: batch.pos + i, |
| 4165 | .n_seq_id: batch.n_seq_id + i, |
| 4166 | .seq_id: batch.seq_id + i, |
| 4167 | .logits: batch.logits + i, |
| 4168 | }; |
| 4169 | |
| 4170 | const int ret = llama_decode(ctx, batch: batch_view); |
| 4171 | |
| 4172 | metrics.on_decoded(slots); |
| 4173 | |
| 4174 | if (ret != 0) { |
| 4175 | { |
| 4176 | std::string err; |
| 4177 | |
| 4178 | if (n_batch == 1 && ret == 1) { |
| 4179 | // TODO: try to terminate only the largest active slot/sequence and continue with the rest |
| 4180 | // need to remove the tokens from the current batch too |
| 4181 | err = "Context size has been exceeded." ; |
| 4182 | } |
| 4183 | |
| 4184 | if (ret == -1) { |
| 4185 | err = "Invalid input batch." ; |
| 4186 | } |
| 4187 | |
| 4188 | if (ret < -1) { |
| 4189 | // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() |
| 4190 | err = "Compute error." ; |
| 4191 | } |
| 4192 | |
| 4193 | // TODO: handle ret == 2 (abort) when we start aborting |
| 4194 | |
| 4195 | if (!err.empty()) { |
| 4196 | SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n" , err.c_str(), i, n_batch, ret); |
| 4197 | |
| 4198 | for (auto & slot : slots) { |
| 4199 | if (slot.is_processing()) { |
| 4200 | send_error(slot, error: err); |
| 4201 | slot.release(); |
| 4202 | } |
| 4203 | } |
| 4204 | |
| 4205 | break; |
| 4206 | } |
| 4207 | } |
| 4208 | |
| 4209 | // retry with half the batch size to try to find a free slot in the KV cache |
| 4210 | if (!try_purge_idle_slots()) { |
| 4211 | n_batch /= 2; |
| 4212 | } |
| 4213 | |
| 4214 | SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n" , i, n_batch, ret); |
| 4215 | |
| 4216 | continue; // continue loop of n_batch |
| 4217 | } |
| 4218 | |
| 4219 | // move the head of the batch forward with the number of tokens we just processed |
| 4220 | i_next = i + n_tokens; |
| 4221 | |
| 4222 | // on successful decode, restore the original batch size |
| 4223 | n_batch = llama_n_batch(ctx); |
| 4224 | |
| 4225 | for (auto & slot : slots) { |
| 4226 | // optionally send prompt processing progress |
| 4227 | if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { |
| 4228 | if (slot.task->params.stream && slot.task->params.return_progress) { |
| 4229 | send_partial_response(slot, tkn: {}, is_progress: true); |
| 4230 | } |
| 4231 | } |
| 4232 | |
| 4233 | if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { |
| 4234 | continue; // continue loop of slots |
| 4235 | } |
| 4236 | |
| 4237 | if (slot.state == SLOT_STATE_DONE_PROMPT) { |
| 4238 | if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { |
| 4239 | // prompt evaluated for embedding |
| 4240 | send_embedding(slot, batch: batch_view); |
| 4241 | slot.release(); |
| 4242 | slot.i_batch = -1; |
| 4243 | continue; // continue loop of slots |
| 4244 | } |
| 4245 | |
| 4246 | if (slot.task->type == SERVER_TASK_TYPE_RERANK) { |
| 4247 | send_rerank(slot, batch: batch_view); |
| 4248 | slot.release(); |
| 4249 | slot.i_batch = -1; |
| 4250 | continue; // continue loop of slots |
| 4251 | } |
| 4252 | |
| 4253 | // prompt evaluated for next-token prediction |
| 4254 | slot.state = SLOT_STATE_GENERATING; |
| 4255 | } else if (slot.state != SLOT_STATE_GENERATING) { |
| 4256 | continue; // continue loop of slots |
| 4257 | } |
| 4258 | |
| 4259 | const int tok_idx = slot.i_batch - i; |
| 4260 | |
| 4261 | llama_token id = common_sampler_sample(gsmpl: slot.smpl, ctx, idx: tok_idx); |
| 4262 | |
| 4263 | slot.i_batch = -1; |
| 4264 | |
| 4265 | common_sampler_accept(gsmpl: slot.smpl, token: id, accept_grammar: true); |
| 4266 | |
| 4267 | slot.n_decoded += 1; |
| 4268 | |
| 4269 | const int64_t t_current = ggml_time_us(); |
| 4270 | |
| 4271 | if (slot.n_decoded == 1) { |
| 4272 | slot.t_start_generation = t_current; |
| 4273 | slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; |
| 4274 | metrics.on_prompt_eval(slot); |
| 4275 | } |
| 4276 | |
| 4277 | slot.t_token_generation = std::max<int64_t>(a: 1, b: t_current - slot.t_start_generation) / 1e3; |
| 4278 | |
| 4279 | completion_token_output result; |
| 4280 | result.tok = id; |
| 4281 | result.text_to_send = common_token_to_piece(ctx, token: result.tok, special: accept_special_token(slot, result.tok)); |
| 4282 | result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs |
| 4283 | |
| 4284 | if (slot.task->params.sampling.n_probs > 0) { |
| 4285 | populate_token_probs(slot, result, post_sampling: slot.task->params.post_sampling_probs, special: params_base.special, idx: tok_idx); |
| 4286 | } |
| 4287 | |
| 4288 | if (!process_token(result, slot)) { |
| 4289 | // release slot because of stop condition |
| 4290 | slot.print_timings(); |
| 4291 | send_final_response(slot); |
| 4292 | metrics.on_prediction(slot); |
| 4293 | slot.release(); |
| 4294 | |
| 4295 | continue; |
| 4296 | } |
| 4297 | } |
| 4298 | |
| 4299 | // do speculative decoding |
| 4300 | // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] |
| 4301 | // perform the speculative drafting for all sequences at the same time in a single batch |
| 4302 | for (auto & slot : slots) { |
| 4303 | if (!slot.is_processing() || !slot.can_speculate()) { |
| 4304 | continue; |
| 4305 | } |
| 4306 | |
| 4307 | if (slot.state != SLOT_STATE_GENERATING) { |
| 4308 | continue; |
| 4309 | } |
| 4310 | |
| 4311 | if (mctx) { |
| 4312 | // we should never reach this, as speculative is automatically disabled if mmproj is loaded |
| 4313 | GGML_ABORT("not supported by multimodal" ); |
| 4314 | } |
| 4315 | |
| 4316 | // determine the max draft that fits the current slot state |
| 4317 | int n_draft_max = slot.task->params.speculative.n_max; |
| 4318 | |
| 4319 | // note: slot.prompt is not yet expanded with the `id` token sampled above |
| 4320 | // also, need to leave space for 1 extra token to allow context shifts |
| 4321 | n_draft_max = std::min(a: n_draft_max, b: slot.n_ctx - slot.prompt.n_tokens() - 2); |
| 4322 | |
| 4323 | if (slot.n_remaining > 0) { |
| 4324 | n_draft_max = std::min(a: n_draft_max, b: slot.n_remaining - 1); |
| 4325 | } |
| 4326 | |
| 4327 | SLT_DBG(slot, "max possible draft: %d\n" , n_draft_max); |
| 4328 | |
| 4329 | if (n_draft_max < slot.task->params.speculative.n_min) { |
| 4330 | SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n" , n_draft_max, slot.task->params.speculative.n_min); |
| 4331 | |
| 4332 | continue; |
| 4333 | } |
| 4334 | |
| 4335 | llama_token id = slot.sampled; |
| 4336 | |
| 4337 | struct common_speculative_params params_spec; |
| 4338 | params_spec.n_draft = n_draft_max; |
| 4339 | params_spec.n_reuse = llama_n_ctx(ctx: slot.ctx_dft) - slot.task->params.speculative.n_max; |
| 4340 | params_spec.p_min = slot.task->params.speculative.p_min; |
| 4341 | |
| 4342 | const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); |
| 4343 | llama_tokens draft = common_speculative_gen_draft(spec: slot.spec, params: params_spec, prompt: cached_text_tokens, id_last: id); |
| 4344 | |
| 4345 | // ignore small drafts |
| 4346 | if (slot.task->params.speculative.n_min > (int) draft.size()) { |
| 4347 | SLT_DBG(slot, "ignoring small draft: %d < %d\n" , (int) draft.size(), slot.task->params.speculative.n_min); |
| 4348 | |
| 4349 | continue; |
| 4350 | } |
| 4351 | |
| 4352 | // keep track of total number of drafted tokens tested |
| 4353 | slot.n_draft_total += draft.size(); |
| 4354 | |
| 4355 | // construct the speculation batch |
| 4356 | common_batch_clear(batch&: slot.batch_spec); |
| 4357 | common_batch_add (batch&: slot.batch_spec, id, pos: slot.prompt.tokens.pos_next(), seq_ids: { slot.id }, logits: true); |
| 4358 | |
| 4359 | for (size_t i = 0; i < draft.size(); ++i) { |
| 4360 | common_batch_add(batch&: slot.batch_spec, id: draft[i], pos: slot.prompt.tokens.pos_next() + 1 + i, seq_ids: { slot.id }, logits: true); |
| 4361 | } |
| 4362 | |
| 4363 | SLT_DBG(slot, "decoding speculative batch, size = %d\n" , slot.batch_spec.n_tokens); |
| 4364 | |
| 4365 | llama_decode(ctx, batch: slot.batch_spec); |
| 4366 | |
| 4367 | // the accepted tokens from the speculation |
| 4368 | const auto ids = common_sampler_sample_and_accept_n(gsmpl: slot.smpl, ctx, draft); |
| 4369 | |
| 4370 | slot.n_decoded += ids.size(); |
| 4371 | |
| 4372 | // update how many tokens out of those tested were accepted |
| 4373 | slot.n_draft_accepted += ids.size() - 1; |
| 4374 | |
| 4375 | slot.prompt.tokens.push_back(tok: id); |
| 4376 | slot.prompt.tokens.insert(inp_tokens: {ids.begin(), ids.end() - 1}); |
| 4377 | |
| 4378 | llama_memory_seq_rm(mem: llama_get_memory(ctx), seq_id: slot.id, p0: slot.prompt.n_tokens(), p1: -1); |
| 4379 | |
| 4380 | for (size_t i = 0; i < ids.size(); ++i) { |
| 4381 | completion_token_output result; |
| 4382 | |
| 4383 | result.tok = ids[i]; |
| 4384 | result.text_to_send = common_token_to_piece(ctx, token: result.tok, special: accept_special_token(slot, result.tok)); |
| 4385 | result.prob = 1.0f; // set later |
| 4386 | |
| 4387 | // TODO: set result.probs |
| 4388 | |
| 4389 | if (!process_token(result, slot)) { |
| 4390 | slot.print_timings(); |
| 4391 | send_final_response(slot); |
| 4392 | metrics.on_prediction(slot); |
| 4393 | slot.release(); |
| 4394 | |
| 4395 | break; |
| 4396 | } |
| 4397 | } |
| 4398 | |
| 4399 | SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n" , (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); |
| 4400 | } |
| 4401 | } |
| 4402 | |
| 4403 | SRV_DBG("%s" , "run slots completed\n" ); |
| 4404 | } |
| 4405 | |
| 4406 | json model_meta() const { |
| 4407 | return json { |
| 4408 | {"vocab_type" , llama_vocab_type (vocab)}, |
| 4409 | {"n_vocab" , llama_vocab_n_tokens (vocab)}, |
| 4410 | {"n_ctx_train" , llama_model_n_ctx_train(model)}, |
| 4411 | {"n_embd" , llama_model_n_embd (model)}, |
| 4412 | {"n_params" , llama_model_n_params (model)}, |
| 4413 | {"size" , llama_model_size (model)}, |
| 4414 | }; |
| 4415 | } |
| 4416 | }; |
| 4417 | |
| 4418 | static void log_server_request(const httplib::Request & req, const httplib::Response & res) { |
| 4419 | // skip GH copilot requests when using default port |
| 4420 | if (req.path == "/v1/health" ) { |
| 4421 | return; |
| 4422 | } |
| 4423 | |
| 4424 | // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch |
| 4425 | |
| 4426 | SRV_INF("request: %s %s %s %d\n" , req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); |
| 4427 | |
| 4428 | SRV_DBG("request: %s\n" , req.body.c_str()); |
| 4429 | SRV_DBG("response: %s\n" , res.body.c_str()); |
| 4430 | } |
| 4431 | |
| 4432 | std::function<void(int)> shutdown_handler; |
| 4433 | std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; |
| 4434 | |
| 4435 | inline void signal_handler(int signal) { |
| 4436 | if (is_terminating.test_and_set()) { |
| 4437 | // in case it hangs, we can force terminate the server by hitting Ctrl+C twice |
| 4438 | // this is for better developer experience, we can remove when the server is stable enough |
| 4439 | fprintf(stderr, format: "Received second interrupt, terminating immediately.\n" ); |
| 4440 | exit(status: 1); |
| 4441 | } |
| 4442 | |
| 4443 | shutdown_handler(signal); |
| 4444 | } |
| 4445 | |
| 4446 | int main(int argc, char ** argv) { |
| 4447 | // own arguments required by this example |
| 4448 | common_params params; |
| 4449 | |
| 4450 | if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_SERVER)) { |
| 4451 | return 1; |
| 4452 | } |
| 4453 | |
| 4454 | // TODO: should we have a separate n_parallel parameter for the server? |
| 4455 | // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 |
| 4456 | // TODO: this is a common configuration that is suitable for most local use cases |
| 4457 | // however, overriding the parameters is a bit confusing - figure out something more intuitive |
| 4458 | if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { |
| 4459 | LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n" , __func__); |
| 4460 | |
| 4461 | params.n_parallel = 4; |
| 4462 | params.kv_unified = true; |
| 4463 | } |
| 4464 | |
| 4465 | common_init(); |
| 4466 | |
| 4467 | // struct that contains llama context and inference |
| 4468 | server_context ctx_server; |
| 4469 | |
| 4470 | llama_backend_init(); |
| 4471 | llama_numa_init(numa: params.numa); |
| 4472 | |
| 4473 | LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n" , params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); |
| 4474 | LOG_INF("\n" ); |
| 4475 | LOG_INF("%s\n" , common_params_get_system_info(params).c_str()); |
| 4476 | LOG_INF("\n" ); |
| 4477 | |
| 4478 | std::unique_ptr<httplib::Server> svr; |
| 4479 | #ifdef CPPHTTPLIB_OPENSSL_SUPPORT |
| 4480 | if (params.ssl_file_key != "" && params.ssl_file_cert != "" ) { |
| 4481 | LOG_INF("Running with SSL: key = %s, cert = %s\n" , params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); |
| 4482 | svr.reset( |
| 4483 | new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) |
| 4484 | ); |
| 4485 | } else { |
| 4486 | LOG_INF("Running without SSL\n" ); |
| 4487 | svr.reset(new httplib::Server()); |
| 4488 | } |
| 4489 | #else |
| 4490 | if (params.ssl_file_key != "" && params.ssl_file_cert != "" ) { |
| 4491 | LOG_ERR("Server is built without SSL support\n" ); |
| 4492 | return 1; |
| 4493 | } |
| 4494 | svr.reset(p: new httplib::Server()); |
| 4495 | #endif |
| 4496 | |
| 4497 | std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL}; |
| 4498 | |
| 4499 | svr->set_default_headers({{"Server" , "llama.cpp" }}); |
| 4500 | svr->set_logger(log_server_request); |
| 4501 | |
| 4502 | auto res_error = [](httplib::Response & res, const json & error_data) { |
| 4503 | json final_response {{"error" , error_data}}; |
| 4504 | res.set_content(s: safe_json_to_str(data: final_response), MIMETYPE_JSON); |
| 4505 | res.status = json_value(body: error_data, key: "code" , default_value: 500); |
| 4506 | }; |
| 4507 | |
| 4508 | auto res_ok = [](httplib::Response & res, const json & data) { |
| 4509 | res.set_content(s: safe_json_to_str(data), MIMETYPE_JSON); |
| 4510 | res.status = 200; |
| 4511 | }; |
| 4512 | |
| 4513 | svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { |
| 4514 | std::string message; |
| 4515 | try { |
| 4516 | std::rethrow_exception(ep); |
| 4517 | } catch (const std::exception & e) { |
| 4518 | message = e.what(); |
| 4519 | } catch (...) { |
| 4520 | message = "Unknown Exception" ; |
| 4521 | } |
| 4522 | |
| 4523 | try { |
| 4524 | json formatted_error = format_error_response(message, type: ERROR_TYPE_SERVER); |
| 4525 | LOG_WRN("got exception: %s\n" , formatted_error.dump().c_str()); |
| 4526 | res_error(res, formatted_error); |
| 4527 | } catch (const std::exception & e) { |
| 4528 | LOG_ERR("got another exception: %s | while hanlding exception: %s\n" , e.what(), message.c_str()); |
| 4529 | } |
| 4530 | }); |
| 4531 | |
| 4532 | svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { |
| 4533 | if (res.status == 404) { |
| 4534 | res_error(res, format_error_response(message: "File Not Found" , type: ERROR_TYPE_NOT_FOUND)); |
| 4535 | } |
| 4536 | // for other error codes, we skip processing here because it's already done by res_error() |
| 4537 | }); |
| 4538 | |
| 4539 | // set timeouts and change hostname and port |
| 4540 | svr->set_read_timeout (sec: params.timeout_read); |
| 4541 | svr->set_write_timeout(sec: params.timeout_write); |
| 4542 | |
| 4543 | std::unordered_map<std::string, std::string> log_data; |
| 4544 | |
| 4545 | log_data["hostname" ] = params.hostname; |
| 4546 | log_data["port" ] = std::to_string(val: params.port); |
| 4547 | |
| 4548 | if (params.api_keys.size() == 1) { |
| 4549 | auto key = params.api_keys[0]; |
| 4550 | log_data["api_key" ] = "api_key: ****" + key.substr(pos: std::max(a: (int)(key.length() - 4), b: 0)); |
| 4551 | } else if (params.api_keys.size() > 1) { |
| 4552 | log_data["api_key" ] = "api_key: " + std::to_string(val: params.api_keys.size()) + " keys loaded" ; |
| 4553 | } |
| 4554 | |
| 4555 | // Necessary similarity of prompt for slot selection |
| 4556 | ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; |
| 4557 | |
| 4558 | // |
| 4559 | // Middlewares |
| 4560 | // |
| 4561 | |
| 4562 | auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { |
| 4563 | static const std::unordered_set<std::string> public_endpoints = { |
| 4564 | "/health" , |
| 4565 | "/v1/health" , |
| 4566 | "/models" , |
| 4567 | "/v1/models" , |
| 4568 | "/api/tags" |
| 4569 | }; |
| 4570 | |
| 4571 | // If API key is not set, skip validation |
| 4572 | if (params.api_keys.empty()) { |
| 4573 | return true; |
| 4574 | } |
| 4575 | |
| 4576 | // If path is public or is static file, skip validation |
| 4577 | if (public_endpoints.find(x: req.path) != public_endpoints.end() || req.path == "/" ) { |
| 4578 | return true; |
| 4579 | } |
| 4580 | |
| 4581 | // Check for API key in the header |
| 4582 | auto = req.get_header_value(key: "Authorization" ); |
| 4583 | |
| 4584 | std::string prefix = "Bearer " ; |
| 4585 | if (auth_header.substr(pos: 0, n: prefix.size()) == prefix) { |
| 4586 | std::string received_api_key = auth_header.substr(pos: prefix.size()); |
| 4587 | if (std::find(first: params.api_keys.begin(), last: params.api_keys.end(), val: received_api_key) != params.api_keys.end()) { |
| 4588 | return true; // API key is valid |
| 4589 | } |
| 4590 | } |
| 4591 | |
| 4592 | // API key is invalid or not provided |
| 4593 | res_error(res, format_error_response(message: "Invalid API Key" , type: ERROR_TYPE_AUTHENTICATION)); |
| 4594 | |
| 4595 | LOG_WRN("Unauthorized: Invalid API Key\n" ); |
| 4596 | |
| 4597 | return false; |
| 4598 | }; |
| 4599 | |
| 4600 | auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { |
| 4601 | server_state current_state = state.load(); |
| 4602 | if (current_state == SERVER_STATE_LOADING_MODEL) { |
| 4603 | auto tmp = string_split<std::string>(input: req.path, separator: '.'); |
| 4604 | if (req.path == "/" || tmp.back() == "html" ) { |
| 4605 | res.set_content(s: reinterpret_cast<const char*>(loading_html), n: loading_html_len, content_type: "text/html; charset=utf-8" ); |
| 4606 | res.status = 503; |
| 4607 | } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags" ) { |
| 4608 | // allow the models endpoint to be accessed during loading |
| 4609 | return true; |
| 4610 | } else { |
| 4611 | res_error(res, format_error_response(message: "Loading model" , type: ERROR_TYPE_UNAVAILABLE)); |
| 4612 | } |
| 4613 | return false; |
| 4614 | } |
| 4615 | return true; |
| 4616 | }; |
| 4617 | |
| 4618 | // register server middlewares |
| 4619 | svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { |
| 4620 | res.set_header(key: "Access-Control-Allow-Origin" , val: req.get_header_value(key: "Origin" )); |
| 4621 | // If this is OPTIONS request, skip validation because browsers don't include Authorization header |
| 4622 | if (req.method == "OPTIONS" ) { |
| 4623 | res.set_header(key: "Access-Control-Allow-Credentials" , val: "true" ); |
| 4624 | res.set_header(key: "Access-Control-Allow-Methods" , val: "GET, POST" ); |
| 4625 | res.set_header(key: "Access-Control-Allow-Headers" , val: "*" ); |
| 4626 | res.set_content(s: "" , content_type: "text/html" ); // blank response, no data |
| 4627 | return httplib::Server::HandlerResponse::Handled; // skip further processing |
| 4628 | } |
| 4629 | if (!middleware_server_state(req, res)) { |
| 4630 | return httplib::Server::HandlerResponse::Handled; |
| 4631 | } |
| 4632 | if (!middleware_validate_api_key(req, res)) { |
| 4633 | return httplib::Server::HandlerResponse::Handled; |
| 4634 | } |
| 4635 | return httplib::Server::HandlerResponse::Unhandled; |
| 4636 | }); |
| 4637 | |
| 4638 | // |
| 4639 | // Route handlers (or controllers) |
| 4640 | // |
| 4641 | |
| 4642 | const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { |
| 4643 | // error and loading states are handled by middleware |
| 4644 | json health = {{"status" , "ok" }}; |
| 4645 | res_ok(res, health); |
| 4646 | }; |
| 4647 | |
| 4648 | const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { |
| 4649 | if (!params.endpoint_slots) { |
| 4650 | res_error(res, format_error_response(message: "This server does not support slots endpoint. Start it with `--slots`" , type: ERROR_TYPE_NOT_SUPPORTED)); |
| 4651 | return; |
| 4652 | } |
| 4653 | |
| 4654 | // request slots data using task queue |
| 4655 | int task_id = ctx_server.queue_tasks.get_new_id(); |
| 4656 | { |
| 4657 | server_task task(SERVER_TASK_TYPE_METRICS); |
| 4658 | task.id = task_id; |
| 4659 | ctx_server.queue_results.add_waiting_task_id(id_task: task_id); |
| 4660 | ctx_server.queue_tasks.post(task: std::move(task), front: true); // high-priority task |
| 4661 | } |
| 4662 | |
| 4663 | // get the result |
| 4664 | server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id); |
| 4665 | ctx_server.queue_results.remove_waiting_task_id(id_task: task_id); |
| 4666 | |
| 4667 | if (result->is_error()) { |
| 4668 | res_error(res, result->to_json()); |
| 4669 | return; |
| 4670 | } |
| 4671 | |
| 4672 | // TODO: get rid of this dynamic_cast |
| 4673 | auto res_task = dynamic_cast<server_task_result_metrics*>(result.get()); |
| 4674 | GGML_ASSERT(res_task != nullptr); |
| 4675 | |
| 4676 | // optionally return "fail_on_no_slot" error |
| 4677 | if (req.has_param(key: "fail_on_no_slot" )) { |
| 4678 | if (res_task->n_idle_slots == 0) { |
| 4679 | res_error(res, format_error_response(message: "no slot available" , type: ERROR_TYPE_UNAVAILABLE)); |
| 4680 | return; |
| 4681 | } |
| 4682 | } |
| 4683 | |
| 4684 | res_ok(res, res_task->slots_data); |
| 4685 | }; |
| 4686 | |
| 4687 | const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { |
| 4688 | if (!params.endpoint_metrics) { |
| 4689 | res_error(res, format_error_response(message: "This server does not support metrics endpoint. Start it with `--metrics`" , type: ERROR_TYPE_NOT_SUPPORTED)); |
| 4690 | return; |
| 4691 | } |
| 4692 | |
| 4693 | // request slots data using task queue |
| 4694 | int task_id = ctx_server.queue_tasks.get_new_id(); |
| 4695 | { |
| 4696 | server_task task(SERVER_TASK_TYPE_METRICS); |
| 4697 | task.id = task_id; |
| 4698 | ctx_server.queue_results.add_waiting_task_id(id_task: task_id); |
| 4699 | ctx_server.queue_tasks.post(task: std::move(task), front: true); // high-priority task |
| 4700 | } |
| 4701 | |
| 4702 | // get the result |
| 4703 | server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id); |
| 4704 | ctx_server.queue_results.remove_waiting_task_id(id_task: task_id); |
| 4705 | |
| 4706 | if (result->is_error()) { |
| 4707 | res_error(res, result->to_json()); |
| 4708 | return; |
| 4709 | } |
| 4710 | |
| 4711 | // TODO: get rid of this dynamic_cast |
| 4712 | auto res_task = dynamic_cast<server_task_result_metrics*>(result.get()); |
| 4713 | GGML_ASSERT(res_task != nullptr); |
| 4714 | |
| 4715 | // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names |
| 4716 | json all_metrics_def = json { |
| 4717 | {"counter" , {{ |
| 4718 | {"name" , "prompt_tokens_total" }, |
| 4719 | {"help" , "Number of prompt tokens processed." }, |
| 4720 | {"value" , (uint64_t) res_task->n_prompt_tokens_processed_total} |
| 4721 | }, { |
| 4722 | {"name" , "prompt_seconds_total" }, |
| 4723 | {"help" , "Prompt process time" }, |
| 4724 | {"value" , (uint64_t) res_task->t_prompt_processing_total / 1.e3} |
| 4725 | }, { |
| 4726 | {"name" , "tokens_predicted_total" }, |
| 4727 | {"help" , "Number of generation tokens processed." }, |
| 4728 | {"value" , (uint64_t) res_task->n_tokens_predicted_total} |
| 4729 | }, { |
| 4730 | {"name" , "tokens_predicted_seconds_total" }, |
| 4731 | {"help" , "Predict process time" }, |
| 4732 | {"value" , (uint64_t) res_task->t_tokens_generation_total / 1.e3} |
| 4733 | }, { |
| 4734 | {"name" , "n_decode_total" }, |
| 4735 | {"help" , "Total number of llama_decode() calls" }, |
| 4736 | {"value" , res_task->n_decode_total} |
| 4737 | }, { |
| 4738 | {"name" , "n_tokens_max" }, |
| 4739 | {"help" , "Largest observed n_tokens." }, |
| 4740 | {"value" , res_task->n_tokens_max} |
| 4741 | }, { |
| 4742 | {"name" , "n_busy_slots_per_decode" }, |
| 4743 | {"help" , "Average number of busy slots per llama_decode() call" }, |
| 4744 | {"value" , (float) res_task->n_busy_slots_total / std::max(a: (float) res_task->n_decode_total, b: 1.f)} |
| 4745 | }}}, |
| 4746 | {"gauge" , {{ |
| 4747 | {"name" , "prompt_tokens_seconds" }, |
| 4748 | {"help" , "Average prompt throughput in tokens/s." }, |
| 4749 | {"value" , res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} |
| 4750 | },{ |
| 4751 | {"name" , "predicted_tokens_seconds" }, |
| 4752 | {"help" , "Average generation throughput in tokens/s." }, |
| 4753 | {"value" , res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} |
| 4754 | },{ |
| 4755 | {"name" , "requests_processing" }, |
| 4756 | {"help" , "Number of requests processing." }, |
| 4757 | {"value" , (uint64_t) res_task->n_processing_slots} |
| 4758 | },{ |
| 4759 | {"name" , "requests_deferred" }, |
| 4760 | {"help" , "Number of requests deferred." }, |
| 4761 | {"value" , (uint64_t) res_task->n_tasks_deferred} |
| 4762 | }}} |
| 4763 | }; |
| 4764 | |
| 4765 | std::stringstream prometheus; |
| 4766 | |
| 4767 | for (const auto & el : all_metrics_def.items()) { |
| 4768 | const auto & type = el.key(); |
| 4769 | const auto & metrics_def = el.value(); |
| 4770 | |
| 4771 | for (const auto & metric_def : metrics_def) { |
| 4772 | const std::string name = metric_def.at(key: "name" ); |
| 4773 | const std::string help = metric_def.at(key: "help" ); |
| 4774 | |
| 4775 | auto value = json_value(body: metric_def, key: "value" , default_value: 0.); |
| 4776 | prometheus << "# HELP llamacpp:" << name << " " << help << "\n" |
| 4777 | << "# TYPE llamacpp:" << name << " " << type << "\n" |
| 4778 | << "llamacpp:" << name << " " << value << "\n" ; |
| 4779 | } |
| 4780 | } |
| 4781 | |
| 4782 | res.set_header(key: "Process-Start-Time-Unix" , val: std::to_string(val: res_task->t_start)); |
| 4783 | |
| 4784 | res.set_content(s: prometheus.str(), content_type: "text/plain; version=0.0.4" ); |
| 4785 | res.status = 200; // HTTP OK |
| 4786 | }; |
| 4787 | |
| 4788 | const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { |
| 4789 | json request_data = json::parse(i: req.body); |
| 4790 | std::string filename = request_data.at(key: "filename" ); |
| 4791 | if (!fs_validate_filename(filename)) { |
| 4792 | res_error(res, format_error_response(message: "Invalid filename" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 4793 | return; |
| 4794 | } |
| 4795 | std::string filepath = params.slot_save_path + filename; |
| 4796 | |
| 4797 | int task_id = ctx_server.queue_tasks.get_new_id(); |
| 4798 | { |
| 4799 | server_task task(SERVER_TASK_TYPE_SLOT_SAVE); |
| 4800 | task.id = task_id; |
| 4801 | task.slot_action.slot_id = id_slot; |
| 4802 | task.slot_action.filename = filename; |
| 4803 | task.slot_action.filepath = filepath; |
| 4804 | |
| 4805 | ctx_server.queue_results.add_waiting_task_id(id_task: task_id); |
| 4806 | ctx_server.queue_tasks.post(task: std::move(task)); |
| 4807 | } |
| 4808 | |
| 4809 | server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id); |
| 4810 | ctx_server.queue_results.remove_waiting_task_id(id_task: task_id); |
| 4811 | |
| 4812 | if (result->is_error()) { |
| 4813 | res_error(res, result->to_json()); |
| 4814 | return; |
| 4815 | } |
| 4816 | |
| 4817 | res_ok(res, result->to_json()); |
| 4818 | }; |
| 4819 | |
| 4820 | const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { |
| 4821 | json request_data = json::parse(i: req.body); |
| 4822 | std::string filename = request_data.at(key: "filename" ); |
| 4823 | if (!fs_validate_filename(filename)) { |
| 4824 | res_error(res, format_error_response(message: "Invalid filename" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 4825 | return; |
| 4826 | } |
| 4827 | std::string filepath = params.slot_save_path + filename; |
| 4828 | |
| 4829 | int task_id = ctx_server.queue_tasks.get_new_id(); |
| 4830 | { |
| 4831 | server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); |
| 4832 | task.id = task_id; |
| 4833 | task.slot_action.slot_id = id_slot; |
| 4834 | task.slot_action.filename = filename; |
| 4835 | task.slot_action.filepath = filepath; |
| 4836 | |
| 4837 | ctx_server.queue_results.add_waiting_task_id(id_task: task_id); |
| 4838 | ctx_server.queue_tasks.post(task: std::move(task)); |
| 4839 | } |
| 4840 | |
| 4841 | server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id); |
| 4842 | ctx_server.queue_results.remove_waiting_task_id(id_task: task_id); |
| 4843 | |
| 4844 | if (result->is_error()) { |
| 4845 | res_error(res, result->to_json()); |
| 4846 | return; |
| 4847 | } |
| 4848 | |
| 4849 | GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr); |
| 4850 | res_ok(res, result->to_json()); |
| 4851 | }; |
| 4852 | |
| 4853 | const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { |
| 4854 | int task_id = ctx_server.queue_tasks.get_new_id(); |
| 4855 | { |
| 4856 | server_task task(SERVER_TASK_TYPE_SLOT_ERASE); |
| 4857 | task.id = task_id; |
| 4858 | task.slot_action.slot_id = id_slot; |
| 4859 | |
| 4860 | ctx_server.queue_results.add_waiting_task_id(id_task: task_id); |
| 4861 | ctx_server.queue_tasks.post(task: std::move(task)); |
| 4862 | } |
| 4863 | |
| 4864 | server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id); |
| 4865 | ctx_server.queue_results.remove_waiting_task_id(id_task: task_id); |
| 4866 | |
| 4867 | if (result->is_error()) { |
| 4868 | res_error(res, result->to_json()); |
| 4869 | return; |
| 4870 | } |
| 4871 | |
| 4872 | GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr); |
| 4873 | res_ok(res, result->to_json()); |
| 4874 | }; |
| 4875 | |
| 4876 | const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { |
| 4877 | if (params.slot_save_path.empty()) { |
| 4878 | res_error(res, format_error_response(message: "This server does not support slots action. Start it with `--slot-save-path`" , type: ERROR_TYPE_NOT_SUPPORTED)); |
| 4879 | return; |
| 4880 | } |
| 4881 | |
| 4882 | std::string id_slot_str = req.path_params.at(k: "id_slot" ); |
| 4883 | int id_slot; |
| 4884 | |
| 4885 | try { |
| 4886 | id_slot = std::stoi(str: id_slot_str); |
| 4887 | } catch (const std::exception &) { |
| 4888 | res_error(res, format_error_response(message: "Invalid slot ID" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 4889 | return; |
| 4890 | } |
| 4891 | |
| 4892 | std::string action = req.get_param_value(key: "action" ); |
| 4893 | |
| 4894 | if (action == "save" ) { |
| 4895 | handle_slots_save(req, res, id_slot); |
| 4896 | } else if (action == "restore" ) { |
| 4897 | handle_slots_restore(req, res, id_slot); |
| 4898 | } else if (action == "erase" ) { |
| 4899 | handle_slots_erase(req, res, id_slot); |
| 4900 | } else { |
| 4901 | res_error(res, format_error_response(message: "Invalid action" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 4902 | } |
| 4903 | }; |
| 4904 | |
| 4905 | const auto handle_props = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { |
| 4906 | json default_generation_settings_for_props; |
| 4907 | |
| 4908 | { |
| 4909 | slot_params params; |
| 4910 | |
| 4911 | params.sampling = ctx_server.params_base.sampling; |
| 4912 | |
| 4913 | default_generation_settings_for_props = json { |
| 4914 | {"params" , params.to_json(only_metrics: true)}, |
| 4915 | {"n_ctx" , ctx_server.slots[0].n_ctx}, |
| 4916 | }; |
| 4917 | } |
| 4918 | |
| 4919 | // this endpoint is publicly available, please only return what is safe to be exposed |
| 4920 | json data = { |
| 4921 | { "default_generation_settings" , default_generation_settings_for_props }, |
| 4922 | { "total_slots" , ctx_server.params_base.n_parallel }, |
| 4923 | { "model_alias" , ctx_server.params_base.model_alias }, |
| 4924 | { "model_path" , ctx_server.params_base.model.path }, |
| 4925 | { "modalities" , json { |
| 4926 | {"vision" , ctx_server.oai_parser_opt.allow_image}, |
| 4927 | {"audio" , ctx_server.oai_parser_opt.allow_audio}, |
| 4928 | } }, |
| 4929 | { "endpoint_slots" , params.endpoint_slots }, |
| 4930 | { "endpoint_props" , params.endpoint_props }, |
| 4931 | { "endpoint_metrics" , params.endpoint_metrics }, |
| 4932 | { "webui" , params.webui }, |
| 4933 | { "chat_template" , common_chat_templates_source(tmpls: ctx_server.chat_templates.get()) }, |
| 4934 | { "bos_token" , common_token_to_piece(ctx: ctx_server.ctx, token: llama_vocab_bos(vocab: ctx_server.vocab), /* special= */ true)}, |
| 4935 | { "eos_token" , common_token_to_piece(ctx: ctx_server.ctx, token: llama_vocab_eos(vocab: ctx_server.vocab), /* special= */ true)}, |
| 4936 | { "build_info" , build_info }, |
| 4937 | }; |
| 4938 | if (ctx_server.params_base.use_jinja) { |
| 4939 | if (auto tool_use_src = common_chat_templates_source(tmpls: ctx_server.chat_templates.get(), variant: "tool_use" )) { |
| 4940 | data["chat_template_tool_use" ] = tool_use_src; |
| 4941 | } |
| 4942 | } |
| 4943 | |
| 4944 | res_ok(res, data); |
| 4945 | }; |
| 4946 | |
| 4947 | const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { |
| 4948 | if (!ctx_server.params_base.endpoint_props) { |
| 4949 | res_error(res, format_error_response(message: "This server does not support changing global properties. Start it with `--props`" , type: ERROR_TYPE_NOT_SUPPORTED)); |
| 4950 | return; |
| 4951 | } |
| 4952 | |
| 4953 | json data = json::parse(i: req.body); |
| 4954 | |
| 4955 | // update any props here |
| 4956 | |
| 4957 | res_ok(res, {{ "success" , true }}); |
| 4958 | }; |
| 4959 | |
| 4960 | const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { |
| 4961 | bool has_mtmd = ctx_server.mctx != nullptr; |
| 4962 | json data = { |
| 4963 | { |
| 4964 | "template" , common_chat_templates_source(tmpls: ctx_server.chat_templates.get()), |
| 4965 | }, |
| 4966 | { |
| 4967 | "model_info" , { |
| 4968 | { "llama.context_length" , ctx_server.slots.back().n_ctx, }, |
| 4969 | } |
| 4970 | }, |
| 4971 | {"modelfile" , "" }, |
| 4972 | {"parameters" , "" }, |
| 4973 | {"template" , common_chat_templates_source(tmpls: ctx_server.chat_templates.get())}, |
| 4974 | {"details" , { |
| 4975 | {"parent_model" , "" }, |
| 4976 | {"format" , "gguf" }, |
| 4977 | {"family" , "" }, |
| 4978 | {"families" , {"" }}, |
| 4979 | {"parameter_size" , "" }, |
| 4980 | {"quantization_level" , "" } |
| 4981 | }}, |
| 4982 | {"model_info" , "" }, |
| 4983 | {"capabilities" , has_mtmd ? json({"completion" ,"multimodal" }) : json({"completion" })} |
| 4984 | }; |
| 4985 | |
| 4986 | res_ok(res, data); |
| 4987 | }; |
| 4988 | |
| 4989 | // handle completion-like requests (completion, chat, infill) |
| 4990 | // we can optionally provide a custom format for partial results and final results |
| 4991 | const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( |
| 4992 | server_task_type type, |
| 4993 | json & data, |
| 4994 | const std::vector<raw_buffer> & files, |
| 4995 | const std::function<bool()> & is_connection_closed, |
| 4996 | httplib::Response & res, |
| 4997 | oaicompat_type oaicompat) -> void { |
| 4998 | GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); |
| 4999 | |
| 5000 | auto completion_id = gen_chatcmplid(); |
| 5001 | std::unordered_set<int> task_ids; |
| 5002 | try { |
| 5003 | std::vector<server_task> tasks; |
| 5004 | |
| 5005 | const auto & prompt = data.at(key: "prompt" ); |
| 5006 | // TODO: this log can become very long, put it behind a flag or think about a more compact format |
| 5007 | //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str()); |
| 5008 | |
| 5009 | // process prompt |
| 5010 | std::vector<server_tokens> inputs; |
| 5011 | |
| 5012 | if (oaicompat && ctx_server.mctx != nullptr) { |
| 5013 | // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. |
| 5014 | inputs.push_back(x: process_mtmd_prompt(mctx: ctx_server.mctx, prompt: prompt.get<std::string>(), files)); |
| 5015 | } else { |
| 5016 | // Everything else, including multimodal completions. |
| 5017 | inputs = tokenize_input_prompts(vocab: ctx_server.vocab, mctx: ctx_server.mctx, json_prompt: prompt, add_special: true, parse_special: true); |
| 5018 | } |
| 5019 | const size_t n_ctx_slot = ctx_server.slots.front().n_ctx; |
| 5020 | tasks.reserve(n: inputs.size()); |
| 5021 | for (size_t i = 0; i < inputs.size(); i++) { |
| 5022 | auto n_prompt_tokens = inputs[i].size(); |
| 5023 | if (n_prompt_tokens >= n_ctx_slot) { |
| 5024 | json error_data = format_error_response(message: "the request exceeds the available context size, try increasing it" , type: ERROR_TYPE_EXCEED_CONTEXT_SIZE); |
| 5025 | error_data["n_prompt_tokens" ] = n_prompt_tokens; |
| 5026 | error_data["n_ctx" ] = n_ctx_slot; |
| 5027 | res_error(res, error_data); |
| 5028 | return; |
| 5029 | } |
| 5030 | server_task task = server_task(type); |
| 5031 | |
| 5032 | task.id = ctx_server.queue_tasks.get_new_id(); |
| 5033 | task.index = i; |
| 5034 | |
| 5035 | task.tokens = std::move(inputs[i]); |
| 5036 | task.params = server_task::params_from_json_cmpl( |
| 5037 | ctx: ctx_server.ctx, |
| 5038 | params_base: ctx_server.params_base, |
| 5039 | data); |
| 5040 | task.id_slot = json_value(body: data, key: "id_slot" , default_value: -1); |
| 5041 | |
| 5042 | // OAI-compat |
| 5043 | task.params.oaicompat = oaicompat; |
| 5044 | task.params.oaicompat_cmpl_id = completion_id; |
| 5045 | // oaicompat_model is already populated by params_from_json_cmpl |
| 5046 | |
| 5047 | tasks.push_back(x: std::move(task)); |
| 5048 | } |
| 5049 | |
| 5050 | task_ids = server_task::get_list_id(tasks); |
| 5051 | ctx_server.queue_results.add_waiting_tasks(tasks); |
| 5052 | ctx_server.queue_tasks.post(tasks: std::move(tasks)); |
| 5053 | } catch (const std::exception & e) { |
| 5054 | res_error(res, format_error_response(message: e.what(), type: ERROR_TYPE_INVALID_REQUEST)); |
| 5055 | return; |
| 5056 | } |
| 5057 | |
| 5058 | bool stream = json_value(body: data, key: "stream" , default_value: false); |
| 5059 | |
| 5060 | if (!stream) { |
| 5061 | ctx_server.receive_multi_results(id_tasks: task_ids, result_handler: [&](std::vector<server_task_result_ptr> & results) { |
| 5062 | if (results.size() == 1) { |
| 5063 | // single result |
| 5064 | res_ok(res, results[0]->to_json()); |
| 5065 | } else { |
| 5066 | // multiple results (multitask) |
| 5067 | json arr = json::array(); |
| 5068 | for (auto & res : results) { |
| 5069 | arr.push_back(val: res->to_json()); |
| 5070 | } |
| 5071 | res_ok(res, arr); |
| 5072 | } |
| 5073 | }, error_handler: [&](const json & error_data) { |
| 5074 | res_error(res, error_data); |
| 5075 | }, is_connection_closed); |
| 5076 | |
| 5077 | ctx_server.queue_results.remove_waiting_task_ids(id_tasks: task_ids); |
| 5078 | } else { |
| 5079 | const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) { |
| 5080 | ctx_server.receive_cmpl_results_stream(id_tasks: task_ids, result_handler: [&](server_task_result_ptr & result) -> bool { |
| 5081 | json res_json = result->to_json(); |
| 5082 | if (res_json.is_array()) { |
| 5083 | for (const auto & res : res_json) { |
| 5084 | if (!server_sent_event(sink, data: res)) { |
| 5085 | // sending failed (HTTP connection closed), cancel the generation |
| 5086 | return false; |
| 5087 | } |
| 5088 | } |
| 5089 | return true; |
| 5090 | } else { |
| 5091 | return server_sent_event(sink, data: res_json); |
| 5092 | } |
| 5093 | }, error_handler: [&](const json & error_data) { |
| 5094 | server_sent_event(sink, data: json{{"error" , error_data}}); |
| 5095 | }, is_connection_closed: [&sink]() { |
| 5096 | // note: do not use req.is_connection_closed here because req is already destroyed |
| 5097 | return !sink.is_writable(); |
| 5098 | }); |
| 5099 | if (oaicompat != OAICOMPAT_TYPE_NONE) { |
| 5100 | static const std::string ev_done = "data: [DONE]\n\n" ; |
| 5101 | sink.write(ev_done.data(), ev_done.size()); |
| 5102 | } |
| 5103 | sink.done(); |
| 5104 | return false; |
| 5105 | }; |
| 5106 | |
| 5107 | auto on_complete = [task_ids, &ctx_server] (bool) { |
| 5108 | ctx_server.queue_results.remove_waiting_task_ids(id_tasks: task_ids); |
| 5109 | }; |
| 5110 | |
| 5111 | res.set_chunked_content_provider(content_type: "text/event-stream" , provider: chunked_content_provider, resource_releaser: on_complete); |
| 5112 | } |
| 5113 | }; |
| 5114 | |
| 5115 | const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { |
| 5116 | json data = json::parse(i: req.body); |
| 5117 | std::vector<raw_buffer> files; // dummy |
| 5118 | handle_completions_impl( |
| 5119 | SERVER_TASK_TYPE_COMPLETION, |
| 5120 | data, |
| 5121 | files, |
| 5122 | req.is_connection_closed, |
| 5123 | res, |
| 5124 | OAICOMPAT_TYPE_NONE); |
| 5125 | }; |
| 5126 | |
| 5127 | const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { |
| 5128 | json data = oaicompat_completion_params_parse(body: json::parse(i: req.body)); |
| 5129 | std::vector<raw_buffer> files; // dummy |
| 5130 | handle_completions_impl( |
| 5131 | SERVER_TASK_TYPE_COMPLETION, |
| 5132 | data, |
| 5133 | files, |
| 5134 | req.is_connection_closed, |
| 5135 | res, |
| 5136 | OAICOMPAT_TYPE_COMPLETION); |
| 5137 | }; |
| 5138 | |
| 5139 | const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { |
| 5140 | // check model compatibility |
| 5141 | std::string err; |
| 5142 | if (llama_vocab_fim_pre(vocab: ctx_server.vocab) == LLAMA_TOKEN_NULL) { |
| 5143 | err += "prefix token is missing. " ; |
| 5144 | } |
| 5145 | if (llama_vocab_fim_suf(vocab: ctx_server.vocab) == LLAMA_TOKEN_NULL) { |
| 5146 | err += "suffix token is missing. " ; |
| 5147 | } |
| 5148 | if (llama_vocab_fim_mid(vocab: ctx_server.vocab) == LLAMA_TOKEN_NULL) { |
| 5149 | err += "middle token is missing. " ; |
| 5150 | } |
| 5151 | if (!err.empty()) { |
| 5152 | res_error(res, format_error_response(message: string_format(fmt: "Infill is not supported by this model: %s" , err.c_str()), type: ERROR_TYPE_NOT_SUPPORTED)); |
| 5153 | return; |
| 5154 | } |
| 5155 | |
| 5156 | json data = json::parse(i: req.body); |
| 5157 | |
| 5158 | // validate input |
| 5159 | if (data.contains(key: "prompt" ) && !data.at(key: "prompt" ).is_string()) { |
| 5160 | // prompt is optional |
| 5161 | res_error(res, format_error_response(message: "\"prompt\" must be a string" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5162 | } |
| 5163 | |
| 5164 | if (!data.contains(key: "input_prefix" )) { |
| 5165 | res_error(res, format_error_response(message: "\"input_prefix\" is required" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5166 | } |
| 5167 | |
| 5168 | if (!data.contains(key: "input_suffix" )) { |
| 5169 | res_error(res, format_error_response(message: "\"input_suffix\" is required" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5170 | } |
| 5171 | |
| 5172 | if (data.contains(key: "input_extra" ) && !data.at(key: "input_extra" ).is_array()) { |
| 5173 | // input_extra is optional |
| 5174 | res_error(res, format_error_response(message: "\"input_extra\" must be an array of {\"filename\": string, \"text\": string}" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5175 | return; |
| 5176 | } |
| 5177 | |
| 5178 | json = json_value(body: data, key: "input_extra" , default_value: json::array()); |
| 5179 | for (const auto & chunk : input_extra) { |
| 5180 | // { "text": string, "filename": string } |
| 5181 | if (!chunk.contains(key: "text" ) || !chunk.at(key: "text" ).is_string()) { |
| 5182 | res_error(res, format_error_response(message: "extra_context chunk must contain a \"text\" field with a string value" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5183 | return; |
| 5184 | } |
| 5185 | // filename is optional |
| 5186 | if (chunk.contains(key: "filename" ) && !chunk.at(key: "filename" ).is_string()) { |
| 5187 | res_error(res, format_error_response(message: "extra_context chunk's \"filename\" field must be a string" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5188 | return; |
| 5189 | } |
| 5190 | } |
| 5191 | data["input_extra" ] = input_extra; // default to empty array if it's not exist |
| 5192 | |
| 5193 | std::string prompt = json_value(body: data, key: "prompt" , default_value: std::string()); |
| 5194 | std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(vocab: ctx_server.vocab, mctx: ctx_server.mctx, json_prompt: prompt, add_special: false, parse_special: true); |
| 5195 | SRV_DBG("creating infill tasks, n_prompts = %d\n" , (int) tokenized_prompts.size()); |
| 5196 | data["prompt" ] = format_infill( |
| 5197 | vocab: ctx_server.vocab, |
| 5198 | input_prefix: data.at(key: "input_prefix" ), |
| 5199 | input_suffix: data.at(key: "input_suffix" ), |
| 5200 | input_extra: data.at(key: "input_extra" ), |
| 5201 | n_batch: ctx_server.params_base.n_batch, |
| 5202 | n_predict: ctx_server.params_base.n_predict, |
| 5203 | n_ctx: ctx_server.slots[0].n_ctx, // TODO: there should be a better way |
| 5204 | spm_infill: ctx_server.params_base.spm_infill, |
| 5205 | tokens_prompt: tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. |
| 5206 | ); |
| 5207 | |
| 5208 | std::vector<raw_buffer> files; // dummy |
| 5209 | handle_completions_impl( |
| 5210 | SERVER_TASK_TYPE_INFILL, |
| 5211 | data, |
| 5212 | files, |
| 5213 | req.is_connection_closed, |
| 5214 | res, |
| 5215 | OAICOMPAT_TYPE_NONE); // infill is not OAI compatible |
| 5216 | }; |
| 5217 | |
| 5218 | const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { |
| 5219 | LOG_DBG("request: %s\n" , req.body.c_str()); |
| 5220 | |
| 5221 | auto body = json::parse(i: req.body); |
| 5222 | std::vector<raw_buffer> files; |
| 5223 | json data = oaicompat_chat_params_parse( |
| 5224 | body, |
| 5225 | opt: ctx_server.oai_parser_opt, |
| 5226 | out_files&: files); |
| 5227 | |
| 5228 | handle_completions_impl( |
| 5229 | SERVER_TASK_TYPE_COMPLETION, |
| 5230 | data, |
| 5231 | files, |
| 5232 | req.is_connection_closed, |
| 5233 | res, |
| 5234 | OAICOMPAT_TYPE_CHAT); |
| 5235 | }; |
| 5236 | |
| 5237 | // same with handle_chat_completions, but without inference part |
| 5238 | const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { |
| 5239 | auto body = json::parse(i: req.body); |
| 5240 | std::vector<raw_buffer> files; // dummy, unused |
| 5241 | json data = oaicompat_chat_params_parse( |
| 5242 | body, |
| 5243 | opt: ctx_server.oai_parser_opt, |
| 5244 | out_files&: files); |
| 5245 | res_ok(res, {{ "prompt" , std::move(data.at(key: "prompt" )) }}); |
| 5246 | }; |
| 5247 | |
| 5248 | const auto handle_models = [¶ms, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) { |
| 5249 | server_state current_state = state.load(); |
| 5250 | json model_meta = nullptr; |
| 5251 | if (current_state == SERVER_STATE_READY) { |
| 5252 | model_meta = ctx_server.model_meta(); |
| 5253 | } |
| 5254 | bool has_mtmd = ctx_server.mctx != nullptr; |
| 5255 | json models = { |
| 5256 | {"models" , { |
| 5257 | { |
| 5258 | {"name" , params.model_alias.empty() ? params.model.path : params.model_alias}, |
| 5259 | {"model" , params.model_alias.empty() ? params.model.path : params.model_alias}, |
| 5260 | {"modified_at" , "" }, |
| 5261 | {"size" , "" }, |
| 5262 | {"digest" , "" }, // dummy value, llama.cpp does not support managing model file's hash |
| 5263 | {"type" , "model" }, |
| 5264 | {"description" , "" }, |
| 5265 | {"tags" , {"" }}, |
| 5266 | {"capabilities" , has_mtmd ? json({"completion" ,"multimodal" }) : json({"completion" })}, |
| 5267 | {"parameters" , "" }, |
| 5268 | {"details" , { |
| 5269 | {"parent_model" , "" }, |
| 5270 | {"format" , "gguf" }, |
| 5271 | {"family" , "" }, |
| 5272 | {"families" , {"" }}, |
| 5273 | {"parameter_size" , "" }, |
| 5274 | {"quantization_level" , "" } |
| 5275 | }} |
| 5276 | } |
| 5277 | }}, |
| 5278 | {"object" , "list" }, |
| 5279 | {"data" , { |
| 5280 | { |
| 5281 | {"id" , params.model_alias.empty() ? params.model.path : params.model_alias}, |
| 5282 | {"object" , "model" }, |
| 5283 | {"created" , std::time(timer: 0)}, |
| 5284 | {"owned_by" , "llamacpp" }, |
| 5285 | {"meta" , model_meta}, |
| 5286 | }, |
| 5287 | }} |
| 5288 | }; |
| 5289 | |
| 5290 | res_ok(res, models); |
| 5291 | }; |
| 5292 | |
| 5293 | const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { |
| 5294 | const json body = json::parse(i: req.body); |
| 5295 | |
| 5296 | json tokens_response = json::array(); |
| 5297 | if (body.count(key: "content" ) != 0) { |
| 5298 | const bool add_special = json_value(body, key: "add_special" , default_value: false); |
| 5299 | const bool parse_special = json_value(body, key: "parse_special" , default_value: true); |
| 5300 | const bool with_pieces = json_value(body, key: "with_pieces" , default_value: false); |
| 5301 | |
| 5302 | llama_tokens tokens = tokenize_mixed(vocab: ctx_server.vocab, json_prompt: body.at(key: "content" ), add_special, parse_special); |
| 5303 | |
| 5304 | if (with_pieces) { |
| 5305 | for (const auto& token : tokens) { |
| 5306 | std::string piece = common_token_to_piece(ctx: ctx_server.ctx, token); |
| 5307 | json piece_json; |
| 5308 | |
| 5309 | // Check if the piece is valid UTF-8 |
| 5310 | if (is_valid_utf8(str: piece)) { |
| 5311 | piece_json = piece; |
| 5312 | } else { |
| 5313 | // If not valid UTF-8, store as array of byte values |
| 5314 | piece_json = json::array(); |
| 5315 | for (unsigned char c : piece) { |
| 5316 | piece_json.push_back(val: static_cast<int>(c)); |
| 5317 | } |
| 5318 | } |
| 5319 | |
| 5320 | tokens_response.push_back(init: { |
| 5321 | {"id" , token}, |
| 5322 | {"piece" , piece_json} |
| 5323 | }); |
| 5324 | } |
| 5325 | } else { |
| 5326 | tokens_response = tokens; |
| 5327 | } |
| 5328 | } |
| 5329 | |
| 5330 | const json data = format_tokenizer_response(tokens: tokens_response); |
| 5331 | res_ok(res, data); |
| 5332 | }; |
| 5333 | |
| 5334 | const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { |
| 5335 | const json body = json::parse(i: req.body); |
| 5336 | |
| 5337 | std::string content; |
| 5338 | if (body.count(key: "tokens" ) != 0) { |
| 5339 | const llama_tokens tokens = body.at(key: "tokens" ); |
| 5340 | content = tokens_to_str(ctx: ctx_server.ctx, begin: tokens.cbegin(), end: tokens.cend()); |
| 5341 | } |
| 5342 | |
| 5343 | const json data = format_detokenized_response(content); |
| 5344 | res_ok(res, data); |
| 5345 | }; |
| 5346 | |
| 5347 | const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { |
| 5348 | if (!ctx_server.params_base.embedding) { |
| 5349 | res_error(res, format_error_response(message: "This server does not support embeddings. Start it with `--embeddings`" , type: ERROR_TYPE_NOT_SUPPORTED)); |
| 5350 | return; |
| 5351 | } |
| 5352 | |
| 5353 | if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx: ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { |
| 5354 | res_error(res, format_error_response(message: "Pooling type 'none' is not OAI compatible. Please use a different pooling type" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5355 | return; |
| 5356 | } |
| 5357 | |
| 5358 | const json body = json::parse(i: req.body); |
| 5359 | |
| 5360 | // for the shape of input/content, see tokenize_input_prompts() |
| 5361 | json prompt; |
| 5362 | if (body.count(key: "input" ) != 0) { |
| 5363 | prompt = body.at(key: "input" ); |
| 5364 | } else if (body.contains(key: "content" )) { |
| 5365 | oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible |
| 5366 | prompt = body.at(key: "content" ); |
| 5367 | } else { |
| 5368 | res_error(res, format_error_response(message: "\"input\" or \"content\" must be provided" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5369 | return; |
| 5370 | } |
| 5371 | |
| 5372 | bool use_base64 = false; |
| 5373 | if (body.count(key: "encoding_format" ) != 0) { |
| 5374 | const std::string& format = body.at(key: "encoding_format" ); |
| 5375 | if (format == "base64" ) { |
| 5376 | use_base64 = true; |
| 5377 | } else if (format != "float" ) { |
| 5378 | res_error(res, format_error_response(message: "The format to return the embeddings in. Can be either float or base64" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5379 | return; |
| 5380 | } |
| 5381 | } |
| 5382 | |
| 5383 | auto tokenized_prompts = tokenize_input_prompts(vocab: ctx_server.vocab, mctx: ctx_server.mctx, json_prompt: prompt, add_special: true, parse_special: true); |
| 5384 | for (const auto & tokens : tokenized_prompts) { |
| 5385 | // this check is necessary for models that do not add BOS token to the input |
| 5386 | if (tokens.empty()) { |
| 5387 | res_error(res, format_error_response(message: "Input content cannot be empty" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5388 | return; |
| 5389 | } |
| 5390 | } |
| 5391 | |
| 5392 | int embd_normalize = 2; // default to Euclidean/L2 norm |
| 5393 | if (body.count(key: "embd_normalize" ) != 0) { |
| 5394 | embd_normalize = body.at(key: "embd_normalize" ); |
| 5395 | if (llama_pooling_type(ctx: ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { |
| 5396 | SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n" , llama_pooling_type(ctx_server.ctx)); |
| 5397 | } |
| 5398 | } |
| 5399 | |
| 5400 | // create and queue the task |
| 5401 | json responses = json::array(); |
| 5402 | bool error = false; |
| 5403 | std::unordered_set<int> task_ids; |
| 5404 | { |
| 5405 | std::vector<server_task> tasks; |
| 5406 | for (size_t i = 0; i < tokenized_prompts.size(); i++) { |
| 5407 | server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); |
| 5408 | |
| 5409 | task.id = ctx_server.queue_tasks.get_new_id(); |
| 5410 | task.index = i; |
| 5411 | task.tokens = std::move(tokenized_prompts[i]); |
| 5412 | |
| 5413 | // OAI-compat |
| 5414 | task.params.oaicompat = oaicompat; |
| 5415 | task.params.embd_normalize = embd_normalize; |
| 5416 | |
| 5417 | tasks.push_back(x: std::move(task)); |
| 5418 | } |
| 5419 | |
| 5420 | task_ids = server_task::get_list_id(tasks); |
| 5421 | ctx_server.queue_results.add_waiting_tasks(tasks); |
| 5422 | ctx_server.queue_tasks.post(tasks: std::move(tasks)); |
| 5423 | } |
| 5424 | |
| 5425 | // get the result |
| 5426 | ctx_server.receive_multi_results(id_tasks: task_ids, result_handler: [&](std::vector<server_task_result_ptr> & results) { |
| 5427 | for (auto & res : results) { |
| 5428 | GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr); |
| 5429 | responses.push_back(val: res->to_json()); |
| 5430 | } |
| 5431 | }, error_handler: [&](const json & error_data) { |
| 5432 | res_error(res, error_data); |
| 5433 | error = true; |
| 5434 | }, is_connection_closed: req.is_connection_closed); |
| 5435 | |
| 5436 | ctx_server.queue_results.remove_waiting_task_ids(id_tasks: task_ids); |
| 5437 | |
| 5438 | if (error) { |
| 5439 | return; |
| 5440 | } |
| 5441 | |
| 5442 | // write JSON response |
| 5443 | json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING |
| 5444 | ? format_embeddings_response_oaicompat(request: body, embeddings: responses, use_base64) |
| 5445 | : json(responses); |
| 5446 | res_ok(res, root); |
| 5447 | }; |
| 5448 | |
| 5449 | const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { |
| 5450 | handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); |
| 5451 | }; |
| 5452 | |
| 5453 | const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { |
| 5454 | handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); |
| 5455 | }; |
| 5456 | |
| 5457 | const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { |
| 5458 | if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { |
| 5459 | res_error(res, format_error_response(message: "This server does not support reranking. Start it with `--reranking`" , type: ERROR_TYPE_NOT_SUPPORTED)); |
| 5460 | return; |
| 5461 | } |
| 5462 | |
| 5463 | const json body = json::parse(i: req.body); |
| 5464 | |
| 5465 | // if true, use TEI API format, otherwise use Jina API format |
| 5466 | // Jina: https://jina.ai/reranker/ |
| 5467 | // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank |
| 5468 | bool is_tei_format = body.contains(key: "texts" ); |
| 5469 | |
| 5470 | json query; |
| 5471 | if (body.count(key: "query" ) == 1) { |
| 5472 | query = body.at(key: "query" ); |
| 5473 | if (!query.is_string()) { |
| 5474 | res_error(res, format_error_response(message: "\"query\" must be a string" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5475 | return; |
| 5476 | } |
| 5477 | } else { |
| 5478 | res_error(res, format_error_response(message: "\"query\" must be provided" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5479 | return; |
| 5480 | } |
| 5481 | |
| 5482 | std::vector<std::string> documents = json_value(body, key: "documents" , |
| 5483 | default_value: json_value(body, key: "texts" , default_value: std::vector<std::string>())); |
| 5484 | if (documents.empty()) { |
| 5485 | res_error(res, format_error_response(message: "\"documents\" must be a non-empty string array" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5486 | return; |
| 5487 | } |
| 5488 | |
| 5489 | int top_n = json_value(body, key: "top_n" , default_value: (int)documents.size()); |
| 5490 | |
| 5491 | // create and queue the task |
| 5492 | json responses = json::array(); |
| 5493 | bool error = false; |
| 5494 | std::unordered_set<int> task_ids; |
| 5495 | { |
| 5496 | std::vector<server_task> tasks; |
| 5497 | tasks.reserve(n: documents.size()); |
| 5498 | for (size_t i = 0; i < documents.size(); i++) { |
| 5499 | auto tmp = format_rerank(model: ctx_server.model, vocab: ctx_server.vocab, mctx: ctx_server.mctx, query, doc: documents[i]); |
| 5500 | server_task task = server_task(SERVER_TASK_TYPE_RERANK); |
| 5501 | task.id = ctx_server.queue_tasks.get_new_id(); |
| 5502 | task.index = i; |
| 5503 | task.tokens = std::move(tmp); |
| 5504 | tasks.push_back(x: std::move(task)); |
| 5505 | } |
| 5506 | |
| 5507 | task_ids = server_task::get_list_id(tasks); |
| 5508 | ctx_server.queue_results.add_waiting_tasks(tasks); |
| 5509 | ctx_server.queue_tasks.post(tasks: std::move(tasks)); |
| 5510 | } |
| 5511 | |
| 5512 | ctx_server.receive_multi_results(id_tasks: task_ids, result_handler: [&](std::vector<server_task_result_ptr> & results) { |
| 5513 | for (auto & res : results) { |
| 5514 | GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr); |
| 5515 | responses.push_back(val: res->to_json()); |
| 5516 | } |
| 5517 | }, error_handler: [&](const json & error_data) { |
| 5518 | res_error(res, error_data); |
| 5519 | error = true; |
| 5520 | }, is_connection_closed: req.is_connection_closed); |
| 5521 | |
| 5522 | if (error) { |
| 5523 | return; |
| 5524 | } |
| 5525 | |
| 5526 | // write JSON response |
| 5527 | json root = format_response_rerank( |
| 5528 | request: body, |
| 5529 | ranks: responses, |
| 5530 | is_tei_format, |
| 5531 | texts&: documents, |
| 5532 | top_n); |
| 5533 | |
| 5534 | res_ok(res, root); |
| 5535 | }; |
| 5536 | |
| 5537 | const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { |
| 5538 | json result = json::array(); |
| 5539 | const auto & loras = ctx_server.params_base.lora_adapters; |
| 5540 | for (size_t i = 0; i < loras.size(); ++i) { |
| 5541 | auto & lora = loras[i]; |
| 5542 | json entry = { |
| 5543 | {"id" , i}, |
| 5544 | {"path" , lora.path}, |
| 5545 | {"scale" , lora.scale}, |
| 5546 | {"task_name" , lora.task_name}, |
| 5547 | {"prompt_prefix" , lora.prompt_prefix}, |
| 5548 | }; |
| 5549 | std::string alora_invocation_string = "" ; |
| 5550 | const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(adapter: lora.ptr); |
| 5551 | std::vector<llama_token> alora_invocation_tokens; |
| 5552 | if (n_alora_tokens) { |
| 5553 | const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(adapter: lora.ptr); |
| 5554 | for (uint64_t i = 0; i < n_alora_tokens; ++i) { |
| 5555 | alora_invocation_string += common_token_to_piece(ctx: ctx_server.ctx, token: alora_tokens[i]); |
| 5556 | alora_invocation_tokens.push_back(x: alora_tokens[i]); |
| 5557 | } |
| 5558 | entry["alora_invocation_string" ] = alora_invocation_string; |
| 5559 | entry["alora_invocation_tokens" ] = alora_invocation_tokens; |
| 5560 | } |
| 5561 | result.push_back(val: std::move(entry)); |
| 5562 | } |
| 5563 | res_ok(res, result); |
| 5564 | res.status = 200; // HTTP OK |
| 5565 | }; |
| 5566 | |
| 5567 | const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { |
| 5568 | const json body = json::parse(i: req.body); |
| 5569 | if (!body.is_array()) { |
| 5570 | res_error(res, format_error_response(message: "Request body must be an array" , type: ERROR_TYPE_INVALID_REQUEST)); |
| 5571 | return; |
| 5572 | } |
| 5573 | |
| 5574 | int task_id = ctx_server.queue_tasks.get_new_id(); |
| 5575 | { |
| 5576 | server_task task(SERVER_TASK_TYPE_SET_LORA); |
| 5577 | task.id = task_id; |
| 5578 | task.set_lora = parse_lora_request(lora_base: ctx_server.params_base.lora_adapters, data: body); |
| 5579 | ctx_server.queue_results.add_waiting_task_id(id_task: task_id); |
| 5580 | ctx_server.queue_tasks.post(task: std::move(task)); |
| 5581 | } |
| 5582 | |
| 5583 | // get the result |
| 5584 | server_task_result_ptr result = ctx_server.queue_results.recv(id_task: task_id); |
| 5585 | ctx_server.queue_results.remove_waiting_task_id(id_task: task_id); |
| 5586 | |
| 5587 | if (result->is_error()) { |
| 5588 | res_error(res, result->to_json()); |
| 5589 | return; |
| 5590 | } |
| 5591 | |
| 5592 | GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr); |
| 5593 | res_ok(res, result->to_json()); |
| 5594 | }; |
| 5595 | |
| 5596 | // |
| 5597 | // Router |
| 5598 | // |
| 5599 | |
| 5600 | if (!params.webui) { |
| 5601 | LOG_INF("Web UI is disabled\n" ); |
| 5602 | } else { |
| 5603 | // register static assets routes |
| 5604 | if (!params.public_path.empty()) { |
| 5605 | // Set the base directory for serving static files |
| 5606 | bool is_found = svr->set_mount_point(mount_point: params.api_prefix + "/" , dir: params.public_path); |
| 5607 | if (!is_found) { |
| 5608 | LOG_ERR("%s: static assets path not found: %s\n" , __func__, params.public_path.c_str()); |
| 5609 | return 1; |
| 5610 | } |
| 5611 | } else { |
| 5612 | // using embedded static index.html |
| 5613 | svr->Get(pattern: params.api_prefix + "/" , handler: [](const httplib::Request & req, httplib::Response & res) { |
| 5614 | if (req.get_header_value(key: "Accept-Encoding" ).find(s: "gzip" ) == std::string::npos) { |
| 5615 | res.set_content(s: "Error: gzip is not supported by this browser" , content_type: "text/plain" ); |
| 5616 | } else { |
| 5617 | res.set_header(key: "Content-Encoding" , val: "gzip" ); |
| 5618 | // COEP and COOP headers, required by pyodide (python interpreter) |
| 5619 | res.set_header(key: "Cross-Origin-Embedder-Policy" , val: "require-corp" ); |
| 5620 | res.set_header(key: "Cross-Origin-Opener-Policy" , val: "same-origin" ); |
| 5621 | res.set_content(s: reinterpret_cast<const char*>(index_html_gz), n: index_html_gz_len, content_type: "text/html; charset=utf-8" ); |
| 5622 | } |
| 5623 | return false; |
| 5624 | }); |
| 5625 | } |
| 5626 | } |
| 5627 | |
| 5628 | // register API routes |
| 5629 | svr->Get (pattern: params.api_prefix + "/health" , handler: handle_health); // public endpoint (no API key check) |
| 5630 | svr->Get (pattern: params.api_prefix + "/v1/health" , handler: handle_health); // public endpoint (no API key check) |
| 5631 | svr->Get (pattern: params.api_prefix + "/metrics" , handler: handle_metrics); |
| 5632 | svr->Get (pattern: params.api_prefix + "/props" , handler: handle_props); |
| 5633 | svr->Post(pattern: params.api_prefix + "/props" , handler: handle_props_change); |
| 5634 | svr->Post(pattern: params.api_prefix + "/api/show" , handler: handle_api_show); |
| 5635 | svr->Get (pattern: params.api_prefix + "/models" , handler: handle_models); // public endpoint (no API key check) |
| 5636 | svr->Get (pattern: params.api_prefix + "/v1/models" , handler: handle_models); // public endpoint (no API key check) |
| 5637 | svr->Get (pattern: params.api_prefix + "/api/tags" , handler: handle_models); // ollama specific endpoint. public endpoint (no API key check) |
| 5638 | svr->Post(pattern: params.api_prefix + "/completion" , handler: handle_completions); // legacy |
| 5639 | svr->Post(pattern: params.api_prefix + "/completions" , handler: handle_completions); |
| 5640 | svr->Post(pattern: params.api_prefix + "/v1/completions" , handler: handle_completions_oai); |
| 5641 | svr->Post(pattern: params.api_prefix + "/chat/completions" , handler: handle_chat_completions); |
| 5642 | svr->Post(pattern: params.api_prefix + "/v1/chat/completions" , handler: handle_chat_completions); |
| 5643 | svr->Post(pattern: params.api_prefix + "/api/chat" , handler: handle_chat_completions); // ollama specific endpoint |
| 5644 | svr->Post(pattern: params.api_prefix + "/infill" , handler: handle_infill); |
| 5645 | svr->Post(pattern: params.api_prefix + "/embedding" , handler: handle_embeddings); // legacy |
| 5646 | svr->Post(pattern: params.api_prefix + "/embeddings" , handler: handle_embeddings); |
| 5647 | svr->Post(pattern: params.api_prefix + "/v1/embeddings" , handler: handle_embeddings_oai); |
| 5648 | svr->Post(pattern: params.api_prefix + "/rerank" , handler: handle_rerank); |
| 5649 | svr->Post(pattern: params.api_prefix + "/reranking" , handler: handle_rerank); |
| 5650 | svr->Post(pattern: params.api_prefix + "/v1/rerank" , handler: handle_rerank); |
| 5651 | svr->Post(pattern: params.api_prefix + "/v1/reranking" , handler: handle_rerank); |
| 5652 | svr->Post(pattern: params.api_prefix + "/tokenize" , handler: handle_tokenize); |
| 5653 | svr->Post(pattern: params.api_prefix + "/detokenize" , handler: handle_detokenize); |
| 5654 | svr->Post(pattern: params.api_prefix + "/apply-template" , handler: handle_apply_template); |
| 5655 | // LoRA adapters hotswap |
| 5656 | svr->Get (pattern: params.api_prefix + "/lora-adapters" , handler: handle_lora_adapters_list); |
| 5657 | svr->Post(pattern: params.api_prefix + "/lora-adapters" , handler: handle_lora_adapters_apply); |
| 5658 | // Save & load slots |
| 5659 | svr->Get (pattern: params.api_prefix + "/slots" , handler: handle_slots); |
| 5660 | svr->Post(pattern: params.api_prefix + "/slots/:id_slot" , handler: handle_slots_action); |
| 5661 | |
| 5662 | // |
| 5663 | // Start the server |
| 5664 | // |
| 5665 | if (params.n_threads_http < 1) { |
| 5666 | // +2 threads for monitoring endpoints |
| 5667 | params.n_threads_http = std::max(a: params.n_parallel + 2, b: (int32_t) std::thread::hardware_concurrency() - 1); |
| 5668 | } |
| 5669 | log_data["n_threads_http" ] = std::to_string(val: params.n_threads_http); |
| 5670 | svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; |
| 5671 | |
| 5672 | // clean up function, to be called before exit |
| 5673 | auto clean_up = [&svr, &ctx_server]() { |
| 5674 | SRV_INF("%s: cleaning up before exit...\n" , __func__); |
| 5675 | svr->stop(); |
| 5676 | ctx_server.queue_results.terminate(); |
| 5677 | llama_backend_free(); |
| 5678 | }; |
| 5679 | |
| 5680 | bool was_bound = false; |
| 5681 | bool is_sock = false; |
| 5682 | if (string_ends_with(str: std::string(params.hostname), suffix: ".sock" )) { |
| 5683 | is_sock = true; |
| 5684 | LOG_INF("%s: setting address family to AF_UNIX\n" , __func__); |
| 5685 | svr->set_address_family(AF_UNIX); |
| 5686 | // bind_to_port requires a second arg, any value other than 0 should |
| 5687 | // simply get ignored |
| 5688 | was_bound = svr->bind_to_port(host: params.hostname, port: 8080); |
| 5689 | } else { |
| 5690 | LOG_INF("%s: binding port with default address family\n" , __func__); |
| 5691 | // bind HTTP listen port |
| 5692 | if (params.port == 0) { |
| 5693 | int bound_port = svr->bind_to_any_port(host: params.hostname); |
| 5694 | if ((was_bound = (bound_port >= 0))) { |
| 5695 | params.port = bound_port; |
| 5696 | } |
| 5697 | } else { |
| 5698 | was_bound = svr->bind_to_port(host: params.hostname, port: params.port); |
| 5699 | } |
| 5700 | } |
| 5701 | |
| 5702 | if (!was_bound) { |
| 5703 | LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n" , __func__, params.hostname.c_str(), params.port); |
| 5704 | clean_up(); |
| 5705 | return 1; |
| 5706 | } |
| 5707 | |
| 5708 | // run the HTTP server in a thread |
| 5709 | std::thread t([&]() { svr->listen_after_bind(); }); |
| 5710 | svr->wait_until_ready(); |
| 5711 | |
| 5712 | LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n" , __func__, params.hostname.c_str(), params.port, params.n_threads_http); |
| 5713 | |
| 5714 | // load the model |
| 5715 | LOG_INF("%s: loading model\n" , __func__); |
| 5716 | |
| 5717 | if (!ctx_server.load_model(params)) { |
| 5718 | clean_up(); |
| 5719 | t.join(); |
| 5720 | LOG_ERR("%s: exiting due to model loading error\n" , __func__); |
| 5721 | return 1; |
| 5722 | } |
| 5723 | |
| 5724 | ctx_server.init(); |
| 5725 | state.store(i: SERVER_STATE_READY); |
| 5726 | |
| 5727 | LOG_INF("%s: model loaded\n" , __func__); |
| 5728 | |
| 5729 | // print sample chat example to make it clear which template is used |
| 5730 | LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n" , __func__, |
| 5731 | common_chat_templates_source(ctx_server.chat_templates.get()), |
| 5732 | common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, ctx_server.params_base.default_template_kwargs).c_str()); |
| 5733 | |
| 5734 | ctx_server.queue_tasks.on_new_task(callback: [&ctx_server](server_task && task) { |
| 5735 | ctx_server.process_single_task(task: std::move(task)); |
| 5736 | }); |
| 5737 | |
| 5738 | ctx_server.queue_tasks.on_update_slots(callback: [&ctx_server]() { |
| 5739 | ctx_server.update_slots(); |
| 5740 | }); |
| 5741 | |
| 5742 | shutdown_handler = [&](int) { |
| 5743 | // this will unblock start_loop() |
| 5744 | ctx_server.queue_tasks.terminate(); |
| 5745 | }; |
| 5746 | |
| 5747 | #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) |
| 5748 | struct sigaction sigint_action; |
| 5749 | sigint_action.sa_handler = signal_handler; |
| 5750 | sigemptyset (set: &sigint_action.sa_mask); |
| 5751 | sigint_action.sa_flags = 0; |
| 5752 | sigaction(SIGINT, act: &sigint_action, NULL); |
| 5753 | sigaction(SIGTERM, act: &sigint_action, NULL); |
| 5754 | #elif defined (_WIN32) |
| 5755 | auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { |
| 5756 | return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; |
| 5757 | }; |
| 5758 | SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); |
| 5759 | #endif |
| 5760 | |
| 5761 | LOG_INF("%s: server is listening on %s - starting the main loop\n" , __func__, |
| 5762 | is_sock ? string_format("unix://%s" , params.hostname.c_str()).c_str() : |
| 5763 | string_format("http://%s:%d" , params.hostname.c_str(), params.port).c_str()); |
| 5764 | |
| 5765 | // this call blocks the main thread until queue_tasks.terminate() is called |
| 5766 | ctx_server.queue_tasks.start_loop(); |
| 5767 | |
| 5768 | clean_up(); |
| 5769 | t.join(); |
| 5770 | llama_memory_breakdown_print(ctx: ctx_server.ctx); |
| 5771 | |
| 5772 | return 0; |
| 5773 | } |
| 5774 | |