1#include "llama-memory-recurrent.h"
2
3#include "llama-impl.h"
4#include "llama-io.h"
5#include "llama-batch.h"
6#include "llama-model.h"
7
8#include <algorithm>
9#include <cassert>
10#include <cstring>
11#include <limits>
12#include <map>
13#include <stdexcept>
14
15//
16// llama_memory_recurrent
17//
18
19llama_memory_recurrent::llama_memory_recurrent(
20 const llama_model & model,
21 ggml_type type_r,
22 ggml_type type_s,
23 bool offload,
24 uint32_t mem_size,
25 uint32_t n_seq_max,
26 const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
27 const int32_t n_layer = hparams.n_layer;
28
29 head = 0;
30 size = mem_size;
31 used = 0;
32
33 cells.clear();
34 cells.resize(new_size: mem_size);
35
36 // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
37 struct ggml_backend_buft_comparator {
38 bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
39 return strcmp(s1: ggml_backend_buft_name(buft: lhs), s2: ggml_backend_buft_name(buft: rhs)) < 0;
40 }
41 };
42 std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
43
44 // create a context for each buffer type
45 auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
46 auto it = ctx_map.find(x: buft);
47 if (it == ctx_map.end()) {
48 ggml_init_params params = {
49 /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
50 /*.mem_buffer =*/ NULL,
51 /*.no_alloc =*/ true,
52 };
53
54 ggml_context * ctx = ggml_init(params);
55 if (!ctx) {
56 return nullptr;
57 }
58
59 ctx_map.emplace(args&: buft, args&: ctx);
60
61 return ctx;
62 }
63
64 return it->second.get();
65 };
66
67 r_l.resize(new_size: n_layer);
68 s_l.resize(new_size: n_layer);
69
70 for (int i = 0; i < n_layer; i++) {
71 if (filter && !filter(i)) {
72 LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
73 continue;
74 }
75
76 const char * dev_name = "CPU";
77
78 ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
79
80 if (offload) {
81 auto * dev = model.dev_layer(il: i);
82 buft = ggml_backend_dev_buffer_type(device: dev);
83
84 dev_name = ggml_backend_dev_name(device: dev);
85 }
86
87 LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
88
89 ggml_context * ctx = ctx_for_buft(buft);
90 if (!ctx) {
91 throw std::runtime_error("failed to create ggml context for rs cache");
92 }
93
94 ggml_tensor * r = ggml_new_tensor_1d(ctx, type: type_r, ne0: hparams.n_embd_r()*mem_size);
95 ggml_tensor * s = ggml_new_tensor_1d(ctx, type: type_s, ne0: hparams.n_embd_s()*mem_size);
96 ggml_format_name(tensor: r, fmt: "cache_r_l%d", i);
97 ggml_format_name(tensor: s, fmt: "cache_s_l%d", i);
98 r_l[i] = r;
99 s_l[i] = s;
100 }
101
102 // allocate tensors and initialize the buffers to avoid NaNs in the padding
103 for (auto & [buft, ctx] : ctx_map) {
104 ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx: ctx.get(), buft);
105 if (!buf) {
106 throw std::runtime_error("failed to allocate buffer for rs cache");
107 }
108 ggml_backend_buffer_clear(buffer: buf, value: 0);
109 LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
110 ctxs_bufs.emplace_back(args: std::move(ctx), args&: buf);
111 }
112
113 {
114 const size_t memory_size_r = size_r_bytes();
115 const size_t memory_size_s = size_s_bytes();
116
117 LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
118 (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
119 ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
120 ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
121 }
122}
123
124void llama_memory_recurrent::clear(bool data) {
125 for (int32_t i = 0; i < (int32_t) size; ++i) {
126 cells[i].pos = -1;
127 cells[i].seq_id.clear();
128 cells[i].src = -1;
129 cells[i].tail = -1;
130 }
131
132 head = 0;
133 used = 0;
134
135 if (data) {
136 for (auto & [_, buf] : ctxs_bufs) {
137 ggml_backend_buffer_clear(buffer: buf.get(), value: 0);
138 }
139 }
140}
141
142bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
143 //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
144 uint32_t new_head = size;
145
146 if (p0 < 0) {
147 p0 = 0;
148 }
149
150 if (p1 < 0) {
151 p1 = std::numeric_limits<llama_pos>::max();
152 }
153
154 // models like Mamba or RWKV can't have a state partially erased
155 if (seq_id >= (int64_t) size) {
156 // could be fatal
157 return false;
158 }
159 if (0 <= seq_id) {
160 int32_t & tail_id = cells[seq_id].tail;
161 if (tail_id >= 0) {
162 const auto & cell = cells[tail_id];
163 // partial intersection is invalid
164 if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
165 //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
166 return false;
167 }
168 // invalidate tails which will be cleared
169 if (p0 <= cell.pos && cell.pos < p1) {
170 tail_id = -1;
171 }
172 }
173 } else {
174 // seq_id is negative, then the range should include everything or nothing
175 if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
176 //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
177 return false;
178 }
179 }
180
181 for (uint32_t i = 0; i < size; ++i) {
182 if (cells[i].pos >= p0 && cells[i].pos < p1) {
183 if (seq_id < 0) {
184 cells[i].seq_id.clear();
185 } else if (cells[i].has_seq_id(id: seq_id)) {
186 cells[i].seq_id.erase(x: seq_id);
187 } else {
188 continue;
189 }
190 if (cells[i].is_empty()) {
191 // keep count of the number of used cells
192 if (cells[i].pos >= 0) {
193 used--;
194 }
195 cells[i].pos = -1;
196 cells[i].src = -1;
197 if (new_head == size) {
198 new_head = i;
199 }
200 }
201 }
202 }
203
204 // If we freed up a slot, set head to it so searching can start there.
205 if (new_head != size && new_head < head) {
206 head = new_head;
207 }
208
209 return true;
210}
211
212void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
213 if (seq_id_src == seq_id_dst) {
214 return;
215 }
216
217 if (p0 < 0) {
218 p0 = 0;
219 }
220
221 if (p1 < 0) {
222 p1 = std::numeric_limits<llama_pos>::max();
223 }
224
225 if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
226 auto & tail_src = cells[seq_id_src];
227 auto & tail_dst = cells[seq_id_dst];
228 if (tail_dst.tail >= 0) {
229 // clear destination seq_id if it wasn't empty
230 auto & cell_dst = cells[tail_dst.tail];
231
232 cell_dst.seq_id.erase(x: seq_id_dst);
233 tail_dst.tail = -1;
234 if (cell_dst.seq_id.empty()) {
235 cell_dst.pos = -1;
236 cell_dst.src = -1;
237 used -= 1;
238 }
239 }
240 if (tail_src.tail >= 0) {
241 auto & cell_src = cells[tail_src.tail];
242
243 cell_src.seq_id.insert(x: seq_id_dst);
244 tail_dst.tail = tail_src.tail;
245 }
246 }
247}
248
249void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
250 uint32_t new_head = size;
251
252 for (uint32_t i = 0; i < size; ++i) {
253 if ((llama_seq_id) i != seq_id) {
254 cells[i].tail = -1;
255 }
256
257 if (!cells[i].has_seq_id(id: seq_id)) {
258 if (cells[i].pos >= 0) {
259 used--;
260 }
261
262 cells[i].pos = -1;
263 cells[i].src = -1;
264 cells[i].seq_id.clear();
265
266 if (new_head == size){
267 new_head = i;
268 }
269 } else {
270 cells[i].seq_id.clear();
271 cells[i].seq_id.insert(x: seq_id);
272 }
273 }
274
275 // If we freed up a slot, set head to it so searching can start there.
276 if (new_head != size && new_head < head) {
277 head = new_head;
278 }
279}
280
281void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
282 if (shift == 0) {
283 return;
284 }
285
286 if (p0 < 0) {
287 p0 = 0;
288 }
289
290 if (p1 < 0) {
291 p1 = std::numeric_limits<llama_pos>::max();
292 }
293
294 // If there is no range then return early to avoid looping over the
295 if (p0 == p1) {
296 return;
297 }
298
299 // for Mamba-like or RWKV models, only the pos needs to be shifted
300 if (0 <= seq_id && seq_id < (int64_t) size) {
301 const int32_t tail_id = cells[seq_id].tail;
302 if (tail_id >= 0) {
303 auto & cell = cells[tail_id];
304 if (cell.has_seq_id(id: seq_id) && p0 <= cell.pos && cell.pos < p1) {
305 cell.pos += shift;
306 }
307 }
308 }
309}
310
311void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
312 if (d == 1) {
313 return;
314 }
315
316 if (p0 < 0) {
317 p0 = 0;
318 }
319
320 if (p1 < 0) {
321 p1 = std::numeric_limits<llama_pos>::max();
322 }
323
324 // If there is no range then return early to avoid looping over the cache.
325 if (p0 == p1) {
326 return;
327 }
328
329 // for Mamba-like or RWKV models, only the pos needs to be changed
330 if (0 <= seq_id && seq_id < (int64_t) size) {
331 const int32_t tail_id = cells[seq_id].tail;
332 if (tail_id >= 0) {
333 auto & cell = cells[tail_id];
334 if (cell.has_seq_id(id: seq_id) && p0 <= cell.pos && cell.pos < p1) {
335 cell.pos /= d;
336 }
337 }
338 }
339}
340
341llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
342 llama_pos result = std::numeric_limits<llama_pos>::max();
343
344 for (uint32_t i = 0; i < size; ++i) {
345 if (cells[i].has_seq_id(id: seq_id)) {
346 result = std::min(a: result, b: cells[i].pos);
347 }
348 }
349
350 if (result == std::numeric_limits<llama_pos>::max()) {
351 result = -1;
352 }
353
354 return result;
355}
356
357llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
358 llama_pos result = -1;
359
360 for (uint32_t i = 0; i < size; ++i) {
361 if (cells[i].has_seq_id(id: seq_id)) {
362 result = std::max(a: result, b: cells[i].pos);
363 }
364 }
365
366 return result;
367}
368
369std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
370 std::map<ggml_backend_buffer_type_t, size_t> ret;
371 for (const auto & [_, buf] : ctxs_bufs) {
372 ret[ggml_backend_buffer_get_type(buffer: buf.get())] += ggml_backend_buffer_get_size(buffer: buf.get());
373 }
374 return ret;
375}
376
377llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
378 do {
379 balloc.split_reset();
380
381 std::vector<llama_ubatch> ubatches;
382 while (true) {
383 llama_ubatch ubatch;
384
385 if (embd_all) {
386 // if all tokens are output, split by sequence
387 ubatch = balloc.split_seq(n_ubatch);
388 } else {
389 // TODO: non-sequential equal split can be done if using unified KV cache
390 // for simplicity, we always use sequential equal split for now
391 ubatch = balloc.split_equal(n_ubatch, sequential: true);
392 }
393
394 if (ubatch.n_tokens == 0) {
395 break;
396 }
397
398 ubatches.push_back(x: std::move(ubatch)); // NOLINT
399 }
400
401 if (balloc.get_n_used() < balloc.get_n_tokens()) {
402 // failed to find a suitable split
403 break;
404 }
405
406 if (!prepare(ubatches)) {
407 break;
408 }
409
410 return std::make_unique<llama_memory_recurrent_context>(args: this, args: std::move(ubatches));
411 } while (false);
412
413 return std::make_unique<llama_memory_recurrent_context>(args: LLAMA_MEMORY_STATUS_FAILED_PREPARE);
414}
415
416llama_memory_context_ptr llama_memory_recurrent::init_full() {
417 return std::make_unique<llama_memory_recurrent_context>(args: this);
418}
419
420llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
421 GGML_UNUSED(lctx);
422 GGML_UNUSED(optimize);
423
424 return std::make_unique<llama_memory_recurrent_context>(args: LLAMA_MEMORY_STATUS_NO_UPDATE);
425}
426
427bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
428 // simply remember the full state because it is very small for this type of cache
429 // TODO: optimize
430 auto org_cells = cells;
431 auto org_used = used;
432 auto org_head = head;
433
434 bool success = true;
435
436 for (const auto & ubatch : ubatches) {
437 if (!find_slot(ubatch)) {
438 success = false;
439 break;
440 }
441 }
442
443 // restore the original state
444 cells = std::move(org_cells);
445 used = org_used;
446 head = org_head;
447
448 return success;
449}
450
451bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
452 const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
453 const uint32_t n_seqs = ubatch.n_seqs;
454
455 // if we have enough unused cells before the current head ->
456 // better to start searching from the beginning of the cache, hoping to fill it
457 if (head > used + 2*n_seqs) {
458 head = 0;
459 }
460
461 // For recurrent state architectures (like Mamba or RWKV),
462 // each cache cell can store the state for a whole sequence.
463 // A slot should be always be contiguous.
464
465 // can only process batches with an equal number of new tokens in each sequence
466 GGML_ASSERT(ubatch.equal_seqs());
467
468 int32_t min = size - 1;
469 int32_t max = 0;
470
471 // everything should fit if all seq_ids are smaller than the max
472 for (uint32_t s = 0; s < n_seqs; ++s) {
473 const uint32_t i = s*n_seq_tokens; // first token of sequence set s
474 const uint32_t n_seq_id = ubatch.n_seq_id[i];
475
476 for (uint32_t j = 0; j < n_seq_id; ++j) {
477 const llama_seq_id seq_id = ubatch.seq_id[i][j];
478
479 if (seq_id < 0 || (uint32_t) seq_id >= size) {
480 // too big seq_id
481 // TODO: would it be possible to resize the cache instead?
482 LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
483 return false;
484 }
485 if (j > 0) {
486 auto & seq = cells[seq_id];
487 if (seq.tail >= 0) {
488 auto & cell = cells[seq.tail];
489 // clear cells from seq_ids that become shared
490 // (should not normally happen, but let's handle it anyway)
491 cell.seq_id.erase(x: seq_id);
492 seq.tail = -1;
493 if (cell.seq_id.empty()) {
494 cell.pos = -1;
495 cell.src = -1;
496 used -= 1;
497 }
498 }
499 }
500 }
501 }
502
503#ifndef NDEBUG
504 {
505 std::vector<int32_t> tails_verif;
506 tails_verif.assign(size, -1);
507 for (uint32_t i = 0; i < size; ++i) {
508 auto & cell = cells[i];
509 for (llama_seq_id seq_id : cell.seq_id) {
510 if (tails_verif[seq_id] != -1) {
511 LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
512 }
513 tails_verif[seq_id] = i;
514 }
515 }
516 for (uint32_t i = 0; i < size; ++i) {
517 if (tails_verif[i] != cells[i].tail) {
518 LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
519 }
520 }
521 }
522#endif
523
524 // find next empty cell
525 uint32_t next_empty_cell = head;
526
527 for (uint32_t i = 0; i < size; ++i) {
528 if (next_empty_cell >= size) { next_empty_cell -= size; }
529 auto & cell = cells[next_empty_cell];
530 if (cell.is_empty()) { break; }
531 next_empty_cell += 1;
532 }
533
534 // find usable cell range
535 for (uint32_t s = 0; s < n_seqs; ++s) {
536 const uint32_t i = s*n_seq_tokens;
537 const llama_seq_id seq_id = ubatch.seq_id[i][0];
538 auto & seq_meta = cells[seq_id];
539 bool has_cell = false;
540 if (seq_meta.tail >= 0) {
541 auto & cell = cells[seq_meta.tail];
542 GGML_ASSERT(cell.has_seq_id(seq_id));
543 // does this seq_id "own" the cell?
544 if (cell.seq_id.size() == 1) { has_cell = true; }
545 }
546 if (!has_cell) {
547 auto & empty_cell = cells[next_empty_cell];
548 GGML_ASSERT(empty_cell.is_empty());
549 // copy old tail into the empty cell
550 if (seq_meta.tail >= 0) {
551 auto & orig_cell = cells[seq_meta.tail];
552 empty_cell.pos = orig_cell.pos;
553 empty_cell.src = orig_cell.src;
554 orig_cell.seq_id.erase(x: seq_id);
555 empty_cell.seq_id.insert(x: seq_id); // will be overwritten
556 GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
557 }
558 seq_meta.tail = next_empty_cell;
559 // find next empty cell
560 if (s + 1 < n_seqs) {
561 for (uint32_t j = 0; j < size; ++j) {
562 next_empty_cell += 1;
563 if (next_empty_cell >= size) { next_empty_cell -= size; }
564 auto & cell = cells[next_empty_cell];
565 if (cell.is_empty()) { break; }
566 }
567 }
568 }
569 if (min > seq_meta.tail) { min = seq_meta.tail; }
570 if (max < seq_meta.tail) { max = seq_meta.tail; }
571 }
572
573 // gather and re-order
574 for (uint32_t s = 0; s < n_seqs; ++s) {
575 const uint32_t i = s*n_seq_tokens;
576 const int32_t dst_id = s + min;
577 const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
578 if (dst_id != src_id) {
579 auto & dst_cell = cells[dst_id];
580 auto & src_cell = cells[src_id];
581
582 std::swap(a&: dst_cell.pos, b&: src_cell.pos);
583 std::swap(a&: dst_cell.src, b&: src_cell.src);
584 std::swap(x&: dst_cell.seq_id, y&: src_cell.seq_id);
585
586 // swap tails
587 for (uint32_t j = 0; j < size; ++j) {
588 int32_t & tail = cells[j].tail;
589 if (tail == src_id) {
590 tail = dst_id;
591 } else if (tail == dst_id) {
592 tail = src_id;
593 }
594 }
595 }
596 }
597
598 // update the pos of the used seqs
599 for (uint32_t s = 0; s < n_seqs; ++s) {
600 const uint32_t i = s*n_seq_tokens;
601 const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
602 const int32_t cell_id = s + min;
603 auto & cell = cells[cell_id];
604
605 if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
606 // What should happen when the pos backtracks or skips a value?
607 // Clearing the state mid-batch would require special-casing which isn't done.
608 LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
609 __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
610 }
611 cell.pos = last_pos;
612 cell.seq_id.clear();
613 for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
614 const llama_seq_id seq_id = ubatch.seq_id[i][j];
615 cell.seq_id.insert(x: seq_id);
616 cells[seq_id].tail = cell_id;
617 }
618 }
619
620 // Find first cell without src refs, to use as the zero-ed state
621 {
622 // TODO: bake-in src refcounts in the cell metadata
623 std::vector<int32_t> refcounts(size, 0);
624 for (size_t i = 0; i < size; ++i) {
625 const int32_t src = cells[i].src;
626 if (src >= 0) {
627 refcounts[src] += 1;
628 }
629 }
630
631 rs_z = -1;
632 for (int i = min; i <= max; ++i) {
633 if (refcounts[i] == 0) {
634 rs_z = i;
635 break;
636 }
637 }
638
639 for (int i = min; i <= max; ++i) {
640 if (cells[i].src < 0) {
641 GGML_ASSERT(rs_z >= 0);
642 cells[i].src0 = rs_z;
643 } else {
644 // Stage the source ids for all used cells to allow correct seq_* behavior
645 // and still make these values available when setting the inputs
646 cells[i].src0 = cells[i].src;
647 }
648 cells[i].src = i; // avoid moving or clearing twice
649 }
650 }
651
652 // allow getting the range of used cells, from head to head + n
653 head = min;
654 n = max - min + 1;
655 used = std::count_if(first: cells.begin(), last: cells.end(),
656 pred: [](const mem_cell & cell){ return !cell.is_empty(); });
657
658 // sanity check
659 return n >= n_seqs;
660}
661
662bool llama_memory_recurrent::get_can_shift() const {
663 // shifting the pos is trivial for recurrent models
664 return true;
665}
666
667size_t llama_memory_recurrent::total_size() const {
668 size_t size = 0;
669 for (const auto & [_, buf] : ctxs_bufs) {
670 size += ggml_backend_buffer_get_size(buffer: buf.get());
671 }
672
673 return size;
674}
675
676size_t llama_memory_recurrent::size_r_bytes() const {
677 size_t size_r_bytes = 0;
678
679 for (const auto & r : r_l) {
680 if (r != nullptr) {
681 size_r_bytes += ggml_nbytes(tensor: r);
682 }
683 }
684
685 return size_r_bytes;
686}
687
688size_t llama_memory_recurrent::size_s_bytes() const {
689 size_t size_s_bytes = 0;
690
691 for (const auto & s : s_l) {
692 if (s != nullptr) {
693 size_s_bytes += ggml_nbytes(tensor: s);
694 }
695 }
696
697 return size_s_bytes;
698}
699
700void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
701 GGML_UNUSED(flags);
702
703 std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
704 uint32_t cell_count = 0;
705
706 // Count the number of cells with the specified seq_id
707 // Find all the ranges of cells with this seq id (or all, when -1)
708 uint32_t cell_range_begin = size;
709 for (uint32_t i = 0; i < size; ++i) {
710 const auto & cell = cells[i];
711 if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(id: seq_id)) {
712 ++cell_count;
713 if (cell_range_begin == size) {
714 cell_range_begin = i;
715 }
716 } else {
717 if (cell_range_begin != size) {
718 cell_ranges.emplace_back(args&: cell_range_begin, args&: i);
719 cell_range_begin = size;
720 }
721 }
722 }
723 if (cell_range_begin != size) {
724 cell_ranges.emplace_back(args&: cell_range_begin, args: size);
725 }
726
727 // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
728 uint32_t cell_count_check = 0;
729 for (const auto & range : cell_ranges) {
730 cell_count_check += range.second - range.first;
731 }
732 GGML_ASSERT(cell_count == cell_count_check);
733
734 io.write(src: &cell_count, size: sizeof(cell_count));
735
736 state_write_meta(io, cell_ranges, seq_id);
737 state_write_data(io, cell_ranges);
738}
739
740void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
741 GGML_UNUSED(flags);
742
743 uint32_t cell_count;
744 io.read_to(dst: &cell_count, size: sizeof(cell_count));
745
746 bool res = true;
747
748 res = res && state_read_meta(io, cell_count, dest_seq_id: seq_id);
749 res = res && state_read_data(io, cell_count);
750
751 if (!res) {
752 if (seq_id == -1) {
753 clear(data: true);
754 } else {
755 seq_rm(seq_id, p0: -1, p1: -1);
756 }
757 throw std::runtime_error("failed to restore kv cache");
758 }
759}
760
761void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
762 for (const auto & range : cell_ranges) {
763 for (uint32_t i = range.first; i < range.second; ++i) {
764 const auto & cell = cells[i];
765 const llama_pos pos = cell.pos;
766 const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
767
768 io.write(src: &pos, size: sizeof(pos));
769 io.write(src: &n_seq_id, size: sizeof(n_seq_id));
770
771 if (n_seq_id) {
772 for (auto seq_id : cell.seq_id) {
773 io.write(src: &seq_id, size: sizeof(seq_id));
774 }
775 }
776 }
777 }
778}
779
780void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
781 const uint32_t s_trans = 0;
782 const uint32_t n_layer = hparams.n_layer;
783
784 io.write(src: &s_trans, size: sizeof(s_trans));
785 io.write(src: &n_layer, size: sizeof(n_layer));
786
787 std::vector<uint8_t> tmp_buf;
788
789 // Iterate and write all the keys first, each row is a cell
790 // Get whole range at a time
791 for (uint32_t il = 0; il < n_layer; ++il) {
792 // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
793 if (r_l[il] == nullptr) continue;
794
795 // Write key type
796 const int32_t r_type_i = (int32_t)r_l[il]->type;
797 io.write(src: &r_type_i, size: sizeof(r_type_i));
798
799 // Write row size of key
800 const uint64_t r_size_row = ggml_row_size(type: r_l[il]->type, ne: hparams.n_embd_r());
801 io.write(src: &r_size_row, size: sizeof(r_size_row));
802
803 // Read each range of cells of k_size length each into tmp_buf and write out
804 for (const auto & range : cell_ranges) {
805 const size_t range_size = range.second - range.first;
806 const size_t buf_size = range_size * r_size_row;
807 io.write_tensor(tensor: r_l[il], offset: range.first * r_size_row, size: buf_size);
808 }
809 }
810
811 if (!s_trans) {
812 for (uint32_t il = 0; il < n_layer; ++il) {
813 // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
814 if (s_l[il] == nullptr) continue;
815
816 // Write value type
817 const int32_t s_type_i = (int32_t)s_l[il]->type;
818 io.write(src: &s_type_i, size: sizeof(s_type_i));
819
820 // Write row size of value
821 const uint64_t s_size_row = ggml_row_size(type: s_l[il]->type, ne: hparams.n_embd_s());
822 io.write(src: &s_size_row, size: sizeof(s_size_row));
823
824 // Read each range of cells of s_size length each into tmp_buf and write out
825 for (const auto & range : cell_ranges) {
826 const size_t range_size = range.second - range.first;
827 const size_t buf_size = range_size * s_size_row;
828 io.write_tensor(tensor: s_l[il], offset: range.first * s_size_row, size: buf_size);
829 }
830 }
831 } else {
832 // When v is transposed, we also need the element size and get the element ranges from each row
833 const uint32_t mem_size = size;
834 for (uint32_t il = 0; il < n_layer; ++il) {
835 // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
836 if (s_l[il] == nullptr) continue;
837
838 const uint32_t n_embd_s = hparams.n_embd_s();
839
840 // Write value type
841 const int32_t s_type_i = (int32_t)s_l[il]->type;
842 io.write(src: &s_type_i, size: sizeof(s_type_i));
843
844 // Write element size
845 const uint32_t s_size_el = ggml_type_size(type: s_l[il]->type);
846 io.write(src: &s_size_el, size: sizeof(s_size_el));
847
848 // Write GQA embedding size
849 io.write(src: &n_embd_s, size: sizeof(n_embd_s));
850
851 // For each row, we get the element values of each cell
852 for (uint32_t j = 0; j < n_embd_s; ++j) {
853 // Read each range of cells of v_size_el length each into tmp_buf and write out
854 for (const auto & range : cell_ranges) {
855 const size_t range_size = range.second - range.first;
856 const size_t src_offset = (range.first + j * mem_size) * s_size_el;
857 const size_t buf_size = range_size * s_size_el;
858 io.write_tensor(tensor: s_l[il], offset: src_offset, size: buf_size);
859 }
860 }
861 }
862 }
863}
864
865bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
866 if (dest_seq_id != -1) {
867 // single sequence
868 seq_rm(seq_id: dest_seq_id, p0: -1, p1: -1);
869
870 if (cell_count == 0) {
871 return true;
872 }
873
874 llama_batch_allocr balloc(hparams.n_pos_per_embd());
875
876 llama_ubatch ubatch = balloc.ubatch_reserve(n_seq_tokens: cell_count, n_seqs: 1);
877
878 for (uint32_t i = 0; i < cell_count; ++i) {
879 llama_pos pos;
880 uint32_t n_seq_id;
881
882 io.read_to(dst: &pos, size: sizeof(pos));
883 io.read_to(dst: &n_seq_id, size: sizeof(n_seq_id));
884
885 if (n_seq_id != 0) {
886 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
887 return false;
888 }
889
890 ubatch.pos[i] = pos;
891 }
892 ubatch.n_seq_id[0] = 1;
893 ubatch.seq_id[0] = &dest_seq_id;
894
895 if (!find_slot(ubatch)) {
896 LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
897 return false;
898 }
899
900 // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
901 // Assume that this is one contiguous block of cells
902 GGML_ASSERT(head + cell_count <= size);
903 GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
904 GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
905 GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
906 GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
907 } else {
908 // whole KV cache restore
909
910 if (cell_count > size) {
911 LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
912 return false;
913 }
914
915 clear(data: true);
916
917 for (uint32_t i = 0; i < cell_count; ++i) {
918 auto & cell = cells[i];
919
920 llama_pos pos;
921 uint32_t n_seq_id;
922
923 io.read_to(dst: &pos, size: sizeof(pos));
924 io.read_to(dst: &n_seq_id, size: sizeof(n_seq_id));
925
926 cell.pos = pos;
927
928 for (uint32_t j = 0; j < n_seq_id; ++j) {
929 llama_seq_id seq_id;
930 io.read_to(dst: &seq_id, size: sizeof(seq_id));
931
932 // TODO: llama_memory_recurrent should have a notion of max sequences
933 //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
934 if (seq_id < 0) {
935 //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
936 LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
937 return false;
938 }
939
940 cell.seq_id.insert(x: seq_id);
941
942 int32_t & tail = cells[seq_id].tail;
943 if (tail != -1) {
944 LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
945 return false;
946 }
947 tail = i;
948 }
949 }
950
951 head = 0;
952 used = cell_count;
953 }
954
955 for (uint32_t i = 0; i < cell_count; ++i) {
956 uint32_t cell_id = head + i;
957 // make sure the recurrent states will keep their restored state
958 cells[cell_id].src = cell_id;
959 }
960
961 return true;
962}
963
964bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
965 uint32_t s_trans;
966 uint32_t n_layer;
967 io.read_to(dst: &s_trans, size: sizeof(s_trans));
968 io.read_to(dst: &n_layer, size: sizeof(n_layer));
969
970 if (n_layer != hparams.n_layer) {
971 LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
972 return false;
973 }
974 if (cell_count > size) {
975 LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
976 return false;
977 }
978 if (false != (bool) s_trans) {
979 LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
980 return false;
981 }
982
983 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
984 for (uint32_t il = 0; il < n_layer; ++il) {
985 // skip null layers
986 if (r_l[il] == nullptr) continue;
987
988 // Read type of key
989 int32_t r_type_i_ref;
990 io.read_to(dst: &r_type_i_ref, size: sizeof(r_type_i_ref));
991 const int32_t r_type_i = (int32_t) r_l[il]->type;
992 if (r_type_i != r_type_i_ref) {
993 LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
994 return false;
995 }
996
997 // Read row size of key
998 uint64_t r_size_row_ref;
999 io.read_to(dst: &r_size_row_ref, size: sizeof(r_size_row_ref));
1000 const size_t r_size_row = ggml_row_size(type: r_l[il]->type, ne: hparams.n_embd_r());
1001 if (r_size_row != r_size_row_ref) {
1002 LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
1003 return false;
1004 }
1005
1006 if (cell_count) {
1007 // Read and set the keys for the whole cell range
1008 ggml_backend_tensor_set(tensor: r_l[il], data: io.read(size: cell_count * r_size_row), offset: head * r_size_row, size: cell_count * r_size_row);
1009 }
1010 }
1011
1012 if (!s_trans) {
1013 for (uint32_t il = 0; il < n_layer; ++il) {
1014 // skip null layers
1015 if (s_l[il] == nullptr) continue;
1016
1017 // Read type of value
1018 int32_t s_type_i_ref;
1019 io.read_to(dst: &s_type_i_ref, size: sizeof(s_type_i_ref));
1020 const int32_t s_type_i = (int32_t)s_l[il]->type;
1021
1022 if (s_type_i != s_type_i_ref) {
1023 LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
1024 return false;
1025 }
1026
1027 // Read row size of value
1028 uint64_t s_size_row_ref;
1029 io.read_to(dst: &s_size_row_ref, size: sizeof(s_size_row_ref));
1030 const size_t s_size_row = ggml_row_size(type: s_l[il]->type, ne: hparams.n_embd_s());
1031 if (s_size_row != s_size_row_ref) {
1032 LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
1033 return false;
1034 }
1035
1036 if (cell_count) {
1037 // Read and set the values for the whole cell range
1038 ggml_backend_tensor_set(tensor: s_l[il], data: io.read(size: cell_count * s_size_row), offset: head * s_size_row, size: cell_count * s_size_row);
1039 }
1040 }
1041 } else {
1042 // For each layer, read the values for each cell (transposed)
1043 for (uint32_t il = 0; il < n_layer; ++il) {
1044 // skip null layers
1045 if (s_l[il] == nullptr) continue;
1046
1047 const uint32_t n_embd_s = hparams.n_embd_s();
1048
1049 // Read type of value
1050 int32_t s_type_i_ref;
1051 io.read_to(dst: &s_type_i_ref, size: sizeof(s_type_i_ref));
1052 const int32_t s_type_i = (int32_t)s_l[il]->type;
1053 if (s_type_i != s_type_i_ref) {
1054 LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
1055 return false;
1056 }
1057
1058 // Read element size of value
1059 uint32_t s_size_el_ref;
1060 io.read_to(dst: &s_size_el_ref, size: sizeof(s_size_el_ref));
1061 const size_t s_size_el = ggml_type_size(type: s_l[il]->type);
1062 if (s_size_el != s_size_el_ref) {
1063 LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
1064 return false;
1065 }
1066
1067 // Read state embedding size
1068 uint32_t n_embd_s_ref;
1069 io.read_to(dst: &n_embd_s_ref, size: sizeof(n_embd_s_ref));
1070 if (n_embd_s != n_embd_s_ref) {
1071 LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
1072 return false;
1073 }
1074
1075 if (cell_count) {
1076 // For each row in the transposed matrix, read the values for the whole cell range
1077 for (uint32_t j = 0; j < n_embd_s; ++j) {
1078 const size_t dst_offset = (head + j * size) * s_size_el;
1079 ggml_backend_tensor_set(tensor: s_l[il], data: io.read(size: cell_count * s_size_el), offset: dst_offset, size: cell_count * s_size_el);
1080 }
1081 }
1082 }
1083 }
1084
1085 return true;
1086}
1087
1088//
1089// llama_memory_recurrent_context
1090//
1091
1092llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
1093
1094llama_memory_recurrent_context::llama_memory_recurrent_context(
1095 llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1096}
1097
1098llama_memory_recurrent_context::llama_memory_recurrent_context(
1099 llama_memory_recurrent * mem,
1100 std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
1101
1102llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
1103
1104bool llama_memory_recurrent_context::next() {
1105 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1106
1107 if (++i_next >= ubatches.size()) {
1108 return false;
1109 }
1110
1111 return true;
1112}
1113
1114bool llama_memory_recurrent_context::apply() {
1115 assert(!llama_memory_status_is_fail(status));
1116
1117 // no ubatches -> this is an update
1118 if (ubatches.empty()) {
1119 // recurrent cache never performs updates
1120 assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
1121
1122 return true;
1123 }
1124
1125 mem->find_slot(ubatch: ubatches[i_next]);
1126
1127 return true;
1128}
1129
1130llama_memory_status llama_memory_recurrent_context::get_status() const {
1131 return status;
1132}
1133
1134const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
1135 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1136
1137 return ubatches[i_next];
1138}
1139
1140uint32_t llama_memory_recurrent_context::get_n_rs() const {
1141 return is_full ? mem->size : mem->n;
1142}
1143
1144uint32_t llama_memory_recurrent_context::get_head() const {
1145 return is_full ? 0 : mem->head;
1146}
1147
1148int32_t llama_memory_recurrent_context::get_rs_z() const {
1149 return is_full ? 0 : mem->rs_z;
1150}
1151
1152uint32_t llama_memory_recurrent_context::get_size() const {
1153 return mem->size;
1154}
1155
1156ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
1157 return mem->r_l[il];
1158}
1159
1160ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
1161 return mem->s_l[il];
1162}
1163
1164int32_t llama_memory_recurrent_context::s_copy(int i) const {
1165 return mem->cells[i + mem->head].src0;
1166}
1167