1#pragma once
2
3#include "llama-batch.h"
4#include "llama-graph.h"
5#include "llama-memory.h"
6
7#include <map>
8#include <set>
9#include <vector>
10
11//
12// llama_memory_recurrent
13//
14
15// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
16// see the implementation of llama_kv_cache_context_i for an example how to do it
17class llama_memory_recurrent : public llama_memory_i {
18public:
19 llama_memory_recurrent(
20 const llama_model & model,
21 ggml_type type_r,
22 ggml_type type_s,
23 bool offload,
24 uint32_t mem_size,
25 uint32_t n_seq_max,
26 const layer_filter_cb & filter);
27
28 ~llama_memory_recurrent() = default;
29
30 //
31 // llama_memory_i
32 //
33
34 llama_memory_context_ptr init_batch(
35 llama_batch_allocr & balloc,
36 uint32_t n_ubatch,
37 bool embd_all) override;
38
39 llama_memory_context_ptr init_full() override;
40
41 llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
42
43 void clear(bool data) override;
44
45 bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
46 void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
47 void seq_keep(llama_seq_id seq_id) override;
48 void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
49 void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
50
51 llama_pos seq_pos_min(llama_seq_id seq_id) const override;
52 llama_pos seq_pos_max(llama_seq_id seq_id) const override;
53
54 std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
55
56 bool prepare(const std::vector<llama_ubatch> & ubatches);
57
58 // find a contiguous slot of memory cells and emplace the ubatch there
59 bool find_slot(const llama_ubatch & ubatch);
60
61 bool get_can_shift() const override;
62
63 // state write/load
64
65 void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
66 void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
67
68 uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
69 uint32_t size = 0; // total number of cells, shared across all sequences
70 uint32_t used = 0; // used cells (i.e. at least one seq_id)
71
72 // computed before each graph build
73 uint32_t n = 0;
74
75 // first zero-ed state
76 int32_t rs_z = -1;
77
78 // TODO: optimize for recurrent state needs
79 struct mem_cell {
80 llama_pos pos = -1;
81 int32_t src = -1; // used to know where states should be copied from
82 int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
83 int32_t tail = -1;
84
85 std::set<llama_seq_id> seq_id;
86
87 bool has_seq_id(const llama_seq_id & id) const {
88 return seq_id.find(x: id) != seq_id.end();
89 }
90
91 bool is_empty() const {
92 return seq_id.empty();
93 }
94
95 bool is_same_seq(const mem_cell & other) const {
96 return seq_id == other.seq_id;
97 }
98 };
99
100 std::vector<mem_cell> cells;
101
102 // per layer
103 std::vector<ggml_tensor *> r_l;
104 std::vector<ggml_tensor *> s_l;
105
106private:
107 //const llama_model & model;
108 const llama_hparams & hparams;
109
110 const uint32_t n_seq_max = 1;
111
112 // ggml contexts for the KV cache along with the allocated backend buffers:
113 std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114
115 size_t total_size() const;
116
117 size_t size_r_bytes() const;
118 size_t size_s_bytes() const;
119
120 void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
121 void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
122
123 bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
124 bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
125};
126
127class llama_memory_recurrent_context : public llama_memory_context_i {
128public:
129 // used for errors
130 llama_memory_recurrent_context(llama_memory_status status);
131
132 // used to create a full-cache or update context
133 llama_memory_recurrent_context(
134 llama_memory_recurrent * mem);
135
136 // used to create a batch processing context from a batch
137 llama_memory_recurrent_context(
138 llama_memory_recurrent * mem,
139 std::vector<llama_ubatch> ubatches);
140
141 virtual ~llama_memory_recurrent_context();
142
143 //
144 // llama_memory_context_i
145 //
146
147 bool next() override;
148 bool apply() override;
149
150 llama_memory_status get_status() const override;
151 const llama_ubatch & get_ubatch() const override;
152
153 //
154 // llama_memory_recurrent_context specific API
155 //
156
157 uint32_t get_n_rs() const;
158 uint32_t get_head() const;
159 int32_t get_rs_z() const;
160 uint32_t get_size() const;
161
162 ggml_tensor * get_r_l(int32_t il) const;
163 ggml_tensor * get_s_l(int32_t il) const;
164
165 int32_t s_copy(int i) const;
166
167private:
168 const llama_memory_status status;
169
170 llama_memory_recurrent * mem;
171
172 size_t i_next = 0;
173
174 std::vector<llama_ubatch> ubatches;
175
176 //
177 // data needed for building the compute graph for the current ubatch:
178 // TODO: extract all the state like `head` and `n` here
179 //
180
181 const bool is_full = false;
182};
183