1#include "models.h"
2
3llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) :
4 llm_graph_context_mamba(params) {
5 ggml_tensor * cur;
6 ggml_tensor * inpL;
7
8 // {n_embd, n_tokens}
9 inpL = build_inp_embd(tok_embd: model.tok_embd);
10 cb(cur: inpL, name: "embedding_output", il: -1);
11
12 ggml_tensor * inp_pos = build_inp_pos();
13
14 auto * inp_hybrid = build_inp_mem_hybrid();
15
16 ggml_tensor * inp_out_ids = build_inp_out_ids();
17
18 for (int il = 0; il < n_layer; ++il) {
19 ggml_tensor * residual = inpL;
20
21 // ggml_graph_add_node(gf, model.layers[il].attn_norm);
22 // cb(model.layers[il].attn_norm, "attn_norm", il);
23
24 // pre_mixer_norm
25 cur = build_norm(cur: inpL, mw: model.layers[il].attn_norm, NULL, type: LLM_NORM_RMS, il);
26
27 // check if this layer is Mamba or Attention
28 bool is_mamba_layer = hparams.is_recurrent(il);
29
30 if (is_mamba_layer) {
31 // PLaMo-2 Mamba layer
32 cur = build_plamo2_mamba_layer(inp: inp_hybrid->get_recr(), cur, model, ubatch, il);
33 } else {
34 // PLaMo-2 Attention layer
35 cur = build_plamo2_attn_layer(inp: inp_hybrid->get_attn(), inp_pos, cur, model, il);
36 }
37
38 // post_mixer_norm
39 cur = build_norm(cur, mw: model.layers[il].attn_post_norm, NULL, type: LLM_NORM_RMS, il);
40 cb(cur, name: "attn_post_norm", il);
41
42 // residual connection
43 cur = ggml_add(ctx: ctx0, a: cur, b: residual);
44 cb(cur, name: "attn_residual", il);
45 residual = cur;
46
47 // pre-ffn norm
48 cur = build_norm(cur, mw: model.layers[il].ffn_norm, NULL, type: LLM_NORM_RMS, il);
49 cb(cur, name: "ffn_pre_norm", il);
50
51 // feed-forward network
52 cur = build_ffn(cur,
53 up: model.layers[il].ffn_up, NULL, NULL,
54 NULL, NULL, NULL,
55 down: model.layers[il].ffn_down, NULL, NULL,
56 NULL, type_op: LLM_FFN_SWIGLU, type_gate: LLM_FFN_SEQ, il);
57 cb(cur, name: "ffn_out", il);
58
59 // post ffn norm
60 cur = build_norm(cur, mw: model.layers[il].ffn_post_norm, NULL, type: LLM_NORM_RMS, il);
61 cb(cur, name: "ffn_post_norm", il);
62
63 if (il == n_layer - 1 && inp_out_ids) {
64 cur = ggml_get_rows(ctx: ctx0, a: cur, b: inp_out_ids);
65 residual = ggml_get_rows(ctx: ctx0, a: residual, b: inp_out_ids);
66 }
67
68 // residual connection
69 cur = ggml_add(ctx: ctx0, a: cur, b: residual);
70 cb(cur, name: "ffn_residual", il);
71
72 inpL = cur;
73 }
74
75 cur = inpL;
76
77 // final norm
78 cur = build_norm(cur, mw: model.output_norm, NULL, type: LLM_NORM_RMS, il: -1);
79 cb(cur, name: "result_norm", il: -1);
80
81 res->t_embd = cur;
82
83 // lm_head
84 cur = build_lora_mm(w: model.output, cur);
85 cb(cur, name: "result_output", il: -1);
86
87 // Explicitly mark as output tensor to ensure proper backend assignment
88 ggml_set_output(tensor: cur);
89
90 res->t_logits = cur;
91
92 ggml_build_forward_expand(cgraph: gf, tensor: cur);
93}
94
95ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp,
96 ggml_tensor * inp_pos,
97 ggml_tensor * cur,
98 const llama_model & model,
99 int il) {
100 // self-attention
101 {
102 // PLaMo-2 uses combined QKV tensor
103 ggml_tensor * qkv = build_lora_mm(w: model.layers[il].wqkv, cur);
104 cb(cur: qkv, name: "wqkv", il);
105
106 // split QKV tensor into Q, K, V
107 const int64_t n_embd_head_q = hparams.n_embd_head_k;
108 const int64_t n_embd_head_k = hparams.n_embd_head_k;
109 const int64_t n_embd_head_v = hparams.n_embd_head_v;
110 int32_t n_head = hparams.n_head(il);
111 int32_t n_head_kv = hparams.n_head_kv(il);
112
113 const int64_t q_offset = 0;
114 const int64_t k_offset = n_embd_head_q * n_head;
115 const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
116
117 ggml_tensor * Qcur = ggml_view_3d(ctx: ctx0, a: qkv, ne0: n_embd_head_q, ne1: n_head, ne2: n_tokens, nb1: n_embd_head_q * sizeof(float),
118 nb2: qkv->nb[1], offset: q_offset * ggml_element_size(tensor: qkv));
119 ggml_tensor * Kcur = ggml_view_3d(ctx: ctx0, a: qkv, ne0: n_embd_head_k, ne1: n_head_kv, ne2: n_tokens, nb1: n_embd_head_k * sizeof(float),
120 nb2: qkv->nb[1], offset: k_offset * ggml_element_size(tensor: qkv));
121 ggml_tensor * Vcur = ggml_view_3d(ctx: ctx0, a: qkv, ne0: n_embd_head_v, ne1: n_head_kv, ne2: n_tokens, nb1: n_embd_head_v * sizeof(float),
122 nb2: qkv->nb[1], offset: v_offset * ggml_element_size(tensor: qkv));
123
124 cb(cur: Qcur, name: "Qcur", il);
125 cb(cur: Kcur, name: "Kcur", il);
126 cb(cur: Vcur, name: "Vcur", il);
127
128 Qcur = build_norm(cur: Qcur, mw: model.layers[il].attn_q_norm, NULL, type: LLM_NORM_RMS, il);
129 cb(cur: Qcur, name: "Qcur_normed", il);
130
131 Qcur = ggml_rope_ext(ctx: ctx0, a: Qcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base, freq_scale,
132 ext_factor, attn_factor, beta_fast, beta_slow);
133
134 Kcur = build_norm(cur: Kcur, mw: model.layers[il].attn_k_norm, NULL, type: LLM_NORM_RMS, il);
135 cb(cur: Kcur, name: "Kcur_normed", il);
136
137 Kcur = ggml_rope_ext(ctx: ctx0, a: Kcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base, freq_scale,
138 ext_factor, attn_factor, beta_fast, beta_slow);
139
140 cur = build_attn(inp,
141 wo: model.layers[il].wo, NULL,
142 q_cur: Qcur, k_cur: Kcur, v_cur: Vcur, NULL, NULL, NULL, kq_scale: 1.0f / sqrtf(x: float(n_embd_head_v)), il);
143 }
144
145 cb(cur, name: "attn_out", il);
146
147 return cur;
148}
149
150ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp,
151 ggml_tensor * cur,
152 const llama_model & model,
153 const llama_ubatch & ubatch,
154 int il) {
155 const auto * mctx_cur = inp->mctx;
156
157 const auto kv_head = mctx_cur->get_head();
158
159 const int64_t d_conv = hparams.ssm_d_conv;
160 const int64_t d_inner = hparams.ssm_d_inner;
161 const int64_t d_state = hparams.ssm_d_state;
162 const int64_t n_heads = hparams.ssm_dt_rank;
163 const int64_t head_dim = d_inner / n_heads;
164 const int64_t n_group = hparams.ssm_n_group;
165 const int64_t n_seqs = ubatch.n_seqs;
166
167 const int64_t n_seq_tokens = ubatch.n_seq_tokens;
168
169 GGML_ASSERT(n_seqs != 0);
170 GGML_ASSERT(ubatch.equal_seqs());
171 GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
172
173 ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
174 ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
175
176 ggml_tensor * conv = build_rs(inp, s: conv_states_all, state_size: hparams.n_embd_r(), n_seqs);
177 conv = ggml_reshape_3d(ctx: ctx0, a: conv, ne0: d_conv - 1, ne1: d_inner + 2 * n_group * d_state, ne2: n_seqs);
178
179 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
180 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens, ne2: n_seqs);
181
182 // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
183 ggml_tensor * zx = build_lora_mm(w: model.layers[il].ssm_in, cur);
184 cb(cur: zx, name: "mamba_in_proj", il);
185 // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
186 zx = ggml_permute(ctx: ctx0, a: zx, axis0: 0, axis1: 2, axis2: 1, axis3: 3);
187 zx = ggml_cont_4d(ctx: ctx0, a: zx, ne0: head_dim * 2, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs);
188 cb(cur: zx, name: "mamba_in_proj_out", il);
189
190 // split into z and x
191 // => {head_dim * n_heads, n_seq_tokens, n_seqs}
192 ggml_tensor * x = ggml_view_4d(ctx: ctx0, a: zx, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs, nb1: zx->nb[1], nb2: zx->nb[2], nb3: zx->nb[3],
193 offset: head_dim * ggml_element_size(tensor: zx));
194 x = ggml_cont_3d(ctx: ctx0, a: x, ne0: head_dim * n_heads, ne1: n_seq_tokens, ne2: n_seqs);
195 // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
196 cb(cur: x, name: "mamba_x_split", il);
197
198 ggml_tensor * z =
199 ggml_view_4d(ctx: ctx0, a: zx, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs, nb1: zx->nb[1], nb2: zx->nb[2], nb3: zx->nb[3], offset: 0);
200 cb(cur: z, name: "mamba_z_split", il);
201
202 // conv1d
203 {
204 // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
205 ggml_tensor * conv_x = ggml_concat(ctx: ctx0, a: conv, b: ggml_transpose(ctx: ctx0, a: x), dim: 0);
206 cb(cur: conv_x, name: "mamba_conv1d_input", il);
207
208 // copy last (d_conv - 1) columns back into the state cache
209 ggml_tensor * last_conv = ggml_view_3d(ctx: ctx0, a: conv_x, ne0: d_conv - 1, ne1: d_inner, ne2: n_seqs, nb1: conv_x->nb[1], nb2: conv_x->nb[2],
210 offset: n_seq_tokens * (conv_x->nb[0]));
211
212 ggml_build_forward_expand(cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: last_conv,
213 b: ggml_view_1d(ctx: ctx0, a: conv_states_all,
214 ne0: (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
215 offset: kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
216 ggml_element_size(tensor: conv_states_all))));
217 cb(cur: conv_states_all, name: "mamba_conv1d_state", il);
218
219 // 1D convolution
220 x = ggml_ssm_conv(ctx: ctx0, sx: conv_x, c: model.layers[il].ssm_conv1d);
221 cb(cur: x, name: "mamba_conv1d", il);
222
223 x = ggml_silu(ctx: ctx0, a: x);
224 cb(cur: x, name: "mamba_conv1d_silu", il);
225 }
226
227 // SSM
228 {
229 // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
230 ggml_tensor * x_bcdt = build_lora_mm(w: model.layers[il].ssm_x, cur: x);
231 cb(cur: x_bcdt, name: "mamba_bcdt_proj", il);
232
233 // split into dt, B, C
234 const int64_t dt_dim = std::max(a: 64, b: int(hparams.n_embd / 16));
235 ggml_tensor * B = ggml_view_3d(ctx: ctx0, a: x_bcdt, ne0: d_state, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_bcdt->nb[1], nb2: x_bcdt->nb[2], offset: 0);
236 ggml_tensor * C = ggml_view_3d(ctx: ctx0, a: x_bcdt, ne0: d_state, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_bcdt->nb[1], nb2: x_bcdt->nb[2],
237 offset: ggml_element_size(tensor: x_bcdt) * d_state);
238 ggml_tensor * dt = ggml_view_3d(ctx: ctx0, a: x_bcdt, ne0: dt_dim, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_bcdt->nb[1], nb2: x_bcdt->nb[2],
239 offset: ggml_element_size(tensor: x_bcdt) * (2 * d_state));
240 cb(cur: B, name: "mamba_B_raw", il);
241 cb(cur: C, name: "mamba_C_raw", il);
242 cb(cur: dt, name: "mamba_dt_raw", il);
243
244 // Apply RMS norm to dt, B, C (PLaMo-2 specific)
245 B = build_norm(cur: B, mw: model.layers[il].ssm_b_norm, NULL, type: LLM_NORM_RMS, il);
246 C = build_norm(cur: C, mw: model.layers[il].ssm_c_norm, NULL, type: LLM_NORM_RMS, il);
247 dt = build_norm(cur: dt, mw: model.layers[il].ssm_dt_norm, NULL, type: LLM_NORM_RMS, il);
248 cb(cur: B, name: "mamba_B_normed", il);
249 cb(cur: C, name: "mamba_C_normed", il);
250 cb(cur: dt, name: "mamba_dt_normed", il);
251
252 // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
253 dt = build_lora_mm(w: model.layers[il].ssm_dt, cur: dt);
254 dt = ggml_add(ctx: ctx0, a: dt, b: model.layers[il].ssm_dt_b);
255 cb(cur: dt, name: "mamba_dt_proj", il);
256
257 ggml_tensor * A = ggml_reshape_2d(ctx: ctx0, a: model.layers[il].ssm_a, ne0: 1, ne1: n_heads);
258 cb(cur: A, name: "mamba_A", il);
259
260 x = ggml_view_4d(ctx: ctx0, a: x, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs, nb1: head_dim * ggml_element_size(tensor: x),
261 nb2: head_dim * n_heads * ggml_element_size(tensor: x),
262 nb3: head_dim * n_heads * n_seq_tokens * ggml_element_size(tensor: x), offset: 0);
263 B = ggml_view_4d(ctx: ctx0, a: B, ne0: d_state, ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * B->nb[0], nb2: B->nb[1], nb3: B->nb[2], offset: 0);
264 C = ggml_view_4d(ctx: ctx0, a: C, ne0: d_state, ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * C->nb[0], nb2: C->nb[1], nb3: C->nb[2], offset: 0);
265
266 // use the states and the indices provided by build_recurrent_state
267 // (this is necessary in order to properly use the states before they are overwritten,
268 // while avoiding to make unnecessary copies of the states)
269 auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
270 ggml_tensor * ssm = ggml_reshape_4d(ctx, a: states, ne0: d_state, ne1: head_dim, ne2: n_heads, ne3: mctx_cur->get_size());
271
272 // Custom operator to optimize the parallel associative scan
273 // as described in the Annex D of the Mamba paper.
274 // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
275 return ggml_ssm_scan(ctx, s: ssm, x, dt, A, B, C, ids);
276 };
277
278 ggml_tensor * y_ssm = build_rs(inp, s: ssm_states_all, state_size: hparams.n_embd_s(), n_seqs: ubatch.n_seqs, get_state_rows: get_ssm_rows);
279 cb(cur: y_ssm, name: "mamba_ssm_scan", il);
280
281 // store last states
282 ggml_build_forward_expand(
283 cgraph: gf, tensor: ggml_cpy(
284 ctx: ctx0,
285 a: ggml_view_1d(ctx: ctx0, a: y_ssm, ne0: n_heads * head_dim * d_state * n_seqs,
286 offset: n_heads * head_dim * n_seq_tokens * n_seqs * ggml_element_size(tensor: y_ssm)),
287 b: ggml_view_1d(ctx: ctx0, a: ssm_states_all, ne0: n_heads * head_dim * d_state * n_seqs,
288 offset: kv_head * n_seqs * n_heads * head_dim * d_state * ggml_element_size(tensor: ssm_states_all))));
289 cb(cur: ssm_states_all, name: "mamba_ssm_states", il);
290
291 ggml_tensor * y = ggml_view_4d(ctx: ctx0, a: y_ssm, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs,
292 nb1: head_dim * ggml_element_size(tensor: x), nb2: head_dim * n_heads * ggml_element_size(tensor: x),
293 nb3: head_dim * n_heads * n_seq_tokens * ggml_element_size(tensor: x), offset: 0);
294 cb(cur: y, name: "mamba_y_view", il);
295
296 // Add D parameter and apply gating with z
297 // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
298 ggml_tensor * D = ggml_reshape_2d(ctx: ctx0, a: model.layers[il].ssm_d, ne0: 1, ne1: n_heads);
299 y = ggml_add(ctx: ctx0, a: y, b: ggml_mul(ctx: ctx0, a: x, b: D));
300 cb(cur: y, name: "mamba_y_add_d", il);
301
302 y = ggml_swiglu_split(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: z), b: y);
303 cb(cur: y, name: "mamba_y_swiglu_z", il);
304
305 // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
306 y = ggml_view_3d(ctx: ctx0, a: y, ne0: head_dim * n_heads, ne1: n_seq_tokens, ne2: n_seqs, nb1: y->nb[2], nb2: y->nb[3], offset: 0);
307 cur = build_lora_mm(w: model.layers[il].ssm_out, cur: y);
308 cb(cur, name: "mamba_out_proj", il);
309 }
310
311 // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
312 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens * n_seqs);
313 cb(cur, name: "mamba_out", il);
314
315 return cur;
316}
317