1#include "models.h"
2
3llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) :
4 llm_graph_context(params),
5 model(model) {}
6
7ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer,
8 ggml_tensor * cur,
9 ggml_tensor * x_prev,
10 llm_arch arch) const {
11 ggml_tensor * sx = ggml_sub(ctx: ctx0, a: x_prev, b: cur);
12 switch (arch) {
13 case LLM_ARCH_RWKV7:
14 {
15 ggml_tensor * xk = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: sx, b: layer->channel_mix_lerp_k), b: cur);
16
17 ggml_tensor * k = ggml_sqr(ctx: ctx0, a: ggml_relu(ctx: ctx0, a: build_lora_mm(w: layer->channel_mix_key, cur: xk)));
18
19 cur = build_lora_mm(w: layer->channel_mix_value, cur: k);
20 }
21 break;
22 default:
23 GGML_ABORT("fatal error");
24 }
25 return cur;
26}
27
28ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * inp,
29 ggml_tensor * cur,
30 ggml_tensor * x_prev,
31 ggml_tensor *& first_layer_value,
32 const llama_ubatch & ubatch,
33 int il) const {
34 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
35
36 const auto n_tokens = ubatch.n_tokens;
37 const auto n_seqs = ubatch.n_seqs;
38 const auto n_embd = hparams.n_embd;
39 const auto head_size = hparams.wkv_head_size;
40 const auto head_count = n_embd / head_size;
41 const auto n_seq_tokens = ubatch.n_seq_tokens;
42
43 const auto kv_head = mctx_cur->get_head();
44
45 const auto & layer = model.layers[il];
46
47 bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
48
49 ggml_tensor * sx = ggml_sub(ctx: ctx0, a: x_prev, b: cur);
50 ggml_tensor * dummy = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_embd, ne1: n_seq_tokens, ne2: n_seqs, ne3: has_gating ? 6 : 5);
51 sx = ggml_repeat(ctx: ctx0, a: sx, b: dummy);
52
53 ggml_tensor * xxx = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: sx, b: layer.time_mix_lerp_fused), b: cur);
54
55 ggml_tensor * xr = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: 0);
56 ggml_tensor * xw = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * sizeof(float));
57 ggml_tensor * xk = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 2 * sizeof(float));
58 ggml_tensor * xv = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 3 * sizeof(float));
59 ggml_tensor * xa = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 4 * sizeof(float));
60 ggml_tensor * xg =
61 has_gating ? ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 5 * sizeof(float)) :
62 nullptr;
63
64 ggml_tensor * r = build_lora_mm(w: layer.time_mix_receptance, cur: xr);
65 ggml_tensor * w = ggml_add(
66 ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_w2, b: ggml_tanh(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_w1, b: xw))),
67 b: layer.time_mix_w0);
68 w = ggml_exp(ctx: ctx0, a: ggml_scale(ctx: ctx0, a: ggml_sigmoid(ctx: ctx0, a: w), s: -0.606531));
69
70 ggml_tensor * k = build_lora_mm(w: layer.time_mix_key, cur: xk);
71 ggml_tensor * v = build_lora_mm(w: layer.time_mix_value, cur: xv);
72 if (first_layer_value == nullptr) {
73 first_layer_value = v;
74 } else {
75 // Add the first layer value as a residual connection.
76 v = ggml_add(ctx: ctx0, a: v,
77 b: ggml_mul(ctx: ctx0, a: ggml_sub(ctx: ctx0, a: first_layer_value, b: v),
78 b: ggml_sigmoid(ctx: ctx0, a: ggml_add(ctx: ctx0,
79 a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_v2,
80 b: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_v1, b: xv)),
81 b: layer.time_mix_v0))));
82 }
83 ggml_tensor * g = nullptr;
84 if (layer.time_mix_g1 && layer.time_mix_g2) {
85 g = ggml_mul_mat(ctx: ctx0, a: layer.time_mix_g2, b: ggml_sigmoid(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_g1, b: xg)));
86 }
87 ggml_tensor * a = ggml_sigmoid(
88 ctx: ctx0, a: ggml_add(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_a2, b: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_a1, b: xa)),
89 b: layer.time_mix_a0));
90
91 ggml_tensor * kk = ggml_reshape_3d(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: k, b: layer.time_mix_k_k), ne0: head_size, ne1: head_count, ne2: n_tokens);
92 kk = ggml_l2_norm(ctx: ctx0, a: kk, eps: 1e-12);
93
94 ggml_tensor * ka = ggml_mul(ctx: ctx0, a: k, b: layer.time_mix_k_a);
95 k = ggml_add(ctx: ctx0, a: k, b: ggml_sub(ctx: ctx0, a: ggml_mul(ctx: ctx0, a, b: ka), b: ka));
96
97 r = ggml_reshape_3d(ctx: ctx0, a: r, ne0: head_size, ne1: head_count, ne2: n_tokens);
98 w = ggml_reshape_3d(ctx: ctx0, a: w, ne0: head_size, ne1: head_count, ne2: n_tokens);
99 k = ggml_reshape_3d(ctx: ctx0, a: k, ne0: head_size, ne1: head_count, ne2: n_tokens);
100 v = ggml_reshape_3d(ctx: ctx0, a: v, ne0: head_size, ne1: head_count, ne2: n_tokens);
101 a = ggml_reshape_3d(ctx: ctx0, a, ne0: head_size, ne1: head_count, ne2: n_tokens);
102
103 ggml_tensor * wkv_state = build_rs(inp, s: mctx_cur->get_s_l(il), state_size: hparams.n_embd_s(), n_seqs);
104
105 ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx: ctx0, r, w, k, v, a: ggml_neg(ctx: ctx0, a: kk), b: ggml_mul(ctx: ctx0, a: kk, b: a), state: wkv_state);
106 cur = ggml_view_1d(ctx: ctx0, a: wkv_output, ne0: n_embd * n_tokens, offset: 0);
107 wkv_state = ggml_view_1d(ctx: ctx0, a: wkv_output, ne0: n_embd * head_size * n_seqs, offset: n_embd * n_tokens * sizeof(float));
108
109 ggml_build_forward_expand(
110 cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: wkv_state,
111 b: ggml_view_1d(ctx: ctx0, a: mctx_cur->get_s_l(il), ne0: hparams.n_embd_s() * n_seqs,
112 offset: hparams.n_embd_s() * kv_head * ggml_element_size(tensor: mctx_cur->get_s_l(il)))));
113
114 if (layer.time_mix_ln && layer.time_mix_ln_b) {
115 // group norm with head_count groups
116 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd / head_count, ne1: head_count, ne2: n_tokens);
117 cur = ggml_norm(ctx: ctx0, a: cur, eps: 64e-5f);
118
119 // Convert back to regular vectors.
120 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens);
121 cur = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: cur, b: layer.time_mix_ln), b: layer.time_mix_ln_b);
122 } else {
123 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens);
124 }
125 ggml_tensor * rk = ggml_sum_rows(
126 ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: k, b: r), b: ggml_reshape_2d(ctx: ctx0, a: layer.time_mix_r_k, ne0: head_size, ne1: head_count)));
127 cur = ggml_add(ctx: ctx0, a: cur, b: ggml_reshape_2d(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: v, b: rk), ne0: n_embd, ne1: n_tokens));
128
129 if (has_gating) {
130 cur = ggml_mul(ctx: ctx0, a: cur, b: g);
131 }
132 cur = build_lora_mm(w: layer.time_mix_output, cur);
133
134 return ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_seq_tokens, ne2: n_seqs);
135}
136