1#include "llama-kv-cache.h"
2
3#include "llama-impl.h"
4#include "llama-io.h"
5#include "llama-model.h"
6#include "llama-context.h"
7
8#include <algorithm>
9#include <cassert>
10#include <cmath>
11#include <cstring>
12#include <limits>
13#include <map>
14#include <stdexcept>
15
16//
17// llama_kv_cache
18//
19
20llama_kv_cache::llama_kv_cache(
21 const llama_model & model,
22 ggml_type type_k,
23 ggml_type type_v,
24 bool v_trans,
25 bool offload,
26 bool unified,
27 uint32_t kv_size,
28 uint32_t n_seq_max,
29 uint32_t n_pad,
30 uint32_t n_swa,
31 llama_swa_type swa_type,
32 const layer_filter_cb & filter,
33 const layer_reuse_cb & reuse) :
34 model(model), hparams(model.hparams), v_trans(v_trans),
35 n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
36
37 GGML_ASSERT(kv_size % n_pad == 0);
38
39 const uint32_t n_layer_kv = hparams.n_layer_kv();
40
41 // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
42 struct ggml_backend_buft_comparator {
43 bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
44 return strcmp(s1: ggml_backend_buft_name(buft: lhs), s2: ggml_backend_buft_name(buft: rhs)) < 0;
45 }
46 };
47 std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
48
49 // create a context for each buffer type
50 auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
51 auto it = ctx_map.find(x: buft);
52 if (it == ctx_map.end()) {
53 ggml_init_params params = {
54 /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
55 /*.mem_buffer =*/ NULL,
56 /*.no_alloc =*/ true,
57 };
58
59 ggml_context * ctx = ggml_init(params);
60 if (!ctx) {
61 return nullptr;
62 }
63
64 ctx_map.emplace(args&: buft, args&: ctx);
65
66 return ctx;
67 }
68
69 return it->second.get();
70 };
71
72 GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
73
74 v_heads.resize(new_size: n_stream);
75 for (uint32_t s = 0; s < n_stream; ++s) {
76 v_heads[s] = 0;
77 }
78
79 v_cells.resize(new_size: n_stream);
80 for (uint32_t s = 0; s < n_stream; ++s) {
81 v_cells[s].resize(n: kv_size);
82 }
83
84 // by default, all sequence ids are mapped to the 0th stream
85 seq_to_stream.resize(LLAMA_MAX_SEQ, x: 0);
86
87 if (n_stream > 1) {
88 seq_to_stream.resize(new_size: n_stream, x: 0);
89 for (uint32_t s = 0; s < n_stream; ++s) {
90 seq_to_stream[s] = s;
91 }
92 }
93
94 // [TAG_V_CACHE_VARIABLE]
95 if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
96 LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
97 __func__, hparams.n_embd_v_gqa_max());
98 }
99
100 for (uint32_t il = 0; il < hparams.n_layer; il++) {
101 if (!hparams.has_kv(il)) {
102 LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
103 continue;
104 }
105
106 if (filter && !filter(il)) {
107 LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
108 continue;
109 }
110
111 // [TAG_V_CACHE_VARIABLE]
112 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
113 const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
114
115 const char * dev_name = "CPU";
116
117 ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
118
119 if (offload) {
120 auto * dev = model.dev_layer(il);
121 buft = ggml_backend_dev_buffer_type(device: dev);
122
123 dev_name = ggml_backend_dev_name(device: dev);
124 }
125
126 LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
127
128 ggml_context * ctx = ctx_for_buft(buft);
129 if (!ctx) {
130 throw std::runtime_error("failed to create ggml context for kv cache");
131 }
132
133 ggml_tensor * k = ggml_new_tensor_3d(ctx, type: type_k, ne0: n_embd_k_gqa, ne1: kv_size, ne2: n_stream);
134 ggml_tensor * v = ggml_new_tensor_3d(ctx, type: type_v, ne0: n_embd_v_gqa, ne1: kv_size, ne2: n_stream);
135
136 ggml_format_name(tensor: k, fmt: "cache_k_l%d", il);
137 ggml_format_name(tensor: v, fmt: "cache_v_l%d", il);
138
139 std::vector<ggml_tensor *> k_stream;
140 std::vector<ggml_tensor *> v_stream;
141
142 for (uint32_t s = 0; s < n_stream; ++s) {
143 k_stream.push_back(x: ggml_view_2d(ctx, a: k, ne0: n_embd_k_gqa, ne1: kv_size, nb1: k->nb[1], offset: s*k->nb[2]));
144 v_stream.push_back(x: ggml_view_2d(ctx, a: v, ne0: n_embd_v_gqa, ne1: kv_size, nb1: v->nb[1], offset: s*v->nb[2]));
145 }
146
147 map_layer_ids[il] = layers.size();
148
149 layers.push_back(x: { .il: il, .k: k, .v: v, .k_stream: k_stream, .v_stream: v_stream, });
150 }
151
152 if (reuse) {
153 LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
154
155 for (uint32_t il = 0; il < hparams.n_layer; il++) {
156 const int32_t il_reuse = reuse(il);
157
158 if (il_reuse < 0) {
159 LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
160 continue;
161 }
162
163 if (filter && !filter(il)) {
164 LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
165 continue;
166 }
167
168 GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
169
170 map_layer_ids[il] = map_layer_ids[il_reuse];
171
172 LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
173 }
174 }
175
176 // allocate tensors and initialize the buffers to avoid NaNs in the padding
177 for (auto & [buft, ctx] : ctx_map) {
178 ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx: ctx.get(), buft);
179 if (!buf) {
180 throw std::runtime_error("failed to allocate buffer for kv cache");
181 }
182
183 LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
184
185 ggml_backend_buffer_clear(buffer: buf, value: 0);
186 ctxs_bufs.emplace_back(args: std::move(ctx), args&: buf);
187 }
188
189 {
190 const size_t memory_size_k = size_k_bytes();
191 const size_t memory_size_v = size_v_bytes();
192
193 LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
194 (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
195 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
196 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
197 }
198
199 const char * LLAMA_KV_CACHE_DEBUG = getenv(name: "LLAMA_KV_CACHE_DEBUG");
200 debug = LLAMA_KV_CACHE_DEBUG ? atoi(nptr: LLAMA_KV_CACHE_DEBUG) : 0;
201}
202
203void llama_kv_cache::clear(bool data) {
204 for (uint32_t s = 0; s < n_stream; ++s) {
205 v_cells[s].reset();
206 v_heads[s] = 0;
207 }
208
209 if (data) {
210 for (auto & [_, buf] : ctxs_bufs) {
211 ggml_backend_buffer_clear(buffer: buf.get(), value: 0);
212 }
213 }
214}
215
216bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
217 GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
218
219 if (p0 < 0) {
220 p0 = 0;
221 }
222
223 if (p1 < 0) {
224 p1 = std::numeric_limits<llama_pos>::max();
225 }
226
227 if (seq_id >= 0) {
228 auto & cells = v_cells[seq_to_stream[seq_id]];
229 auto & head = v_heads[seq_to_stream[seq_id]];
230
231 uint32_t new_head = cells.size();
232
233 for (uint32_t i = 0; i < cells.size(); ++i) {
234 if (!cells.pos_in(i, p0, p1)) {
235 continue;
236 }
237
238 if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
239 if (new_head == cells.size()) {
240 new_head = i;
241 }
242 }
243 }
244
245 // If we freed up a slot, set head to it so searching can start there.
246 if (new_head != cells.size() && new_head < head) {
247 head = new_head;
248 }
249 } else {
250 // match any sequence
251 for (uint32_t s = 0; s < n_stream; ++s) {
252 auto & cells = v_cells[s];
253 auto & head = v_heads[s];
254
255 uint32_t new_head = cells.size();
256
257 for (uint32_t i = 0; i < cells.size(); ++i) {
258 if (!cells.pos_in(i, p0, p1)) {
259 continue;
260 }
261
262 cells.rm(i);
263
264 if (new_head == cells.size()) {
265 new_head = i;
266 }
267 }
268
269 // If we freed up a slot, set head to it so searching can start there.
270 if (new_head != cells.size() && new_head < head) {
271 head = new_head;
272 }
273 }
274 }
275
276 return true;
277}
278
279void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
280 GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
281 GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
282
283 const auto s0 = seq_to_stream[seq_id_src];
284 const auto s1 = seq_to_stream[seq_id_dst];
285
286 if (s0 == s1) {
287 // since both sequences are in the same stream, no data copy is necessary
288 // we just have to update the cells meta data
289
290 auto & cells = v_cells[s0];
291
292 if (seq_id_src == seq_id_dst) {
293 return;
294 }
295
296 if (p0 < 0) {
297 p0 = 0;
298 }
299
300 if (p1 < 0) {
301 p1 = std::numeric_limits<llama_pos>::max();
302 }
303
304 for (uint32_t i = 0; i < cells.size(); ++i) {
305 if (!cells.pos_in(i, p0, p1)) {
306 continue;
307 }
308
309 if (cells.seq_has(i, seq_id: seq_id_src)) {
310 cells.seq_add(i, seq_id: seq_id_dst);
311 }
312 }
313
314 return;
315 }
316
317 // cross-stream sequence copies require to copy the actual buffer data
318
319 bool is_full = true;
320
321 if (p0 > 0 && p0 + 1 < (int) get_size()) {
322 is_full = false;
323 }
324
325 if (p1 > 0 && p1 + 1 < (int) get_size()) {
326 is_full = false;
327 }
328
329 GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
330
331 // enqueue the copy operation - the buffer copy will be performed during the next update
332 sc_info.ssrc.push_back(x: s0);
333 sc_info.sdst.push_back(x: s1);
334
335 v_cells[s1].reset();
336 for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
337 if (v_cells[s0].seq_has(i, seq_id: seq_id_src)) {
338 llama_pos pos = v_cells[s0].pos_get(i);
339 llama_pos shift = v_cells[s0].get_shift(i);
340
341 llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
342
343 if (shift != 0) {
344 pos -= shift;
345 assert(pos >= 0);
346 }
347
348 v_cells[s1].pos_set(i, p: pos);
349 v_cells[s1].seq_add(i, seq_id: seq_id_dst);
350
351 if (shift != 0) {
352 v_cells[s1].pos_add(i, d: shift);
353 }
354
355 v_cells[s1].ext_set(i, p: ext);
356 }
357 }
358
359 v_heads[s1] = v_heads[s0];
360
361 //for (uint32_t s = 0; s < n_stream; ++s) {
362 // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
363 //}
364}
365
366void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
367 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
368
369 auto & cells = v_cells[seq_to_stream[seq_id]];
370 auto & head = v_heads[seq_to_stream[seq_id]];
371
372 uint32_t new_head = cells.size();
373
374 for (uint32_t i = 0; i < cells.size(); ++i) {
375 if (cells.seq_keep(i, seq_id)) {
376 if (new_head == cells.size()) {
377 new_head = i;
378 }
379 }
380 }
381
382 // If we freed up a slot, set head to it so searching can start there.
383 if (new_head != cells.size() && new_head < head) {
384 head = new_head;
385 }
386}
387
388void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
389 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
390 GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
391
392 auto & cells = v_cells[seq_to_stream[seq_id]];
393 auto & head = v_heads[seq_to_stream[seq_id]];
394
395 if (shift == 0) {
396 return;
397 }
398
399 uint32_t new_head = cells.size();
400
401 if (p0 < 0) {
402 p0 = 0;
403 }
404
405 if (p1 < 0) {
406 p1 = std::numeric_limits<llama_pos>::max();
407 }
408
409 // If there is no range then return early to avoid looping over all cells.
410 if (p0 == p1) {
411 return;
412 }
413
414 for (uint32_t i = 0; i < cells.size(); ++i) {
415 if (!cells.pos_in(i, p0, p1)) {
416 continue;
417 }
418
419 if (cells.seq_has(i, seq_id)) {
420 if (cells.pos_add(i, d: shift)) {
421 if (new_head == cells.size()) {
422 new_head = i;
423 }
424 }
425 }
426 }
427
428 // If we freed up a slot, set head to it so searching can start there.
429 // Otherwise we just start the next search from the beginning.
430 head = new_head != cells.size() ? new_head : 0;
431}
432
433void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
434 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
435 GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
436
437 auto & cells = v_cells[seq_to_stream[seq_id]];
438
439 if (d == 1) {
440 return;
441 }
442
443 if (p0 < 0) {
444 p0 = 0;
445 }
446
447 if (p1 < 0) {
448 p1 = std::numeric_limits<llama_pos>::max();
449 }
450
451 // If there is no range then return early to avoid looping over the cache.
452 if (p0 == p1) {
453 return;
454 }
455
456 for (uint32_t i = 0; i < cells.size(); ++i) {
457 if (!cells.pos_in(i, p0, p1)) {
458 continue;
459 }
460
461 if (cells.seq_has(i, seq_id)) {
462 cells.pos_div(i, d);
463 }
464 }
465}
466
467llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
468 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
469
470 const auto & cells = v_cells[seq_to_stream[seq_id]];
471
472 return cells.seq_pos_min(seq_id);
473}
474
475llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
476 GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
477
478 const auto & cells = v_cells[seq_to_stream[seq_id]];
479
480 return cells.seq_pos_max(seq_id);
481}
482
483std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
484 std::map<ggml_backend_buffer_type_t, size_t> ret;
485 for (const auto & [_, buf] : ctxs_bufs) {
486 ret[ggml_backend_buffer_get_type(buffer: buf.get())] += ggml_backend_buffer_get_size(buffer: buf.get());
487 }
488 return ret;
489}
490
491llama_memory_context_ptr llama_kv_cache::init_batch(
492 llama_batch_allocr & balloc,
493 uint32_t n_ubatch,
494 bool embd_all) {
495 GGML_UNUSED(embd_all);
496
497 do {
498 balloc.split_reset();
499
500 std::vector<llama_ubatch> ubatches;
501 while (true) {
502 auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, sequential: true);
503
504 if (ubatch.n_tokens == 0) {
505 break;
506 }
507
508 ubatches.push_back(x: std::move(ubatch)); // NOLINT
509 }
510
511 if (balloc.get_n_used() < balloc.get_n_tokens()) {
512 // failed to find a suitable split
513 break;
514 }
515
516 auto sinfos = prepare(ubatches);
517 if (sinfos.empty()) {
518 break;
519 }
520
521 return std::make_unique<llama_kv_cache_context>(
522 args: this, args: std::move(sinfos), args: std::move(ubatches));
523 } while (false);
524
525 return std::make_unique<llama_kv_cache_context>(args: LLAMA_MEMORY_STATUS_FAILED_PREPARE);
526}
527
528llama_memory_context_ptr llama_kv_cache::init_full() {
529 return std::make_unique<llama_kv_cache_context>(args: this);
530}
531
532llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
533 GGML_UNUSED(optimize);
534
535 bool do_shift = get_has_shift();
536
537 return std::make_unique<llama_kv_cache_context>(args: this, args&: lctx, args&: do_shift, args: std::move(sc_info));
538}
539
540llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
541 llama_kv_cache::slot_info_vec_t res;
542
543 struct state_t {
544 slot_info sinfo; // slot info for the ubatch
545
546 std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
547
548 std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
549 };
550
551 // remember the old state of the cells so we can restore it in the end
552 std::vector<state_t> states;
553
554 bool success = true;
555
556 for (const auto & ubatch : ubatches) {
557 // only find a suitable slot for the ubatch. don't modify the cells yet
558 const auto sinfo_new = find_slot(ubatch, cont: false);
559 if (sinfo_new.empty()) {
560 success = false;
561 break;
562 }
563
564 // remeber the position that we found
565 res.push_back(x: sinfo_new);
566
567 // store the old state of the cells in the recovery stack
568 {
569 state_t state = { .sinfo: sinfo_new, .v_heads_old: v_heads, .v_cells: {} };
570
571 for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
572 auto & cells = v_cells[sinfo_new.strm[s]];
573
574 state.v_cells.push_back(x: cells.cp(idxs: sinfo_new.idxs[s]));
575 }
576
577 states.push_back(x: std::move(state));
578 }
579
580 // now emplace the ubatch
581 apply_ubatch(sinfo: sinfo_new, ubatch);
582 }
583
584 GGML_ASSERT(!states.empty() || !success);
585
586 // iterate backwards and restore the cells to their original state
587 for (auto it = states.rbegin(); it != states.rend(); ++it) {
588 const auto & sinfo = it->sinfo;
589
590 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
591 auto & cells = v_cells[sinfo.strm[s]];
592 auto & head = v_heads[sinfo.strm[s]];
593
594 cells.set(idxs: sinfo.idxs[s], other: it->v_cells[s]);
595 head = it->v_heads_old[s];
596 }
597 }
598
599 if (!success) {
600 return {};
601 }
602
603 return res;
604}
605
606bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
607 bool updated = false;
608
609 auto * sched = lctx->get_sched();
610
611 if (!sc_info.empty()) {
612 assert(n_stream > 1 && "stream copy should never happen with a single stream");
613
614 llama_synchronize(ctx: lctx);
615
616 const size_t n_copy = sc_info.ssrc.size();
617
618 for (size_t i = 0; i < n_copy; ++i) {
619 const auto ssrc = sc_info.ssrc[i];
620 const auto sdst = sc_info.sdst[i];
621
622 assert(ssrc < n_stream);
623 assert(sdst < n_stream);
624
625 LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
626
627 assert(ssrc != sdst);
628
629 for (uint32_t il = 0; il < layers.size(); ++il) {
630 const auto & layer = layers[il];
631
632 ggml_backend_tensor_copy(src: layer.k_stream[ssrc], dst: layer.k_stream[sdst]);
633 ggml_backend_tensor_copy(src: layer.v_stream[ssrc], dst: layer.v_stream[sdst]);
634 }
635 }
636 }
637
638 if (do_shift) {
639 if (!get_can_shift()) {
640 GGML_ABORT("The current KV cache / model configuration does not support K-shift");
641 }
642
643 LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
644
645 // apply K-shift if needed
646 if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
647 ggml_backend_sched_reset(sched);
648
649 auto * res = lctx->get_gf_res_reserve();
650
651 res->reset();
652
653 auto * gf = build_graph_shift(res, lctx);
654 if (!ggml_backend_sched_alloc_graph(sched, graph: gf)) {
655 LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
656 return updated;
657 }
658
659 res->set_inputs(nullptr);
660
661 if (lctx->graph_compute(gf, batched: false) != GGML_STATUS_SUCCESS) {
662 LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
663 return updated;
664 }
665
666 updated = true;
667 }
668
669 for (uint32_t s = 0; s < n_stream; ++s) {
670 auto & cells = v_cells[s];
671
672 cells.reset_shift();
673 }
674 }
675
676 return updated;
677}
678
679llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
680
681 if (debug > 0) {
682 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
683 const auto seq_id = ubatch.seq_id_unq[s];
684 const auto stream_id = seq_to_stream[seq_id];
685 const auto & cells = v_cells[stream_id];
686 const uint32_t head_cur = v_heads[stream_id];
687
688 LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
689 __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
690
691 if ((debug == 2 && n_swa > 0) || debug > 2) {
692 std::string ss;
693 for (uint32_t i = 0; i < cells.size(); ++i) {
694 if (cells.is_empty(i)) {
695 ss += '.';
696 } else {
697 assert(cells.seq_count(i) >= 1);
698
699 if (cells.seq_count(i) == 1) {
700 ss += std::to_string(val: cells.seq_get(i));
701 } else {
702 ss += 'M';
703 }
704 }
705 if (i%256 == 255) {
706 ss += " *";
707 ss += '\n';
708 }
709 }
710 LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
711 }
712
713 if ((debug == 2 && n_swa > 0) || debug > 2) {
714 std::string ss;
715 for (uint32_t i = 0; i < cells.size(); ++i) {
716 std::string cur;
717 if (cells.is_empty(i)) {
718 cur = '.';
719 } else {
720 cur = std::to_string(val: cells.pos_get(i));
721 }
722 const int n = cur.size();
723 for (int j = 0; j < 5 - n; ++j) {
724 cur += ' ';
725 }
726 ss += cur;
727 if (i%256 == 255) {
728 ss += " *";
729 }
730 if (i%64 == 63) {
731 ss += '\n';
732 }
733 }
734 LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
735 }
736
737 for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
738 if (cells.seq_pos_min(seq_id: s) < 0) {
739 continue;
740 }
741
742 LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
743 }
744 }
745 }
746
747 uint32_t n_tokens = ubatch.n_tokens;
748 uint32_t n_seqs = 1;
749
750 if (n_stream > 1) {
751 GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
752
753 n_seqs = ubatch.n_seqs_unq;
754 n_tokens = n_tokens / n_seqs;
755 }
756
757 slot_info res = {
758 /*.s0 =*/ LLAMA_MAX_SEQ,
759 /*.s1 =*/ 0,
760 /*.strm =*/ { },
761 /*.idxs =*/ { },
762 };
763
764 res.resize(n: n_seqs);
765
766 for (uint32_t s = 0; s < n_seqs; ++s) {
767 const auto seq_id = ubatch.seq_id_unq[s];
768
769 if (n_stream > 1) {
770 GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
771 GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
772 }
773
774 res.s0 = std::min<uint32_t>(a: res.s0, b: seq_to_stream[seq_id]);
775 res.s1 = std::max<uint32_t>(a: res.s1, b: seq_to_stream[seq_id]);
776
777 res.strm[s] = seq_to_stream[seq_id];
778 res.idxs[s].reserve(n: n_tokens);
779
780 const auto & cells = v_cells[seq_to_stream[seq_id]];
781
782 uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
783
784 // if we have enough unused cells before the current head ->
785 // better to start searching from the beginning of the cache, hoping to fill it
786 if (head_cur > cells.get_used() + 2*n_tokens) {
787 head_cur = 0;
788 }
789
790 if (n_tokens > cells.size()) {
791 LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
792 return { };
793 }
794
795 uint32_t n_tested = 0;
796
797 // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
798 // for non-continuous slots, we test the tokens one by one
799 const uint32_t n_test = cont ? n_tokens : 1;
800
801 while (true) {
802 if (head_cur + n_test > cells.size()) {
803 n_tested += cells.size() - head_cur;
804 head_cur = 0;
805 continue;
806 }
807
808 for (uint32_t i = 0; i < n_test; i++) {
809 const auto idx = head_cur;
810
811 head_cur++;
812 n_tested++;
813
814 //const llama_pos pos = ubatch.pos[i];
815 //const llama_seq_id seq_id = ubatch.seq_id[i][0];
816
817 // can we use this cell? either:
818 // - the cell is empty
819 // - the cell is occupied only by one sequence:
820 // - (disabled) mask causally, if the sequence is the same as the one we are inserting
821 // - mask SWA, using current max pos for that sequence in the cache
822 // always insert in the cell with minimum pos
823 bool can_use = cells.is_empty(i: idx);
824
825 if (!can_use && cells.seq_count(i: idx) == 1) {
826 const llama_pos pos_cell = cells.pos_get(i: idx);
827
828 // (disabled) causal mask
829 // note: it's better to purge any "future" tokens beforehand
830 //if (cells.seq_has(idx, seq_id)) {
831 // can_use = pos_cell >= pos;
832 //}
833
834 if (!can_use) {
835 const llama_seq_id seq_id_cell = cells.seq_get(i: idx);
836
837 // SWA mask
838 if (is_masked_swa(p0: pos_cell, p1: cells.seq_pos_max(seq_id: seq_id_cell) + 1)) {
839 can_use = true;
840 }
841 }
842 }
843
844 if (can_use) {
845 res.idxs[s].push_back(x: idx);
846 } else {
847 if (cont) {
848 break;
849 }
850 }
851 }
852
853 if (res.idxs[s].size() == n_tokens) {
854 break;
855 }
856
857 if (cont) {
858 res.idxs[s].clear();
859 }
860
861 if (n_tested >= cells.size()) {
862 //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
863 return { };
864 }
865 }
866
867 // we didn't find a suitable slot - return empty result
868 if (res.idxs[s].size() < n_tokens) {
869 return { };
870 }
871 }
872
873 assert(res.s1 >= res.s0);
874
875 return res;
876}
877
878void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
879 // keep track of the max sequence position that we would overwrite with this ubatch
880 // for non-SWA cache, this would be always empty
881 llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
882 for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
883 seq_pos_max_rm[s] = -1;
884 }
885
886 assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
887
888 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
889 for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
890 const uint32_t i = s*sinfo.size() + ii;
891
892 auto & cells = v_cells[sinfo.strm[s]];
893
894 const auto idx = sinfo.idxs[s][ii];
895
896 if (!cells.is_empty(i: idx)) {
897 assert(cells.seq_count(idx) == 1);
898
899 const llama_seq_id seq_id = cells.seq_get(i: idx);
900 const llama_pos pos = cells.pos_get(i: idx);
901
902 seq_pos_max_rm[seq_id] = std::max(a: seq_pos_max_rm[seq_id], b: pos);
903
904 cells.rm(i: idx);
905 }
906
907 cells.pos_set(i: idx, p: ubatch.pos[i]);
908
909 if (ubatch.is_pos_2d()) {
910 llama_kv_cell_ext ext {
911 /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
912 /*.y =*/ ubatch.pos[i + ubatch.n_tokens],
913 };
914 cells.ext_set(i: idx, p: ext);
915 }
916
917 for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
918 cells.seq_add(i: idx, seq_id: ubatch.seq_id[i][s]);
919 }
920 }
921 }
922
923 // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
924 // will be present in the cache. so we have to purge any position which is less than those we would overwrite
925 // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
926 for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
927 if (seq_pos_max_rm[s] == -1) {
928 continue;
929 }
930
931 GGML_ASSERT(s < seq_to_stream.size());
932
933 auto & cells = v_cells[seq_to_stream[s]];
934
935 if (cells.seq_pos_min(seq_id: s) <= seq_pos_max_rm[s]) {
936 LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
937 __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
938
939 seq_rm(seq_id: s, p0: cells.seq_pos_min(seq_id: s), p1: seq_pos_max_rm[s] + 1);
940 }
941 }
942
943 // move the head at the end of the slot
944 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
945 auto & head = v_heads[sinfo.strm[s]];
946
947 head = sinfo.idxs[s].back() + 1;
948 }
949}
950
951bool llama_kv_cache::get_can_shift() const {
952 return true;
953}
954
955uint32_t llama_kv_cache::get_size() const {
956 const auto & cells = v_cells[seq_to_stream[0]];
957
958 return cells.size();
959}
960
961uint32_t llama_kv_cache::get_n_stream() const {
962 return n_stream;
963}
964
965bool llama_kv_cache::get_has_shift() const {
966 bool result = false;
967
968 for (uint32_t s = 0; s < n_stream; ++s) {
969 result |= v_cells[s].get_has_shift();
970 }
971
972 return result;
973}
974
975uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
976 uint32_t result = 0;
977
978 // pad the n_kv value so that the graph remains constant across batches and can be reused
979 // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
980 const uint32_t n_pad_cur = std::max(a: n_pad, b: 256u);
981
982 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
983 const auto & cells = v_cells[sinfo.strm[s]];
984
985 result = std::max(a: std::min(a: cells.size(), b: std::max(a: n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), b: result);
986 }
987
988 return result;
989}
990
991ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
992 const int32_t ikv = map_layer_ids.at(k: il);
993
994 auto * k = layers[ikv].k;
995
996 const uint64_t kv_size = get_size();
997 const uint64_t n_embd_k_gqa = k->ne[0];
998
999 assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
1000
1001 const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1002
1003 return ggml_view_4d(ctx, a: k,
1004 ne0: hparams.n_embd_head_k, ne1: hparams.n_head_kv(il), ne2: n_kv, ne3: ns,
1005 nb1: ggml_row_size(type: k->type, ne: hparams.n_embd_head_k),
1006 nb2: ggml_row_size(type: k->type, ne: n_embd_k_gqa),
1007 nb3: ggml_row_size(type: k->type, ne: n_embd_k_gqa*kv_size),
1008 offset: ggml_row_size(type: k->type, ne: n_embd_k_gqa*kv_size)*sinfo.s0);
1009}
1010
1011ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1012 const int32_t ikv = map_layer_ids.at(k: il);
1013
1014 auto * v = layers[ikv].v;
1015
1016 const uint64_t kv_size = get_size();
1017 const uint64_t n_embd_v_gqa = v->ne[0];
1018
1019 // [TAG_V_CACHE_VARIABLE]
1020 assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
1021
1022 const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1023
1024 if (!v_trans) {
1025 // note: v->nb[1] <= v->nb[2]
1026 return ggml_view_4d(ctx, a: v,
1027 ne0: hparams.n_embd_head_v, ne1: hparams.n_head_kv(il), ne2: n_kv, ne3: ns,
1028 nb1: ggml_row_size(type: v->type, ne: hparams.n_embd_head_v), // v->nb[1]
1029 nb2: ggml_row_size(type: v->type, ne: n_embd_v_gqa), // v->nb[2]
1030 nb3: ggml_row_size(type: v->type, ne: n_embd_v_gqa*kv_size), // v->nb[3]
1031 offset: ggml_row_size(type: v->type, ne: n_embd_v_gqa*kv_size)*sinfo.s0);
1032 }
1033
1034 // note: v->nb[1] > v->nb[2]
1035 return ggml_view_4d(ctx, a: v,
1036 ne0: n_kv, ne1: hparams.n_head_kv(il), ne2: hparams.n_embd_head_v, ne3: ns,
1037 nb1: ggml_row_size(type: v->type, ne: kv_size*hparams.n_embd_head_v), // v->nb[1]
1038 nb2: ggml_row_size(type: v->type, ne: kv_size), // v->nb[2]
1039 nb3: ggml_row_size(type: v->type, ne: kv_size*n_embd_v_gqa), // v->nb[3]
1040 offset: ggml_row_size(type: v->type, ne: kv_size*n_embd_v_gqa)*sinfo.s0);
1041}
1042
1043ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
1044 GGML_UNUSED(sinfo);
1045
1046 const int32_t ikv = map_layer_ids.at(k: il);
1047
1048 ggml_tensor * k = layers[ikv].k;
1049
1050 const int64_t n_embd_head = k_cur->ne[0];
1051 const int64_t n_head = k_cur->ne[1];
1052 const int64_t n_tokens = k_cur->ne[2];
1053
1054 const int64_t n_embd_gqa = n_embd_head*n_head;
1055
1056 // we can merge dims 0 and 1
1057 // TODO: add ggml helper function for this?
1058 GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
1059
1060 k_cur = ggml_view_2d(ctx, a: k_cur, ne0: n_embd_gqa, ne1: n_tokens, nb1: k_cur->nb[2], offset: 0);
1061
1062 const int64_t n_stream = k->ne[2];
1063
1064 if (n_stream > 1) {
1065 const int64_t kv_size = get_size();
1066
1067 assert(n_embd_gqa == k->ne[0]);
1068 assert(kv_size == k->ne[1]);
1069
1070 // merge the buffer across all streams because the idxs are global
1071 k = ggml_reshape_2d(ctx, a: k, ne0: n_embd_gqa, ne1: kv_size*n_stream);
1072 }
1073
1074 // store the current K values into the cache
1075 return ggml_set_rows(ctx, a: k, b: k_cur, c: k_idxs);
1076}
1077
1078ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
1079 GGML_UNUSED(sinfo);
1080
1081 const int32_t ikv = map_layer_ids.at(k: il);
1082
1083 auto * v = layers[ikv].v;
1084
1085 const int64_t n_embd_head = v_cur->ne[0];
1086 const int64_t n_head = v_cur->ne[1];
1087 const int64_t n_tokens = v_cur->ne[2];
1088
1089 const int64_t n_embd_gqa = n_embd_head*n_head;
1090
1091 // we can merge dims 0 and 1
1092 GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
1093
1094 const int64_t n_stream = v->ne[2];
1095
1096 // take this branch when FA is enabled (the V cache is not transposed)
1097 if (!v_trans) {
1098 v_cur = ggml_view_2d(ctx, a: v_cur, ne0: n_embd_gqa, ne1: n_tokens, nb1: v_cur->nb[2], offset: 0);
1099
1100 if (n_stream > 1) {
1101 const int64_t kv_size = get_size();
1102
1103 assert(n_embd_gqa == v->ne[0]);
1104 assert(kv_size == v->ne[1]);
1105
1106 // merge the buffer across all streams because the idxs are global
1107 v = ggml_reshape_2d(ctx, a: v, ne0: n_embd_gqa, ne1: kv_size*n_stream);
1108 }
1109
1110 return ggml_set_rows(ctx, a: v, b: v_cur, c: v_idxs);
1111 }
1112
1113 if (ggml_row_size(type: v_cur->type, ne: n_embd_gqa) == v_cur->nb[2]) {
1114 // we can merge dims 0, 1 and 2
1115 v_cur = ggml_reshape_2d(ctx, a: v_cur, ne0: n_embd_gqa, ne1: n_tokens);
1116 } else {
1117 // otherwise -> make a copy to get contiguous data
1118 v_cur = ggml_cont_2d (ctx, a: v_cur, ne0: n_embd_gqa, ne1: n_tokens);
1119 }
1120
1121 // [TAG_V_CACHE_VARIABLE]
1122 if (n_embd_gqa < v->ne[0]) {
1123 v_cur = ggml_pad(ctx, a: v_cur, p0: v->ne[0] - n_embd_gqa, p1: 0, p2: 0, p3: 0);
1124 }
1125
1126 // in this branch the v_idxs are constructed in such a way that each row is a single head element
1127 ggml_tensor * v_view = ggml_reshape_2d(ctx, a: v, ne0: 1, ne1: ggml_nelements(tensor: v));
1128
1129 v_cur = ggml_reshape_2d(ctx, a: v_cur, ne0: 1, ne1: ggml_nelements(tensor: v_cur));
1130
1131 return ggml_set_rows(ctx, a: v_view, b: v_cur, c: v_idxs);
1132}
1133
1134ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1135 const uint32_t n_tokens = ubatch.n_tokens;
1136
1137 ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, type: GGML_TYPE_I64, ne0: n_tokens);
1138
1139 ggml_set_input(tensor: k_idxs);
1140
1141 return k_idxs;
1142}
1143
1144ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1145 const uint32_t n_tokens = ubatch.n_tokens;
1146
1147 ggml_tensor * v_idxs;
1148
1149 if (!v_trans) {
1150 v_idxs = ggml_new_tensor_1d(ctx, type: GGML_TYPE_I64, ne0: n_tokens);
1151 } else {
1152 v_idxs = ggml_new_tensor_1d(ctx, type: GGML_TYPE_I64, ne0: n_tokens*hparams.n_embd_v_gqa_max());
1153 }
1154
1155 ggml_set_input(tensor: v_idxs);
1156
1157 return v_idxs;
1158}
1159
1160void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1161 const uint32_t n_tokens = ubatch->n_tokens;
1162 GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1163
1164 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1165 int64_t * data = (int64_t *) dst->data;
1166
1167 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1168 const int64_t offs = sinfo.strm[s]*get_size();
1169
1170 for (uint32_t i = 0; i < sinfo.size(); ++i) {
1171 data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1172 }
1173 }
1174}
1175
1176void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
1177 const uint32_t n_tokens = ubatch->n_tokens;
1178 GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1179
1180 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1181 int64_t * data = (int64_t *) dst->data;
1182
1183 if (!v_trans) {
1184 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1185 const int64_t offs = sinfo.strm[s]*get_size();
1186
1187 for (uint32_t i = 0; i < sinfo.size(); ++i) {
1188 data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1189 }
1190 }
1191 } else {
1192 // note: the V cache is transposed when not using flash attention
1193 const int64_t kv_size = get_size();
1194
1195 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
1196
1197 for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1198 const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
1199
1200 for (uint32_t i = 0; i < sinfo.size(); ++i) {
1201 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1202 data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
1203 }
1204 }
1205 }
1206 }
1207}
1208
1209void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
1210 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1211
1212 int32_t * data = (int32_t *) dst->data;
1213
1214 for (uint32_t s = 0; s < n_stream; ++s) {
1215 const auto & cells = v_cells[s];
1216
1217 for (uint32_t i = 0; i < cells.size(); ++i) {
1218 data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1219 }
1220 }
1221}
1222
1223void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1224 const uint32_t n_tokens = ubatch->n_tokens;
1225
1226 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1227 float * data = (float *) dst->data;
1228
1229 const int64_t n_kv = dst->ne[0];
1230 const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1231
1232 GGML_ASSERT(n_tokens%n_stream == 0);
1233
1234 // n_tps == n_tokens_per_stream
1235 const int64_t n_tps = n_tokens/n_stream;
1236 const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
1237
1238 std::fill(first: data, last: data + ggml_nelements(tensor: dst), value: -INFINITY);
1239
1240 // Use only the previous KV cells of the correct sequence for each token of the ubatch.
1241 // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
1242 // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
1243 // Causal mask:
1244 // xxx-------
1245 // xxxx------
1246 // xxxxx-----
1247 // Non-causal mask:
1248 // xxxxx-----
1249 // xxxxx-----
1250 // xxxxx-----
1251 // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
1252 // TODO: optimize this section
1253 for (uint32_t h = 0; h < 1; ++h) {
1254 for (uint32_t s = 0; s < n_stream; ++s) {
1255 for (uint32_t ii = 0; ii < n_tps; ++ii) {
1256 const uint32_t i = s*n_tps + ii;
1257
1258 const llama_seq_id seq_id = ubatch->seq_id[i][0];
1259
1260 const auto & cells = v_cells[seq_to_stream[seq_id]];
1261
1262 const llama_pos p1 = ubatch->pos[i];
1263
1264 // for M-RoPE
1265 const bool is_2d = ubatch->is_pos_2d();
1266 const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1267 const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
1268
1269 const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1270
1271 for (uint32_t j = 0; j < n_kv; ++j) {
1272 if (cells.is_empty(i: j)) {
1273 continue;
1274 }
1275
1276 // mask the token if not the same sequence
1277 if (!cells.seq_has(i: j, seq_id)) {
1278 continue;
1279 }
1280
1281 const llama_pos p0 = cells.pos_get(i: j);
1282
1283 // mask future tokens
1284 if (causal_attn && p0 > p1) {
1285 continue;
1286 }
1287
1288 // M-RoPE causal mask
1289 if (causal_attn && is_2d && p0 == p1) {
1290 const auto & p0_ext = cells.ext_get(i: j);
1291 if (p0_ext.is_2d_gt(ox: p1_x, oy: p1_y)) {
1292 continue;
1293 }
1294 }
1295
1296 // apply SWA if any
1297 if (is_masked_swa(p0, p1)) {
1298 continue;
1299 }
1300
1301 data[idst + j] = hparams.use_alibi ? -std::abs(x: p0 - p1) : 0.0f;
1302 }
1303 }
1304 }
1305 }
1306}
1307
1308void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1309 const int64_t n_tokens = ubatch->n_tokens;
1310
1311 GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
1312 const auto & cells = v_cells[0];
1313
1314 GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1315 GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
1316
1317 int32_t * data = (int32_t *) dst->data;
1318
1319 const int32_t n_kv = dst->ne[0];
1320
1321 for (int h = 0; h < 1; ++h) {
1322 for (int i = 0; i < n_tokens; ++i) {
1323 for (int j = 0; j < n_kv; ++j) {
1324 // the position when the cells is empty is irrelevant - it will be masked out later in the attention
1325 const llama_pos p0 = cells.is_empty(i: j) ? -1 : cells.pos_get(i: j);
1326
1327 data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(x: p0, y: ubatch->pos[i], n_buckets: hparams.n_rel_attn_bkts, bidirectional: false);
1328 }
1329 }
1330 }
1331}
1332
1333size_t llama_kv_cache::total_size() const {
1334 size_t size = 0;
1335
1336 for (const auto & [_, buf] : ctxs_bufs) {
1337 size += ggml_backend_buffer_get_size(buffer: buf.get());
1338 }
1339
1340 return size;
1341}
1342
1343size_t llama_kv_cache::size_k_bytes() const {
1344 size_t size_k_bytes = 0;
1345
1346 for (const auto & layer : layers) {
1347 size_k_bytes += ggml_nbytes(tensor: layer.k);
1348 }
1349
1350 return size_k_bytes;
1351}
1352
1353size_t llama_kv_cache::size_v_bytes() const {
1354 size_t size_v_bytes = 0;
1355
1356 for (const auto & layer : layers) {
1357 size_v_bytes += ggml_nbytes(tensor: layer.v);
1358 }
1359
1360 return size_v_bytes;
1361}
1362
1363ggml_tensor * llama_kv_cache::build_rope_shift(
1364 const llama_cparams & cparams,
1365 ggml_context * ctx,
1366 ggml_tensor * cur,
1367 ggml_tensor * shift,
1368 ggml_tensor * factors,
1369 float freq_base,
1370 float freq_scale) const {
1371 const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
1372
1373 const auto & yarn_ext_factor = cparams.yarn_ext_factor;
1374 const auto & yarn_beta_fast = cparams.yarn_beta_fast;
1375 const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1376
1377 const auto & n_rot = hparams.n_rot;
1378 const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
1379 // @ngxson : this is a workaround
1380 // for M-RoPE, we want to rotate the whole vector when doing KV shift
1381 // a normal RoPE should work, we just need to use the correct ordering
1382 // ref: https://github.com/ggml-org/llama.cpp/pull/13870
1383 ? LLAMA_ROPE_TYPE_NEOX
1384 : hparams.rope_type;
1385
1386 // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
1387 // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
1388 const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
1389 ? 1.0f / (1.0f + 0.1f * logf(x: 1.0f / freq_scale))
1390 : cparams.yarn_attn_factor;
1391
1392 ggml_tensor * tmp;
1393
1394 if (ggml_is_quantized(type: cur->type)) {
1395 // dequantize to f32 -> RoPE -> quantize back
1396 tmp = ggml_cast(ctx, a: cur, type: GGML_TYPE_F32);
1397
1398 tmp = ggml_rope_ext(ctx, a: tmp,
1399 b: shift, c: factors, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base, freq_scale,
1400 ext_factor: yarn_ext_factor, attn_factor: yarn_attn_factor, beta_fast: yarn_beta_fast, beta_slow: yarn_beta_slow);
1401
1402 tmp = ggml_cpy(ctx, a: tmp, b: cur);
1403 } else {
1404 // we rotate only the first n_rot dimensions
1405 tmp = ggml_rope_ext_inplace(ctx, a: cur,
1406 b: shift, c: factors, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base, freq_scale,
1407 ext_factor: yarn_ext_factor, attn_factor: yarn_attn_factor, beta_fast: yarn_beta_fast, beta_slow: yarn_beta_slow);
1408 }
1409
1410 return tmp;
1411}
1412
1413class llm_graph_input_k_shift : public llm_graph_input_i {
1414public:
1415 llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
1416 virtual ~llm_graph_input_k_shift() = default;
1417
1418 void set_input(const llama_ubatch * ubatch) override;
1419
1420 ggml_tensor * k_shift; // I32 [kv_size*n_stream]
1421
1422 const llama_kv_cache * kv_self;
1423};
1424
1425void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1426 GGML_UNUSED(ubatch);
1427
1428 if (k_shift) {
1429 kv_self->set_input_k_shift(k_shift);
1430 }
1431}
1432
1433ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
1434 auto * ctx = res->get_ctx();
1435 auto * gf = res->get_gf();
1436
1437 const auto & n_embd_head_k = hparams.n_embd_head_k;
1438 //const auto & n_embd_head_v = hparams.n_embd_head_v;
1439
1440 auto inp = std::make_unique<llm_graph_input_k_shift>(args: this);
1441
1442 inp->k_shift = ggml_new_tensor_1d(ctx, type: GGML_TYPE_I32, ne0: (int64_t) get_size()*n_stream);
1443 ggml_set_input(tensor: inp->k_shift);
1444
1445 const auto & cparams = lctx->get_cparams();
1446
1447 for (const auto & layer : layers) {
1448 const uint32_t il = layer.il;
1449
1450 const int64_t n_head_kv = hparams.n_head_kv(il);
1451 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1452
1453 const float freq_base_l = model.get_rope_freq_base (cparams, il);
1454 const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
1455
1456 ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
1457
1458 ggml_tensor * k =
1459 ggml_view_3d(ctx, a: layer.k,
1460 ne0: n_embd_head_k, ne1: n_head_kv, ne2: get_size()*n_stream,
1461 nb1: ggml_row_size(type: layer.k->type, ne: n_embd_head_k),
1462 nb2: ggml_row_size(type: layer.k->type, ne: n_embd_k_gqa),
1463 offset: 0);
1464
1465 ggml_tensor * cur = build_rope_shift(cparams, ctx, cur: k, shift: inp->k_shift, factors: rope_factors, freq_base: freq_base_l, freq_scale: freq_scale_l);
1466
1467 ggml_build_forward_expand(cgraph: gf, tensor: cur);
1468 }
1469
1470 res->add_input(input: std::move(inp));
1471
1472 return gf;
1473}
1474
1475bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
1476 return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
1477}
1478
1479void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
1480 GGML_UNUSED(flags);
1481
1482 io.write(src: &n_stream, size: sizeof(n_stream));
1483
1484 for (uint32_t s = 0; s < n_stream; ++s) {
1485 cell_ranges_t cr { .strm: s, .data: {} };
1486
1487 uint32_t cell_count = 0;
1488
1489 const auto & cells = v_cells[s];
1490
1491 // Count the number of cells with the specified seq_id
1492 // Find all the ranges of cells with this seq id (or all, when -1)
1493 uint32_t cell_range_begin = cells.size();
1494
1495 for (uint32_t i = 0; i < cells.size(); ++i) {
1496 if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1497 ++cell_count;
1498 if (cell_range_begin == cells.size()) {
1499 cell_range_begin = i;
1500 }
1501 } else {
1502 if (cell_range_begin != cells.size()) {
1503 cr.data.emplace_back(args&: cell_range_begin, args&: i);
1504 cell_range_begin = cells.size();
1505 }
1506 }
1507 }
1508
1509 if (cell_range_begin != cells.size()) {
1510 cr.data.emplace_back(args&: cell_range_begin, args: cells.size());
1511 }
1512
1513 // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1514 uint32_t cell_count_check = 0;
1515 for (const auto & range : cr.data) {
1516 cell_count_check += range.second - range.first;
1517 }
1518 GGML_ASSERT(cell_count == cell_count_check);
1519
1520 io.write(src: &cell_count, size: sizeof(cell_count));
1521
1522 // skip empty streams
1523 if (cell_count == 0) {
1524 continue;
1525 }
1526
1527 state_write_meta(io, cr, seq_id);
1528 state_write_data(io, cr);
1529 }
1530}
1531
1532void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1533 GGML_UNUSED(flags);
1534
1535 GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
1536
1537 uint32_t n_stream_cur;
1538 io.read_to(dst: &n_stream_cur, size: sizeof(n_stream_cur));
1539 if (n_stream_cur != n_stream) {
1540 throw std::runtime_error("n_stream mismatch");
1541 }
1542
1543 for (uint32_t s = 0; s < n_stream; ++s) {
1544 uint32_t cell_count;
1545 io.read_to(dst: &cell_count, size: sizeof(cell_count));
1546
1547 if (cell_count == 0) {
1548 continue;
1549 }
1550
1551 const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
1552
1553 bool res = true;
1554 res = res && state_read_meta(io, strm, cell_count, dest_seq_id: seq_id);
1555 res = res && state_read_data(io, strm, cell_count);
1556
1557 if (!res) {
1558 if (seq_id == -1) {
1559 clear(data: true);
1560 } else {
1561 seq_rm(seq_id, p0: -1, p1: -1);
1562 }
1563 throw std::runtime_error("failed to restore kv cache");
1564 }
1565 }
1566}
1567
1568void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
1569 const auto & cells = v_cells[cr.strm];
1570
1571 for (const auto & range : cr.data) {
1572 for (uint32_t i = range.first; i < range.second; ++i) {
1573 std::vector<llama_seq_id> seq_ids;
1574
1575 for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
1576 if (cur == seq_id || seq_id == -1) {
1577 if (cells.seq_has(i, seq_id: cur)) {
1578 seq_ids.push_back(x: cur);
1579 }
1580 }
1581 }
1582
1583 const llama_pos pos = cells.pos_get(i);
1584 const uint32_t n_seq_id = seq_ids.size();
1585
1586 io.write(src: &pos, size: sizeof(pos));
1587 io.write(src: &n_seq_id, size: sizeof(n_seq_id));
1588
1589 // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
1590 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1591
1592 for (const auto & seq_id : seq_ids) {
1593 io.write(src: &seq_id, size: sizeof(seq_id));
1594 }
1595 }
1596 }
1597}
1598
1599void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
1600 const auto & cells = v_cells[cr.strm];
1601
1602 const uint32_t v_trans = this->v_trans ? 1 : 0;
1603 const uint32_t n_layer = layers.size();
1604
1605 io.write(src: &v_trans, size: sizeof(v_trans));
1606 io.write(src: &n_layer, size: sizeof(n_layer));
1607
1608 std::vector<uint8_t> tmp_buf;
1609
1610 // Iterate and write all the keys first, each row is a cell
1611 // Get whole range at a time
1612 for (const auto & layer : layers) {
1613 const uint32_t il = layer.il;
1614
1615 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1616
1617 auto * k = layer.k_stream[cr.strm];
1618
1619 // Write key type
1620 const int32_t k_type_i = (int32_t) k->type;
1621 io.write(src: &k_type_i, size: sizeof(k_type_i));
1622
1623 // Write row size of key
1624 const uint64_t k_size_row = ggml_row_size(type: k->type, ne: n_embd_k_gqa);
1625 io.write(src: &k_size_row, size: sizeof(k_size_row));
1626
1627 // Read each range of cells of k_size length each into tmp_buf and write out
1628 for (const auto & range : cr.data) {
1629 const size_t range_size = range.second - range.first;
1630 const size_t buf_size = range_size * k_size_row;
1631 io.write_tensor(tensor: k, offset: range.first * k_size_row, size: buf_size);
1632 }
1633 }
1634
1635 if (!v_trans) {
1636 for (const auto & layer : layers) {
1637 const uint32_t il = layer.il;
1638
1639 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1640
1641 auto * v = layer.v_stream[cr.strm];
1642
1643 // Write value type
1644 const int32_t v_type_i = (int32_t) v->type;
1645 io.write(src: &v_type_i, size: sizeof(v_type_i));
1646
1647 // Write row size of value
1648 const uint64_t v_size_row = ggml_row_size(type: v->type, ne: n_embd_v_gqa);
1649 io.write(src: &v_size_row, size: sizeof(v_size_row));
1650
1651 // Read each range of cells of v_size length each into tmp_buf and write out
1652 for (const auto & range : cr.data) {
1653 const size_t range_size = range.second - range.first;
1654 const size_t buf_size = range_size * v_size_row;
1655 io.write_tensor(tensor: v, offset: range.first * v_size_row, size: buf_size);
1656 }
1657 }
1658 } else {
1659 // When v is transposed, we also need the element size and get the element ranges from each row
1660 const uint32_t kv_size = cells.size();
1661
1662 for (const auto & layer : layers) {
1663 const uint32_t il = layer.il;
1664
1665 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1666
1667 auto * v = layer.v_stream[cr.strm];
1668
1669 // Write value type
1670 const int32_t v_type_i = (int32_t) v->type;
1671 io.write(src: &v_type_i, size: sizeof(v_type_i));
1672
1673 // Write element size
1674 const uint32_t v_size_el = ggml_type_size(type: v->type);
1675 io.write(src: &v_size_el, size: sizeof(v_size_el));
1676
1677 // Write GQA embedding size
1678 io.write(src: &n_embd_v_gqa, size: sizeof(n_embd_v_gqa));
1679
1680 // For each row, we get the element values of each cell
1681 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1682 // Read each range of cells of v_size_el length each into tmp_buf and write out
1683 for (const auto & range : cr.data) {
1684 const size_t range_size = range.second - range.first;
1685 const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1686 const size_t buf_size = range_size * v_size_el;
1687 io.write_tensor(tensor: v, offset: src_offset, size: buf_size);
1688 }
1689 }
1690 }
1691 }
1692}
1693
1694bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
1695 auto & cells = v_cells[strm];
1696 auto & head = v_heads[strm];
1697
1698 if (dest_seq_id != -1) {
1699 // single sequence
1700 seq_rm(seq_id: dest_seq_id, p0: -1, p1: -1);
1701
1702 llama_batch_allocr balloc(hparams.n_pos_per_embd());
1703
1704 llama_ubatch ubatch = balloc.ubatch_reserve(n_seq_tokens: cell_count, n_seqs: 1);
1705
1706 ubatch.seq_id_unq[0] = dest_seq_id;
1707
1708 for (uint32_t i = 0; i < cell_count; ++i) {
1709 llama_pos pos;
1710 uint32_t n_seq_id;
1711
1712 io.read_to(dst: &pos, size: sizeof(pos));
1713 io.read_to(dst: &n_seq_id, size: sizeof(n_seq_id));
1714
1715 if (n_seq_id != 1) {
1716 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1717 return false;
1718 }
1719
1720 // read the sequence id, but directly discard it - we will use dest_seq_id instead
1721 {
1722 llama_seq_id seq_id;
1723 io.read_to(dst: &seq_id, size: sizeof(seq_id));
1724 }
1725
1726 ubatch.pos[i] = pos;
1727 ubatch.n_seq_id[i] = n_seq_id;
1728 ubatch.seq_id[i] = &dest_seq_id;
1729 }
1730
1731 const auto sinfo = find_slot(ubatch, cont: true);
1732 if (sinfo.empty()) {
1733 LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1734 return false;
1735 }
1736
1737 // TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
1738 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1739 apply_ubatch(sinfo, ubatch);
1740
1741 const auto head_cur = sinfo.head();
1742
1743 // keep the head at the old position because we will read the KV data into it in state_read_data()
1744 head = head_cur;
1745
1746 LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
1747
1748 // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1749 // Assume that this is one contiguous block of cells
1750 GGML_ASSERT(head_cur + cell_count <= cells.size());
1751 GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
1752 GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
1753 GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1754 GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
1755 } else {
1756 // whole KV cache restore
1757
1758 if (cell_count > cells.size()) {
1759 LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1760 return false;
1761 }
1762
1763 clear(data: true);
1764
1765 for (uint32_t i = 0; i < cell_count; ++i) {
1766 llama_pos pos;
1767 uint32_t n_seq_id;
1768
1769 io.read_to(dst: &pos, size: sizeof(pos));
1770 io.read_to(dst: &n_seq_id, size: sizeof(n_seq_id));
1771
1772 cells.pos_set(i, p: pos);
1773
1774 for (uint32_t j = 0; j < n_seq_id; ++j) {
1775 llama_seq_id seq_id;
1776 io.read_to(dst: &seq_id, size: sizeof(seq_id));
1777
1778 if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1779 LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
1780 return false;
1781 }
1782
1783 cells.seq_add(i, seq_id);
1784 }
1785 }
1786
1787 head = 0;
1788 }
1789
1790 return true;
1791}
1792
1793bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
1794 auto & cells = v_cells[strm];
1795 auto & head = v_heads[strm];
1796
1797 uint32_t v_trans;
1798 uint32_t n_layer;
1799
1800 io.read_to(dst: &v_trans, size: sizeof(v_trans));
1801 io.read_to(dst: &n_layer, size: sizeof(n_layer));
1802
1803 if (n_layer != layers.size()) {
1804 LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
1805 return false;
1806 }
1807
1808 if (cell_count > cells.size()) {
1809 LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
1810 return false;
1811 }
1812
1813 if (this->v_trans != (bool) v_trans) {
1814 LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1815 return false;
1816 }
1817
1818 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1819 for (const auto & layer : layers) {
1820 const uint32_t il = layer.il;
1821
1822 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1823
1824 auto * k = layer.k_stream[strm];
1825
1826 // Read type of key
1827 int32_t k_type_i_ref;
1828 io.read_to(dst: &k_type_i_ref, size: sizeof(k_type_i_ref));
1829 const int32_t k_type_i = (int32_t) k->type;
1830 if (k_type_i != k_type_i_ref) {
1831 LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1832 return false;
1833 }
1834
1835 // Read row size of key
1836 uint64_t k_size_row_ref;
1837 io.read_to(dst: &k_size_row_ref, size: sizeof(k_size_row_ref));
1838 const size_t k_size_row = ggml_row_size(type: k->type, ne: n_embd_k_gqa);
1839 if (k_size_row != k_size_row_ref) {
1840 LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1841 return false;
1842 }
1843
1844 if (cell_count) {
1845 // Read and set the keys for the whole cell range
1846 ggml_backend_tensor_set(tensor: k, data: io.read(size: cell_count * k_size_row), offset: head * k_size_row, size: cell_count * k_size_row);
1847 }
1848 }
1849
1850 if (!this->v_trans) {
1851 for (const auto & layer : layers) {
1852 const uint32_t il = layer.il;
1853
1854 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1855
1856 auto * v = layer.v_stream[strm];
1857
1858 // Read type of value
1859 int32_t v_type_i_ref;
1860 io.read_to(dst: &v_type_i_ref, size: sizeof(v_type_i_ref));
1861 const int32_t v_type_i = (int32_t) v->type;
1862 if (v_type_i != v_type_i_ref) {
1863 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1864 return false;
1865 }
1866
1867 // Read row size of value
1868 uint64_t v_size_row_ref;
1869 io.read_to(dst: &v_size_row_ref, size: sizeof(v_size_row_ref));
1870 const size_t v_size_row = ggml_row_size(type: v->type, ne: n_embd_v_gqa);
1871 if (v_size_row != v_size_row_ref) {
1872 LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1873 return false;
1874 }
1875
1876 if (cell_count) {
1877 // Read and set the values for the whole cell range
1878 ggml_backend_tensor_set(tensor: v, data: io.read(size: cell_count * v_size_row), offset: head * v_size_row, size: cell_count * v_size_row);
1879 }
1880 }
1881 } else {
1882 // For each layer, read the values for each cell (transposed)
1883 for (const auto & layer : layers) {
1884 const uint32_t il = layer.il;
1885
1886 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1887
1888 auto * v = layer.v_stream[strm];
1889
1890 // Read type of value
1891 int32_t v_type_i_ref;
1892 io.read_to(dst: &v_type_i_ref, size: sizeof(v_type_i_ref));
1893 const int32_t v_type_i = (int32_t) v->type;
1894 if (v_type_i != v_type_i_ref) {
1895 LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1896 return false;
1897 }
1898
1899 // Read element size of value
1900 uint32_t v_size_el_ref;
1901 io.read_to(dst: &v_size_el_ref, size: sizeof(v_size_el_ref));
1902 const size_t v_size_el = ggml_type_size(type: v->type);
1903 if (v_size_el != v_size_el_ref) {
1904 LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1905 return false;
1906 }
1907
1908 // Read GQA embedding size
1909 uint32_t n_embd_v_gqa_ref;
1910 io.read_to(dst: &n_embd_v_gqa_ref, size: sizeof(n_embd_v_gqa_ref));
1911 if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1912 LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1913 return false;
1914 }
1915
1916 if (cell_count) {
1917 // For each row in the transposed matrix, read the values for the whole cell range
1918 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1919 const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1920 ggml_backend_tensor_set(tensor: v, data: io.read(size: cell_count * v_size_el), offset: dst_offset, size: cell_count * v_size_el);
1921 }
1922 }
1923 }
1924 }
1925
1926 return true;
1927}
1928
1929//
1930// llama_kv_cache_context
1931//
1932
1933llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
1934
1935llama_kv_cache_context::llama_kv_cache_context(
1936 llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1937 n_kv = kv->get_size();
1938
1939 const uint32_t n_stream = kv->get_n_stream();
1940
1941 // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
1942 sinfos.resize(new_size: 1);
1943 sinfos[0].s0 = 0;
1944 sinfos[0].s1 = n_stream - 1;
1945 sinfos[0].idxs.resize(new_size: n_stream);
1946 for (uint32_t s = 0; s < n_stream; ++s) {
1947 sinfos[0].strm.push_back(x: s);
1948 sinfos[0].idxs[s].resize(new_size: 1, x: 0);
1949 }
1950}
1951
1952llama_kv_cache_context::llama_kv_cache_context(
1953 llama_kv_cache * kv,
1954 llama_context * lctx,
1955 bool do_shift,
1956 stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
1957 if (!do_shift && this->sc_info.empty()) {
1958 status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1959 }
1960}
1961
1962llama_kv_cache_context::llama_kv_cache_context(
1963 llama_kv_cache * kv,
1964 llama_kv_cache::slot_info_vec_t sinfos,
1965 std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
1966}
1967
1968llama_kv_cache_context::~llama_kv_cache_context() = default;
1969
1970bool llama_kv_cache_context::next() {
1971 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1972
1973 if (++i_cur >= ubatches.size()) {
1974 return false;
1975 }
1976
1977 return true;
1978}
1979
1980bool llama_kv_cache_context::apply() {
1981 assert(!llama_memory_status_is_fail(status));
1982
1983 // no ubatches -> this is a KV cache update
1984 if (ubatches.empty()) {
1985 kv->update(lctx, do_shift, sc_info);
1986
1987 return true;
1988 }
1989
1990 kv->apply_ubatch(sinfo: sinfos[i_cur], ubatch: ubatches[i_cur]);
1991 n_kv = kv->get_n_kv(sinfo: sinfos[i_cur]);
1992
1993 return true;
1994}
1995
1996llama_memory_status llama_kv_cache_context::get_status() const {
1997 return status;
1998}
1999
2000const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
2001 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
2002
2003 return ubatches[i_cur];
2004}
2005
2006uint32_t llama_kv_cache_context::get_n_kv() const {
2007 return n_kv;
2008}
2009
2010ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
2011 return kv->get_k(ctx, il, n_kv, sinfo: sinfos[i_cur]);
2012}
2013
2014ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
2015 return kv->get_v(ctx, il, n_kv, sinfo: sinfos[i_cur]);
2016}
2017
2018ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
2019 return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfo: sinfos[i_cur]);
2020}
2021
2022ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
2023 return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfo: sinfos[i_cur]);
2024}
2025
2026ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2027 return kv->build_input_k_idxs(ctx, ubatch);
2028}
2029
2030ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
2031 return kv->build_input_v_idxs(ctx, ubatch);
2032}
2033
2034void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
2035 kv->set_input_k_shift(dst);
2036}
2037
2038void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2039 kv->set_input_k_idxs(dst, ubatch, sinfo: sinfos[i_cur]);
2040}
2041
2042void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2043 kv->set_input_v_idxs(dst, ubatch, sinfo: sinfos[i_cur]);
2044}
2045
2046void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
2047 kv->set_input_kq_mask(dst, ubatch, causal_attn);
2048}
2049
2050void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
2051 kv->set_input_pos_bucket(dst, ubatch);
2052}
2053