1#include "llama-kv-cache-iswa.h"
2
3#include "llama-impl.h"
4#include "llama-batch.h"
5#include "llama-model.h"
6
7#include <algorithm>
8#include <cassert>
9
10//
11// llama_kv_cache_iswa
12//
13
14llama_kv_cache_iswa::llama_kv_cache_iswa(
15 const llama_model & model,
16 ggml_type type_k,
17 ggml_type type_v,
18 bool v_trans,
19 bool offload,
20 bool swa_full,
21 bool unified,
22 uint32_t kv_size,
23 uint32_t n_seq_max,
24 uint32_t n_ubatch,
25 uint32_t n_pad,
26 const layer_filter_cb & filter,
27 const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28
29 // chain filters
30 const layer_filter_cb filter_base = [&](int32_t il) {
31 if (filter && !filter(il)) {
32 return false;
33 }
34
35 return !model.hparams.is_swa(il);
36 };
37
38 const layer_filter_cb filter_swa = [&](int32_t il) {
39 if (filter && !filter(il)) {
40 return false;
41 }
42
43 return model.hparams.is_swa(il);
44 };
45
46 const uint32_t size_base = kv_size;
47
48 // note: the SWA cache is always padded to 256 for performance
49 // https://github.com/ggml-org/llama.cpp/issues/17037
50 uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
51
52 // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
53 if (swa_full) {
54 LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
55 __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
56
57 size_swa = size_base;
58 }
59
60 LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
61
62 kv_base = std::make_unique<llama_kv_cache>(
63 args: model, args&: type_k, args&: type_v,
64 args&: v_trans, args&: offload, args&: unified, args: size_base, args&: n_seq_max, args&: n_pad,
65 args: 0, args: LLAMA_SWA_TYPE_NONE, args: filter_base, args: reuse);
66
67 LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
68
69 kv_swa = std::make_unique<llama_kv_cache>(
70 args: model, args&: type_k, args&: type_v,
71 args&: v_trans, args&: offload, args&: unified, args&: size_swa, args&: n_seq_max, args&: n_pad,
72 args: hparams.n_swa, args: hparams.swa_type, args: filter_swa, args: reuse);
73}
74
75void llama_kv_cache_iswa::clear(bool data) {
76 kv_base->clear(data);
77 kv_swa ->clear(data);
78}
79
80bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
81 bool res = true;
82
83 res = res & kv_base->seq_rm(seq_id, p0, p1);
84 res = res & kv_swa ->seq_rm(seq_id, p0, p1);
85
86 return res;
87}
88
89void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
90 kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
91 kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
92}
93
94void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
95 kv_base->seq_keep(seq_id);
96 kv_swa ->seq_keep(seq_id);
97}
98
99void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
100 kv_base->seq_add(seq_id, p0, p1, shift);
101 kv_swa ->seq_add(seq_id, p0, p1, shift);
102}
103
104void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
105 kv_base->seq_div(seq_id, p0, p1, d);
106 kv_swa ->seq_div(seq_id, p0, p1, d);
107}
108
109llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
110 // the base cache is a superset of the SWA cache, so we can just check the SWA cache
111 return kv_swa->seq_pos_min(seq_id);
112}
113
114llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
115 return kv_swa->seq_pos_max(seq_id);
116}
117
118std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
119 std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
120 for (const auto & buft_size : kv_swa->memory_breakdown()) {
121 mb[buft_size.first] += buft_size.second;
122 }
123 return mb;
124}
125
126llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
127 GGML_UNUSED(embd_all);
128
129 // first try simple split
130 do {
131 if (!unified) {
132 // requires equal splits, so we skip the simple split
133 break;
134 }
135
136 balloc.split_reset();
137
138 std::vector<llama_ubatch> ubatches;
139 while (true) {
140 auto ubatch = balloc.split_simple(n_ubatch);
141
142 if (ubatch.n_tokens == 0) {
143 break;
144 }
145
146 ubatches.push_back(x: std::move(ubatch)); // NOLINT
147 }
148
149 if (balloc.get_n_used() < balloc.get_n_tokens()) {
150 // failed to find a suitable split
151 break;
152 }
153
154 auto sinfos_base = kv_base->prepare(ubatches);
155 if (sinfos_base.empty()) {
156 break;
157 }
158
159 auto sinfos_swa = kv_swa->prepare(ubatches);
160 if (sinfos_swa.empty()) {
161 break;
162 }
163
164 assert(sinfos_base.size() == sinfos_swa.size());
165
166 return std::make_unique<llama_kv_cache_iswa_context>(
167 args: this, args: std::move(sinfos_base), args: std::move(sinfos_swa), args: std::move(ubatches));
168 } while (false);
169
170 // if it fails, try equal split
171 do {
172 balloc.split_reset();
173
174 std::vector<llama_ubatch> ubatches;
175 while (true) {
176 auto ubatch = balloc.split_equal(n_ubatch, sequential: !unified);
177
178 if (ubatch.n_tokens == 0) {
179 break;
180 }
181
182 ubatches.push_back(x: std::move(ubatch)); // NOLINT
183 }
184
185 if (balloc.get_n_used() < balloc.get_n_tokens()) {
186 // failed to find a suitable split
187 break;
188 }
189
190 auto sinfos_base = kv_base->prepare(ubatches);
191 if (sinfos_base.empty()) {
192 break;
193 }
194
195 auto sinfos_swa = kv_swa->prepare(ubatches);
196 if (sinfos_swa.empty()) {
197 break;
198 }
199
200 assert(sinfos_base.size() == sinfos_swa.size());
201
202 return std::make_unique<llama_kv_cache_iswa_context>(
203 args: this, args: std::move(sinfos_base), args: std::move(sinfos_swa), args: std::move(ubatches));
204 } while (false);
205
206 // TODO: if we fail again, we should attempt different splitting strategies
207 // but to do that properly, we first have to refactor the batches to be more flexible
208
209 return std::make_unique<llama_kv_cache_iswa_context>(args: LLAMA_MEMORY_STATUS_FAILED_PREPARE);
210}
211
212llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
213 return std::make_unique<llama_kv_cache_iswa_context>(args: this);
214}
215
216llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
217 return std::make_unique<llama_kv_cache_iswa_context>(args: this, args&: lctx, args&: optimize);
218}
219
220bool llama_kv_cache_iswa::get_can_shift() const {
221 return kv_base->get_size() == kv_swa->get_size();
222}
223
224void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
225 if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
226 kv_base->state_write(io, seq_id, flags);
227 }
228
229 kv_swa->state_write(io, seq_id, flags);
230}
231
232void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
233 if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
234 kv_base->state_read(io, seq_id, flags);
235 }
236
237 kv_swa->state_read(io, seq_id, flags);
238}
239
240llama_kv_cache * llama_kv_cache_iswa::get_base() const {
241 return kv_base.get();
242}
243
244llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
245 return kv_swa.get();
246}
247
248//
249// llama_kv_cache_iswa_context
250//
251
252llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
253
254llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
255 llama_kv_cache_iswa * kv) :
256 ctx_base(kv->get_base()->init_full()),
257 ctx_swa (kv->get_swa ()->init_full()),
258 status(llama_memory_status_combine(s0: ctx_base->get_status(), s1: ctx_swa->get_status())) {
259}
260
261llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
262 llama_kv_cache_iswa * kv,
263 llama_context * lctx,
264 bool optimize) :
265 ctx_base(kv->get_base()->init_update(lctx, optimize)),
266 ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
267 status(llama_memory_status_combine(s0: ctx_base->get_status(), s1: ctx_swa->get_status())) {
268}
269
270llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
271 llama_kv_cache_iswa * kv,
272 slot_info_vec_t sinfos_base,
273 slot_info_vec_t sinfos_swa,
274 std::vector<llama_ubatch> ubatches) :
275 ubatches(std::move(ubatches)),
276 // note: here we copy the ubatches. not sure if this is ideal
277 ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
278 ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
279 status(llama_memory_status_combine(s0: ctx_base->get_status(), s1: ctx_swa->get_status())) {
280}
281
282llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
283
284bool llama_kv_cache_iswa_context::next() {
285 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
286
287 ctx_base->next();
288 ctx_swa ->next();
289
290 if (++i_next >= ubatches.size()) {
291 return false;
292 }
293
294 return true;
295}
296
297bool llama_kv_cache_iswa_context::apply() {
298 assert(!llama_memory_status_is_fail(status));
299
300 bool res = true;
301
302 res = res & ctx_base->apply();
303 res = res & ctx_swa ->apply();
304
305 return res;
306}
307
308llama_memory_status llama_kv_cache_iswa_context::get_status() const {
309 return status;
310}
311
312const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
313 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
314
315 return ubatches[i_next];
316}
317
318const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
319 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
320
321 return static_cast<const llama_kv_cache_context *>(ctx_base.get());
322}
323
324const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
325 assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
326
327 return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
328}
329