1#include "llama-hparams.h"
2
3#include "ggml.h"
4#include <cassert>
5
6void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
7 if (dense_first) {
8 for (uint32_t il = 0; il < n_layer; ++il) {
9 swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
10 }
11 } else {
12 for (uint32_t il = 0; il < n_layer; ++il) {
13 swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
14 }
15 }
16}
17
18bool llama_hparams::is_swa_any() const {
19 for (uint32_t il = 0; il < n_layer; ++il) {
20 if (swa_layers[il]) {
21 return true;
22 }
23 }
24
25 return false;
26}
27
28uint32_t llama_hparams::n_head(uint32_t il) const {
29 if (il < n_layer) {
30 return n_head_arr[il];
31 }
32
33 GGML_ABORT("fatal error");
34}
35
36uint32_t llama_hparams::n_head_kv(uint32_t il) const {
37 if (il < n_layer) {
38 return n_head_kv_arr[il];
39 }
40
41 GGML_ABORT("fatal error");
42}
43
44uint32_t llama_hparams::n_ff(uint32_t il) const {
45 if (il < n_layer) {
46 return n_ff_arr[il];
47 }
48
49 GGML_ABORT("fatal error");
50}
51
52uint32_t llama_hparams::n_gqa(uint32_t il) const {
53 const uint32_t n_head = this->n_head(il);
54 const uint32_t n_head_kv = this->n_head_kv(il);
55
56 if (n_head_kv == 0) {
57 return 0;
58 }
59
60 return n_head/n_head_kv;
61}
62
63uint32_t llama_hparams::n_embd_inp() const {
64 uint32_t n_embd_inp = n_embd;
65
66 if (n_deepstack_layers > 0) {
67 n_embd_inp += n_embd * n_deepstack_layers;
68 }
69
70 return n_embd_inp;
71}
72
73uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
74 const uint32_t n_head_kv = this->n_head_kv(il);
75
76 return n_embd_head_k * n_head_kv;
77}
78
79uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
80 const uint32_t n_head_kv = this->n_head_kv(il);
81
82 return n_embd_head_v * n_head_kv;
83}
84
85bool llama_hparams::is_n_embd_k_gqa_variable() const {
86 const uint32_t val = n_embd_k_gqa();
87 for (uint32_t il = 0; il < n_layer; ++il) {
88 if (val != n_embd_k_gqa(il)) {
89 return true;
90 }
91 }
92
93 return false;
94}
95
96bool llama_hparams::is_n_embd_v_gqa_variable() const {
97 const uint32_t val = n_embd_v_gqa();
98 for (uint32_t il = 0; il < n_layer; ++il) {
99 if (val != n_embd_v_gqa(il)) {
100 return true;
101 }
102 }
103
104 return false;
105}
106
107uint32_t llama_hparams::n_embd_k_gqa_max() const {
108 uint32_t val = n_embd_k_gqa();
109 for (uint32_t il = 0; il < n_layer; ++il) {
110 val = std::max(a: val, b: n_embd_k_gqa(il));
111 }
112
113 return val;
114}
115
116uint32_t llama_hparams::n_embd_v_gqa_max() const {
117 uint32_t val = n_embd_v_gqa();
118 for (uint32_t il = 0; il < n_layer; ++il) {
119 val = std::max(a: val, b: n_embd_v_gqa(il));
120 }
121
122 return val;
123}
124
125uint32_t llama_hparams::n_embd_r() const {
126 if (wkv_head_size != 0) {
127 // for RWKV models
128 return token_shift_count * n_embd;
129 }
130
131 if (n_shortconv_l_cache != 0) {
132 // for LFM2 models
133 return n_embd * (n_shortconv_l_cache - 1);
134 }
135
136 // TODO: maybe support other convolution strides than 1
137 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
138 // Corresponds to Mamba's conv_states size
139 return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
140}
141
142uint32_t llama_hparams::n_embd_s() const {
143 if (wkv_head_size != 0) {
144 // corresponds to RWKV's wkv_states size
145 return n_embd * wkv_head_size;
146 }
147
148 // corresponds to Mamba's ssm_states size
149 return ssm_d_state * ssm_d_inner;
150}
151
152bool llama_hparams::is_recurrent(uint32_t il) const {
153 if (il < n_layer) {
154 return recurrent_layer_arr[il];
155 }
156
157 GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer);
158}
159
160uint32_t llama_hparams::n_pos_per_embd() const {
161 return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
162}
163
164bool llama_hparams::is_swa(uint32_t il) const {
165 if (il < n_layer) {
166 return swa_layers[il];
167 }
168
169 GGML_ABORT("fatal error");
170}
171
172bool llama_hparams::has_kv(uint32_t il) const {
173 if (n_layer_kv_from_start >= 0) {
174 if (il < (uint32_t) n_layer_kv_from_start) {
175 return true;
176 }
177
178 return false;
179 }
180
181 // by default, all layers have kv
182 return true;
183}
184
185uint32_t llama_hparams::n_layer_kv() const {
186 uint32_t res = 0;
187
188 for (uint32_t il = 0; il < n_layer; ++il) {
189 if (has_kv(il)) {
190 res++;
191 }
192 }
193
194 return res;
195}
196
197bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
198 assert(p0 >= 0 && p1 >= 0);
199
200 switch (swa_type) {
201 case LLAMA_SWA_TYPE_NONE:
202 {
203 } break;
204 case LLAMA_SWA_TYPE_STANDARD:
205 {
206 if (p1 - p0 >= (int32_t) n_swa) {
207 return true;
208 }
209 } break;
210 case LLAMA_SWA_TYPE_CHUNKED:
211 {
212 const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
213
214 if (p0 < pos_chunk_start) {
215 return true;
216 }
217 } break;
218 case LLAMA_SWA_TYPE_SYMMETRIC:
219 {
220 const int32_t half_n_swa = (int32_t) n_swa / 2;
221 const int32_t pos_diff = p1 - p0;
222
223 // Mask if outside the symmetric window
224 if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
225 return true;
226 }
227 } break;
228 }
229
230 return false;
231}
232