1#pragma once
2
3#include "llama.h"
4#include "llama-cparams.h"
5
6#include <bitset>
7#include <cassert>
8#include <cstring>
9#include <map>
10#include <set>
11#include <vector>
12
13struct llama_kv_cell_ext {
14 // 2D spatial positions, typically used for M-RoPE
15 llama_pos x = 0;
16 llama_pos y = 0;
17
18 // return true if the current 2D spatial position is greater than other
19 bool is_2d_gt(llama_pos ox, llama_pos oy) const {
20 return (y > oy) || (y == oy && x > ox);
21 }
22
23 void reset() {
24 static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
25
26 memset(s: this, c: 0, n: sizeof(*this));
27 }
28};
29
30// meta information about KV cells that can be part of multiple sequences at the same time
31// TODO: add unit tests
32class llama_kv_cells {
33public:
34 void reset() {
35 for (uint32_t i = 0; i < pos.size(); ++i) {
36 pos[i] = -1;
37 ext[i].reset();
38 shift[i] = 0;
39 seq[i].reset();
40 }
41
42 has_shift = false;
43
44 used.clear();
45
46 for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
47 seq_pos[s].clear();
48 }
49 }
50
51 void reset_shift() {
52 has_shift = false;
53
54 for (uint32_t i = 0; i < shift.size(); ++i) {
55 shift[i] = 0;
56 }
57 }
58
59 uint32_t size() const {
60 return pos.size();
61 }
62
63 void resize(uint32_t n) {
64 pos.resize(new_size: n);
65 ext.resize(new_size: n);
66 shift.resize(new_size: n);
67 seq.resize(new_size: n);
68
69 reset();
70 }
71
72 bool is_empty(uint32_t i) const {
73 assert(i < pos.size());
74 assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
75
76 return pos[i] == -1;
77 }
78
79 uint32_t get_used() const {
80 return used.size();
81 }
82
83 // the index of the first cell that is used
84 // return 0 if no cells are used
85 uint32_t used_min() const {
86 return used.empty() ? 0 : *used.begin();
87 }
88
89 // the index of the last cell that is used + 1
90 // return 0 if no cells are used
91 uint32_t used_max_p1() const {
92 return used.empty() ? 0 : *used.rbegin() + 1;
93 }
94
95 bool get_has_shift() const {
96 return has_shift;
97 }
98
99 // move cell isrc to idst (used during defrag)
100 //void mv(uint32_t isrc, uint32_t idst) {
101 // assert(isrc < pos.size());
102 // assert(idst < pos.size());
103
104 // assert(pos[idst] == -1);
105 // assert(pos[isrc] != -1);
106
107 // pos [idst] = pos [isrc];
108 // shift[idst] = shift[isrc];
109 // seq [idst] = seq [isrc];
110
111 // pos [isrc] = -1;
112 // shift[isrc] = 0;
113 // seq [isrc].reset();
114
115 // used.erase (isrc);
116 // used.insert(idst);
117 //}
118
119 // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
120 llama_kv_cells cp(uint32_t i, uint32_t n) const {
121 assert(i + n <= pos.size());
122
123 llama_kv_cells res;
124
125 res.resize(n);
126
127 for (uint32_t j = 0; j < n; ++j) {
128 const auto idx = i + j;
129
130 res.pos[j] = pos[idx];
131 res.ext[j] = ext[idx];
132 res.seq[j] = seq[idx];
133
134 assert(shift[idx] == 0);
135 }
136
137 return res;
138 }
139
140 // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
141 llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
142 llama_kv_cells res;
143
144 res.resize(n: idxs.size());
145
146 for (uint32_t j = 0; j < idxs.size(); ++j) {
147 const auto idx = idxs[j];
148
149 res.pos[j] = pos[idx];
150 res.ext[j] = ext[idx];
151 res.seq[j] = seq[idx];
152
153 assert(shift[idx] == 0);
154 }
155
156 return res;
157 }
158
159 // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
160 void set(uint32_t i, const llama_kv_cells & other) {
161 assert(i + other.pos.size() <= pos.size());
162
163 for (uint32_t j = 0; j < other.pos.size(); ++j) {
164 const auto idx = i + j;
165
166 if (pos[idx] == -1 && other.pos[j] != -1) {
167 used.insert(x: i + j);
168 }
169
170 if (pos[idx] != -1 && other.pos[j] == -1) {
171 used.erase(x: i + j);
172 }
173
174 if (pos[idx] != -1) {
175 seq_pos_rm(i: i + j);
176 }
177
178 pos[idx] = other.pos[j];
179 ext[idx] = other.ext[j];
180 seq[idx] = other.seq[j];
181
182 if (pos[idx] != -1) {
183 seq_pos_add(i: i + j);
184 }
185
186 assert(shift[idx] == 0);
187 }
188 }
189
190 // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
191 void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
192 assert(idxs.size() == other.pos.size());
193
194 for (uint32_t j = 0; j < other.pos.size(); ++j) {
195 const auto idx = idxs[j];
196
197 if (pos[idx] == -1 && other.pos[j] != -1) {
198 used.insert(x: idx);
199 }
200
201 if (pos[idx] != -1 && other.pos[j] == -1) {
202 used.erase(x: idx);
203 }
204
205 if (pos[idx] != -1) {
206 seq_pos_rm(i: idx);
207 }
208
209 pos[idx] = other.pos[j];
210 ext[idx] = other.ext[j];
211 seq[idx] = other.seq[j];
212
213 if (pos[idx] != -1) {
214 seq_pos_add(i: idx);
215 }
216
217 assert(shift[idx] == 0);
218 }
219 }
220
221 // clear a non-empty cell
222 void rm(uint32_t i) {
223 assert(i < pos.size());
224 assert(pos[i] != -1);
225
226 seq_pos_rm(i);
227 seq[i].reset();
228
229 pos[i] = -1;
230 ext[i].reset();
231 shift[i] = 0;
232
233 used.erase(x: i);
234 }
235
236 // note: call only if the cell has seq_id
237 // return true if the cell becomes empty
238 bool seq_rm(uint32_t i, llama_seq_id seq_id) {
239 assert(i < pos.size());
240 assert(seq[i].test(seq_id));
241 assert(pos[i] != -1);
242 assert(seq_id >= 0);
243
244 seq[i].reset(position: seq_id);
245 seq_pos_dec(s: seq_id, p: pos[i]);
246
247 if (seq[i].none()) {
248 pos[i] = -1;
249 ext[i].reset();
250 shift[i] = 0;
251
252 used.erase(x: i);
253
254 return true;
255 }
256
257 return false;
258 }
259
260 // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
261 bool seq_keep(uint32_t i, llama_seq_id seq_id) {
262 assert(i < pos.size());
263
264 if (seq[i].test(position: seq_id)) {
265 seq_pos_rm(i);
266 seq[i].reset();
267
268 seq[i].set(position: seq_id);
269 seq_pos_inc(s: seq_id, p: pos[i]);
270
271 return false;
272 }
273
274 if (seq[i].any()) {
275 seq_pos_rm(i);
276 seq[i].reset();
277
278 pos[i] = -1;
279 ext[i].reset();
280 shift[i] = 0;
281
282 used.erase(x: i);
283
284 return true;
285 }
286
287 assert(pos[i] == -1);
288
289 return false;
290 }
291
292 // number of different sequences in the cell
293 int seq_count(uint32_t i) const {
294 assert(i < pos.size());
295 assert(pos[i] != -1);
296
297 return seq[i].count();
298 }
299
300 // check if the cell contains seq_id
301 bool seq_has(uint32_t i, llama_seq_id seq_id) const {
302 assert(i < pos.size());
303 assert(seq_id >= 0);
304
305 return seq[i].test(position: seq_id);
306 }
307
308 // note: call only if the cell is not empty and the seq_id is not in the cell
309 void seq_add(uint32_t i, llama_seq_id seq_id) {
310 assert(i < pos.size());
311 assert(pos[i] != -1);
312 assert(!seq[i].test(seq_id));
313
314 seq[i].set(position: seq_id);
315 seq_pos_inc(s: seq_id, p: pos[i]);
316 }
317
318 // return the sequence id of this cell
319 // note: call only for cells with exactly one sequence
320 llama_seq_id seq_get(uint32_t i) const {
321 assert(seq[i].count() == 1);
322
323 for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
324 if (seq[i].test(position: s)) {
325 return s;
326 }
327 }
328
329 return -1;
330 }
331
332 // the minimum position of sequence seq_id currently present in any of the cells
333 // return -1 if the sequence is not present
334 llama_pos seq_pos_min(llama_seq_id seq_id) const {
335 assert(seq_id >= 0);
336 assert(seq_id < LLAMA_MAX_SEQ);
337
338 if (seq_pos[seq_id].empty()) {
339 return -1;
340 }
341
342 assert(seq_pos[seq_id].begin()->second > 0);
343
344 return seq_pos[seq_id].begin()->first;
345 }
346
347 // the maximum position of sequence seq_id currently present in any of the cells
348 // return -1 if the sequence is not present
349 llama_pos seq_pos_max(llama_seq_id seq_id) const {
350 assert(seq_id >= 0);
351 assert(seq_id < LLAMA_MAX_SEQ);
352
353 if (seq_pos[seq_id].empty()) {
354 return -1;
355 }
356
357 assert(seq_pos[seq_id].rbegin()->second > 0);
358
359 return seq_pos[seq_id].rbegin()->first;
360 }
361
362 // note: call only if the cell is not empty
363 llama_pos pos_get(uint32_t i) const {
364 assert(i < pos.size());
365 assert(pos[i] != -1);
366
367 return pos[i];
368 }
369
370 const llama_kv_cell_ext & ext_get(uint32_t i) const {
371 assert(i < pos.size());
372 assert(pos[i] != -1);
373
374 return ext[i];
375 }
376
377 // note: call only if the cell is not empty
378 llama_pos get_shift(uint32_t i) const {
379 assert(i < pos.size());
380 assert(pos[i] != -1);
381
382 return shift[i];
383 }
384
385 // check if a cell is not empty and its position is within [p0, p1)
386 bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
387 assert(i < pos.size());
388
389 return pos[i] >= p0 && pos[i] < p1;
390 }
391
392 // set the position of an empty cell
393 // does not modify "has_shift"
394 // note: call only if the cell is empty
395 void pos_set(uint32_t i, llama_pos p) {
396 assert(i < pos.size());
397 assert(pos[i] == -1);
398 assert(seq[i].none());
399
400 pos[i] = p;
401
402 used.insert(x: i);
403 }
404
405 void ext_set(uint32_t i, llama_kv_cell_ext p) {
406 assert(i < ext.size());
407 ext[i] = p;
408 }
409
410 // pos[i] = pos[i] + d
411 // sets "has_shift" to true
412 // note: call only if the cell is not empty
413 bool pos_add(uint32_t i, llama_pos d) {
414 assert(i < pos.size());
415 assert(pos[i] != -1);
416
417 seq_pos_rm(i);
418
419 pos[i] += d;
420 shift[i] += d;
421
422 has_shift = true;
423
424 if (pos[i] < 0) {
425 seq[i].reset();
426 pos[i] = -1;
427 shift[i] = 0;
428
429 used.erase(x: i);
430
431 return true;
432 }
433
434 seq_pos_add(i);
435
436 return false;
437 }
438
439 // pos[i] = pos[i] / d
440 // sets "has_shift" to true
441 // note: call only if the cell is not empty
442 void pos_div(uint32_t i, int d) {
443 assert(i < pos.size());
444 assert(pos[i] != -1);
445
446 const llama_pos p_old = pos[i];
447
448 seq_pos_rm(i);
449
450 pos[i] /= d;
451 shift[i] += p_old - pos[i];
452
453 seq_pos_add(i);
454
455 has_shift = true;
456 }
457
458private:
459 bool has_shift = false;
460
461 // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
462 std::set<uint32_t> used;
463
464 std::vector<llama_pos> pos;
465
466 // stores extra info per cell
467 std::vector<llama_kv_cell_ext> ext;
468
469 // this array accumulates any applied shifts to the pos array since the last reset_shift() call
470 // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
471 //
472 // cells.pos_add(x, shift_x);
473 // cells.pos_div(y, shift_y);
474 // ...
475 //
476 // if (cells.has_shift()) {
477 // for (int i = 0; i < n; ++i) {
478 // auto shift_i = cells.get_shift(i);
479 // ...
480 // }
481 // cells.reset_shift();
482 // }
483 //
484 std::vector<llama_pos> shift;
485
486 using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
487
488 // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
489 std::vector<seq_set_t> seq;
490
491 // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
492 // if the position p is not present, seq_pos[s][p] is not set
493 // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
494 //
495 // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
496 // - during performing a cache reuse via (rm + add)
497 // - some vision models have input embeddings with repeating positions
498 //
499 std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
500
501 // helper functions for updating `seq_pos`, once cell at a time:
502
503 void seq_pos_dec(llama_seq_id s, llama_pos p) {
504 auto it = seq_pos[s].find(x: p);
505 assert(it != seq_pos[s].end());
506
507 if (--it->second == 0) {
508 seq_pos[s].erase(position: it);
509 }
510 }
511
512 void seq_pos_inc(llama_seq_id s, llama_pos p) {
513 seq_pos[s][p]++;
514 }
515
516 // remove cell i
517 void seq_pos_rm(uint32_t i) {
518 for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
519 if (seq[i].test(position: s)) {
520 seq_pos_dec(s, p: pos[i]);
521 }
522 }
523 }
524
525 // add cell i
526 void seq_pos_add(uint32_t i) {
527 for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
528 if (seq[i].test(position: s)) {
529 seq_pos_inc(s, p: pos[i]);
530 }
531 }
532 }
533};
534