1#pragma once
2
3#include "llama.h"
4
5#include <map>
6#include <memory>
7#include <functional>
8
9struct llama_ubatch;
10
11class llama_batch_allocr;
12
13class llama_io_write_i;
14class llama_io_read_i;
15
16struct llama_memory_params {
17 // kv cache
18 ggml_type type_k;
19 ggml_type type_v;
20
21 // use full-size SWA cache
22 bool swa_full;
23};
24
25enum llama_memory_status {
26 LLAMA_MEMORY_STATUS_SUCCESS = 0,
27 LLAMA_MEMORY_STATUS_NO_UPDATE,
28 LLAMA_MEMORY_STATUS_FAILED_PREPARE,
29 LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
30};
31
32// helper function for combining the status of two memory contexts
33// useful for implementing hybrid memory types (e.g. iSWA)
34llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
35
36// helper function for checking if a memory status indicates a failure
37bool llama_memory_status_is_fail(llama_memory_status status);
38
39// the interface for managing the memory context during batch processing
40// this interface is implemented per memory type. see:
41// - llama_kv_cache_context
42// - llama_kv_cache_iswa_context
43// ...
44//
45// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
46struct llama_memory_context_i {
47 virtual ~llama_memory_context_i() = default;
48
49 // consume the current ubatch from the context and proceed to the next one
50 // return false if we are done
51 virtual bool next() = 0;
52
53 // apply the memory state for the current ubatch to the memory object
54 // return false on failure
55 virtual bool apply() = 0;
56
57 // get the current ubatch
58 virtual const llama_ubatch & get_ubatch() const = 0;
59
60 // get the status of the memory context - used for error handling and checking if any updates would be applied
61 virtual llama_memory_status get_status() const = 0;
62};
63
64using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
65
66// general concept of LLM memory
67// the KV cache is a type of LLM memory, but there can be other types
68struct llama_memory_i {
69 // this callback is used to filter out layers that should not be included in the cache
70 using layer_filter_cb = std::function<bool(int32_t il)>;
71
72 // this callback is used to specify which layers should reuse memory from other layers
73 // return negative value to indicate that the layer il should not reuse memory
74 using layer_reuse_cb = std::function<int32_t(int32_t il)>;
75
76 virtual ~llama_memory_i() = default;
77
78 // split the input batch into a set of ubatches and verify that they can fit into the cache
79 // return a context object containing the ubatches and memory state required to process them
80 // check the llama_memory_context_i::get_status() for the result
81 virtual llama_memory_context_ptr init_batch(
82 llama_batch_allocr & balloc,
83 uint32_t n_ubatch,
84 bool embd_all) = 0;
85
86 // simulate full cache, used for allocating worst-case compute buffers
87 virtual llama_memory_context_ptr init_full() = 0;
88
89 // prepare for any pending memory updates, such as shifts, copies, etc.
90 // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
91 virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
92
93 // getters
94 virtual bool get_can_shift() const = 0;
95
96 //
97 // ops
98 //
99
100 // if data == true, the data buffers will also be cleared together with the metadata
101 virtual void clear(bool data) = 0;
102
103 virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
104 virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
105 virtual void seq_keep(llama_seq_id seq_id) = 0;
106 virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
107 virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
108
109 virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
110 virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
111
112 virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0;
113
114 //
115 // state write/read
116 //
117
118 virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
119 virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
120};
121
122using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
123