1#include "llama-context.h"
2
3#include "llama-impl.h"
4#include "llama-batch.h"
5#include "llama-io.h"
6#include "llama-memory.h"
7#include "llama-mmap.h"
8#include "llama-model.h"
9
10#include <cinttypes>
11#include <cstring>
12#include <limits>
13#include <stdexcept>
14
15//
16// llama_context
17//
18
19llama_context::llama_context(
20 const llama_model & model,
21 llama_context_params params) :
22 model(model),
23 balloc(std::make_unique<llama_batch_allocr>(args: model.hparams.n_pos_per_embd())) {
24 // TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
25 // may need to be backend-dependent
26 LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
27
28 t_start_us = model.t_start_us;
29 t_load_us = model.t_load_us;
30
31 const auto & hparams = model.hparams;
32
33 cparams.n_seq_max = std::max(a: 1u, b: params.n_seq_max);
34 if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
35 throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
36 }
37
38 cparams.n_threads = params.n_threads;
39 cparams.n_threads_batch = params.n_threads_batch;
40 cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
41 cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
42 cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
43 cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
44 cparams.embeddings = params.embeddings;
45 cparams.offload_kqv = params.offload_kqv;
46 cparams.no_perf = params.no_perf;
47 cparams.pooling_type = params.pooling_type;
48 cparams.warmup = false;
49
50 cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
51 cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
52 cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
53
54 cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
55 hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
56 hparams.n_ctx_train;
57
58 cparams.cb_eval = params.cb_eval;
59 cparams.cb_eval_user_data = params.cb_eval_user_data;
60
61 auto rope_scaling_type = params.rope_scaling_type;
62 if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
63 rope_scaling_type = hparams.rope_scaling_type_train;
64 }
65
66 if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
67 cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
68 }
69
70 if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
71 cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
72 }
73
74 cparams.yarn_attn_factor *= hparams.rope_attn_factor;
75
76 if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
77 if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
78 cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
79 } else {
80 cparams.pooling_type = hparams.pooling_type;
81 }
82 }
83
84 if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
85 cparams.causal_attn = hparams.causal_attn;
86 } else {
87 cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
88 }
89
90 cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
91
92 // with causal attention, the batch size is limited by the context size
93 cparams.n_batch = cparams.causal_attn ? std::min(a: cparams.n_ctx, b: params.n_batch) : params.n_batch;
94
95 // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
96 // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
97 // ref: https://github.com/ggerganov/llama.cpp/pull/5021
98 // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
99 if (cparams.n_batch < GGML_KQ_MASK_PAD) {
100 LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
101 cparams.n_batch = GGML_KQ_MASK_PAD;
102 }
103 cparams.n_ubatch = std::min(a: cparams.n_batch, b: params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
104
105 cparams.op_offload = params.op_offload;
106 cparams.kv_unified = params.kv_unified;
107
108 {
109 const char * LLAMA_GRAPH_REUSE_DISABLE = getenv(name: "LLAMA_GRAPH_REUSE_DISABLE");
110 graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(nptr: LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
111
112 if (graph_reuse_disable) {
113 LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__);
114 }
115 }
116
117 // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
118 cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
119
120 if (cparams.kv_unified) {
121 cparams.n_ctx_seq = cparams.n_ctx;
122 } else {
123 cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
124 cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
125
126 if (cparams.n_ctx_seq == 0) {
127 throw std::runtime_error("n_ctx_seq == 0");
128 }
129
130 if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
131 cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
132 LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
133 }
134 }
135
136 LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
137 LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
138 LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
139 LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
140 LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
141 LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
142 LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
143 LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
144 LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
145 LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
146
147 if (cparams.n_ctx_seq < hparams.n_ctx_train) {
148 LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
149 __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
150 }
151
152 if (cparams.n_ctx_seq > hparams.n_ctx_train) {
153 LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
154 __func__, cparams.n_ctx_seq, hparams.n_ctx_train);
155 }
156
157 if (!hparams.vocab_only) {
158 // GPU backends
159 for (auto * dev : model.devices) {
160 ggml_backend_t backend = ggml_backend_dev_init(device: dev, params: nullptr);
161 if (backend == nullptr) {
162 throw std::runtime_error(format(fmt: "failed to initialize %s backend", ggml_backend_dev_name(device: dev)));
163 }
164 backends.emplace_back(args&: backend);
165 }
166
167 // add ACCEL backends (such as BLAS)
168 for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
169 ggml_backend_dev_t dev = ggml_backend_dev_get(index: i);
170 if (ggml_backend_dev_type(device: dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
171 ggml_backend_t backend = ggml_backend_dev_init(device: dev, params: nullptr);
172 if (backend == nullptr) {
173 throw std::runtime_error(format(fmt: "failed to initialize %s backend", ggml_backend_dev_name(device: dev)));
174 }
175 backends.emplace_back(args&: backend);
176 }
177 }
178
179 // add CPU backend
180 backend_cpu = ggml_backend_init_by_type(type: GGML_BACKEND_DEVICE_TYPE_CPU, params: nullptr);
181 if (backend_cpu == nullptr) {
182 throw std::runtime_error("failed to initialize CPU backend");
183 }
184 backends.emplace_back(args&: backend_cpu);
185
186 // create a list of the set_n_threads functions in the backends
187 for (auto & backend : backends) {
188 ggml_backend_dev_t dev = ggml_backend_get_device(backend: backend.get());
189 ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(device: dev) : nullptr;
190 if (reg) {
191 auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, name: "ggml_backend_set_n_threads");
192 if (ggml_backend_set_n_threads_fn) {
193 set_n_threads_fns.emplace_back(args: backend.get(), args&: ggml_backend_set_n_threads_fn);
194 }
195 }
196 }
197
198 llama_set_abort_callback(ctx: this, abort_callback: params.abort_callback, abort_callback_data: params.abort_callback_data);
199
200 // graph outputs buffer
201 {
202 // resized during inference when a batch uses more outputs
203 if (output_reserve(n_outputs: params.n_seq_max) < params.n_seq_max) {
204 throw std::runtime_error("failed to reserve initial output buffer");
205 }
206
207 LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
208 ggml_backend_buffer_name (buf_output.get()),
209 ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
210 }
211 }
212
213 // init the memory module
214 if (!hparams.vocab_only) {
215 llama_memory_params params_mem = {
216 /*.type_k =*/ params.type_k,
217 /*.type_v =*/ params.type_v,
218 /*.swa_full =*/ params.swa_full,
219 };
220
221 memory.reset(p: model.create_memory(params: params_mem, cparams));
222 }
223
224 // init backends
225 if (!hparams.vocab_only) {
226 LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
227
228 backend_buft.clear();
229 backend_ptrs.clear();
230
231 for (auto & backend : backends) {
232 auto * buft = ggml_backend_get_default_buffer_type(backend: backend.get());
233 auto backend_type = ggml_backend_dev_type(device: ggml_backend_get_device(backend: backend.get()));
234
235 if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
236 // use the host buffer of the first device CPU for faster transfer of the intermediate state
237 auto * dev = model.devices[0];
238 auto * host_buft = ggml_backend_dev_host_buffer_type(device: dev);
239 if (host_buft) {
240 buft = host_buft;
241 }
242 }
243
244 backend_buft.push_back(x: buft);
245 backend_ptrs.push_back(x: backend.get());
246 }
247
248 LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
249
250 const size_t max_nodes = this->graph_max_nodes();
251
252 LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
253
254 gf_res_prev.reset(p: new llm_graph_result(max_nodes));
255 gf_res_reserve.reset(p: new llm_graph_result(max_nodes));
256
257 // TODO: move these checks to ggml_backend_sched
258 // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
259 bool pipeline_parallel =
260 model.n_devices() > 1 &&
261 model.params.n_gpu_layers > (int) model.hparams.n_layer &&
262 model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
263 cparams.offload_kqv &&
264 !model.has_tensor_overrides();
265
266 // pipeline parallelism requires support for async compute and events in all devices
267 if (pipeline_parallel) {
268 for (auto & backend : backends) {
269 auto dev_type = ggml_backend_dev_type(device: ggml_backend_get_device(backend: backend.get()));
270 if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
271 // ignore CPU backend
272 continue;
273 }
274 auto * dev = ggml_backend_get_device(backend: backend.get());
275 ggml_backend_dev_props props;
276 ggml_backend_dev_get_props(device: dev, props: &props);
277 if (!props.caps.async || !props.caps.events) {
278 // device does not support async compute or events
279 pipeline_parallel = false;
280 break;
281 }
282 }
283 }
284
285 sched.reset(p: ggml_backend_sched_new(backends: backend_ptrs.data(), bufts: backend_buft.data(), n_backends: backend_ptrs.size(), graph_size: max_nodes, parallel: pipeline_parallel, op_offload: cparams.op_offload));
286
287 if (pipeline_parallel) {
288 LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
289 }
290
291 llama_memory_context_ptr mctx;
292 if (memory) {
293 LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
294 mctx = memory->init_full();
295 if (!mctx) {
296 throw std::runtime_error("failed to initialize memory module");
297 }
298 }
299
300 cross.v_embd.clear();
301
302 const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
303 const uint32_t n_tokens = std::min(a: cparams.n_ctx, b: cparams.n_ubatch);
304
305 // avoid reserving graphs with zero outputs - assume one output per sequence
306 n_outputs = n_seqs;
307
308 LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
309
310 // resolve automatic Flash Attention use
311 if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
312 auto * gf = graph_reserve(n_tokens: 1, n_seqs, n_outputs, mctx: mctx.get(), split_only: true);
313 if (!gf) {
314 throw std::runtime_error("failed to split graph for Flash Attention check");
315 }
316
317 const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
318 bool fa_device_mismatch = false;
319 for (int i = 0; i < ggml_graph_n_nodes(cgraph: gf); i++) {
320 ggml_tensor * n = ggml_graph_node(cgraph: gf, i);
321 if (n->op != GGML_OP_FLASH_ATTN_EXT) {
322 continue;
323 }
324 ggml_backend_dev_t device_fa = ggml_backend_get_device(
325 backend: ggml_backend_sched_get_tensor_backend(sched: sched.get(), node: n));
326
327 // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
328 GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
329 const int il = std::stoi(str: n->name + prefix_len);
330 ggml_backend_dev_t device_kv = model.dev_layer(il);
331 if (device_fa != device_kv) {
332 LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
333 "is assigned to device %s (usually due to missing support)\n",
334 __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
335 // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
336 fa_device_mismatch = true;
337 break;
338 }
339 }
340 if (fa_device_mismatch) {
341 cparams.flash_attn = false;
342 LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
343 if (ggml_is_quantized(type: params.type_v)) {
344 throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
345 }
346 } else {
347 cparams.flash_attn = true;
348 LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
349 }
350 }
351
352 // reserve worst-case graph
353 int n_splits_pp = -1;
354 int n_nodes_pp = -1;
355
356 int n_splits_tg = -1;
357 int n_nodes_tg = -1;
358
359 // reserve pp (prompt processing) graph first so that buffers are only allocated once
360 {
361 auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs: n_tokens, mctx: mctx.get());
362 if (!gf) {
363 if (pipeline_parallel) {
364 LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
365 sched.reset(p: ggml_backend_sched_new(backends: backend_ptrs.data(), bufts: backend_buft.data(), n_backends: backend_ptrs.size(), graph_size: max_nodes, parallel: false, op_offload: cparams.op_offload));
366 gf = graph_reserve(n_tokens, n_seqs, n_outputs: n_tokens, mctx: mctx.get());
367 }
368 if (!gf) {
369 throw std::runtime_error("failed to allocate compute pp buffers");
370 }
371 }
372
373 n_splits_pp = ggml_backend_sched_get_n_splits(sched: sched.get());
374 n_nodes_pp = ggml_graph_n_nodes(cgraph: gf);
375 }
376
377 // reserve with tg (token generation) graph to get the number of splits and nodes
378 {
379 auto * gf = graph_reserve(n_tokens: n_seqs, n_seqs, n_outputs: n_seqs, mctx: mctx.get());
380 if (!gf) {
381 throw std::runtime_error("failed to allocate compute tg buffers");
382 }
383
384 n_splits_tg = ggml_backend_sched_get_n_splits(sched: sched.get());
385 n_nodes_tg = ggml_graph_n_nodes(cgraph: gf);
386 }
387
388 // reserve again with pp graph to avoid ggml-alloc reallocations during inference
389 {
390 // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
391 //
392 // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
393 //
394 auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs: n_tokens, mctx: mctx.get());
395 if (!gf) {
396 throw std::runtime_error("failed to allocate compute pp buffers");
397 }
398 }
399
400 for (size_t i = 0; i < backend_ptrs.size(); ++i) {
401 ggml_backend_t backend = backend_ptrs[i];
402 ggml_backend_buffer_type_t buft = backend_buft[i];
403 size_t size = ggml_backend_sched_get_buffer_size(sched: sched.get(), backend);
404 if (size > 1) {
405 LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
406 ggml_backend_buft_name(buft),
407 size / 1024.0 / 1024.0);
408 }
409 }
410
411 if (n_nodes_pp == n_nodes_tg) {
412 LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
413 } else {
414 LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
415 }
416
417 if (n_splits_pp == n_splits_tg) {
418 LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
419 } else {
420 LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
421 }
422 }
423}
424
425llama_context::~llama_context() {
426 ggml_opt_free(opt_ctx);
427}
428
429void llama_context::synchronize() {
430 ggml_backend_sched_synchronize(sched: sched.get());
431
432 // FIXME: if multiple single tokens are evaluated without a synchronization,
433 // the stats will be added to the prompt evaluation stats
434 // this should only happen when using batch size 1 to evaluate a batch
435
436 // add the evaluation to the stats
437 if (n_queued_tokens == 1) {
438 if (!cparams.no_perf) {
439 t_eval_us += ggml_time_us() - t_compute_start_us;
440 }
441 n_eval++;
442 } else if (n_queued_tokens > 1) {
443 if (!cparams.no_perf) {
444 t_p_eval_us += ggml_time_us() - t_compute_start_us;
445 }
446 n_p_eval += n_queued_tokens;
447 }
448
449 // get a more accurate load time, upon first eval
450 if (n_queued_tokens > 0 && !has_evaluated_once) {
451 t_load_us = ggml_time_us() - t_start_us;
452 has_evaluated_once = true;
453 }
454
455 n_queued_tokens = 0;
456 t_compute_start_us = 0;
457}
458
459const llama_model & llama_context::get_model() const {
460 return model;
461}
462
463const llama_cparams & llama_context::get_cparams() const {
464 return cparams;
465}
466
467ggml_backend_sched_t llama_context::get_sched() const {
468 return sched.get();
469}
470
471uint32_t llama_context::n_ctx() const {
472 return cparams.n_ctx;
473}
474
475uint32_t llama_context::n_ctx_seq() const {
476 return cparams.n_ctx_seq;
477}
478
479uint32_t llama_context::n_batch() const {
480 return cparams.n_batch;
481}
482
483uint32_t llama_context::n_ubatch() const {
484 return cparams.n_ubatch;
485}
486
487uint32_t llama_context::n_seq_max() const {
488 return cparams.n_seq_max;
489}
490
491uint32_t llama_context::n_threads() const {
492 return cparams.n_threads;
493}
494
495uint32_t llama_context::n_threads_batch() const {
496 return cparams.n_threads_batch;
497}
498
499llama_memory_t llama_context::get_memory() const {
500 return memory.get();
501}
502
503bool llama_context::memory_update(bool optimize) {
504 if (!memory) {
505 return false;
506 }
507
508 {
509 const auto mctx = memory->init_update(lctx: this, optimize);
510 switch (mctx->get_status()) {
511 case LLAMA_MEMORY_STATUS_SUCCESS:
512 {
513 // noop
514 } break;
515 case LLAMA_MEMORY_STATUS_NO_UPDATE:
516 {
517 // no updates need to be performed
518 return false;
519 }
520 case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
521 case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
522 {
523 LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
524 return false;
525 }
526 }
527
528 // reset the previous graph result to make sure that it won't be reused
529 // TODO: change the mctx->apply() to return information if a graph reserve is needed
530 // reset the graph result only if the memory module did reset the scheduler
531 gf_res_prev->reset();
532
533 if (!mctx->apply()) {
534 LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
535 }
536 }
537
538 // if the memory module did any computation, we have to reserve a new worst-case graph
539 {
540 const auto mctx = memory->init_full();
541 if (!mctx) {
542 throw std::runtime_error("failed to initialize memory context");
543 }
544
545 const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
546 const uint32_t n_tokens = std::min(a: cparams.n_ctx, b: cparams.n_ubatch);
547
548 auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs: n_tokens, mctx: mctx.get());
549 if (!gf) {
550 LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
551 }
552 }
553
554 return true;
555}
556
557enum llama_pooling_type llama_context::pooling_type() const {
558 return cparams.pooling_type;
559}
560
561float * llama_context::get_logits() {
562 output_reorder();
563
564 return logits;
565}
566
567float * llama_context::get_logits_ith(int32_t i) {
568 int64_t j = -1;
569
570 output_reorder();
571
572 try {
573 if (logits == nullptr) {
574 throw std::runtime_error("no logits");
575 }
576
577 if (i < 0) {
578 j = n_outputs + i;
579 if (j < 0) {
580 throw std::runtime_error(format(fmt: "negative index out of range [0, %d)", n_outputs));
581 }
582 } else if ((size_t) i >= output_ids.size()) {
583 throw std::runtime_error(format(fmt: "out of range [0, %zu)", output_ids.size()));
584 } else {
585 j = output_ids[i];
586 }
587
588 if (j < 0) {
589 throw std::runtime_error(format(fmt: "batch.logits[%d] != true", i));
590 }
591 if (j >= n_outputs) {
592 // This should not happen
593 throw std::runtime_error(format(fmt: "corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
594 }
595
596 return logits + j*model.vocab.n_tokens();
597 } catch (const std::exception & err) {
598 LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
599#ifndef NDEBUG
600 GGML_ABORT("fatal error");
601#else
602 return nullptr;
603#endif
604 }
605}
606
607float * llama_context::get_embeddings() {
608 output_reorder();
609
610 return embd;
611}
612
613float * llama_context::get_embeddings_ith(int32_t i) {
614 int64_t j = -1;
615
616 output_reorder();
617
618 try {
619 if (embd == nullptr) {
620 throw std::runtime_error("no embeddings");
621 }
622
623 if (i < 0) {
624 j = n_outputs + i;
625 if (j < 0) {
626 throw std::runtime_error(format(fmt: "negative index out of range [0, %d)", n_outputs));
627 }
628 } else if ((size_t) i >= output_ids.size()) {
629 throw std::runtime_error(format(fmt: "out of range [0, %zu)", output_ids.size()));
630 } else {
631 j = output_ids[i];
632 }
633
634 if (j < 0) {
635 throw std::runtime_error(format(fmt: "batch.logits[%d] != true", i));
636 }
637 if (j >= n_outputs) {
638 // This should not happen
639 throw std::runtime_error(format(fmt: "corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
640 }
641
642 return embd + j*model.hparams.n_embd;
643 } catch (const std::exception & err) {
644 LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
645#ifndef NDEBUG
646 GGML_ABORT("fatal error");
647#else
648 return nullptr;
649#endif
650 }
651}
652
653float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
654 auto it = embd_seq.find(x: seq_id);
655 if (it == embd_seq.end()) {
656 return nullptr;
657 }
658
659 return it->second.data();
660}
661
662void llama_context::attach_threadpool(
663 ggml_threadpool_t threadpool,
664 ggml_threadpool_t threadpool_batch) {
665 LLAMA_LOG_DEBUG("%s: call\n", __func__);
666
667 this->threadpool = threadpool;
668 this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
669}
670
671void llama_context::detach_threadpool() {
672 LLAMA_LOG_DEBUG("%s: call\n", __func__);
673
674 this->threadpool = nullptr;
675 this->threadpool_batch = nullptr;
676}
677
678void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
679 LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
680
681 cparams.n_threads = n_threads;
682 cparams.n_threads_batch = n_threads_batch;
683}
684
685void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
686 LLAMA_LOG_DEBUG("%s: call\n", __func__);
687
688 this->abort_callback = abort_callback;
689 this->abort_callback_data = abort_callback_data;
690
691 for (auto & backend : backends) {
692 auto * reg = ggml_backend_dev_backend_reg(device: ggml_backend_get_device(backend: backend.get()));
693 auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, name: "ggml_backend_set_abort_callback");
694 if (set_abort_callback_fn) {
695 set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
696 }
697 }
698}
699
700void llama_context::set_embeddings(bool value) {
701 LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
702
703 cparams.embeddings = value;
704}
705
706void llama_context::set_causal_attn(bool value) {
707 LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
708
709 cparams.causal_attn = value;
710}
711
712void llama_context::set_warmup(bool value) {
713 LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
714
715 cparams.warmup = value;
716}
717
718void llama_context::set_adapter_lora(
719 llama_adapter_lora * adapter,
720 float scale) {
721 LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
722
723 loras[adapter] = scale;
724}
725
726bool llama_context::rm_adapter_lora(
727 llama_adapter_lora * adapter) {
728 LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
729
730 auto pos = loras.find(x: adapter);
731 if (pos != loras.end()) {
732 loras.erase(position: pos);
733 return true;
734 }
735
736 return false;
737}
738
739void llama_context::clear_adapter_lora() {
740 LLAMA_LOG_DEBUG("%s: call\n", __func__);
741
742 loras.clear();
743}
744
745bool llama_context::apply_adapter_cvec(
746 const float * data,
747 size_t len,
748 int32_t n_embd,
749 int32_t il_start,
750 int32_t il_end) {
751 LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
752
753 return cvec.apply(model, data, len, n_embd, il_start, il_end);
754}
755
756llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
757 if (mctx && !mctx->apply()) {
758 LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
759 ret = GGML_STATUS_FAILED;
760 return nullptr;
761 }
762
763 auto * res = gf_res_prev.get();
764 auto * gf = res->get_gf();
765
766 // the new graph parameters
767 // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
768 const auto gparams = graph_params(res, ubatch, mctx, gtype);
769
770 if (!graph_reuse_disable && res->can_reuse(params: gparams)) {
771 //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
772
773 n_reused++;
774 } else {
775 res->reset();
776
777 ggml_backend_sched_reset(sched: sched.get());
778 ggml_backend_sched_set_eval_callback(sched: sched.get(), callback: cparams.cb_eval, user_data: cparams.cb_eval_user_data);
779
780 //const auto t_start_us = ggml_time_us();
781
782 gf = model.build_graph(params: gparams);
783
784 //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
785
786 if (!gf) {
787 LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
788 ret = GGML_STATUS_FAILED;
789 return nullptr;
790 }
791
792 if (!ggml_backend_sched_alloc_graph(sched: sched.get(), graph: gf)) {
793 LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
794 ret = GGML_STATUS_ALLOC_FAILED;
795 return nullptr;
796 }
797 }
798
799 // set the input data for the input tensors
800 {
801 //const auto t_start_us = ggml_time_us();
802
803 res->set_inputs(&ubatch);
804
805 //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
806 }
807
808 const auto status = graph_compute(gf: res->get_gf(), batched: ubatch.n_tokens > 1);
809 if (status != GGML_STATUS_SUCCESS) {
810 LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
811 ret = status;
812 return nullptr;
813 }
814
815 ret = GGML_STATUS_SUCCESS;
816
817 return res;
818}
819
820int llama_context::encode(const llama_batch & batch_inp) {
821 GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
822
823 if (batch_inp.n_tokens == 0) {
824 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
825 return -1;
826 }
827
828 const auto & hparams = model.hparams;
829
830 const int64_t n_embd = hparams.n_embd_inp();
831 const int64_t n_vocab = model.vocab.n_tokens();
832
833 // note: during encode, we always pass the full sequence starting from pos = 0
834 if (!balloc->init(batch_inp, vocab: model.vocab, memory: nullptr, n_embd, n_seq_max: cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all: true)) {
835 LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
836 return -1;
837 }
838
839 const uint32_t n_tokens = balloc->get_n_tokens();
840
841 // [TAG_NO_CACHE_PAD]
842 // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
843 const llama_ubatch ubatch = balloc->split_simple(n_ubatch: n_tokens);
844
845 // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
846 GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
847
848 if (t_compute_start_us == 0) {
849 t_compute_start_us = ggml_time_us();
850 }
851
852 // TODO: this clear of the buffer can easily be forgotten - need something better
853 embd_seq.clear();
854
855 n_queued_tokens += n_tokens;
856
857 // reserve output buffer
858 if (output_reserve(n_outputs: n_tokens) < n_tokens) {
859 LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
860 return -2;
861 };
862
863 for (uint32_t i = 0; i < n_tokens; ++i) {
864 output_ids[i] = i;
865 }
866
867 n_outputs = n_tokens;
868
869 const auto causal_attn_org = cparams.causal_attn;
870
871 // always use non-causal attention for encoder graphs
872 // TODO: this is a tmp solution until we have a proper way to support enc-dec models
873 // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
874 cparams.causal_attn = false;
875
876 ggml_status status;
877 const auto * res = process_ubatch(ubatch, gtype: LLM_GRAPH_TYPE_ENCODER, mctx: nullptr, ret&: status);
878
879 cparams.causal_attn = causal_attn_org;
880
881 if (!res) {
882 switch (status) {
883 case GGML_STATUS_ABORTED: return 2;
884 case GGML_STATUS_ALLOC_FAILED: return -2;
885 case GGML_STATUS_FAILED: return -3;
886 case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
887 }
888 }
889
890 auto * t_logits = res->get_logits();
891 auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
892
893 // extract logits
894 if (logits && t_logits) {
895 ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched: sched.get(), node: t_logits);
896 GGML_ASSERT(backend_res != nullptr);
897 GGML_ASSERT(logits != nullptr);
898
899 ggml_backend_tensor_get_async(backend: backend_res, tensor: t_logits, data: logits, offset: 0, size: n_tokens*n_vocab*sizeof(float));
900 }
901
902 // extract embeddings
903 if (embd && t_embd) {
904 ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched: sched.get(), node: t_embd);
905 GGML_ASSERT(backend_embd != nullptr);
906
907 switch (cparams.pooling_type) {
908 case LLAMA_POOLING_TYPE_NONE:
909 {
910 // extract token embeddings
911 GGML_ASSERT(embd != nullptr);
912
913 GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
914 ggml_backend_tensor_get_async(backend: backend_embd, tensor: t_embd, data: embd, offset: 0, size: n_tokens*n_embd*sizeof(float));
915 } break;
916 case LLAMA_POOLING_TYPE_MEAN:
917 case LLAMA_POOLING_TYPE_CLS:
918 case LLAMA_POOLING_TYPE_LAST:
919 {
920 // extract sequence embeddings
921 auto & embd_seq_out = embd_seq;
922
923 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
924 const llama_seq_id seq_id = ubatch.seq_id_unq[s];
925 const int32_t seq_idx = ubatch.seq_idx[seq_id];
926
927 embd_seq_out[seq_id].resize(new_size: n_embd);
928 ggml_backend_tensor_get_async(backend: backend_embd, tensor: t_embd, data: embd_seq_out[seq_id].data(), offset: (n_embd*seq_idx)*sizeof(float), size: n_embd*sizeof(float));
929 }
930 } break;
931 case LLAMA_POOLING_TYPE_RANK:
932 {
933 // extract the rerank score - n_cls_out floats per sequence
934 auto & embd_seq_out = embd_seq;
935
936 const uint32_t n_cls_out = hparams.n_cls_out;
937
938 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
939 const llama_seq_id seq_id = ubatch.seq_id_unq[s];
940 const int32_t seq_idx = ubatch.seq_idx[seq_id];
941
942 embd_seq_out[seq_id].resize(new_size: n_cls_out);
943 ggml_backend_tensor_get_async(backend: backend_embd, tensor: t_embd, data: embd_seq_out[seq_id].data(), offset: (n_cls_out*seq_idx)*sizeof(float), size: n_cls_out*sizeof(float));
944 }
945 } break;
946 case LLAMA_POOLING_TYPE_UNSPECIFIED:
947 {
948 GGML_ABORT("unknown pooling type");
949 }
950 }
951 }
952
953 // TODO: hacky solution
954 if (model.arch == LLM_ARCH_T5 && t_embd) {
955 //cross.t_embd = t_embd;
956
957 synchronize();
958
959 cross.n_embd = t_embd->ne[0];
960 cross.n_enc = t_embd->ne[1];
961 cross.v_embd.resize(new_size: cross.n_embd*cross.n_enc);
962 memcpy(dest: cross.v_embd.data(), src: embd, n: ggml_nbytes(tensor: t_embd));
963
964 const auto & batch = balloc->get_batch();
965
966 // remember the sequence ids used during the encoding - needed for cross attention later
967 cross.seq_ids_enc.resize(new_size: n_tokens);
968 for (uint32_t i = 0; i < n_tokens; i++) {
969 cross.seq_ids_enc[i].clear();
970
971 for (int s = 0; s < batch.n_seq_id[i]; s++) {
972 const llama_seq_id seq_id = batch.seq_id[i][s];
973
974 cross.seq_ids_enc[i].insert(x: seq_id);
975 }
976 }
977 }
978
979 return 0;
980}
981
982int llama_context::decode(const llama_batch & batch_inp) {
983 GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
984
985 if (!memory) {
986 LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
987 return encode(batch_inp);
988 }
989
990 if (batch_inp.n_tokens == 0) {
991 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
992 return -1;
993 }
994
995 const auto & vocab = model.vocab;
996 const auto & hparams = model.hparams;
997
998 const int64_t n_vocab = vocab.n_tokens();
999 const int64_t n_embd = hparams.n_embd_inp();
1000
1001 // when computing embeddings, all tokens are output
1002 const bool output_all = cparams.embeddings;
1003
1004 if (!balloc->init(batch_inp, vocab, memory: memory.get(), n_embd, n_seq_max: cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
1005 LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
1006 return -1;
1007 }
1008
1009 const uint32_t n_tokens_all = balloc->get_n_tokens();
1010 const uint32_t n_outputs_all = balloc->get_n_outputs();
1011
1012 if (output_all) {
1013 // require that all tokens are output
1014 if (n_outputs_all != n_tokens_all) {
1015 LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
1016 __func__, n_outputs_all, n_tokens_all);
1017 return -1;
1018 }
1019 }
1020
1021 GGML_ASSERT(n_tokens_all <= cparams.n_batch);
1022
1023 GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
1024
1025 if (t_compute_start_us == 0) {
1026 t_compute_start_us = ggml_time_us();
1027 }
1028 n_queued_tokens += n_tokens_all;
1029
1030 // TODO: this clear of the buffer can easily be forgotten - need something better
1031 embd_seq.clear();
1032 output_swaps.clear();
1033
1034 bool did_optimize = false;
1035
1036 // handle any pending shifts/copies
1037 memory_update(optimize: false);
1038
1039 llama_memory_context_ptr mctx;
1040
1041 while (true) {
1042 mctx = memory->init_batch(balloc&: *balloc, n_ubatch: cparams.n_ubatch, embd_all: output_all);
1043 if (!mctx) {
1044 return -2;
1045 }
1046
1047 switch (mctx->get_status()) {
1048 case LLAMA_MEMORY_STATUS_SUCCESS:
1049 {
1050 } break;
1051 case LLAMA_MEMORY_STATUS_NO_UPDATE:
1052 {
1053 LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
1054
1055 return -2;
1056 }
1057 case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
1058 {
1059 if (!did_optimize) {
1060 did_optimize = true;
1061
1062 if (memory_update(optimize: true)) {
1063 LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
1064
1065 continue;
1066 }
1067 }
1068
1069 LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
1070
1071 return 1;
1072 }
1073 case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
1074 {
1075 LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
1076
1077 return -2;
1078 }
1079 }
1080
1081 break;
1082 }
1083
1084 // reserve output buffer
1085 if (output_reserve(n_outputs: n_outputs_all) < n_outputs_all) {
1086 LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1087 return -2;
1088 };
1089
1090 int64_t n_outputs_prev = 0;
1091
1092 do {
1093 const auto & ubatch = mctx->get_ubatch();
1094
1095 // count the outputs in this ubatch
1096 {
1097 int32_t n_outputs_new = 0;
1098
1099 if (n_outputs_all == n_tokens_all) {
1100 n_outputs_new = ubatch.n_tokens;
1101 } else {
1102 for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1103 n_outputs_new += (int32_t) (ubatch.output[i] != 0);
1104 }
1105 }
1106
1107 // needs to happen before the graph is built
1108 n_outputs = n_outputs_new;
1109 }
1110
1111 ggml_status status;
1112 const auto * res = process_ubatch(ubatch, gtype: LLM_GRAPH_TYPE_DECODER, mctx: mctx.get(), ret&: status);
1113
1114 if (!res) {
1115 // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
1116 llama_pos pos_min[LLAMA_MAX_SEQ];
1117 for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1118 pos_min[s] = std::numeric_limits<llama_pos>::max();
1119 }
1120
1121 for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1122 const auto & seq_id = ubatch.seq_id[i][0];
1123
1124 pos_min[seq_id] = std::min(a: pos_min[seq_id], b: ubatch.pos[i]);
1125 }
1126
1127 for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1128 if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1129 continue;
1130 }
1131
1132 LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1133
1134 memory->seq_rm(seq_id: s, p0: pos_min[s], p1: -1);
1135 }
1136
1137 switch (status) {
1138 case GGML_STATUS_ABORTED: return 2;
1139 case GGML_STATUS_ALLOC_FAILED: return -2;
1140 case GGML_STATUS_FAILED: return -3;
1141 case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
1142 }
1143 }
1144
1145 // plot the computation graph in dot format (for debugging purposes)
1146 //if (n_past%100 == 0) {
1147 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1148 //}
1149
1150 auto * t_logits = res->get_logits();
1151 auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1152
1153 if (t_embd && res->get_embd_pooled()) {
1154 t_embd = res->get_embd_pooled();
1155 }
1156
1157 // extract logits
1158 if (t_logits && n_outputs > 0) {
1159 ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched: sched.get(), node: t_logits);
1160 GGML_ASSERT(backend_res != nullptr);
1161 GGML_ASSERT(logits != nullptr);
1162
1163 float * logits_out = logits + n_outputs_prev*n_vocab;
1164
1165 if (n_outputs) {
1166 GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1167 GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1168 ggml_backend_tensor_get_async(backend: backend_res, tensor: t_logits, data: logits_out, offset: 0, size: n_outputs*n_vocab*sizeof(float));
1169 }
1170 }
1171
1172 // extract embeddings
1173 if (t_embd && n_outputs > 0) {
1174 ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched: sched.get(), node: t_embd);
1175 GGML_ASSERT(backend_embd != nullptr);
1176
1177 switch (cparams.pooling_type) {
1178 case LLAMA_POOLING_TYPE_NONE:
1179 {
1180 // extract token embeddings
1181 GGML_ASSERT(embd != nullptr);
1182 float * embd_out = embd + n_outputs_prev*n_embd;
1183
1184 if (n_outputs) {
1185 GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1186 GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1187 ggml_backend_tensor_get_async(backend: backend_embd, tensor: t_embd, data: embd_out, offset: 0, size: n_outputs*n_embd*sizeof(float));
1188 }
1189 } break;
1190 case LLAMA_POOLING_TYPE_MEAN:
1191 case LLAMA_POOLING_TYPE_CLS:
1192 case LLAMA_POOLING_TYPE_LAST:
1193 {
1194 // extract sequence embeddings (cleared before processing each batch)
1195 auto & embd_seq_out = embd_seq;
1196
1197 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1198 const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1199 const int32_t seq_idx = ubatch.seq_idx[seq_id];
1200
1201 embd_seq_out[seq_id].resize(new_size: n_embd);
1202 ggml_backend_tensor_get_async(backend: backend_embd, tensor: t_embd, data: embd_seq_out[seq_id].data(), offset: (n_embd*seq_idx)*sizeof(float), size: n_embd*sizeof(float));
1203 }
1204 } break;
1205 case LLAMA_POOLING_TYPE_RANK:
1206 {
1207 // extract the rerank score - n_cls_out floats per sequence
1208 auto & embd_seq_out = embd_seq;
1209
1210 const uint32_t n_cls_out = hparams.n_cls_out;
1211
1212 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1213 const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1214 const int32_t seq_idx = ubatch.seq_idx[seq_id];
1215
1216 embd_seq_out[seq_id].resize(new_size: n_cls_out);
1217 ggml_backend_tensor_get_async(backend: backend_embd, tensor: t_embd, data: embd_seq_out[seq_id].data(), offset: (n_cls_out*seq_idx)*sizeof(float), size: n_cls_out*sizeof(float));
1218 }
1219 } break;
1220 case LLAMA_POOLING_TYPE_UNSPECIFIED:
1221 {
1222 GGML_ABORT("unknown pooling type");
1223 }
1224 }
1225 }
1226
1227 n_outputs_prev += n_outputs;
1228 } while (mctx->next());
1229
1230 // set to total number of outputs in the batch, for use in llama_get_logits_ith
1231 n_outputs = n_outputs_all;
1232
1233 // set output mappings
1234 if (n_outputs > 0) {
1235 bool sorted_output = true;
1236
1237 auto & out_ids = balloc->get_out_ids();
1238
1239 GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1240
1241 for (int64_t i = 0; i < n_outputs; ++i) {
1242 int64_t out_id = out_ids[i];
1243 output_ids[out_id] = i;
1244 if (out_id != i) {
1245 sorted_output = false;
1246 }
1247 }
1248
1249 // make the outputs have the same order they had in the user-provided batch
1250 // note: this is mostly relevant for recurrent models atm
1251 if (!sorted_output) {
1252 GGML_ASSERT((size_t) n_outputs == out_ids.size());
1253
1254 // TODO: is there something more efficient which also minimizes swaps?
1255 // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1256 for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1257 uint32_t j_min = i;
1258 for (uint32_t j = i + 1; j < n_outputs; ++j) {
1259 if (out_ids[j] < out_ids[j_min]) {
1260 j_min = j;
1261 }
1262 }
1263 if (j_min == i) {
1264 continue;
1265 }
1266 std::swap(a&: out_ids[i], b&: out_ids[j_min]);
1267
1268 // remember the swaps and apply them lazily upon logits/embeddings access
1269 output_swaps.push_back(x: { .i0: i, .i1: j_min });
1270 }
1271
1272 std::fill(first: output_ids.begin(), last: output_ids.end(), value: -1);
1273
1274 for (uint32_t i = 0; i < n_outputs; ++i) {
1275 output_ids[out_ids[i]] = i;
1276 }
1277 }
1278 }
1279
1280 // wait for the computation to finish (automatically done when obtaining the model output)
1281 //synchronize();
1282
1283 return 0;
1284}
1285
1286//
1287// output
1288//
1289
1290uint32_t llama_context::output_reserve(int32_t n_outputs) {
1291 const auto & hparams = model.hparams;
1292 const auto & vocab = model.vocab;
1293
1294 const int64_t n_outputs_max = std::max<int64_t>(a: n_outputs, b: n_seq_max());
1295
1296 const auto n_batch = cparams.n_batch;
1297 const auto n_vocab = vocab.n_tokens();
1298 const auto n_embd = hparams.n_embd;
1299
1300 bool has_logits = true;
1301 bool has_embd = cparams.embeddings;
1302
1303 // TODO: hacky enc-dec support
1304 if (model.arch == LLM_ARCH_T5) {
1305 has_logits = true;
1306 has_embd = true;
1307 }
1308
1309 logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1310 embd_size = has_embd ? n_embd*n_outputs_max : 0;
1311
1312 if (output_ids.empty()) {
1313 // init, never resized afterwards
1314 output_ids.resize(new_size: n_batch);
1315 }
1316
1317 const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buffer: buf_output.get()) : 0;
1318 const size_t new_size = (logits_size + embd_size) * sizeof(float);
1319
1320 // alloc only when more than the current capacity is required
1321 // TODO: also consider shrinking the buffer
1322 if (!buf_output || prev_size < new_size) {
1323 if (buf_output) {
1324#ifndef NDEBUG
1325 // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1326 LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1327#endif
1328 buf_output = nullptr;
1329 logits = nullptr;
1330 embd = nullptr;
1331 }
1332
1333 auto * buft = ggml_backend_cpu_buffer_type();
1334 // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
1335 auto * output_dev = model.dev_output();
1336 auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(device: output_dev) : nullptr;
1337 if (output_dev_host_buft) {
1338 buft = output_dev_host_buft;
1339 }
1340 buf_output.reset(p: ggml_backend_buft_alloc_buffer(buft, size: new_size));
1341 if (buf_output == nullptr) {
1342 LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
1343 return 0;
1344 }
1345 }
1346
1347 float * output_base = (float *) ggml_backend_buffer_get_base(buffer: buf_output.get());
1348
1349 logits = has_logits ? output_base : nullptr;
1350 embd = has_embd ? output_base + logits_size : nullptr;
1351
1352 // set all ids as invalid (negative)
1353 std::fill(first: output_ids.begin(), last: output_ids.end(), value: -1);
1354
1355 this->n_outputs = 0;
1356
1357 return n_outputs_max;
1358}
1359
1360void llama_context::output_reorder() {
1361 const uint64_t n_vocab = model.vocab.n_tokens();
1362 const uint64_t n_embd = model.hparams.n_embd;
1363
1364 for (size_t s = 0; s < output_swaps.size(); ++s) {
1365 const uint64_t i0 = output_swaps[s].i0;
1366 const uint64_t i1 = output_swaps[s].i1;
1367
1368 if (logits_size > 0) {
1369 for (uint64_t k = 0; k < n_vocab; k++) {
1370 std::swap(a&: logits[i0*n_vocab + k], b&: logits[i1*n_vocab + k]);
1371 }
1372 }
1373
1374 if (embd_size > 0) {
1375 for (uint64_t k = 0; k < n_embd; k++) {
1376 std::swap(a&: embd[i0*n_embd + k], b&: embd[i1*n_embd + k]);
1377 }
1378 }
1379 }
1380
1381 output_swaps.clear();
1382}
1383
1384//
1385// graph
1386//
1387
1388uint32_t llama_context::graph_max_nodes() const {
1389 return std::max<uint32_t>(a: 1024u, b: 8u*model.n_tensors());
1390}
1391
1392llm_graph_result * llama_context::get_gf_res_reserve() const {
1393 return static_cast<llm_graph_result *>(gf_res_reserve.get());
1394}
1395
1396ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
1397 LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1398 GGML_ASSERT(n_outputs >= 1);
1399
1400 if (n_tokens % n_seqs != 0) {
1401 n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1402 n_outputs = std::min(a: n_outputs, b: n_tokens);
1403
1404 LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1405 }
1406
1407 ggml_backend_sched_reset(sched: sched.get());
1408
1409 // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
1410 gf_res_prev->reset();
1411
1412 // store the n_outputs as it is, and restore it afterwards
1413 // TODO: not sure if needed, might simplify in the future by removing this
1414 const auto save_n_outputs = this->n_outputs;
1415
1416 this->n_outputs = n_outputs;
1417
1418 llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1419 llama_ubatch ubatch = balloc.ubatch_reserve(n_seq_tokens: n_tokens/n_seqs, n_seqs);
1420
1421 auto * res = gf_res_reserve.get();
1422
1423 const auto gparams = graph_params(res, ubatch, mctx, gtype: LLM_GRAPH_TYPE_DEFAULT);
1424
1425 res->reset();
1426
1427 auto * gf = model.build_graph(params: gparams);
1428
1429 this->n_outputs = save_n_outputs;
1430
1431 // initialize scheduler with the specified graph
1432 if (split_only) {
1433 ggml_backend_sched_split_graph(sched: sched.get(), graph: gf);
1434 } else if (!ggml_backend_sched_reserve(sched: sched.get(), measure_graph: gf)) {
1435 LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1436 return nullptr;
1437 }
1438
1439 return gf;
1440}
1441
1442llm_graph_params llama_context::graph_params(
1443 llm_graph_result * res,
1444 const llama_ubatch & ubatch,
1445 const llama_memory_context_i * mctx,
1446 llm_graph_type gtype) const {
1447 return {
1448 /*.arch =*/ model.arch,
1449 /*.hparams =*/ model.hparams,
1450 /*.cparams =*/ cparams,
1451 /*.ubatch =*/ ubatch,
1452 /*.gtype =*/ gtype,
1453 /*.sched =*/ sched.get(),
1454 /*.backend_cpu =*/ backend_cpu,
1455 /*.cvec =*/ &cvec,
1456 /*.loras =*/ &loras,
1457 /*.mctx =*/ mctx,
1458 /*.cross =*/ &cross,
1459 /*.n_outputs =*/ n_outputs,
1460 /*.cb =*/ graph_get_cb(),
1461 /*.res =*/ res,
1462 };
1463}
1464
1465ggml_status llama_context::graph_compute(
1466 ggml_cgraph * gf,
1467 bool batched) {
1468 int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
1469 ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
1470
1471 if (backend_cpu != nullptr) {
1472 auto * reg = ggml_backend_dev_backend_reg(device: ggml_backend_get_device(backend: backend_cpu));
1473 auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, name: "ggml_backend_cpu_set_threadpool");
1474 if (set_threadpool_fn) {
1475 set_threadpool_fn(backend_cpu, tp);
1476 }
1477 }
1478
1479 // set the number of threads for all the backends
1480 for (const auto & set_n_threads_fn : set_n_threads_fns) {
1481 set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
1482 }
1483
1484 auto status = ggml_backend_sched_graph_compute_async(sched: sched.get(), graph: gf);
1485 if (status != GGML_STATUS_SUCCESS) {
1486 LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
1487 }
1488
1489 // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
1490
1491 return status;
1492}
1493
1494llm_graph_cb llama_context::graph_get_cb() const {
1495 return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
1496 if (il >= 0) {
1497 ggml_format_name(tensor: cur, fmt: "%s-%d", name, il);
1498 } else {
1499 ggml_set_name(tensor: cur, name);
1500 }
1501
1502 if (!cparams.offload_kqv) {
1503 if (strcmp(s1: name, s2: "kqv_merged_cont") == 0) {
1504 // all nodes between the KV store and the attention output are run on the CPU
1505 ggml_backend_sched_set_tensor_backend(sched: sched.get(), node: cur, backend: backend_cpu);
1506 }
1507 }
1508
1509 // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1510 // FIXME: fix in ggml_backend_sched
1511 const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
1512 if (ubatch.n_tokens < 32 || full_offload) {
1513 if (il != -1 && strcmp(s1: name, s2: "norm") == 0) {
1514 const auto & dev_layer = model.dev_layer(il);
1515 for (const auto & backend : backends) {
1516 if (ggml_backend_get_device(backend: backend.get()) == dev_layer) {
1517 if (ggml_backend_supports_op(backend: backend.get(), op: cur)) {
1518 ggml_backend_sched_set_tensor_backend(sched: sched.get(), node: cur, backend: backend.get());
1519 }
1520 }
1521 }
1522 }
1523 }
1524 };
1525}
1526
1527//
1528// state save/load
1529//
1530
1531class llama_io_write_dummy : public llama_io_write_i {
1532public:
1533 llama_io_write_dummy() = default;
1534
1535 void write(const void * /* src */, size_t size) override {
1536 size_written += size;
1537 }
1538
1539 void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
1540 size_written += size;
1541 }
1542
1543 size_t n_bytes() override {
1544 return size_written;
1545 }
1546
1547private:
1548 size_t size_written = 0;
1549};
1550
1551class llama_io_write_buffer : public llama_io_write_i {
1552public:
1553 llama_io_write_buffer(
1554 uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1555
1556 void write(const void * src, size_t size) override {
1557 if (size > buf_size) {
1558 throw std::runtime_error("unexpectedly reached end of buffer");
1559 }
1560 memcpy(dest: ptr, src: src, n: size);
1561 ptr += size;
1562 size_written += size;
1563 buf_size -= size;
1564 }
1565
1566 void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
1567 if (size > buf_size) {
1568 throw std::runtime_error("unexpectedly reached end of buffer");
1569 }
1570 ggml_backend_tensor_get(tensor, data: ptr, offset, size);
1571 ptr += size;
1572 size_written += size;
1573 buf_size -= size;
1574 }
1575
1576 size_t n_bytes() override {
1577 return size_written;
1578 }
1579
1580private:
1581 uint8_t * ptr;
1582 size_t buf_size = 0;
1583 size_t size_written = 0;
1584};
1585
1586class llama_io_read_buffer : public llama_io_read_i {
1587public:
1588 llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1589
1590 const uint8_t * read(size_t size) override {
1591 const uint8_t * base_ptr = ptr;
1592 if (size > buf_size) {
1593 throw std::runtime_error("unexpectedly reached end of buffer");
1594 }
1595 ptr += size;
1596 size_read += size;
1597 buf_size -= size;
1598 return base_ptr;
1599 }
1600
1601 void read_to(void * dst, size_t size) override {
1602 memcpy(dest: dst, src: read(size), n: size);
1603 }
1604
1605 size_t n_bytes() override {
1606 return size_read;
1607 }
1608
1609private:
1610 const uint8_t * ptr;
1611 size_t buf_size = 0;
1612 size_t size_read = 0;
1613};
1614
1615class llama_io_write_file : public llama_io_write_i {
1616public:
1617 llama_io_write_file(llama_file * f) : file(f) {}
1618
1619 void write(const void * src, size_t size) override {
1620 file->write_raw(ptr: src, len: size);
1621 size_written += size;
1622 }
1623
1624 void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
1625 temp_buffer.resize(new_size: size);
1626 ggml_backend_tensor_get(tensor, data: temp_buffer.data(), offset, size);
1627 write(src: temp_buffer.data(), size: temp_buffer.size());
1628 }
1629
1630 size_t n_bytes() override {
1631 return size_written;
1632 }
1633
1634private:
1635 llama_file * file;
1636 size_t size_written = 0;
1637 std::vector<uint8_t> temp_buffer;
1638};
1639
1640class llama_io_read_file : public llama_io_read_i {
1641public:
1642 llama_io_read_file(llama_file * f) : file(f) {}
1643
1644 void read_to(void * dst, size_t size) override {
1645 file->read_raw(ptr: dst, len: size);
1646 size_read += size;
1647 }
1648
1649 const uint8_t * read(size_t size) override {
1650 temp_buffer.resize(new_size: size);
1651 read_to(dst: temp_buffer.data(), size);
1652 return temp_buffer.data();
1653 }
1654
1655 size_t n_bytes() override {
1656 return size_read;
1657 }
1658
1659private:
1660 llama_file * file;
1661 size_t size_read = 0;
1662 std::vector<uint8_t> temp_buffer;
1663};
1664
1665size_t llama_context::state_get_size() {
1666 llama_io_write_dummy io;
1667 try {
1668 return state_write_data(io);
1669 } catch (const std::exception & err) {
1670 LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1671 return 0;
1672 }
1673}
1674
1675size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
1676 llama_io_write_buffer io(dst, size);
1677 try {
1678 return state_write_data(io);
1679 } catch (const std::exception & err) {
1680 LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1681 return 0;
1682 }
1683}
1684
1685size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1686 llama_io_read_buffer io(src, size);
1687 try {
1688 return state_read_data(io);
1689 } catch (const std::exception & err) {
1690 LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1691 return 0;
1692 }
1693}
1694
1695size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
1696 llama_io_write_dummy io;
1697 try {
1698 return state_seq_write_data(io, seq_id, flags);
1699 } catch (const std::exception & err) {
1700 LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1701 return 0;
1702 }
1703}
1704
1705size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
1706 llama_io_write_buffer io(dst, size);
1707 try {
1708 return state_seq_write_data(io, seq_id, flags);
1709 } catch (const std::exception & err) {
1710 LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1711 return 0;
1712 }
1713}
1714
1715size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
1716 llama_io_read_buffer io(src, size);
1717 try {
1718 return state_seq_read_data(io, seq_id, flags);
1719 } catch (const std::exception & err) {
1720 LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1721 return 0;
1722 }
1723}
1724
1725bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1726 llama_file file(filepath, "rb");
1727
1728 // sanity checks
1729 {
1730 const uint32_t magic = file.read_u32();
1731 const uint32_t version = file.read_u32();
1732
1733 if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
1734 LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
1735 return false;
1736 }
1737 }
1738
1739 // load the prompt
1740 {
1741 const uint32_t n_token_count = file.read_u32();
1742
1743 if (n_token_count > n_token_capacity) {
1744 LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1745 return false;
1746 }
1747
1748 file.read_raw(ptr: tokens_out, len: sizeof(llama_token) * n_token_count);
1749 *n_token_count_out = n_token_count;
1750 }
1751
1752 // restore the context state
1753 {
1754 const size_t n_state_size_cur = file.size() - file.tell();
1755
1756 llama_io_read_file io( &file);
1757 const size_t n_read = state_read_data(io);
1758
1759 if (n_read != n_state_size_cur) {
1760 LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
1761 return false;
1762 }
1763 }
1764
1765 return true;
1766}
1767
1768bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
1769 llama_file file(filepath, "wb");
1770
1771 file.write_u32(LLAMA_SESSION_MAGIC);
1772 file.write_u32(LLAMA_SESSION_VERSION);
1773
1774 // save the prompt
1775 file.write_u32(val: (uint32_t) n_token_count);
1776 file.write_raw(ptr: tokens, len: sizeof(llama_token) * n_token_count);
1777
1778 // save the context state using stream saving
1779 llama_io_write_file io(&file);
1780 state_write_data(io);
1781
1782 return true;
1783}
1784
1785size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1786 llama_file file(filepath, "rb");
1787
1788 // version checks
1789 {
1790 const uint32_t magic = file.read_u32();
1791 const uint32_t version = file.read_u32();
1792
1793 if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
1794 LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
1795 return 0;
1796 }
1797 }
1798
1799 // load the prompt
1800 {
1801 const uint32_t n_token_count = file.read_u32();
1802
1803 if (n_token_count > n_token_capacity) {
1804 LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1805 return 0;
1806 }
1807
1808 file.read_raw(ptr: tokens_out, len: sizeof(llama_token) * n_token_count);
1809 *n_token_count_out = n_token_count;
1810 }
1811
1812 // restore the context state
1813 {
1814 const size_t state_size = file.size() - file.tell();
1815 llama_io_read_file io(&file);
1816 const size_t nread = state_seq_read_data(io, seq_id, flags: 0);
1817 if (!nread) {
1818 LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1819 return 0;
1820 }
1821 GGML_ASSERT(nread <= state_size);
1822 GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
1823 }
1824
1825 return file.tell();
1826}
1827
1828size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
1829 llama_file file(filepath, "wb");
1830
1831 file.write_u32(LLAMA_STATE_SEQ_MAGIC);
1832 file.write_u32(LLAMA_STATE_SEQ_VERSION);
1833
1834 // save the prompt
1835 file.write_u32(val: (uint32_t) n_token_count);
1836 file.write_raw(ptr: tokens, len: sizeof(llama_token) * n_token_count);
1837
1838 // save the context state using stream saving
1839 llama_io_write_file io(&file);
1840 state_seq_write_data(io, seq_id, flags: 0);
1841
1842 const size_t res = file.tell();
1843 GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
1844
1845 return res;
1846}
1847
1848size_t llama_context::state_write_data(llama_io_write_i & io) {
1849 LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
1850
1851 // write model info
1852 {
1853 LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
1854
1855 const std::string arch_str = llm_arch_name(arch: model.arch);
1856 io.write_string(str: arch_str);
1857 // TODO: add more model-specific info which should prevent loading the session file if not identical
1858 }
1859
1860 // write output ids
1861 {
1862 LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
1863
1864 const auto n_outputs = this->n_outputs;
1865 const auto & output_ids = this->output_ids;
1866
1867 std::vector<int32_t> w_output_pos;
1868
1869 w_output_pos.resize(new_size: n_outputs);
1870
1871 // build a more compact representation of the output ids
1872 for (size_t i = 0; i < n_batch(); ++i) {
1873 // map an output id to a position in the batch
1874 int64_t pos = output_ids[i];
1875 if (pos >= 0) {
1876 GGML_ASSERT(pos < n_outputs);
1877 w_output_pos[pos] = i;
1878 }
1879 }
1880
1881 io.write(src: &n_outputs, size: sizeof(n_outputs));
1882
1883 if (n_outputs) {
1884 io.write(src: w_output_pos.data(), size: n_outputs * sizeof(int32_t));
1885 }
1886 }
1887
1888 // write logits
1889 {
1890 LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
1891
1892 const uint64_t logits_size = std::min(a: (uint64_t) this->logits_size, b: (uint64_t) n_outputs * model.vocab.n_tokens());
1893
1894 io.write(src: &logits_size, size: sizeof(logits_size));
1895
1896 if (logits_size) {
1897 io.write(src: logits, size: logits_size * sizeof(float));
1898 }
1899 }
1900
1901 // write embeddings
1902 {
1903 LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
1904
1905 const uint64_t embd_size = std::min(a: (uint64_t) this->embd_size, b: (uint64_t) n_outputs * model.hparams.n_embd);
1906
1907 io.write(src: &embd_size, size: sizeof(embd_size));
1908
1909 if (embd_size) {
1910 io.write(src: embd, size: embd_size * sizeof(float));
1911 }
1912 }
1913
1914 if (memory != nullptr) {
1915 LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
1916 memory->state_write(io);
1917 }
1918
1919 return io.n_bytes();
1920}
1921
1922size_t llama_context::state_read_data(llama_io_read_i & io) {
1923 LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
1924
1925 // read model info
1926 {
1927 LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
1928
1929 const std::string cur_arch_str = llm_arch_name(arch: model.arch);
1930
1931 std::string arch_str;
1932 io.read_string(str&: arch_str);
1933 if (cur_arch_str != arch_str) {
1934 throw std::runtime_error(format(fmt: "wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
1935 }
1936 // TODO: add more info which needs to be identical but which is not verified otherwise
1937 }
1938
1939 // read output ids
1940 {
1941 LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
1942
1943 auto n_outputs = this->n_outputs;
1944 io.read_to(dst: &n_outputs, size: sizeof(n_outputs));
1945
1946 if (n_outputs > output_reserve(n_outputs)) {
1947 throw std::runtime_error("could not reserve outputs");
1948 }
1949
1950 std::vector<int32_t> output_pos;
1951
1952 if (n_outputs) {
1953 output_pos.resize(new_size: n_outputs);
1954 io.read_to(dst: output_pos.data(), size: n_outputs * sizeof(int32_t));
1955
1956 for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
1957 int32_t id = output_pos[i];
1958 if ((uint32_t) id >= n_batch()) {
1959 throw std::runtime_error(format(fmt: "invalid output id, %d does not fit in batch size of %u", id, n_batch()));
1960 }
1961 this->output_ids[id] = i;
1962 }
1963
1964 this->n_outputs = n_outputs;
1965 }
1966 }
1967
1968 // read logits
1969 {
1970 LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
1971
1972 uint64_t logits_size;
1973 io.read_to(dst: &logits_size, size: sizeof(logits_size));
1974
1975 if (this->logits_size < logits_size) {
1976 throw std::runtime_error("logits buffer too small");
1977 }
1978
1979 if (logits_size) {
1980 io.read_to(dst: this->logits, size: logits_size * sizeof(float));
1981 }
1982 }
1983
1984 // read embeddings
1985 {
1986 LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
1987
1988 uint64_t embd_size;
1989 io.read_to(dst: &embd_size, size: sizeof(embd_size));
1990
1991 if (this->embd_size < embd_size) {
1992 throw std::runtime_error("embeddings buffer too small");
1993 }
1994
1995 if (embd_size) {
1996 io.read_to(dst: this->embd, size: embd_size * sizeof(float));
1997 }
1998 }
1999
2000 if (memory) {
2001 LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
2002
2003 memory->state_read(io);
2004 }
2005
2006 return io.n_bytes();
2007}
2008
2009size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
2010 GGML_UNUSED(seq_id);
2011
2012 if (memory) {
2013 memory->state_write(io, seq_id, flags);
2014 }
2015
2016 return io.n_bytes();
2017}
2018
2019size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
2020 GGML_UNUSED(seq_id);
2021
2022 if (memory) {
2023 memory->state_read(io, seq_id, flags);
2024 }
2025
2026 return io.n_bytes();
2027}
2028
2029//
2030// perf
2031//
2032
2033llama_perf_context_data llama_context::perf_get_data() const {
2034 llama_perf_context_data data = {};
2035
2036 data.t_start_ms = 1e-3 * t_start_us;
2037 data.t_load_ms = 1e-3 * t_load_us;
2038 data.t_p_eval_ms = 1e-3 * t_p_eval_us;
2039 data.t_eval_ms = 1e-3 * t_eval_us;
2040 data.n_p_eval = std::max(a: 1, b: n_p_eval);
2041 data.n_eval = std::max(a: 1, b: n_eval);
2042 data.n_reused = std::max(a: 0, b: n_reused);
2043
2044 return data;
2045}
2046
2047void llama_context::perf_reset() {
2048 t_start_us = ggml_time_us();
2049 t_eval_us = n_eval = 0;
2050 t_p_eval_us = n_p_eval = 0;
2051 n_reused = 0;
2052}
2053
2054std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
2055 std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
2056 for (const auto & buft_size : model.memory_breakdown()) {
2057 ret[buft_size.first].model += buft_size.second;
2058 }
2059 for (const auto & buft_size : memory->memory_breakdown()) {
2060 ret[buft_size.first].context += buft_size.second;
2061 }
2062 for (const auto & backend_ptr : backends) {
2063 ggml_backend_t backend = backend_ptr.get();
2064 ret[ggml_backend_sched_get_buffer_type(sched: sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched: sched.get(), backend);
2065 }
2066 return ret;
2067}
2068
2069//
2070// training
2071//
2072
2073static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
2074 if (!tensor || tensor->type != GGML_TYPE_F32) {
2075 return;
2076 }
2077 if (!param_filter(tensor, userdata)) {
2078 return;
2079 }
2080 if (strcmp(s1: tensor->name, s2: "token_embd.weight") == 0) {
2081 return; // FIXME
2082 }
2083 if (strcmp(s1: tensor->name, s2: "rope_freqs.weight") == 0) {
2084 return; // FIXME
2085 }
2086 ggml_set_param(tensor);
2087}
2088
2089void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
2090 GGML_ASSERT(!opt_ctx);
2091 model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
2092 const uint32_t n_batch = std::min(a: this->n_batch(), b: model->hparams.n_ctx_train);
2093 const uint32_t n_ubatch = std::min(a: this->n_ubatch(), b: n_batch);
2094 GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
2095 GGML_ASSERT(n_batch % n_ubatch == 0);
2096
2097 ggml_opt_params opt_params = ggml_opt_default_params(backend_sched: sched.get(), loss_type: GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
2098 opt_params.opt_period = n_batch / n_ubatch;
2099 opt_params.get_opt_pars = lopt_params.get_opt_pars;
2100 opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
2101 opt_params.optimizer = lopt_params.optimizer_type;
2102 opt_ctx = ggml_opt_init(params: opt_params);
2103
2104 llama_opt_param_filter param_filter = lopt_params.param_filter;
2105 void * param_filter_ud = lopt_params.param_filter_ud;
2106
2107 //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
2108 llama_set_param(tensor: model->type_embd, param_filter, userdata: param_filter_ud);
2109 llama_set_param(tensor: model->pos_embd, param_filter, userdata: param_filter_ud);
2110 llama_set_param(tensor: model->tok_norm, param_filter, userdata: param_filter_ud);
2111 llama_set_param(tensor: model->tok_norm_b, param_filter, userdata: param_filter_ud);
2112 llama_set_param(tensor: model->output_norm, param_filter, userdata: param_filter_ud);
2113 llama_set_param(tensor: model->output_norm_b, param_filter, userdata: param_filter_ud);
2114 llama_set_param(tensor: model->output, param_filter, userdata: param_filter_ud);
2115 llama_set_param(tensor: model->output_b, param_filter, userdata: param_filter_ud);
2116 llama_set_param(tensor: model->output_norm_enc, param_filter, userdata: param_filter_ud);
2117 llama_set_param(tensor: model->cls, param_filter, userdata: param_filter_ud);
2118 llama_set_param(tensor: model->cls_b, param_filter, userdata: param_filter_ud);
2119 llama_set_param(tensor: model->cls_out, param_filter, userdata: param_filter_ud);
2120 llama_set_param(tensor: model->cls_out_b, param_filter, userdata: param_filter_ud);
2121
2122 for (struct llama_layer & layer : model->layers) {
2123 for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
2124 llama_set_param(tensor: reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, userdata: param_filter_ud);
2125 }
2126 }
2127}
2128
2129void llama_context::opt_epoch_iter(
2130 ggml_opt_dataset_t dataset,
2131 ggml_opt_result_t result,
2132 const std::vector<llama_token> & tokens,
2133 const std::vector<llama_token> & labels_sparse,
2134 llama_batch & batch,
2135 ggml_opt_epoch_callback callback,
2136 bool train,
2137 int64_t idata_in_loop,
2138 int64_t ndata_in_loop,
2139 int64_t t_loop_start) {
2140 GGML_ASSERT(opt_ctx);
2141 const uint32_t n_ctx = llama_model_n_ctx_train(model: &model);
2142 const uint32_t n_batch = std::min(a: this->n_batch(), b: n_ctx);
2143 const uint32_t n_ubatch = std::min(a: this->n_ubatch(), b: n_batch);
2144
2145 memory->clear(data: true);
2146
2147 for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
2148 batch.n_tokens = n_batch;
2149 for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
2150 batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
2151 batch.pos [pos_batch] = pos_ctx + pos_batch;
2152 batch.n_seq_id[pos_batch] = 1;
2153 batch.seq_id [pos_batch][0] = 0;
2154 batch.logits [pos_batch] = true;
2155 }
2156
2157 if (!balloc->init(batch_inp: batch, vocab: model.vocab, memory: nullptr, n_embd: model.hparams.n_embd_inp(), n_seq_max: cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all: true)) {
2158 LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2159 return;
2160 }
2161
2162 const uint32_t n_tokens_all = balloc->get_n_tokens();
2163
2164 n_queued_tokens += n_tokens_all;
2165
2166 embd_seq.clear();
2167
2168 uint32_t n_outputs_all = n_tokens_all;
2169
2170 auto mctx = memory->init_batch(balloc&: *balloc, n_ubatch: cparams.n_ubatch, embd_all: true);
2171 if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2172 LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2173 break;
2174 }
2175
2176 // reserve output buffer
2177 if (output_reserve(n_outputs: n_outputs_all) < n_outputs_all) {
2178 LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
2179 GGML_ABORT("TODO: handle this error");
2180 };
2181
2182 uint32_t pos_batch = 0;
2183 do {
2184 const auto & ubatch = mctx->get_ubatch();
2185
2186 n_outputs = ubatch.n_tokens;
2187
2188 if (!mctx->apply()) {
2189 LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
2190 break;
2191 }
2192
2193 auto * res = gf_res_prev.get();
2194
2195 const auto gparams = graph_params(res, ubatch, mctx: mctx.get(), gtype: LLM_GRAPH_TYPE_DEFAULT);
2196
2197 res->reset();
2198
2199 auto * gf = model.build_graph(params: gparams);
2200
2201 struct ggml_context * ctx_compute_opt;
2202 {
2203 const size_t size_gf = ggml_graph_size(cgraph: gf);
2204 const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size: size_gf, /*grads = */ true);
2205 struct ggml_init_params params = {
2206 /*.mem_size =*/ size_meta,
2207 /*.mem_buffer =*/ nullptr,
2208 /*.no_alloc =*/ true,
2209 };
2210 ctx_compute_opt = ggml_init(params);
2211 }
2212 ggml_opt_prepare_alloc(opt_ctx, ctx_compute: ctx_compute_opt, gf, inputs: res->get_tokens(), outputs: res->get_logits());
2213 ggml_opt_alloc(opt_ctx, backward: train);
2214
2215 res->set_inputs(&ubatch);
2216 {
2217 struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
2218 GGML_ASSERT(labels->ne[1] == n_ubatch);
2219 ggml_set_zero(tensor: labels);
2220 const float onef = 1.0f;
2221 for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
2222 const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
2223 GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
2224 ggml_backend_tensor_set(tensor: labels, data: &onef, offset: (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), size: sizeof(float));
2225 }
2226 }
2227 ggml_opt_eval(opt_ctx, result);
2228 if (callback) {
2229 callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2230 }
2231 ggml_free(ctx: ctx_compute_opt);
2232
2233 pos_batch += ubatch.n_tokens;
2234 } while (mctx->next());
2235 }
2236}
2237
2238void llama_context::opt_epoch(
2239 ggml_opt_dataset_t dataset,
2240 ggml_opt_result_t result_train,
2241 ggml_opt_result_t result_eval,
2242 int64_t idata_split,
2243 ggml_opt_epoch_callback callback_train,
2244 ggml_opt_epoch_callback callback_eval) {
2245 const uint32_t n_ctx = this->n_ctx();
2246 const uint32_t n_batch = std::min(a: cparams.n_batch, b: n_ctx);
2247 const uint32_t n_ubatch = std::min(a: cparams.n_ubatch, b: n_batch);
2248 const int64_t ndata = ggml_opt_dataset_ndata(dataset);
2249
2250 GGML_ASSERT(idata_split >= 0);
2251 GGML_ASSERT(idata_split <= ndata);
2252
2253 const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
2254
2255 struct llama_batch batch = llama_batch_init(n_tokens: n_batch, embd: 0, n_seq_max: 1);
2256 std::vector<llama_token> tokens(n_ctx);
2257 std::vector<llama_token> labels_sparse(n_ctx);
2258
2259 int64_t idata = 0;
2260
2261 int64_t t_loop_start = ggml_time_us();
2262 int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2263 for (; idata < idata_split; ++idata) {
2264 constexpr bool train = true;
2265 const int64_t idata_in_loop = idata*ubatch_per_ctx;
2266
2267 ggml_opt_dataset_get_batch_host(dataset, data_batch: tokens.data(), nb_data_batch: n_ctx*sizeof(llama_token), labels_batch: labels_sparse.data(), ibatch: idata);
2268 opt_epoch_iter(dataset, result: result_train, tokens, labels_sparse, batch,
2269 callback: callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2270 }
2271
2272 t_loop_start = ggml_time_us();
2273 ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2274 for (; idata < ndata; ++idata) {
2275 constexpr bool train = false;
2276 const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2277
2278 ggml_opt_dataset_get_batch_host(dataset, data_batch: tokens.data(), nb_data_batch: n_ctx*sizeof(llama_token), labels_batch: labels_sparse.data(), ibatch: idata);
2279 opt_epoch_iter(dataset, result: result_eval, tokens, labels_sparse, batch,
2280 callback: callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2281 }
2282
2283 llama_batch_free(batch);
2284}
2285
2286//
2287// interface implementation
2288//
2289
2290llama_context_params llama_context_default_params() {
2291 llama_context_params result = {
2292 /*.n_ctx =*/ 512,
2293 /*.n_batch =*/ 2048,
2294 /*.n_ubatch =*/ 512,
2295 /*.n_seq_max =*/ 1,
2296 /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
2297 /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
2298 /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2299 /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2300 /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2301 /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
2302 /*.rope_freq_base =*/ 0.0f,
2303 /*.rope_freq_scale =*/ 0.0f,
2304 /*.yarn_ext_factor =*/ -1.0f,
2305 /*.yarn_attn_factor =*/ -1.0f,
2306 /*.yarn_beta_fast =*/ -1.0f,
2307 /*.yarn_beta_slow =*/ -1.0f,
2308 /*.yarn_orig_ctx =*/ 0,
2309 /*.defrag_thold =*/ -1.0f,
2310 /*.cb_eval =*/ nullptr,
2311 /*.cb_eval_user_data =*/ nullptr,
2312 /*.type_k =*/ GGML_TYPE_F16,
2313 /*.type_v =*/ GGML_TYPE_F16,
2314 /*.abort_callback =*/ nullptr,
2315 /*.abort_callback_data =*/ nullptr,
2316 /*.embeddings =*/ false,
2317 /*.offload_kqv =*/ true,
2318 /*.no_perf =*/ true,
2319 /*.op_offload =*/ true,
2320 /*.swa_full =*/ true,
2321 /*.kv_unified =*/ false,
2322 };
2323
2324 return result;
2325}
2326
2327llama_context * llama_init_from_model(
2328 llama_model * model,
2329 llama_context_params params) {
2330 if (!model) {
2331 LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
2332 return nullptr;
2333 }
2334
2335 if (params.n_batch == 0 && params.n_ubatch == 0) {
2336 LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
2337 return nullptr;
2338 }
2339
2340 if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
2341 LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
2342 return nullptr;
2343 }
2344
2345 if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
2346 LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2347 params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2348 }
2349
2350 if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(type: params.type_k)) {
2351 const uint32_t blck_size = ggml_blck_size(type: params.type_k);
2352 if (model->hparams.n_embd_head_k % blck_size != 0) {
2353 LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2354 __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2355 return nullptr;
2356 }
2357 }
2358
2359 if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(type: params.type_v)) {
2360 const uint32_t blck_size = ggml_blck_size(type: params.type_v);
2361 if (model->hparams.n_embd_head_v % blck_size != 0) {
2362 LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2363 __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2364 return nullptr;
2365 }
2366 }
2367
2368 if (ggml_is_quantized(type: params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
2369 LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2370 return nullptr;
2371 }
2372
2373 if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
2374 params.pooling_type != model->hparams.pooling_type) {
2375 //user-specified pooling-type is different from the model default
2376 LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
2377 model->hparams.pooling_type, params.pooling_type);
2378 }
2379
2380 try {
2381 auto * ctx = new llama_context(*model, params);
2382 return ctx;
2383 } catch (const std::exception & err) {
2384 LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
2385 }
2386
2387 return nullptr;
2388}
2389
2390// deprecated
2391llama_context * llama_new_context_with_model(
2392 llama_model * model,
2393 llama_context_params params) {
2394 return llama_init_from_model(model, params);
2395}
2396
2397void llama_free(llama_context * ctx) {
2398 delete ctx;
2399}
2400
2401uint32_t llama_n_ctx(const llama_context * ctx) {
2402 return ctx->n_ctx();
2403}
2404
2405uint32_t llama_n_ctx_seq(const llama_context * ctx) {
2406 return ctx->n_ctx_seq();
2407}
2408
2409uint32_t llama_n_batch(const llama_context * ctx) {
2410 return ctx->n_batch();
2411}
2412
2413uint32_t llama_n_ubatch(const llama_context * ctx) {
2414 return ctx->n_ubatch();
2415}
2416
2417uint32_t llama_n_seq_max(const llama_context * ctx) {
2418 return ctx->n_seq_max();
2419}
2420
2421const llama_model * llama_get_model(const llama_context * ctx) {
2422 return &ctx->get_model();
2423}
2424
2425enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
2426 return ctx->pooling_type();
2427}
2428
2429void llama_attach_threadpool(
2430 llama_context * ctx,
2431 ggml_threadpool_t threadpool,
2432 ggml_threadpool_t threadpool_batch) {
2433 ctx->attach_threadpool(threadpool, threadpool_batch);
2434}
2435
2436void llama_detach_threadpool(llama_context * ctx) {
2437 ctx->detach_threadpool();
2438}
2439
2440void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
2441 ctx->set_n_threads(n_threads, n_threads_batch);
2442}
2443
2444int32_t llama_n_threads(llama_context * ctx) {
2445 return ctx->n_threads();
2446}
2447
2448int32_t llama_n_threads_batch(llama_context * ctx) {
2449 return ctx->n_threads_batch();
2450}
2451
2452void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
2453 ctx->set_abort_callback(abort_callback, abort_callback_data);
2454}
2455
2456void llama_set_embeddings(llama_context * ctx, bool embeddings) {
2457 ctx->set_embeddings(embeddings);
2458}
2459
2460void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
2461 ctx->set_causal_attn(causal_attn);
2462}
2463
2464void llama_set_warmup(llama_context * ctx, bool warmup) {
2465 ctx->set_warmup(warmup);
2466}
2467
2468void llama_synchronize(llama_context * ctx) {
2469 ctx->synchronize();
2470}
2471
2472float * llama_get_logits(llama_context * ctx) {
2473 ctx->synchronize();
2474
2475 return ctx->get_logits();
2476}
2477
2478float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
2479 ctx->synchronize();
2480
2481 return ctx->get_logits_ith(i);
2482}
2483
2484float * llama_get_embeddings(llama_context * ctx) {
2485 ctx->synchronize();
2486
2487 return ctx->get_embeddings();
2488}
2489
2490float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
2491 ctx->synchronize();
2492
2493 return ctx->get_embeddings_ith(i);
2494}
2495
2496float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
2497 ctx->synchronize();
2498
2499 return ctx->get_embeddings_seq(seq_id);
2500}
2501
2502// llama adapter API
2503
2504int32_t llama_set_adapter_lora(
2505 llama_context * ctx,
2506 llama_adapter_lora * adapter,
2507 float scale) {
2508 ctx->set_adapter_lora(adapter, scale);
2509
2510 return 0;
2511}
2512
2513int32_t llama_rm_adapter_lora(
2514 llama_context * ctx,
2515 llama_adapter_lora * adapter) {
2516 bool res = ctx->rm_adapter_lora(adapter);
2517
2518 return res ? 0 : -1;
2519}
2520
2521void llama_clear_adapter_lora(llama_context * ctx) {
2522 ctx->clear_adapter_lora();
2523}
2524
2525int32_t llama_apply_adapter_cvec(
2526 llama_context * ctx,
2527 const float * data,
2528 size_t len,
2529 int32_t n_embd,
2530 int32_t il_start,
2531 int32_t il_end) {
2532 bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
2533
2534 return res ? 0 : -1;
2535}
2536
2537//
2538// memory
2539//
2540
2541llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2542 return ctx->get_memory();
2543}
2544
2545void llama_memory_clear(llama_memory_t mem, bool data) {
2546 if (!mem) {
2547 return;
2548 }
2549
2550 mem->clear(data);
2551}
2552
2553bool llama_memory_seq_rm(
2554 llama_memory_t mem,
2555 llama_seq_id seq_id,
2556 llama_pos p0,
2557 llama_pos p1) {
2558 if (!mem) {
2559 return true;
2560 }
2561
2562 return mem->seq_rm(seq_id, p0, p1);
2563}
2564
2565void llama_memory_seq_cp(
2566 llama_memory_t mem,
2567 llama_seq_id seq_id_src,
2568 llama_seq_id seq_id_dst,
2569 llama_pos p0,
2570 llama_pos p1) {
2571 if (!mem) {
2572 return;
2573 }
2574
2575 mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2576}
2577
2578void llama_memory_seq_keep(
2579 llama_memory_t mem,
2580 llama_seq_id seq_id) {
2581 if (!mem) {
2582 return;
2583 }
2584
2585 mem->seq_keep(seq_id);
2586}
2587
2588void llama_memory_seq_add(
2589 llama_memory_t mem,
2590 llama_seq_id seq_id,
2591 llama_pos p0,
2592 llama_pos p1,
2593 llama_pos delta) {
2594 if (!mem) {
2595 return;
2596 }
2597
2598 mem->seq_add(seq_id, p0, p1, shift: delta);
2599}
2600
2601void llama_memory_seq_div(
2602 llama_memory_t mem,
2603 llama_seq_id seq_id,
2604 llama_pos p0,
2605 llama_pos p1,
2606 int d) {
2607 if (!mem) {
2608 return;
2609 }
2610
2611 mem->seq_div(seq_id, p0, p1, d);
2612}
2613
2614llama_pos llama_memory_seq_pos_min(
2615 llama_memory_t mem,
2616 llama_seq_id seq_id) {
2617 if (!mem) {
2618 return -1;
2619 }
2620
2621 return mem->seq_pos_min(seq_id);
2622}
2623
2624llama_pos llama_memory_seq_pos_max(
2625 llama_memory_t mem,
2626 llama_seq_id seq_id) {
2627 if (!mem) {
2628 return -1;
2629 }
2630
2631 return mem->seq_pos_max(seq_id);
2632}
2633
2634bool llama_memory_can_shift(llama_memory_t mem) {
2635 if (!mem) {
2636 return false;
2637 }
2638
2639 return mem->get_can_shift();
2640}
2641
2642// llama state API
2643
2644// deprecated
2645size_t llama_get_state_size(llama_context * ctx) {
2646 return llama_state_get_size(ctx);
2647}
2648
2649// deprecated
2650size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
2651 return llama_state_get_data(ctx, dst, size: -1);
2652}
2653
2654// deprecated
2655size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
2656 return llama_state_set_data(ctx, src, size: -1);
2657}
2658
2659// deprecated
2660bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2661 return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
2662}
2663
2664// deprecated
2665bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2666 return llama_state_save_file(ctx, path_session, tokens, n_token_count);
2667}
2668
2669// Returns the *actual* size of the state.
2670// Intended to be used when saving to state to a buffer.
2671size_t llama_state_get_size(llama_context * ctx) {
2672 return ctx->state_get_size();
2673}
2674
2675size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
2676 ctx->synchronize();
2677
2678 return ctx->state_get_data(dst, size);
2679}
2680
2681// Sets the state reading from the specified source address
2682size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
2683 ctx->synchronize();
2684
2685 return ctx->state_set_data(src, size);
2686}
2687
2688bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2689 ctx->synchronize();
2690
2691 try {
2692 return ctx->state_load_file(filepath: path_session, tokens_out, n_token_capacity, n_token_count_out);
2693 } catch (const std::exception & err) {
2694 LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
2695 return false;
2696 }
2697}
2698
2699bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2700 ctx->synchronize();
2701
2702 try {
2703 return ctx->state_save_file(filepath: path_session, tokens, n_token_count);
2704 } catch (const std::exception & err) {
2705 LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
2706 return false;
2707 }
2708}
2709
2710size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2711 return llama_state_seq_get_size_ext(ctx, seq_id, flags: 0);
2712}
2713
2714size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2715 return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, flags: 0);
2716}
2717
2718size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2719 return llama_state_seq_set_data_ext(ctx, src, size, dest_seq_id: seq_id, flags: 0);
2720}
2721
2722size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
2723 return ctx->state_seq_get_size(seq_id, flags);
2724}
2725
2726size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2727 ctx->synchronize();
2728
2729 return ctx->state_seq_get_data(seq_id, dst, size, flags);
2730}
2731
2732size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
2733 ctx->synchronize();
2734
2735 return ctx->state_seq_set_data(seq_id, src, size, flags);
2736}
2737
2738size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
2739 ctx->synchronize();
2740
2741 try {
2742 return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
2743 } catch (const std::exception & err) {
2744 LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
2745 return 0;
2746 }
2747}
2748
2749size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2750 ctx->synchronize();
2751
2752 try {
2753 return ctx->state_seq_load_file(seq_id: dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
2754 } catch (const std::exception & err) {
2755 LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
2756 return 0;
2757 }
2758}
2759
2760///
2761
2762int32_t llama_encode(
2763 llama_context * ctx,
2764 llama_batch batch) {
2765 const int ret = ctx->encode(batch_inp: batch);
2766 if (ret != 0) {
2767 LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2768 }
2769
2770 return ret;
2771}
2772
2773int32_t llama_decode(
2774 llama_context * ctx,
2775 llama_batch batch) {
2776 const int ret = ctx->decode(batch_inp: batch);
2777 if (ret != 0 && ret != 1) {
2778 LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2779 }
2780
2781 return ret;
2782}
2783
2784//
2785// perf
2786//
2787
2788llama_perf_context_data llama_perf_context(const llama_context * ctx) {
2789 llama_perf_context_data data = {};
2790
2791 if (ctx == nullptr) {
2792 return data;
2793 }
2794
2795 data = ctx->perf_get_data();
2796
2797 return data;
2798}
2799
2800void llama_perf_context_print(const llama_context * ctx) {
2801 const auto data = llama_perf_context(ctx);
2802
2803 const double t_end_ms = 1e-3 * ggml_time_us();
2804
2805 LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
2806 LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
2807 __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
2808 LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2809 __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2810 LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2811 LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
2812}
2813
2814void llama_perf_context_reset(llama_context * ctx) {
2815 ctx->perf_reset();
2816}
2817
2818void llama_memory_breakdown_print(const struct llama_context * ctx) {
2819 const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
2820
2821 std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
2822
2823 std::vector<std::array<std::string, 9>> table_data;
2824 table_data.reserve(n: devices.size());
2825 const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n";
2826 const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n";
2827 const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n";
2828
2829 table_data.push_back(x: {template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"});
2830
2831 constexpr size_t MiB = 1024 * 1024;
2832 const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "};
2833
2834 // track seen buffer types to avoid double counting:
2835 std::set<ggml_backend_buffer_type_t> seen_buffer_types;
2836
2837 // accumulative memory breakdown for each device and for host:
2838 std::vector<llama_memory_breakdown_data> mb_dev(devices.size());
2839 llama_memory_breakdown_data mb_host;
2840
2841 for (const auto & buft_mb : memory_breakdown) {
2842 ggml_backend_buffer_type_t buft = buft_mb.first;
2843 const llama_memory_breakdown_data & mb = buft_mb.second;
2844 if (ggml_backend_buft_is_host(buft)) {
2845 mb_host.model += mb.model;
2846 mb_host.context += mb.context;
2847 mb_host.compute += mb.compute;
2848 seen_buffer_types.insert(x: buft);
2849 continue;
2850 }
2851 ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
2852 if (dev) {
2853 int i_dev = -1;
2854 for (size_t i = 0; i < devices.size(); i++) {
2855 if (devices[i] == dev) {
2856 i_dev = i;
2857 break;
2858 }
2859 }
2860 if (i_dev != -1) {
2861 mb_dev[i_dev].model += mb.model;
2862 mb_dev[i_dev].context += mb.context;
2863 mb_dev[i_dev].compute += mb.compute;
2864 seen_buffer_types.insert(x: buft);
2865 continue;
2866 }
2867 }
2868 }
2869
2870 // print memory breakdown for each device:
2871 for (size_t i = 0; i < devices.size(); i++) {
2872 ggml_backend_dev_t dev = devices[i];
2873 llama_memory_breakdown_data mb = mb_dev[i];
2874
2875 const std::string name = ggml_backend_dev_name(device: dev);
2876 std::string desc = ggml_backend_dev_description(device: dev);
2877 for (const std::string & prefix : desc_prefixes_strip) {
2878 if (desc.length() >= prefix.length() && desc.substr(pos: 0, n: prefix.length()) == prefix) {
2879 desc = desc.substr(pos: prefix.length());
2880 }
2881 }
2882
2883 size_t free, total;
2884 ggml_backend_dev_memory(device: dev, free: &free, total: &total);
2885
2886 const size_t self = mb.model + mb.context + mb.compute;
2887 const size_t unaccounted = total - self - free;
2888
2889 table_data.push_back(x: {
2890 template_gpu,
2891 " - " + name + " (" + desc + ")",
2892 std::to_string(val: total / MiB),
2893 std::to_string(val: free / MiB),
2894 std::to_string(val: self / MiB),
2895 std::to_string(val: mb.model / MiB),
2896 std::to_string(val: mb.context / MiB),
2897 std::to_string(val: mb.compute / MiB),
2898 std::to_string(val: unaccounted / MiB)});
2899 }
2900
2901 // print memory breakdown for host:
2902 {
2903 const size_t self = mb_host.model + mb_host.context + mb_host.compute;
2904 table_data.push_back(x: {
2905 template_other,
2906 " - Host",
2907 "", // total
2908 "", // free
2909 std::to_string(val: self / MiB),
2910 std::to_string(val: mb_host.model / MiB),
2911 std::to_string(val: mb_host.context / MiB),
2912 std::to_string(val: mb_host.compute / MiB),
2913 ""}); // unaccounted
2914 }
2915
2916 // print memory breakdown for all remaining buffer types:
2917 for (const auto & buft_mb : memory_breakdown) {
2918 ggml_backend_buffer_type_t buft = buft_mb.first;
2919 const llama_memory_breakdown_data & mb = buft_mb.second;
2920 if (seen_buffer_types.count(x: buft) == 1) {
2921 continue;
2922 }
2923 const std::string name = ggml_backend_buft_name(buft);
2924 const size_t self = mb.model + mb.context + mb.compute;
2925 table_data.push_back(x: {
2926 template_other,
2927 " - " + name,
2928 "", // total
2929 "", // free
2930 std::to_string(val: self / MiB),
2931 std::to_string(val: mb.model / MiB),
2932 std::to_string(val: mb.context / MiB),
2933 std::to_string(val: mb.compute / MiB),
2934 ""}); // unaccounted
2935 seen_buffer_types.insert(x: buft);
2936 }
2937
2938 for (size_t j = 1; j < table_data[0].size(); j++) {
2939 size_t max_len = 0;
2940 for (const auto & td : table_data) {
2941 max_len = std::max(a: max_len, b: td[j].length());
2942 }
2943 for (auto & td : table_data) {
2944 td[j].insert(pos: j == 1 ? td[j].length() : 0, n: max_len - td[j].length(), c: ' ');
2945 }
2946 }
2947 for (const auto & td : table_data) {
2948 LLAMA_LOG_INFO(td[0].c_str(),
2949 __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(),
2950 td[6].c_str(), td[7].c_str(), td[8].c_str());
2951 }
2952}
2953
2954//
2955// training
2956//
2957
2958bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
2959 GGML_UNUSED(tensor);
2960 GGML_UNUSED(userdata);
2961 return true;
2962}
2963
2964void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
2965 ctx->opt_init(model, lopt_params);
2966}
2967
2968void llama_opt_epoch(
2969 struct llama_context * ctx,
2970 ggml_opt_dataset_t dataset,
2971 ggml_opt_result_t result_train,
2972 ggml_opt_result_t result_eval,
2973 int64_t idata_split,
2974 ggml_opt_epoch_callback callback_train,
2975 ggml_opt_epoch_callback callback_eval) {
2976 ctx->opt_epoch(
2977 dataset,
2978 result_train,
2979 result_eval,
2980 idata_split,
2981 callback_train,
2982 callback_eval);
2983}
2984