| 1 | #include "models.h" |
| 2 | |
| 3 | llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} |
| 4 | |
| 5 | ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, |
| 6 | ggml_tensor * cur, |
| 7 | const llama_model & model, |
| 8 | const llama_ubatch & ubatch, |
| 9 | int il) { |
| 10 | const auto * mctx_cur = inp->mctx; |
| 11 | |
| 12 | const auto kv_head = mctx_cur->get_head(); |
| 13 | |
| 14 | const auto & layer = model.layers[il]; |
| 15 | |
| 16 | const int64_t d_conv = hparams.ssm_d_conv; |
| 17 | const int64_t d_inner = hparams.ssm_d_inner; |
| 18 | const int64_t d_state = hparams.ssm_d_state; |
| 19 | const int64_t dt_rank = hparams.ssm_dt_rank; |
| 20 | const int64_t n_head = d_inner; |
| 21 | const int64_t head_dim = 1; |
| 22 | const int64_t n_seqs = ubatch.n_seqs; |
| 23 | // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) |
| 24 | const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; |
| 25 | |
| 26 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
| 27 | |
| 28 | GGML_ASSERT(n_seqs != 0); |
| 29 | GGML_ASSERT(ubatch.equal_seqs()); |
| 30 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
| 31 | |
| 32 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
| 33 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
| 34 | |
| 35 | ggml_tensor * conv = build_rs(inp, s: conv_states_all, state_size: hparams.n_embd_r(), n_seqs); |
| 36 | conv = ggml_reshape_3d(ctx: ctx0, a: conv, ne0: d_conv - 1, ne1: d_inner, ne2: n_seqs); |
| 37 | |
| 38 | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
| 39 | cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens, ne2: n_seqs); |
| 40 | |
| 41 | // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} |
| 42 | ggml_tensor * xz = build_lora_mm(w: layer.ssm_in, cur); |
| 43 | // split the above in two |
| 44 | // => {d_inner, n_seq_tokens, n_seqs} |
| 45 | ggml_tensor * x = ggml_view_3d(ctx: ctx0, a: xz, ne0: d_inner, ne1: xz->ne[1], ne2: xz->ne[2], nb1: xz->nb[1], nb2: xz->nb[2], offset: 0); |
| 46 | ggml_tensor * z = |
| 47 | ggml_view_3d(ctx: ctx0, a: xz, ne0: d_inner, ne1: xz->ne[1], ne2: xz->ne[2], nb1: xz->nb[1], nb2: xz->nb[2], offset: d_inner * ggml_element_size(tensor: xz)); |
| 48 | |
| 49 | // conv |
| 50 | { |
| 51 | // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} |
| 52 | ggml_tensor * conv_x = ggml_concat(ctx: ctx0, a: conv, b: ggml_transpose(ctx: ctx0, a: x), dim: 0); |
| 53 | |
| 54 | // copy last (d_conv - 1) columns back into the state cache |
| 55 | ggml_tensor * last_conv = ggml_view_3d(ctx: ctx0, a: conv_x, ne0: d_conv - 1, ne1: d_inner, ne2: n_seqs, nb1: conv_x->nb[1], nb2: conv_x->nb[2], |
| 56 | offset: n_seq_tokens * (conv_x->nb[0])); |
| 57 | |
| 58 | ggml_build_forward_expand( |
| 59 | cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: last_conv, |
| 60 | b: ggml_view_1d(ctx: ctx0, a: conv_states_all, ne0: (d_conv - 1) * (d_inner) * (n_seqs), |
| 61 | offset: kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(tensor: conv_states_all)))); |
| 62 | |
| 63 | // 1D convolution |
| 64 | // The equivalent is to make a self-overlapping view of conv_x |
| 65 | // over d_conv columns at each stride in the 3rd dimension, |
| 66 | // then element-wise multiply that with the conv1d weight, |
| 67 | // then sum the elements of each row, |
| 68 | // (the last two steps are a dot product over rows (also doable with mul_mat)) |
| 69 | // then permute away the ne[0] dimension, |
| 70 | // and then you're left with the resulting x tensor. |
| 71 | // For simultaneous sequences, all sequences need to have the same length. |
| 72 | x = ggml_ssm_conv(ctx: ctx0, sx: conv_x, c: layer.ssm_conv1d); |
| 73 | |
| 74 | // bias |
| 75 | x = ggml_add(ctx: ctx0, a: x, b: layer.ssm_conv1d_b); |
| 76 | |
| 77 | x = ggml_silu(ctx: ctx0, a: x); |
| 78 | } |
| 79 | |
| 80 | // ssm |
| 81 | { |
| 82 | // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} |
| 83 | ggml_tensor * x_db = build_lora_mm(w: layer.ssm_x, cur: x); |
| 84 | // split |
| 85 | ggml_tensor * dt = ggml_view_3d(ctx: ctx0, a: x_db, ne0: dt_rank, ne1: n_seq_tokens, ne2: n_seqs, nb1: x_db->nb[1], nb2: x_db->nb[2], offset: 0); |
| 86 | ggml_tensor * B = |
| 87 | ggml_view_4d(ctx: ctx0, a: x_db, ne0: d_state, /* n_group */ ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * x_db->nb[0], nb2: x_db->nb[1], |
| 88 | nb3: x_db->nb[2], offset: ggml_element_size(tensor: x_db) * dt_rank); |
| 89 | ggml_tensor * C = |
| 90 | ggml_view_4d(ctx: ctx0, a: x_db, ne0: d_state, /* n_group */ ne1: 1, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * x_db->nb[0], nb2: x_db->nb[1], |
| 91 | nb3: x_db->nb[2], offset: ggml_element_size(tensor: x_db) * (dt_rank + d_state)); |
| 92 | |
| 93 | // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers |
| 94 | if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) { |
| 95 | dt = build_norm(cur: dt, mw: layer.ssm_dt_norm, NULL, type: LLM_NORM_RMS, il); |
| 96 | B = build_norm(cur: B, mw: layer.ssm_b_norm, NULL, type: LLM_NORM_RMS, il); |
| 97 | C = build_norm(cur: C, mw: layer.ssm_c_norm, NULL, type: LLM_NORM_RMS, il); |
| 98 | } |
| 99 | |
| 100 | // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} |
| 101 | dt = build_lora_mm(w: layer.ssm_dt, cur: dt); |
| 102 | dt = ggml_add(ctx: ctx0, a: dt, b: layer.ssm_dt_b); |
| 103 | |
| 104 | cur = x; |
| 105 | x = ggml_reshape_4d(ctx: ctx0, a: x, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs); |
| 106 | |
| 107 | ggml_tensor * A = layer.ssm_a; |
| 108 | |
| 109 | // use the states and the indices provided by build_recurrent_state |
| 110 | // (this is necessary in order to properly use the states before they are overwritten, |
| 111 | // while avoiding to make unnecessary copies of the states) |
| 112 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
| 113 | ggml_tensor * ssm = ggml_reshape_4d(ctx, a: states, ne0: d_state, ne1: head_dim, ne2: n_head, ne3: mctx_cur->get_size()); |
| 114 | |
| 115 | // Custom operator to optimize the parallel associative scan |
| 116 | // as described in the Annex D of the Mamba paper. |
| 117 | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
| 118 | return ggml_ssm_scan(ctx, s: ssm, x, dt, A, B, C, ids); |
| 119 | }; |
| 120 | |
| 121 | ggml_tensor * y_ssm = build_rs(inp, s: ssm_states_all, state_size: hparams.n_embd_s(), n_seqs: ubatch.n_seqs, get_state_rows: get_ssm_rows); |
| 122 | |
| 123 | // store last states |
| 124 | ggml_build_forward_expand( |
| 125 | cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: ggml_view_1d(ctx: ctx0, a: y_ssm, ne0: d_state * d_inner * n_seqs, offset: x->nb[3] * x->ne[3]), |
| 126 | b: ggml_view_1d(ctx: ctx0, a: ssm_states_all, ne0: d_state * d_inner * n_seqs, |
| 127 | offset: kv_head * d_state * d_inner * ggml_element_size(tensor: ssm_states_all)))); |
| 128 | |
| 129 | ggml_tensor * y = ggml_view_3d(ctx: ctx0, a: y_ssm, ne0: d_inner, ne1: n_seq_tokens, ne2: n_seqs, nb1: x->nb[2], nb2: x->nb[3], offset: 0); |
| 130 | |
| 131 | // TODO: skip computing output earlier for unused tokens |
| 132 | |
| 133 | y = ggml_add(ctx: ctx0, a: y, b: ggml_mul(ctx: ctx0, a: cur, b: layer.ssm_d)); |
| 134 | y = ggml_swiglu_split(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: z), b: y); |
| 135 | |
| 136 | // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
| 137 | cur = build_lora_mm(w: layer.ssm_out, cur: y); |
| 138 | } |
| 139 | |
| 140 | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
| 141 | cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens * n_seqs); |
| 142 | |
| 143 | return cur; |
| 144 | } |
| 145 | |
| 146 | ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp, |
| 147 | ggml_tensor * cur, |
| 148 | const llama_model & model, |
| 149 | const llama_ubatch & ubatch, |
| 150 | int il) const { |
| 151 | const auto * mctx_cur = inp->mctx; |
| 152 | |
| 153 | const auto kv_head = mctx_cur->get_head(); |
| 154 | |
| 155 | const int64_t d_conv = hparams.ssm_d_conv; |
| 156 | const int64_t d_inner = hparams.ssm_d_inner; |
| 157 | const int64_t d_state = hparams.ssm_d_state; |
| 158 | const int64_t n_head = hparams.ssm_dt_rank; |
| 159 | const int64_t head_dim = d_inner / n_head; |
| 160 | const int64_t n_group = hparams.ssm_n_group; |
| 161 | const int64_t n_seqs = ubatch.n_seqs; |
| 162 | |
| 163 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
| 164 | |
| 165 | GGML_ASSERT(n_seqs != 0); |
| 166 | GGML_ASSERT(ubatch.equal_seqs()); |
| 167 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
| 168 | |
| 169 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
| 170 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
| 171 | |
| 172 | ggml_tensor * conv = build_rs(inp, s: conv_states_all, state_size: hparams.n_embd_r(), n_seqs); |
| 173 | conv = ggml_reshape_3d(ctx: ctx0, a: conv, ne0: d_conv - 1, ne1: d_inner + 2 * n_group * d_state, ne2: n_seqs); |
| 174 | |
| 175 | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
| 176 | cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens, ne2: n_seqs); |
| 177 | |
| 178 | // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads |
| 179 | |
| 180 | // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} |
| 181 | ggml_tensor * zxBCdt = build_lora_mm(w: model.layers[il].ssm_in, cur); |
| 182 | |
| 183 | // split the above in three |
| 184 | ggml_tensor * z = ggml_view_4d(ctx: ctx0, a: zxBCdt, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs, nb1: head_dim * zxBCdt->nb[0], |
| 185 | nb2: zxBCdt->nb[1], nb3: zxBCdt->nb[2], offset: 0); |
| 186 | ggml_tensor * xBC = ggml_view_3d(ctx: ctx0, a: zxBCdt, ne0: d_inner + 2 * n_group * d_state, ne1: n_seq_tokens, ne2: n_seqs, nb1: zxBCdt->nb[1], |
| 187 | nb2: zxBCdt->nb[2], offset: d_inner * ggml_element_size(tensor: zxBCdt)); |
| 188 | ggml_tensor * dt = ggml_view_3d(ctx: ctx0, a: zxBCdt, ne0: n_head, ne1: n_seq_tokens, ne2: n_seqs, nb1: zxBCdt->nb[1], nb2: zxBCdt->nb[2], |
| 189 | offset: (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(tensor: zxBCdt)); |
| 190 | |
| 191 | // conv |
| 192 | { |
| 193 | // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} |
| 194 | ggml_tensor * conv_x = ggml_concat(ctx: ctx0, a: conv, b: ggml_transpose(ctx: ctx0, a: xBC), dim: 0); |
| 195 | |
| 196 | // copy last (d_conv - 1) columns back into the state cache |
| 197 | ggml_tensor * last_conv = ggml_view_3d(ctx: ctx0, a: conv_x, ne0: d_conv - 1, ne1: d_inner + 2 * n_group * d_state, ne2: n_seqs, |
| 198 | nb1: conv_x->nb[1], nb2: conv_x->nb[2], offset: n_seq_tokens * (conv_x->nb[0])); |
| 199 | |
| 200 | ggml_build_forward_expand(cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: last_conv, |
| 201 | b: ggml_view_1d(ctx: ctx0, a: conv_states_all, |
| 202 | ne0: (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs), |
| 203 | offset: kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) * |
| 204 | ggml_element_size(tensor: conv_states_all)))); |
| 205 | |
| 206 | // 1D convolution |
| 207 | // The equivalent is to make a self-overlapping view of conv_x |
| 208 | // over d_conv columns at each stride in the 3rd dimension, |
| 209 | // then element-wise multiply that with the conv1d weight, |
| 210 | // then sum the elements of each row, |
| 211 | // (the last two steps are a dot product over rows (also doable with mul_mat)) |
| 212 | // then permute away the ne[0] dimension, |
| 213 | // and then you're left with the resulting x tensor. |
| 214 | // For simultaneous sequences, all sequences need to have the same length. |
| 215 | xBC = ggml_ssm_conv(ctx: ctx0, sx: conv_x, c: model.layers[il].ssm_conv1d); |
| 216 | |
| 217 | // bias |
| 218 | xBC = ggml_add(ctx: ctx0, a: xBC, b: model.layers[il].ssm_conv1d_b); |
| 219 | |
| 220 | xBC = ggml_silu(ctx: ctx0, a: xBC); |
| 221 | } |
| 222 | |
| 223 | // ssm |
| 224 | { |
| 225 | // These correspond to V K Q in SSM/attention duality |
| 226 | ggml_tensor * x = ggml_view_4d(ctx: ctx0, a: xBC, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs, nb1: head_dim * xBC->nb[0], |
| 227 | nb2: xBC->nb[1], nb3: xBC->nb[2], offset: 0); |
| 228 | ggml_tensor * B = ggml_view_4d(ctx: ctx0, a: xBC, ne0: d_state, ne1: n_group, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * xBC->nb[0], |
| 229 | nb2: xBC->nb[1], nb3: xBC->nb[2], offset: d_inner * ggml_element_size(tensor: xBC)); |
| 230 | ggml_tensor * C = ggml_view_4d(ctx: ctx0, a: xBC, ne0: d_state, ne1: n_group, ne2: n_seq_tokens, ne3: n_seqs, nb1: d_state * xBC->nb[0], |
| 231 | nb2: xBC->nb[1], nb3: xBC->nb[2], offset: (d_inner + n_group * d_state) * ggml_element_size(tensor: xBC)); |
| 232 | |
| 233 | // {n_head, n_seq_tokens, n_seqs} |
| 234 | dt = ggml_add(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: dt), b: model.layers[il].ssm_dt_b); |
| 235 | |
| 236 | ggml_tensor * A = model.layers[il].ssm_a; |
| 237 | |
| 238 | // use the states and the indices provided by build_recurrent_state |
| 239 | // (this is necessary in order to properly use the states before they are overwritten, |
| 240 | // while avoiding to make unnecessary copies of the states) |
| 241 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
| 242 | ggml_tensor * ssm = ggml_reshape_4d(ctx, a: states, ne0: d_state, ne1: head_dim, ne2: n_head, ne3: mctx_cur->get_size()); |
| 243 | |
| 244 | // TODO: use semistructured matrices to implement state-space duality |
| 245 | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
| 246 | return ggml_ssm_scan(ctx, s: ssm, x, dt, A, B, C, ids); |
| 247 | }; |
| 248 | |
| 249 | ggml_tensor * y_ssm = build_rs(inp, s: ssm_states_all, state_size: hparams.n_embd_s(), n_seqs: ubatch.n_seqs, get_state_rows: get_ssm_rows); |
| 250 | |
| 251 | // store last states |
| 252 | ggml_build_forward_expand( |
| 253 | cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: ggml_view_1d(ctx: ctx0, a: y_ssm, ne0: d_state * d_inner * n_seqs, offset: ggml_nelements(tensor: x) * x->nb[0]), |
| 254 | b: ggml_view_1d(ctx: ctx0, a: ssm_states_all, ne0: d_state * d_inner * n_seqs, |
| 255 | offset: kv_head * d_state * d_inner * ggml_element_size(tensor: ssm_states_all)))); |
| 256 | |
| 257 | ggml_tensor * y = ggml_view_4d(ctx: ctx0, a: y_ssm, ne0: head_dim, ne1: n_head, ne2: n_seq_tokens, ne3: n_seqs, nb1: x->nb[1], nb2: n_head * x->nb[1], |
| 258 | nb3: n_seq_tokens * n_head * x->nb[1], offset: 0); |
| 259 | |
| 260 | // TODO: skip computing output earlier for unused tokens |
| 261 | |
| 262 | y = ggml_add(ctx: ctx0, a: y, b: ggml_mul(ctx: ctx0, a: x, b: model.layers[il].ssm_d)); |
| 263 | cb(cur: y, name: "mamba2_y_add_d" , il); |
| 264 | y = ggml_swiglu_split(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: z), b: y); |
| 265 | |
| 266 | // grouped RMS norm |
| 267 | if (model.layers[il].ssm_norm) { |
| 268 | y = ggml_reshape_4d(ctx: ctx0, a: y, ne0: d_inner / n_group, ne1: n_group, ne2: n_seq_tokens, ne3: n_seqs); |
| 269 | y = build_norm(cur: y, mw: model.layers[il].ssm_norm, NULL, type: LLM_NORM_RMS, il); |
| 270 | } |
| 271 | |
| 272 | y = ggml_reshape_3d(ctx: ctx0, a: y, ne0: d_inner, ne1: n_seq_tokens, ne2: n_seqs); |
| 273 | |
| 274 | // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
| 275 | cur = build_lora_mm(w: model.layers[il].ssm_out, cur: y); |
| 276 | } |
| 277 | |
| 278 | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
| 279 | cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: n_seq_tokens * n_seqs); |
| 280 | cb(cur, name: "mamba_out" , il); |
| 281 | |
| 282 | return cur; |
| 283 | } |
| 284 | |