1#pragma once
2
3#include "llama-batch.h"
4#include "llama-graph.h"
5#include "llama-kv-cache.h"
6#include "llama-memory.h"
7#include "llama-memory-recurrent.h"
8
9#include <memory>
10#include <vector>
11
12//
13// llama_memory_hybrid
14//
15
16// utilizes instances of llama_memory_recurrent and llama_kv_cache to
17// support models where each layer may be either attention-based or recurrent
18
19class llama_memory_hybrid : public llama_memory_i {
20public:
21 llama_memory_hybrid(
22 const llama_model & model,
23 /* attn */
24 ggml_type type_k,
25 ggml_type type_v,
26 bool v_trans,
27 uint32_t kv_size,
28 uint32_t n_pad,
29 uint32_t n_swa,
30 llama_swa_type swa_type,
31 /* recurrent */
32 ggml_type type_r,
33 ggml_type type_s,
34 uint32_t rs_size,
35 /* common */
36 uint32_t n_seq_max,
37 bool offload,
38 bool unified,
39 /* layer filters */
40 const layer_filter_cb & filter_attn = nullptr,
41 const layer_filter_cb & filter_recr = nullptr);
42
43 ~llama_memory_hybrid() = default;
44
45 //
46 // llama_memory_i
47 //
48
49 llama_memory_context_ptr init_batch(
50 llama_batch_allocr & balloc,
51 uint32_t n_ubatch,
52 bool embd_all) override;
53
54 llama_memory_context_ptr init_full() override;
55
56 llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
57
58 bool get_can_shift() const override;
59
60 void clear(bool data) override;
61
62 bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
63 void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
64 void seq_keep(llama_seq_id seq_id) override;
65 void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
66 void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
67
68 llama_pos seq_pos_min(llama_seq_id seq_id) const override;
69 llama_pos seq_pos_max(llama_seq_id seq_id) const override;
70
71 std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
72
73 // state write/load
74
75 void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
76 void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
77
78 //
79 // llama_memory_hybrid specific API
80 //
81
82 llama_kv_cache * get_mem_attn() const;
83 llama_memory_recurrent * get_mem_recr() const;
84
85private:
86 const llama_hparams & hparams;
87
88 const std::unique_ptr<llama_kv_cache> mem_attn;
89 const std::unique_ptr<llama_memory_recurrent> mem_recr;
90};
91
92class llama_memory_hybrid_context : public llama_memory_context_i {
93public:
94 using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
95
96 // init failure
97 explicit llama_memory_hybrid_context(llama_memory_status status);
98
99 // init full
100 explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
101
102 // init update
103 explicit llama_memory_hybrid_context(
104 llama_memory_hybrid * mem,
105 llama_context * lctx,
106 bool optimize);
107
108 // init success
109 llama_memory_hybrid_context(
110 llama_memory_hybrid * mem,
111 slot_info_vec_t sinfos_attn,
112 std::vector<llama_ubatch> ubatches);
113
114 ~llama_memory_hybrid_context() = default;
115
116 bool next() override;
117 bool apply() override;
118
119 llama_memory_status get_status() const override;
120 const llama_ubatch & get_ubatch() const override;
121
122 //
123 // llama_memory_hybrid_context
124 //
125
126 const llama_kv_cache_context * get_attn() const;
127 const llama_memory_recurrent_context * get_recr() const;
128
129private:
130 // the index of the next ubatch to process
131 size_t i_next = 0;
132
133 std::vector<llama_ubatch> ubatches;
134
135 const llama_memory_context_ptr ctx_attn;
136 const llama_memory_context_ptr ctx_recr;
137
138 const llama_memory_status status;
139};
140