1#pragma once
2
3#include "llama-kv-cache.h"
4
5#include <vector>
6
7//
8// llama_kv_cache_iswa
9//
10
11// utilizes two instances of llama_kv_cache
12// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
13
14class llama_kv_cache_iswa : public llama_memory_i {
15public:
16 llama_kv_cache_iswa(
17 const llama_model & model,
18 ggml_type type_k,
19 ggml_type type_v,
20 bool v_trans,
21 bool offload,
22 bool swa_full,
23 bool unified,
24 uint32_t kv_size,
25 uint32_t n_seq_max,
26 uint32_t n_ubatch,
27 uint32_t n_pad,
28 const layer_filter_cb & filter,
29 const layer_reuse_cb & reuse);
30
31 ~llama_kv_cache_iswa() = default;
32
33 //
34 // llama_memory_i
35 //
36
37 llama_memory_context_ptr init_batch(
38 llama_batch_allocr & balloc,
39 uint32_t n_ubatch,
40 bool embd_all) override;
41
42 llama_memory_context_ptr init_full() override;
43
44 llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
45
46 bool get_can_shift() const override;
47
48 void clear(bool data) override;
49
50 bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
51 void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
52 void seq_keep(llama_seq_id seq_id) override;
53 void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
54 void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
55
56 llama_pos seq_pos_min(llama_seq_id seq_id) const override;
57 llama_pos seq_pos_max(llama_seq_id seq_id) const override;
58
59 std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
60
61 // state write/load
62
63 void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
64 void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
65
66 //
67 // llama_kv_cache_iswa specific API
68 //
69
70 llama_kv_cache * get_base() const;
71 llama_kv_cache * get_swa () const;
72
73private:
74 const llama_hparams & hparams;
75
76 const bool unified;
77
78 std::unique_ptr<llama_kv_cache> kv_base;
79 std::unique_ptr<llama_kv_cache> kv_swa;
80};
81
82class llama_kv_cache_iswa_context : public llama_memory_context_i {
83public:
84 using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
85
86 // used for errors
87 llama_kv_cache_iswa_context(llama_memory_status status);
88
89 // used to create a full-cache context
90 llama_kv_cache_iswa_context(
91 llama_kv_cache_iswa * kv);
92
93 // used to create an update context
94 llama_kv_cache_iswa_context(
95 llama_kv_cache_iswa * kv,
96 llama_context * lctx,
97 bool optimize);
98
99 // used to create a batch processing context from a batch
100 llama_kv_cache_iswa_context(
101 llama_kv_cache_iswa * kv,
102 slot_info_vec_t sinfos_base,
103 slot_info_vec_t sinfos_swa,
104 std::vector<llama_ubatch> ubatches);
105
106 virtual ~llama_kv_cache_iswa_context();
107
108 //
109 // llama_memory_context_i
110 //
111
112 bool next() override;
113 bool apply() override;
114
115 llama_memory_status get_status() const override;
116 const llama_ubatch & get_ubatch() const override;
117
118 //
119 // llama_kv_cache_iswa_context specific API
120 //
121
122 const llama_kv_cache_context * get_base() const;
123 const llama_kv_cache_context * get_swa() const;
124
125private:
126 //llama_kv_cache_iswa * kv;
127
128 // the index of the next ubatch to process
129 size_t i_next = 0;
130
131 std::vector<llama_ubatch> ubatches;
132
133 const llama_memory_context_ptr ctx_base;
134 const llama_memory_context_ptr ctx_swa;
135
136 const llama_memory_status status;
137};
138