| 1 | #include "models.h" |
| 2 | |
| 3 | llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : |
| 4 | llm_graph_context_mamba(params) { |
| 5 | ggml_tensor * cur; |
| 6 | ggml_tensor * inpL; |
| 7 | |
| 8 | // {n_embd, n_tokens} |
| 9 | inpL = build_inp_embd(tok_embd: model.tok_embd); |
| 10 | cb(cur: inpL, name: "embedding_output" , il: -1); |
| 11 | |
| 12 | ggml_tensor * inp_pos = build_inp_pos(); |
| 13 | |
| 14 | auto * inp_hybrid = build_inp_mem_hybrid(); |
| 15 | |
| 16 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
| 17 | |
| 18 | for (int il = 0; il < n_layer; ++il) { |
| 19 | ggml_tensor * residual = inpL; |
| 20 | |
| 21 | // ggml_graph_add_node(gf, model.layers[il].attn_norm); |
| 22 | // cb(model.layers[il].attn_norm, "attn_norm", il); |
| 23 | |
| 24 | // pre_mixer_norm |
| 25 | cur = build_norm(cur: inpL, mw: model.layers[il].attn_norm, NULL, type: LLM_NORM_RMS, il); |
| 26 | |
| 27 | // check if this layer is Mamba or Attention |
| 28 | bool is_mamba_layer = hparams.is_recurrent(il); |
| 29 | |
| 30 | if (is_mamba_layer) { |
| 31 | // PLaMo-2 Mamba layer |
| 32 | cur = build_plamo2_mamba_layer(inp: inp_hybrid->get_recr(), cur, model, ubatch, il); |
| 33 | } else { |
| 34 | // PLaMo-2 Attention layer |
| 35 | cur = build_plamo2_attn_layer(inp: inp_hybrid->get_attn(), inp_pos, cur, model, il); |
| 36 | } |
| 37 | |
| 38 | // post_mixer_norm |
| 39 | cur = build_norm(cur, mw: model.layers[il].attn_post_norm, NULL, type: LLM_NORM_RMS, il); |
| 40 | cb(cur, name: "attn_post_norm" , il); |
| 41 | |
| 42 | // residual connection |
| 43 | cur = ggml_add(ctx: ctx0, a: cur, b: residual); |
| 44 | cb(cur, name: "attn_residual" , il); |
| 45 | residual = cur; |
| 46 | |
| 47 | // pre-ffn norm |
| 48 | cur = build_norm(cur, mw: model.layers[il].ffn_norm, NULL, type: LLM_NORM_RMS, il); |
| 49 | cb(cur, name: "ffn_pre_norm" , il); |
| 50 | |
| 51 | // feed-forward network |
| 52 | cur = build_ffn(cur, |
| 53 | up: model.layers[il].ffn_up, NULL, NULL, |
| 54 | NULL, NULL, NULL, |
| 55 | down: model.layers[il].ffn_down, NULL, NULL, |
| 56 | NULL, type_op: LLM_FFN_SWIGLU, type_gate: LLM_FFN_SEQ, il); |
| 57 | cb(cur, name: "ffn_out" , il); |
| 58 | |
| 59 | // post ffn norm |
| 60 | cur = build_norm(cur, mw: model.layers[il].ffn_post_norm, NULL, type: LLM_NORM_RMS, il); |
| 61 | cb(cur, name: "ffn_post_norm" , il); |
| 62 | |
| 63 | if (il == n_layer - 1 && inp_out_ids) { |
| 64 | cur = ggml_get_rows(ctx: ctx0, a: cur, b: inp_out_ids); |
| 65 | residual = ggml_get_rows(ctx: ctx0, a: residual, b: inp_out_ids); |
| 66 | } |
| 67 | |
| 68 | // residual connection |
| 69 | cur = ggml_add(ctx: ctx0, a: cur, b: residual); |
| 70 | cb(cur, name: "ffn_residual" , il); |
| 71 | |
| 72 | inpL = cur; |
| 73 | } |
| 74 | |
| 75 | cur = inpL; |
| 76 | |
| 77 | // final norm |
| 78 | cur = build_norm(cur, mw: model.output_norm, NULL, type: LLM_NORM_RMS, il: -1); |
| 79 | cb(cur, name: "result_norm" , il: -1); |
| 80 | |
| 81 | res->t_embd = cur; |
| 82 | |
| 83 | // lm_head |
| 84 | cur = build_lora_mm(w: model.output, cur); |
| 85 | cb(cur, name: "result_output" , il: -1); |
| 86 | |
| 87 | // Explicitly mark as output tensor to ensure proper backend assignment |
| 88 | ggml_set_output(tensor: cur); |
| 89 | |
| 90 | res->t_logits = cur; |
| 91 | |
| 92 | ggml_build_forward_expand(cgraph: gf, tensor: cur); |
| 93 | } |
| 94 | |
| 95 | ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, |
| 96 | ggml_tensor * inp_pos, |
| 97 | ggml_tensor * cur, |
| 98 | const llama_model & model, |
| 99 | int il) { |
| 100 | // self-attention |
| 101 | { |
| 102 | // PLaMo-2 uses combined QKV tensor |
| 103 | ggml_tensor * qkv = build_lora_mm(w: model.layers[il].wqkv, cur); |
| 104 | cb(cur: qkv, name: "wqkv" , il); |
| 105 | |
| 106 | // split QKV tensor into Q, K, V |
| 107 | const int64_t n_embd_head_q = hparams.n_embd_head_k; |
| 108 | const int64_t n_embd_head_k = hparams.n_embd_head_k; |
| 109 | const int64_t n_embd_head_v = hparams.n_embd_head_v; |
| 110 | int32_t n_head = hparams.n_head(il); |
| 111 | int32_t n_head_kv = hparams.n_head_kv(il); |
| 112 | |
| 113 | const int64_t q_offset = 0; |
| 114 | const int64_t k_offset = n_embd_head_q * n_head; |
| 115 | const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; |
| 116 | |
| 117 | ggml_tensor * Qcur = ggml_view_3d(ctx: ctx0, a: qkv, ne0: n_embd_head_q, ne1: n_head, ne2: n_tokens, nb1: n_embd_head_q * sizeof(float), |
| 118 | nb2: qkv->nb[1], offset: q_offset * ggml_element_size(tensor: qkv)); |
| 119 | ggml_tensor * Kcur = ggml_view_3d(ctx: ctx0, a: qkv, ne0: n_embd_head_k, ne1: n_head_kv, ne2: n_tokens, nb1: n_embd_head_k * sizeof(float), |
| 120 | nb2: qkv->nb[1], offset: k_offset * ggml_element_size(tensor: qkv)); |
| 121 | ggml_tensor * Vcur = ggml_view_3d(ctx: ctx0, a: qkv, ne0: n_embd_head_v, ne1: n_head_kv, ne2: n_tokens, nb1: n_embd_head_v * sizeof(float), |
| 122 | nb2: qkv->nb[1], offset: v_offset * ggml_element_size(tensor: qkv)); |
| 123 | |
| 124 | cb(cur: Qcur, name: "Qcur" , il); |
| 125 | cb(cur: Kcur, name: "Kcur" , il); |
| 126 | cb(cur: Vcur, name: "Vcur" , il); |
| 127 | |
| 128 | Qcur = build_norm(cur: Qcur, mw: model.layers[il].attn_q_norm, NULL, type: LLM_NORM_RMS, il); |
| 129 | cb(cur: Qcur, name: "Qcur_normed" , il); |
| 130 | |
| 131 | Qcur = ggml_rope_ext(ctx: ctx0, a: Qcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base, freq_scale, |
| 132 | ext_factor, attn_factor, beta_fast, beta_slow); |
| 133 | |
| 134 | Kcur = build_norm(cur: Kcur, mw: model.layers[il].attn_k_norm, NULL, type: LLM_NORM_RMS, il); |
| 135 | cb(cur: Kcur, name: "Kcur_normed" , il); |
| 136 | |
| 137 | Kcur = ggml_rope_ext(ctx: ctx0, a: Kcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base, freq_scale, |
| 138 | ext_factor, attn_factor, beta_fast, beta_slow); |
| 139 | |
| 140 | cur = build_attn(inp, |
| 141 | wo: model.layers[il].wo, NULL, |
| 142 | q_cur: Qcur, k_cur: Kcur, v_cur: Vcur, NULL, NULL, NULL, kq_scale: 1.0f / sqrtf(x: float(n_embd_head_v)), il); |
| 143 | } |
| 144 | |
| 145 | cb(cur, name: "attn_out" , il); |
| 146 | |
| 147 | return cur; |
| 148 | } |
| 149 | |
| 150 | ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp, |
| 151 | ggml_tensor * cur, |
| 152 | const llama_model & model, |
| 153 | const llama_ubatch & ubatch, |
| 154 | int il) { |
| 155 | const auto * mctx_cur = inp->mctx; |
| 156 | |
| 157 | const auto kv_head = mctx_cur->get_head(); |
| 158 | |
| 159 | const int64_t d_conv = hparams.ssm_d_conv; |
| 160 | const int64_t d_inner = hparams.ssm_d_inner; |
| 161 | const int64_t d_state = hparams.ssm_d_state; |
| 162 | const int64_t n_heads = hparams.ssm_dt_rank; |
| 163 | const int64_t head_dim = d_inner / n_heads; |
| 164 | const int64_t n_group = hparams.ssm_n_group; |
| 165 | const int64_t n_seqs = ubatch.n_seqs; |
| 166 | |
| 167 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
| 168 | |
| 169 | GGML_ASSERT(n_seqs != 0); |
| 170 | GGML_ASSERT(ubatch.equal_seqs()); |
| 171 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
| 172 | |
| 173 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
| 174 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
| 175 | |
| 176 | ggml_tensor * conv = build_rs(inp, s: conv_states_all, state_size: hparams.n_embd_r(), n_seqs); |
| 177 | conv = ggml_reshape_3d(ctx: ctx0, a: conv, ne0: d_conv - 1, ne1: d_inner + 2 * n_group * d_state, ne2: n_seqs); |
| 178 | |
| 179 | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
| 180 | cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens, ne2: n_seqs); |
| 181 | |
| 182 | // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} |
| 183 | ggml_tensor * zx = build_lora_mm(w: model.layers[il].ssm_in, cur); |
| 184 | cb(cur: zx, name: "mamba_in_proj" , il); |
| 185 | // {8192, 5, 1, 1} -> {8192, 1, 5, 1} |
| 186 | zx = ggml_permute(ctx: ctx0, a: zx, axis0: 0, axis1: 2, axis2: 1, axis3: 3); |
| 187 | zx = ggml_cont_4d(ctx: ctx0, a: zx, ne0: head_dim * 2, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs); |
| 188 | cb(cur: zx, name: "mamba_in_proj_out" , il); |
| 189 | |
| 190 | // split into z and x |
| 191 | // => {head_dim * n_heads, n_seq_tokens, n_seqs} |
| 192 | ggml_tensor * x = ggml_view_4d(ctx: ctx0, a: zx, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs, nb1: zx->nb[1], nb2: zx->nb[2], nb3: zx->nb[3], |
| 193 | offset: head_dim * ggml_element_size(tensor: zx)); |
| 194 | x = ggml_cont_3d(ctx: ctx0, a: x, ne0: head_dim * n_heads, ne1: n_seq_tokens, ne2: n_seqs); |
| 195 | // x = ggml_permute(ctx0, x, 0, 2, 1, 3); |
| 196 | cb(cur: x, name: "mamba_x_split" , il); |
| 197 | |
| 198 | ggml_tensor * z = |
| 199 | ggml_view_4d(ctx: ctx0, a: zx, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs, nb1: zx->nb[1], nb2: zx->nb[2], nb3: zx->nb[3], offset: 0); |
| 200 | cb(cur: z, name: "mamba_z_split" , il); |
| 201 | |
| 202 | // conv1d |
| 203 | { |
| 204 | // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} |
| 205 | ggml_tensor * conv_x = ggml_concat(ctx: ctx0, a: conv, b: ggml_transpose(ctx: ctx0, a: x), dim: 0); |
| 206 | cb(cur: conv_x, name: "mamba_conv1d_input" , il); |
| 207 | |
| 208 | // copy last (d_conv - 1) columns back into the state cache |
| 209 | ggml_tensor * last_conv = ggml_view_3d(ctx: ctx0, a: conv_x, ne0: d_conv - 1, ne1: d_inner, ne2: n_seqs, nb1: conv_x->nb[1], nb2: conv_x->nb[2], |
| 210 | offset: n_seq_tokens * (conv_x->nb[0])); |
| 211 | |
| 212 | ggml_build_forward_expand(cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: last_conv, |
| 213 | b: ggml_view_1d(ctx: ctx0, a: conv_states_all, |
| 214 | ne0: (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs), |
| 215 | offset: kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) * |
| 216 | ggml_element_size(tensor: conv_states_all)))); |
| 217 | cb(cur: conv_states_all, name: "mamba_conv1d_state" , il); |
| 218 | |
| 219 | // 1D convolution |
| 220 | x = ggml_ssm_conv(ctx: ctx0, sx: conv_x, c: model.layers[il].ssm_conv1d); |
| 221 | cb(cur: x, name: "mamba_conv1d" , il); |
| 222 | |
| 223 | x = ggml_silu(ctx: ctx0, a: x); |
| 224 | cb(cur: x, name: "mamba_conv1d_silu" , il); |
| 225 | } |
| 226 | |
| 227 | // SSM |
| 228 | { |
| 229 | // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} |
| 230 | ggml_tensor * x_bcdt = build_lora_mm(w: model.layers[il].ssm_x, cur: x); |
| 231 | cb(cur: x_bcdt, name: "mamba_bcdt_proj" , il); |
| 232 | |
| 233 | // split into dt, B, C |
| 234 | const int64_t dt_dim = std::max(a: 64, b: int(hparams.n_embd / 16)); |
| 235 | ggml_tensor * B = ggml_view_3d(ctx: ctx0, a: x_bcdt, ne0: d_state, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_bcdt->nb[1], nb2: x_bcdt->nb[2], offset: 0); |
| 236 | ggml_tensor * C = ggml_view_3d(ctx: ctx0, a: x_bcdt, ne0: d_state, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_bcdt->nb[1], nb2: x_bcdt->nb[2], |
| 237 | offset: ggml_element_size(tensor: x_bcdt) * d_state); |
| 238 | ggml_tensor * dt = ggml_view_3d(ctx: ctx0, a: x_bcdt, ne0: dt_dim, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_bcdt->nb[1], nb2: x_bcdt->nb[2], |
| 239 | offset: ggml_element_size(tensor: x_bcdt) * (2 * d_state)); |
| 240 | cb(cur: B, name: "mamba_B_raw" , il); |
| 241 | cb(cur: C, name: "mamba_C_raw" , il); |
| 242 | cb(cur: dt, name: "mamba_dt_raw" , il); |
| 243 | |
| 244 | // Apply RMS norm to dt, B, C (PLaMo-2 specific) |
| 245 | B = build_norm(cur: B, mw: model.layers[il].ssm_b_norm, NULL, type: LLM_NORM_RMS, il); |
| 246 | C = build_norm(cur: C, mw: model.layers[il].ssm_c_norm, NULL, type: LLM_NORM_RMS, il); |
| 247 | dt = build_norm(cur: dt, mw: model.layers[il].ssm_dt_norm, NULL, type: LLM_NORM_RMS, il); |
| 248 | cb(cur: B, name: "mamba_B_normed" , il); |
| 249 | cb(cur: C, name: "mamba_C_normed" , il); |
| 250 | cb(cur: dt, name: "mamba_dt_normed" , il); |
| 251 | |
| 252 | // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} |
| 253 | dt = build_lora_mm(w: model.layers[il].ssm_dt, cur: dt); |
| 254 | dt = ggml_add(ctx: ctx0, a: dt, b: model.layers[il].ssm_dt_b); |
| 255 | cb(cur: dt, name: "mamba_dt_proj" , il); |
| 256 | |
| 257 | ggml_tensor * A = ggml_reshape_2d(ctx: ctx0, a: model.layers[il].ssm_a, ne0: 1, ne1: n_heads); |
| 258 | cb(cur: A, name: "mamba_A" , il); |
| 259 | |
| 260 | x = ggml_view_4d(ctx: ctx0, a: x, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs, nb1: head_dim * ggml_element_size(tensor: x), |
| 261 | nb2: head_dim * n_heads * ggml_element_size(tensor: x), |
| 262 | nb3: head_dim * n_heads * n_seq_tokens * ggml_element_size(tensor: x), offset: 0); |
| 263 | B = ggml_view_4d(ctx: ctx0, a: B, ne0: d_state, ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * B->nb[0], nb2: B->nb[1], nb3: B->nb[2], offset: 0); |
| 264 | C = ggml_view_4d(ctx: ctx0, a: C, ne0: d_state, ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * C->nb[0], nb2: C->nb[1], nb3: C->nb[2], offset: 0); |
| 265 | |
| 266 | // use the states and the indices provided by build_recurrent_state |
| 267 | // (this is necessary in order to properly use the states before they are overwritten, |
| 268 | // while avoiding to make unnecessary copies of the states) |
| 269 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
| 270 | ggml_tensor * ssm = ggml_reshape_4d(ctx, a: states, ne0: d_state, ne1: head_dim, ne2: n_heads, ne3: mctx_cur->get_size()); |
| 271 | |
| 272 | // Custom operator to optimize the parallel associative scan |
| 273 | // as described in the Annex D of the Mamba paper. |
| 274 | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
| 275 | return ggml_ssm_scan(ctx, s: ssm, x, dt, A, B, C, ids); |
| 276 | }; |
| 277 | |
| 278 | ggml_tensor * y_ssm = build_rs(inp, s: ssm_states_all, state_size: hparams.n_embd_s(), n_seqs: ubatch.n_seqs, get_state_rows: get_ssm_rows); |
| 279 | cb(cur: y_ssm, name: "mamba_ssm_scan" , il); |
| 280 | |
| 281 | // store last states |
| 282 | ggml_build_forward_expand( |
| 283 | cgraph: gf, tensor: ggml_cpy( |
| 284 | ctx: ctx0, |
| 285 | a: ggml_view_1d(ctx: ctx0, a: y_ssm, ne0: n_heads * head_dim * d_state * n_seqs, |
| 286 | offset: n_heads * head_dim * n_seq_tokens * n_seqs * ggml_element_size(tensor: y_ssm)), |
| 287 | b: ggml_view_1d(ctx: ctx0, a: ssm_states_all, ne0: n_heads * head_dim * d_state * n_seqs, |
| 288 | offset: kv_head * n_seqs * n_heads * head_dim * d_state * ggml_element_size(tensor: ssm_states_all)))); |
| 289 | cb(cur: ssm_states_all, name: "mamba_ssm_states" , il); |
| 290 | |
| 291 | ggml_tensor * y = ggml_view_4d(ctx: ctx0, a: y_ssm, ne0: head_dim, ne1: n_heads, ne2: n_seq_tokens, ne3: n_seqs, |
| 292 | nb1: head_dim * ggml_element_size(tensor: x), nb2: head_dim * n_heads * ggml_element_size(tensor: x), |
| 293 | nb3: head_dim * n_heads * n_seq_tokens * ggml_element_size(tensor: x), offset: 0); |
| 294 | cb(cur: y, name: "mamba_y_view" , il); |
| 295 | |
| 296 | // Add D parameter and apply gating with z |
| 297 | // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} |
| 298 | ggml_tensor * D = ggml_reshape_2d(ctx: ctx0, a: model.layers[il].ssm_d, ne0: 1, ne1: n_heads); |
| 299 | y = ggml_add(ctx: ctx0, a: y, b: ggml_mul(ctx: ctx0, a: x, b: D)); |
| 300 | cb(cur: y, name: "mamba_y_add_d" , il); |
| 301 | |
| 302 | y = ggml_swiglu_split(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: z), b: y); |
| 303 | cb(cur: y, name: "mamba_y_swiglu_z" , il); |
| 304 | |
| 305 | // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
| 306 | y = ggml_view_3d(ctx: ctx0, a: y, ne0: head_dim * n_heads, ne1: n_seq_tokens, ne2: n_seqs, nb1: y->nb[2], nb2: y->nb[3], offset: 0); |
| 307 | cur = build_lora_mm(w: model.layers[il].ssm_out, cur: y); |
| 308 | cb(cur, name: "mamba_out_proj" , il); |
| 309 | } |
| 310 | |
| 311 | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
| 312 | cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens * n_seqs); |
| 313 | cb(cur, name: "mamba_out" , il); |
| 314 | |
| 315 | return cur; |
| 316 | } |
| 317 | |