| 1 | #include "models.h" |
| 2 | |
| 3 | llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : |
| 4 | llm_graph_context(params), |
| 5 | model(model) {} |
| 6 | |
| 7 | ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer, |
| 8 | ggml_tensor * cur, |
| 9 | ggml_tensor * x_prev, |
| 10 | llm_arch arch) const { |
| 11 | ggml_tensor * sx = ggml_sub(ctx: ctx0, a: x_prev, b: cur); |
| 12 | switch (arch) { |
| 13 | case LLM_ARCH_RWKV7: |
| 14 | { |
| 15 | ggml_tensor * xk = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: sx, b: layer->channel_mix_lerp_k), b: cur); |
| 16 | |
| 17 | ggml_tensor * k = ggml_sqr(ctx: ctx0, a: ggml_relu(ctx: ctx0, a: build_lora_mm(w: layer->channel_mix_key, cur: xk))); |
| 18 | |
| 19 | cur = build_lora_mm(w: layer->channel_mix_value, cur: k); |
| 20 | } |
| 21 | break; |
| 22 | default: |
| 23 | GGML_ABORT("fatal error" ); |
| 24 | } |
| 25 | return cur; |
| 26 | } |
| 27 | |
| 28 | ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * inp, |
| 29 | ggml_tensor * cur, |
| 30 | ggml_tensor * x_prev, |
| 31 | ggml_tensor *& first_layer_value, |
| 32 | const llama_ubatch & ubatch, |
| 33 | int il) const { |
| 34 | const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); |
| 35 | |
| 36 | const auto n_tokens = ubatch.n_tokens; |
| 37 | const auto n_seqs = ubatch.n_seqs; |
| 38 | const auto n_embd = hparams.n_embd; |
| 39 | const auto head_size = hparams.wkv_head_size; |
| 40 | const auto head_count = n_embd / head_size; |
| 41 | const auto n_seq_tokens = ubatch.n_seq_tokens; |
| 42 | |
| 43 | const auto kv_head = mctx_cur->get_head(); |
| 44 | |
| 45 | const auto & layer = model.layers[il]; |
| 46 | |
| 47 | bool has_gating = layer.time_mix_g1 && layer.time_mix_g2; |
| 48 | |
| 49 | ggml_tensor * sx = ggml_sub(ctx: ctx0, a: x_prev, b: cur); |
| 50 | ggml_tensor * dummy = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_embd, ne1: n_seq_tokens, ne2: n_seqs, ne3: has_gating ? 6 : 5); |
| 51 | sx = ggml_repeat(ctx: ctx0, a: sx, b: dummy); |
| 52 | |
| 53 | ggml_tensor * xxx = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: sx, b: layer.time_mix_lerp_fused), b: cur); |
| 54 | |
| 55 | ggml_tensor * xr = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: 0); |
| 56 | ggml_tensor * xw = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * sizeof(float)); |
| 57 | ggml_tensor * xk = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 2 * sizeof(float)); |
| 58 | ggml_tensor * xv = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 3 * sizeof(float)); |
| 59 | ggml_tensor * xa = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 4 * sizeof(float)); |
| 60 | ggml_tensor * xg = |
| 61 | has_gating ? ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 5 * sizeof(float)) : |
| 62 | nullptr; |
| 63 | |
| 64 | ggml_tensor * r = build_lora_mm(w: layer.time_mix_receptance, cur: xr); |
| 65 | ggml_tensor * w = ggml_add( |
| 66 | ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_w2, b: ggml_tanh(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_w1, b: xw))), |
| 67 | b: layer.time_mix_w0); |
| 68 | w = ggml_exp(ctx: ctx0, a: ggml_scale(ctx: ctx0, a: ggml_sigmoid(ctx: ctx0, a: w), s: -0.606531)); |
| 69 | |
| 70 | ggml_tensor * k = build_lora_mm(w: layer.time_mix_key, cur: xk); |
| 71 | ggml_tensor * v = build_lora_mm(w: layer.time_mix_value, cur: xv); |
| 72 | if (first_layer_value == nullptr) { |
| 73 | first_layer_value = v; |
| 74 | } else { |
| 75 | // Add the first layer value as a residual connection. |
| 76 | v = ggml_add(ctx: ctx0, a: v, |
| 77 | b: ggml_mul(ctx: ctx0, a: ggml_sub(ctx: ctx0, a: first_layer_value, b: v), |
| 78 | b: ggml_sigmoid(ctx: ctx0, a: ggml_add(ctx: ctx0, |
| 79 | a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_v2, |
| 80 | b: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_v1, b: xv)), |
| 81 | b: layer.time_mix_v0)))); |
| 82 | } |
| 83 | ggml_tensor * g = nullptr; |
| 84 | if (layer.time_mix_g1 && layer.time_mix_g2) { |
| 85 | g = ggml_mul_mat(ctx: ctx0, a: layer.time_mix_g2, b: ggml_sigmoid(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_g1, b: xg))); |
| 86 | } |
| 87 | ggml_tensor * a = ggml_sigmoid( |
| 88 | ctx: ctx0, a: ggml_add(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_a2, b: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_a1, b: xa)), |
| 89 | b: layer.time_mix_a0)); |
| 90 | |
| 91 | ggml_tensor * kk = ggml_reshape_3d(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: k, b: layer.time_mix_k_k), ne0: head_size, ne1: head_count, ne2: n_tokens); |
| 92 | kk = ggml_l2_norm(ctx: ctx0, a: kk, eps: 1e-12); |
| 93 | |
| 94 | ggml_tensor * ka = ggml_mul(ctx: ctx0, a: k, b: layer.time_mix_k_a); |
| 95 | k = ggml_add(ctx: ctx0, a: k, b: ggml_sub(ctx: ctx0, a: ggml_mul(ctx: ctx0, a, b: ka), b: ka)); |
| 96 | |
| 97 | r = ggml_reshape_3d(ctx: ctx0, a: r, ne0: head_size, ne1: head_count, ne2: n_tokens); |
| 98 | w = ggml_reshape_3d(ctx: ctx0, a: w, ne0: head_size, ne1: head_count, ne2: n_tokens); |
| 99 | k = ggml_reshape_3d(ctx: ctx0, a: k, ne0: head_size, ne1: head_count, ne2: n_tokens); |
| 100 | v = ggml_reshape_3d(ctx: ctx0, a: v, ne0: head_size, ne1: head_count, ne2: n_tokens); |
| 101 | a = ggml_reshape_3d(ctx: ctx0, a, ne0: head_size, ne1: head_count, ne2: n_tokens); |
| 102 | |
| 103 | ggml_tensor * wkv_state = build_rs(inp, s: mctx_cur->get_s_l(il), state_size: hparams.n_embd_s(), n_seqs); |
| 104 | |
| 105 | ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx: ctx0, r, w, k, v, a: ggml_neg(ctx: ctx0, a: kk), b: ggml_mul(ctx: ctx0, a: kk, b: a), state: wkv_state); |
| 106 | cur = ggml_view_1d(ctx: ctx0, a: wkv_output, ne0: n_embd * n_tokens, offset: 0); |
| 107 | wkv_state = ggml_view_1d(ctx: ctx0, a: wkv_output, ne0: n_embd * head_size * n_seqs, offset: n_embd * n_tokens * sizeof(float)); |
| 108 | |
| 109 | ggml_build_forward_expand( |
| 110 | cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: wkv_state, |
| 111 | b: ggml_view_1d(ctx: ctx0, a: mctx_cur->get_s_l(il), ne0: hparams.n_embd_s() * n_seqs, |
| 112 | offset: hparams.n_embd_s() * kv_head * ggml_element_size(tensor: mctx_cur->get_s_l(il))))); |
| 113 | |
| 114 | if (layer.time_mix_ln && layer.time_mix_ln_b) { |
| 115 | // group norm with head_count groups |
| 116 | cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd / head_count, ne1: head_count, ne2: n_tokens); |
| 117 | cur = ggml_norm(ctx: ctx0, a: cur, eps: 64e-5f); |
| 118 | |
| 119 | // Convert back to regular vectors. |
| 120 | cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens); |
| 121 | cur = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: cur, b: layer.time_mix_ln), b: layer.time_mix_ln_b); |
| 122 | } else { |
| 123 | cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens); |
| 124 | } |
| 125 | ggml_tensor * rk = ggml_sum_rows( |
| 126 | ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: k, b: r), b: ggml_reshape_2d(ctx: ctx0, a: layer.time_mix_r_k, ne0: head_size, ne1: head_count))); |
| 127 | cur = ggml_add(ctx: ctx0, a: cur, b: ggml_reshape_2d(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: v, b: rk), ne0: n_embd, ne1: n_tokens)); |
| 128 | |
| 129 | if (has_gating) { |
| 130 | cur = ggml_mul(ctx: ctx0, a: cur, b: g); |
| 131 | } |
| 132 | cur = build_lora_mm(w: layer.time_mix_output, cur); |
| 133 | |
| 134 | return ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_seq_tokens, ne2: n_seqs); |
| 135 | } |
| 136 | |