| 1 | #pragma once |
| 2 | |
| 3 | #include "llama-batch.h" |
| 4 | #include "llama-graph.h" |
| 5 | #include "llama-kv-cells.h" |
| 6 | #include "llama-memory.h" |
| 7 | |
| 8 | #include <unordered_map> |
| 9 | #include <vector> |
| 10 | |
| 11 | struct llama_cparams; |
| 12 | struct llama_hparams; |
| 13 | struct llama_model; |
| 14 | struct llama_context; |
| 15 | |
| 16 | // |
| 17 | // llama_kv_cache |
| 18 | // |
| 19 | |
| 20 | class llama_kv_cache : public llama_memory_i { |
| 21 | public: |
| 22 | struct stream_copy_info { |
| 23 | bool empty() const { |
| 24 | assert(ssrc.size() == sdst.size()); |
| 25 | return ssrc.empty(); |
| 26 | } |
| 27 | |
| 28 | std::vector<uint32_t> ssrc; |
| 29 | std::vector<uint32_t> sdst; |
| 30 | }; |
| 31 | |
| 32 | // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the |
| 33 | // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] |
| 34 | struct slot_info { |
| 35 | // data for ggml_set_rows |
| 36 | using idx_vec_t = std::vector<uint32_t>; |
| 37 | |
| 38 | // number of streams: ns = s1 - s0 + 1 |
| 39 | uint32_t s0; |
| 40 | uint32_t s1; |
| 41 | |
| 42 | std::vector<llama_seq_id> strm; // [ns] |
| 43 | std::vector<idx_vec_t> idxs; // [ns] |
| 44 | |
| 45 | uint32_t head() const { |
| 46 | GGML_ASSERT(idxs.size() == 1); |
| 47 | GGML_ASSERT(!idxs[0].empty()); |
| 48 | |
| 49 | return idxs[0][0]; |
| 50 | } |
| 51 | |
| 52 | void resize(size_t n) { |
| 53 | strm.resize(new_size: n); |
| 54 | idxs.resize(new_size: n); |
| 55 | } |
| 56 | |
| 57 | size_t size() const { |
| 58 | GGML_ASSERT(idxs.size() == strm.size()); |
| 59 | GGML_ASSERT(!idxs.empty()); |
| 60 | |
| 61 | return idxs[0].size(); |
| 62 | } |
| 63 | |
| 64 | size_t n_stream() const { |
| 65 | return strm.size(); |
| 66 | } |
| 67 | |
| 68 | bool empty() const { |
| 69 | return idxs.empty(); |
| 70 | } |
| 71 | |
| 72 | void clear() { |
| 73 | idxs.clear(); |
| 74 | } |
| 75 | }; |
| 76 | |
| 77 | using slot_info_vec_t = std::vector<slot_info>; |
| 78 | |
| 79 | llama_kv_cache( |
| 80 | const llama_model & model, |
| 81 | ggml_type type_k, |
| 82 | ggml_type type_v, |
| 83 | bool v_trans, |
| 84 | bool offload, |
| 85 | bool unified, |
| 86 | uint32_t kv_size, |
| 87 | uint32_t n_seq_max, |
| 88 | uint32_t n_pad, |
| 89 | uint32_t n_swa, |
| 90 | llama_swa_type swa_type, |
| 91 | const layer_filter_cb & filter, |
| 92 | const layer_reuse_cb & reuse); |
| 93 | |
| 94 | ~llama_kv_cache() = default; |
| 95 | |
| 96 | // |
| 97 | // llama_memory_i |
| 98 | // |
| 99 | |
| 100 | llama_memory_context_ptr init_batch( |
| 101 | llama_batch_allocr & balloc, |
| 102 | uint32_t n_ubatch, |
| 103 | bool embd_all) override; |
| 104 | |
| 105 | llama_memory_context_ptr init_full() override; |
| 106 | |
| 107 | llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; |
| 108 | |
| 109 | bool get_can_shift() const override; |
| 110 | |
| 111 | void clear(bool data) override; |
| 112 | |
| 113 | bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; |
| 114 | void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |
| 115 | void seq_keep(llama_seq_id seq_id) override; |
| 116 | void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; |
| 117 | void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; |
| 118 | |
| 119 | llama_pos seq_pos_min(llama_seq_id seq_id) const override; |
| 120 | llama_pos seq_pos_max(llama_seq_id seq_id) const override; |
| 121 | |
| 122 | std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; |
| 123 | |
| 124 | // state write/load |
| 125 | |
| 126 | void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; |
| 127 | void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; |
| 128 | |
| 129 | // |
| 130 | // llama_kv_cache specific API |
| 131 | // |
| 132 | |
| 133 | uint32_t get_size() const; |
| 134 | uint32_t get_n_stream() const; |
| 135 | |
| 136 | bool get_has_shift() const; |
| 137 | |
| 138 | // |
| 139 | // graph_build API |
| 140 | // |
| 141 | |
| 142 | uint32_t get_n_kv(const slot_info & sinfo) const; |
| 143 | |
| 144 | // get views of the current state of the cache |
| 145 | ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
| 146 | ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; |
| 147 | |
| 148 | // store k_cur and v_cur in the cache based on the provided head location |
| 149 | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; |
| 150 | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; |
| 151 | |
| 152 | // |
| 153 | // preparation API |
| 154 | // |
| 155 | |
| 156 | // find places for the provided ubatches in the cache, returns the slot infos |
| 157 | // return empty vector on failure |
| 158 | slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches); |
| 159 | |
| 160 | bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); |
| 161 | |
| 162 | // find a slot of kv cells that can hold the ubatch |
| 163 | // if cont == true, then the slot must be continuous |
| 164 | // return empty slot_info on failure |
| 165 | slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; |
| 166 | |
| 167 | // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] |
| 168 | void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); |
| 169 | |
| 170 | // |
| 171 | // input API |
| 172 | // |
| 173 | |
| 174 | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
| 175 | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
| 176 | |
| 177 | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
| 178 | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; |
| 179 | |
| 180 | void set_input_k_shift(ggml_tensor * dst) const; |
| 181 | |
| 182 | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
| 183 | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
| 184 | |
| 185 | private: |
| 186 | const llama_model & model; |
| 187 | const llama_hparams & hparams; |
| 188 | |
| 189 | struct kv_layer { |
| 190 | // layer index in the model |
| 191 | // note: can be different from the layer index in the KV cache |
| 192 | uint32_t il; |
| 193 | |
| 194 | ggml_tensor * k; |
| 195 | ggml_tensor * v; |
| 196 | |
| 197 | std::vector<ggml_tensor *> k_stream; |
| 198 | std::vector<ggml_tensor *> v_stream; |
| 199 | }; |
| 200 | |
| 201 | bool v_trans = true; // the value tensor is transposed |
| 202 | |
| 203 | const uint32_t n_seq_max = 1; |
| 204 | const uint32_t n_stream = 1; |
| 205 | |
| 206 | // required padding |
| 207 | const uint32_t n_pad = 1; |
| 208 | |
| 209 | // SWA |
| 210 | const uint32_t n_swa = 0; |
| 211 | |
| 212 | // env: LLAMA_KV_CACHE_DEBUG |
| 213 | int debug = 0; |
| 214 | |
| 215 | // this is the SWA type of the cache - not to be confused with the model SWA type |
| 216 | const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; |
| 217 | |
| 218 | // ggml contexts for the KV cache along with the allocated backend buffers: |
| 219 | std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs; |
| 220 | |
| 221 | // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) |
| 222 | // note: this is not part of the KV state and it's only used to speed-up the find_slot() method |
| 223 | std::vector<uint32_t> v_heads; |
| 224 | |
| 225 | std::vector<llama_kv_cells> v_cells; |
| 226 | |
| 227 | // maps from a sequence id to a stream id |
| 228 | std::vector<uint32_t> seq_to_stream; |
| 229 | |
| 230 | // pending stream copies that will be applied during the next update |
| 231 | stream_copy_info sc_info; |
| 232 | |
| 233 | std::vector<kv_layer> layers; |
| 234 | |
| 235 | // model layer id -> KV cache layer id |
| 236 | std::unordered_map<int32_t, int32_t> map_layer_ids; |
| 237 | |
| 238 | size_t total_size() const; |
| 239 | |
| 240 | size_t size_k_bytes() const; |
| 241 | size_t size_v_bytes() const; |
| 242 | |
| 243 | bool is_masked_swa(llama_pos p0, llama_pos p1) const; |
| 244 | |
| 245 | ggml_tensor * build_rope_shift( |
| 246 | const llama_cparams & cparams, |
| 247 | ggml_context * ctx, |
| 248 | ggml_tensor * cur, |
| 249 | ggml_tensor * shift, |
| 250 | ggml_tensor * factors, |
| 251 | float freq_base, |
| 252 | float freq_scale) const; |
| 253 | |
| 254 | ggml_cgraph * build_graph_shift( |
| 255 | llm_graph_result * res, |
| 256 | llama_context * lctx) const; |
| 257 | |
| 258 | struct cell_ranges_t { |
| 259 | uint32_t strm; |
| 260 | |
| 261 | std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive |
| 262 | }; |
| 263 | |
| 264 | void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; |
| 265 | void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; |
| 266 | |
| 267 | bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); |
| 268 | bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); |
| 269 | }; |
| 270 | |
| 271 | class llama_kv_cache_context : public llama_memory_context_i { |
| 272 | public: |
| 273 | // some shorthands |
| 274 | using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; |
| 275 | using stream_copy_info = llama_kv_cache::stream_copy_info; |
| 276 | |
| 277 | // used for errors |
| 278 | llama_kv_cache_context(llama_memory_status status); |
| 279 | |
| 280 | // used to create a full-cache context |
| 281 | llama_kv_cache_context( |
| 282 | llama_kv_cache * kv); |
| 283 | |
| 284 | // used to create an update context |
| 285 | llama_kv_cache_context( |
| 286 | llama_kv_cache * kv, |
| 287 | llama_context * lctx, |
| 288 | bool do_shift, |
| 289 | stream_copy_info sc_info); |
| 290 | |
| 291 | // used to create a batch procesing context from a batch |
| 292 | llama_kv_cache_context( |
| 293 | llama_kv_cache * kv, |
| 294 | slot_info_vec_t sinfos, |
| 295 | std::vector<llama_ubatch> ubatches); |
| 296 | |
| 297 | virtual ~llama_kv_cache_context(); |
| 298 | |
| 299 | // |
| 300 | // llama_memory_context_i |
| 301 | // |
| 302 | |
| 303 | bool next() override; |
| 304 | bool apply() override; |
| 305 | |
| 306 | llama_memory_status get_status() const override; |
| 307 | const llama_ubatch & get_ubatch() const override; |
| 308 | |
| 309 | // |
| 310 | // llama_kv_cache_context specific API |
| 311 | // |
| 312 | |
| 313 | uint32_t get_n_kv() const; |
| 314 | |
| 315 | // get views of the current state of the cache |
| 316 | ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; |
| 317 | ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; |
| 318 | |
| 319 | // store k_cur and v_cur in the cache based on the provided head location |
| 320 | // note: the heads in k_cur and v_cur should be layed out contiguously in memory |
| 321 | // - k_cur [n_embd_head_k, n_head_k, n_tokens] |
| 322 | // - k_idxs [n_tokens] |
| 323 | // - v_cur [n_embd_head_v, n_head_v, n_tokens] |
| 324 | // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed |
| 325 | ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; |
| 326 | ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; |
| 327 | |
| 328 | // create destination indices for each head of the current batch for where it would be written in the KV cache |
| 329 | // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but |
| 330 | // helps understand the implementation logic of cpy_k and cpy_v |
| 331 | ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
| 332 | ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; |
| 333 | |
| 334 | void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
| 335 | void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
| 336 | |
| 337 | void set_input_k_shift (ggml_tensor * dst) const; |
| 338 | void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; |
| 339 | void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |
| 340 | |
| 341 | private: |
| 342 | llama_memory_status status; |
| 343 | |
| 344 | llama_kv_cache * kv; |
| 345 | llama_context * lctx; |
| 346 | |
| 347 | // |
| 348 | // update context |
| 349 | // |
| 350 | |
| 351 | bool do_shift = false; |
| 352 | |
| 353 | stream_copy_info sc_info; |
| 354 | |
| 355 | // |
| 356 | // batch processing context |
| 357 | // |
| 358 | |
| 359 | // the index of the cur ubatch to process |
| 360 | size_t i_cur = 0; |
| 361 | |
| 362 | slot_info_vec_t sinfos; |
| 363 | |
| 364 | std::vector<llama_ubatch> ubatches; |
| 365 | |
| 366 | // |
| 367 | // data needed for building the compute graph for the current ubatch: |
| 368 | // |
| 369 | |
| 370 | // a heuristic, to avoid attending the full cache if it is not yet utilized |
| 371 | // as the cache gets filled, the benefit from this heuristic disappears |
| 372 | int32_t n_kv; |
| 373 | }; |
| 374 | |