1#include "models.h"
2
3llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
4
5ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
6 ggml_tensor * cur,
7 const llama_model & model,
8 const llama_ubatch & ubatch,
9 int il) {
10 const auto * mctx_cur = inp->mctx;
11
12 const auto kv_head = mctx_cur->get_head();
13
14 const auto & layer = model.layers[il];
15
16 const int64_t d_conv = hparams.ssm_d_conv;
17 const int64_t d_inner = hparams.ssm_d_inner;
18 const int64_t d_state = hparams.ssm_d_state;
19 const int64_t dt_rank = hparams.ssm_dt_rank;
20 const int64_t n_head = d_inner;
21 const int64_t head_dim = 1;
22 const int64_t n_seqs = ubatch.n_seqs;
23 // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
24 const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
25
26 const int64_t n_seq_tokens = ubatch.n_seq_tokens;
27
28 GGML_ASSERT(n_seqs != 0);
29 GGML_ASSERT(ubatch.equal_seqs());
30 GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
31
32 ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
33 ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
34
35 ggml_tensor * conv = build_rs(inp, s: conv_states_all, state_size: hparams.n_embd_r(), n_seqs);
36 conv = ggml_reshape_3d(ctx: ctx0, a: conv, ne0: d_conv - 1, ne1: d_inner, ne2: n_seqs);
37
38 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
39 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens, ne2: n_seqs);
40
41 // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
42 ggml_tensor * xz = build_lora_mm(w: layer.ssm_in, cur);
43 // split the above in two
44 // => {d_inner, n_seq_tokens, n_seqs}
45 ggml_tensor * x = ggml_view_3d(ctx: ctx0, a: xz, ne0: d_inner, ne1: xz->ne[1], ne2: xz->ne[2], nb1: xz->nb[1], nb2: xz->nb[2], offset: 0);
46 ggml_tensor * z =
47 ggml_view_3d(ctx: ctx0, a: xz, ne0: d_inner, ne1: xz->ne[1], ne2: xz->ne[2], nb1: xz->nb[1], nb2: xz->nb[2], offset: d_inner * ggml_element_size(tensor: xz));
48
49 // conv
50 {
51 // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
52 ggml_tensor * conv_x = ggml_concat(ctx: ctx0, a: conv, b: ggml_transpose(ctx: ctx0, a: x), dim: 0);
53
54 // copy last (d_conv - 1) columns back into the state cache
55 ggml_tensor * last_conv = ggml_view_3d(ctx: ctx0, a: conv_x, ne0: d_conv - 1, ne1: d_inner, ne2: n_seqs, nb1: conv_x->nb[1], nb2: conv_x->nb[2],
56 offset: n_seq_tokens * (conv_x->nb[0]));
57
58 ggml_build_forward_expand(
59 cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: last_conv,
60 b: ggml_view_1d(ctx: ctx0, a: conv_states_all, ne0: (d_conv - 1) * (d_inner) * (n_seqs),
61 offset: kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(tensor: conv_states_all))));
62
63 // 1D convolution
64 // The equivalent is to make a self-overlapping view of conv_x
65 // over d_conv columns at each stride in the 3rd dimension,
66 // then element-wise multiply that with the conv1d weight,
67 // then sum the elements of each row,
68 // (the last two steps are a dot product over rows (also doable with mul_mat))
69 // then permute away the ne[0] dimension,
70 // and then you're left with the resulting x tensor.
71 // For simultaneous sequences, all sequences need to have the same length.
72 x = ggml_ssm_conv(ctx: ctx0, sx: conv_x, c: layer.ssm_conv1d);
73
74 // bias
75 x = ggml_add(ctx: ctx0, a: x, b: layer.ssm_conv1d_b);
76
77 x = ggml_silu(ctx: ctx0, a: x);
78 }
79
80 // ssm
81 {
82 // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
83 ggml_tensor * x_db = build_lora_mm(w: layer.ssm_x, cur: x);
84 // split
85 ggml_tensor * dt = ggml_view_3d(ctx: ctx0, a: x_db, ne0: dt_rank, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_db->nb[1], nb2: x_db->nb[2], offset: 0);
86 ggml_tensor * B =
87 ggml_view_4d(ctx: ctx0, a: x_db, ne0: d_state, /* n_group */ ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * x_db->nb[0], nb2: x_db->nb[1],
88 nb3: x_db->nb[2], offset: ggml_element_size(tensor: x_db) * dt_rank);
89 ggml_tensor * C =
90 ggml_view_4d(ctx: ctx0, a: x_db, ne0: d_state, /* n_group */ ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * x_db->nb[0], nb2: x_db->nb[1],
91 nb3: x_db->nb[2], offset: ggml_element_size(tensor: x_db) * (dt_rank + d_state));
92
93 // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers
94 if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) {
95 dt = build_norm(cur: dt, mw: layer.ssm_dt_norm, NULL, type: LLM_NORM_RMS, il);
96 B = build_norm(cur: B, mw: layer.ssm_b_norm, NULL, type: LLM_NORM_RMS, il);
97 C = build_norm(cur: C, mw: layer.ssm_c_norm, NULL, type: LLM_NORM_RMS, il);
98 }
99
100 // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
101 dt = build_lora_mm(w: layer.ssm_dt, cur: dt);
102 dt = ggml_add(ctx: ctx0, a: dt, b: layer.ssm_dt_b);
103
104 cur = x;
105 x = ggml_reshape_4d(ctx: ctx0, a: x, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs);
106
107 ggml_tensor * A = layer.ssm_a;
108
109 // use the states and the indices provided by build_recurrent_state
110 // (this is necessary in order to properly use the states before they are overwritten,
111 // while avoiding to make unnecessary copies of the states)
112 auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
113 ggml_tensor * ssm = ggml_reshape_4d(ctx, a: states, ne0: d_state, ne1: head_dim, ne2: n_head, ne3: mctx_cur->get_size());
114
115 // Custom operator to optimize the parallel associative scan
116 // as described in the Annex D of the Mamba paper.
117 // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
118 return ggml_ssm_scan(ctx, s: ssm, x, dt, A, B, C, ids);
119 };
120
121 ggml_tensor * y_ssm = build_rs(inp, s: ssm_states_all, state_size: hparams.n_embd_s(), n_seqs: ubatch.n_seqs, get_state_rows: get_ssm_rows);
122
123 // store last states
124 ggml_build_forward_expand(
125 cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: ggml_view_1d(ctx: ctx0, a: y_ssm, ne0: d_state * d_inner * n_seqs, offset: x->nb[3] * x->ne[3]),
126 b: ggml_view_1d(ctx: ctx0, a: ssm_states_all, ne0: d_state * d_inner * n_seqs,
127 offset: kv_head * d_state * d_inner * ggml_element_size(tensor: ssm_states_all))));
128
129 ggml_tensor * y = ggml_view_3d(ctx: ctx0, a: y_ssm, ne0: d_inner, ne1: n_seq_tokens, ne2: n_seqs, nb1: x->nb[2], nb2: x->nb[3], offset: 0);
130
131 // TODO: skip computing output earlier for unused tokens
132
133 y = ggml_add(ctx: ctx0, a: y, b: ggml_mul(ctx: ctx0, a: cur, b: layer.ssm_d));
134 y = ggml_swiglu_split(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: z), b: y);
135
136 // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
137 cur = build_lora_mm(w: layer.ssm_out, cur: y);
138 }
139
140 // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
141 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens * n_seqs);
142
143 return cur;
144}
145
146ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp,
147 ggml_tensor * cur,
148 const llama_model & model,
149 const llama_ubatch & ubatch,
150 int il) const {
151 const auto * mctx_cur = inp->mctx;
152
153 const auto kv_head = mctx_cur->get_head();
154
155 const int64_t d_conv = hparams.ssm_d_conv;
156 const int64_t d_inner = hparams.ssm_d_inner;
157 const int64_t d_state = hparams.ssm_d_state;
158 const int64_t n_head = hparams.ssm_dt_rank;
159 const int64_t head_dim = d_inner / n_head;
160 const int64_t n_group = hparams.ssm_n_group;
161 const int64_t n_seqs = ubatch.n_seqs;
162
163 const int64_t n_seq_tokens = ubatch.n_seq_tokens;
164
165 GGML_ASSERT(n_seqs != 0);
166 GGML_ASSERT(ubatch.equal_seqs());
167 GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
168
169 ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
170 ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
171
172 ggml_tensor * conv = build_rs(inp, s: conv_states_all, state_size: hparams.n_embd_r(), n_seqs);
173 conv = ggml_reshape_3d(ctx: ctx0, a: conv, ne0: d_conv - 1, ne1: d_inner + 2 * n_group * d_state, ne2: n_seqs);
174
175 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
176 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens, ne2: n_seqs);
177
178 // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
179
180 // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
181 ggml_tensor * zxBCdt = build_lora_mm(w: model.layers[il].ssm_in, cur);
182
183 // split the above in three
184 ggml_tensor * z = ggml_view_4d(ctx: ctx0, a: zxBCdt, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs, nb1: head_dim * zxBCdt->nb[0],
185 nb2: zxBCdt->nb[1], nb3: zxBCdt->nb[2], offset: 0);
186 ggml_tensor * xBC = ggml_view_3d(ctx: ctx0, a: zxBCdt, ne0: d_inner + 2 * n_group * d_state, ne1: n_seq_tokens, ne2: n_seqs, nb1: zxBCdt->nb[1],
187 nb2: zxBCdt->nb[2], offset: d_inner * ggml_element_size(tensor: zxBCdt));
188 ggml_tensor * dt = ggml_view_3d(ctx: ctx0, a: zxBCdt, ne0: n_head, ne1: n_seq_tokens, ne2: n_seqs, nb1: zxBCdt->nb[1], nb2: zxBCdt->nb[2],
189 offset: (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(tensor: zxBCdt));
190
191 // conv
192 {
193 // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
194 ggml_tensor * conv_x = ggml_concat(ctx: ctx0, a: conv, b: ggml_transpose(ctx: ctx0, a: xBC), dim: 0);
195
196 // copy last (d_conv - 1) columns back into the state cache
197 ggml_tensor * last_conv = ggml_view_3d(ctx: ctx0, a: conv_x, ne0: d_conv - 1, ne1: d_inner + 2 * n_group * d_state, ne2: n_seqs,
198 nb1: conv_x->nb[1], nb2: conv_x->nb[2], offset: n_seq_tokens * (conv_x->nb[0]));
199
200 ggml_build_forward_expand(cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: last_conv,
201 b: ggml_view_1d(ctx: ctx0, a: conv_states_all,
202 ne0: (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
203 offset: kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
204 ggml_element_size(tensor: conv_states_all))));
205
206 // 1D convolution
207 // The equivalent is to make a self-overlapping view of conv_x
208 // over d_conv columns at each stride in the 3rd dimension,
209 // then element-wise multiply that with the conv1d weight,
210 // then sum the elements of each row,
211 // (the last two steps are a dot product over rows (also doable with mul_mat))
212 // then permute away the ne[0] dimension,
213 // and then you're left with the resulting x tensor.
214 // For simultaneous sequences, all sequences need to have the same length.
215 xBC = ggml_ssm_conv(ctx: ctx0, sx: conv_x, c: model.layers[il].ssm_conv1d);
216
217 // bias
218 xBC = ggml_add(ctx: ctx0, a: xBC, b: model.layers[il].ssm_conv1d_b);
219
220 xBC = ggml_silu(ctx: ctx0, a: xBC);
221 }
222
223 // ssm
224 {
225 // These correspond to V K Q in SSM/attention duality
226 ggml_tensor * x = ggml_view_4d(ctx: ctx0, a: xBC, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs, nb1: head_dim * xBC->nb[0],
227 nb2: xBC->nb[1], nb3: xBC->nb[2], offset: 0);
228 ggml_tensor * B = ggml_view_4d(ctx: ctx0, a: xBC, ne0: d_state, ne1: n_group, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * xBC->nb[0],
229 nb2: xBC->nb[1], nb3: xBC->nb[2], offset: d_inner * ggml_element_size(tensor: xBC));
230 ggml_tensor * C = ggml_view_4d(ctx: ctx0, a: xBC, ne0: d_state, ne1: n_group, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * xBC->nb[0],
231 nb2: xBC->nb[1], nb3: xBC->nb[2], offset: (d_inner + n_group * d_state) * ggml_element_size(tensor: xBC));
232
233 // {n_head, n_seq_tokens, n_seqs}
234 dt = ggml_add(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: dt), b: model.layers[il].ssm_dt_b);
235
236 ggml_tensor * A = model.layers[il].ssm_a;
237
238 // use the states and the indices provided by build_recurrent_state
239 // (this is necessary in order to properly use the states before they are overwritten,
240 // while avoiding to make unnecessary copies of the states)
241 auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
242 ggml_tensor * ssm = ggml_reshape_4d(ctx, a: states, ne0: d_state, ne1: head_dim, ne2: n_head, ne3: mctx_cur->get_size());
243
244 // TODO: use semistructured matrices to implement state-space duality
245 // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
246 return ggml_ssm_scan(ctx, s: ssm, x, dt, A, B, C, ids);
247 };
248
249 ggml_tensor * y_ssm = build_rs(inp, s: ssm_states_all, state_size: hparams.n_embd_s(), n_seqs: ubatch.n_seqs, get_state_rows: get_ssm_rows);
250
251 // store last states
252 ggml_build_forward_expand(
253 cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: ggml_view_1d(ctx: ctx0, a: y_ssm, ne0: d_state * d_inner * n_seqs, offset: ggml_nelements(tensor: x) * x->nb[0]),
254 b: ggml_view_1d(ctx: ctx0, a: ssm_states_all, ne0: d_state * d_inner * n_seqs,
255 offset: kv_head * d_state * d_inner * ggml_element_size(tensor: ssm_states_all))));
256
257 ggml_tensor * y = ggml_view_4d(ctx: ctx0, a: y_ssm, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs, nb1: x->nb[1], nb2: n_head * x->nb[1],
258 nb3: n_seq_tokens * n_head * x->nb[1], offset: 0);
259
260 // TODO: skip computing output earlier for unused tokens
261
262 y = ggml_add(ctx: ctx0, a: y, b: ggml_mul(ctx: ctx0, a: x, b: model.layers[il].ssm_d));
263 cb(cur: y, name: "mamba2_y_add_d", il);
264 y = ggml_swiglu_split(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: z), b: y);
265
266 // grouped RMS norm
267 if (model.layers[il].ssm_norm) {
268 y = ggml_reshape_4d(ctx: ctx0, a: y, ne0: d_inner / n_group, ne1: n_group, ne2: n_seq_tokens, ne3: n_seqs);
269 y = build_norm(cur: y, mw: model.layers[il].ssm_norm, NULL, type: LLM_NORM_RMS, il);
270 }
271
272 y = ggml_reshape_3d(ctx: ctx0, a: y, ne0: d_inner, ne1: n_seq_tokens, ne2: n_seqs);
273
274 // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
275 cur = build_lora_mm(w: model.layers[il].ssm_out, cur: y);
276 }
277
278 // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
279 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens * n_seqs);
280 cb(cur, name: "mamba_out", il);
281
282 return cur;
283}
284