1#include "models.h"
2
3llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) :
4 llm_graph_context(params),
5 model(model) {}
6
7ggml_tensor * llm_build_rwkv6_base::build_rwkv6_channel_mix(const llama_layer * layer,
8 ggml_tensor * cur,
9 ggml_tensor * x_prev,
10 llm_arch arch) const {
11 ggml_tensor * sx = ggml_sub(ctx: ctx0, a: x_prev, b: cur);
12 switch (arch) {
13 case LLM_ARCH_RWKV6:
14 {
15 ggml_tensor * xk = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: sx, b: layer->channel_mix_lerp_k), b: cur);
16 ggml_tensor * xr = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: sx, b: layer->channel_mix_lerp_r), b: cur);
17
18 ggml_tensor * r = ggml_sigmoid(ctx: ctx0, a: build_lora_mm(w: layer->channel_mix_receptance, cur: xr));
19 ggml_tensor * k = ggml_sqr(ctx: ctx0, a: ggml_relu(ctx: ctx0, a: build_lora_mm(w: layer->channel_mix_key, cur: xk)));
20 cur = ggml_mul(ctx: ctx0, a: r, b: build_lora_mm(w: layer->channel_mix_value, cur: k));
21 }
22 break;
23 default:
24 GGML_ABORT("fatal error");
25 }
26 return cur;
27}
28
29ggml_tensor * llm_build_rwkv6_base::build_rwkv6_time_mix(llm_graph_input_rs * inp,
30 ggml_tensor * cur,
31 ggml_tensor * x_prev,
32 const llama_ubatch & ubatch,
33 int il) const {
34 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
35
36 const auto n_tokens = ubatch.n_tokens;
37 const auto n_seqs = ubatch.n_seqs;
38 const auto n_seq_tokens = ubatch.n_seq_tokens;
39 const auto n_embd = hparams.n_embd;
40 const auto head_size = hparams.wkv_head_size;
41 const auto n_head = n_embd / head_size;
42 const auto n_head_kv = hparams.n_head_kv(il);
43
44 const auto kv_head = mctx_cur->get_head();
45
46 const auto & layer = model.layers[il];
47
48 bool is_qrwkv = layer.time_mix_first == nullptr;
49
50 ggml_tensor * sx = ggml_sub(ctx: ctx0, a: x_prev, b: cur);
51
52 sx = ggml_reshape_2d(ctx: ctx0, a: sx, ne0: n_embd, ne1: n_tokens);
53 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens);
54
55 ggml_tensor * xxx = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: sx, b: layer.time_mix_lerp_x), b: cur);
56
57 xxx = ggml_reshape_4d(ctx: ctx0, a: ggml_tanh(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_w1, b: xxx)),
58 ne0: layer.time_mix_w1->ne[1] / 5, ne1: 1, ne2: 5, ne3: n_tokens);
59
60 xxx = ggml_cont(ctx: ctx0, a: ggml_permute(ctx: ctx0, a: xxx, axis0: 0, axis1: 1, axis2: 3, axis3: 2));
61
62 xxx = ggml_mul_mat(
63 ctx: ctx0, a: ggml_reshape_4d(ctx: ctx0, a: layer.time_mix_w2, ne0: layer.time_mix_w2->ne[0], ne1: layer.time_mix_w2->ne[1], ne2: 1, ne3: 5), b: xxx);
64
65 ggml_tensor *xw, *xk, *xv, *xr, *xg;
66 if (layer.time_mix_lerp_fused) {
67 // fusing these weights makes some performance improvement
68 sx = ggml_reshape_3d(ctx: ctx0, a: sx, ne0: n_embd, ne1: 1, ne2: n_tokens);
69 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd, ne1: 1, ne2: n_tokens);
70 xxx = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_add(ctx: ctx0, a: xxx, b: layer.time_mix_lerp_fused), b: sx), b: cur);
71 xw = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: 0);
72 xk = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * sizeof(float));
73 xv = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 2 * sizeof(float));
74 xr = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 3 * sizeof(float));
75 xg = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 4 * sizeof(float));
76 } else {
77 // for backward compatibility
78 xw = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: 0);
79 xk = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * sizeof(float));
80 xv = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 2 * sizeof(float));
81 xr = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 3 * sizeof(float));
82 xg = ggml_view_2d(ctx: ctx0, a: xxx, ne0: n_embd, ne1: n_tokens, nb1: xxx->nb[1], offset: n_embd * n_tokens * 4 * sizeof(float));
83
84 xw = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_add(ctx: ctx0, a: xw, b: layer.time_mix_lerp_w), b: sx), b: cur);
85 xk = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_add(ctx: ctx0, a: xk, b: layer.time_mix_lerp_k), b: sx), b: cur);
86 xv = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_add(ctx: ctx0, a: xv, b: layer.time_mix_lerp_v), b: sx), b: cur);
87 xr = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_add(ctx: ctx0, a: xr, b: layer.time_mix_lerp_r), b: sx), b: cur);
88 xg = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: ggml_add(ctx: ctx0, a: xg, b: layer.time_mix_lerp_g), b: sx), b: cur);
89 }
90 ggml_tensor * r = build_lora_mm(w: layer.time_mix_receptance, cur: xr);
91 ggml_tensor * k = build_lora_mm(w: layer.time_mix_key, cur: xk);
92 ggml_tensor * v = build_lora_mm(w: layer.time_mix_value, cur: xv);
93 if (layer.time_mix_receptance_b) {
94 r = ggml_add(ctx: ctx0, a: r, b: layer.time_mix_receptance_b);
95 }
96 if (layer.time_mix_key_b) {
97 k = ggml_add(ctx: ctx0, a: k, b: layer.time_mix_key_b);
98 }
99 if (layer.time_mix_value_b) {
100 v = ggml_add(ctx: ctx0, a: v, b: layer.time_mix_value_b);
101 }
102 ggml_tensor * g = build_lora_mm(w: layer.time_mix_gate, cur: xg);
103 if (is_qrwkv) {
104 g = ggml_sigmoid(ctx: ctx0, a: g);
105 } else {
106 g = ggml_silu(ctx: ctx0, a: g);
107 }
108 if (n_head_kv != 0 && n_head_kv != n_head) {
109 GGML_ASSERT(n_head % n_head_kv == 0);
110 k = ggml_reshape_4d(ctx: ctx0, a: k, ne0: head_size, ne1: 1, ne2: n_head_kv, ne3: n_tokens);
111 v = ggml_reshape_4d(ctx: ctx0, a: v, ne0: head_size, ne1: 1, ne2: n_head_kv, ne3: n_tokens);
112 ggml_tensor * tmp = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: head_size, ne1: n_head / n_head_kv, ne2: n_head_kv, ne3: n_tokens);
113 k = ggml_repeat(ctx: ctx0, a: k, b: tmp);
114 v = ggml_repeat(ctx: ctx0, a: v, b: tmp);
115 }
116 k = ggml_reshape_3d(ctx: ctx0, a: k, ne0: head_size, ne1: n_head, ne2: n_tokens);
117 v = ggml_reshape_3d(ctx: ctx0, a: v, ne0: head_size, ne1: n_head, ne2: n_tokens);
118 r = ggml_reshape_3d(ctx: ctx0, a: r, ne0: head_size, ne1: n_head, ne2: n_tokens);
119
120 ggml_tensor * w =
121 ggml_mul_mat(ctx: ctx0, a: layer.time_mix_decay_w2, b: ggml_tanh(ctx: ctx0, a: ggml_mul_mat(ctx: ctx0, a: layer.time_mix_decay_w1, b: xw)));
122
123 w = ggml_add(ctx: ctx0, a: w, b: layer.time_mix_decay);
124 w = ggml_exp(ctx: ctx0, a: ggml_neg(ctx: ctx0, a: ggml_exp(ctx: ctx0, a: w)));
125 w = ggml_reshape_3d(ctx: ctx0, a: w, ne0: head_size, ne1: n_head, ne2: n_tokens);
126
127 if (is_qrwkv) {
128 // k = k * (1 - w)
129 k = ggml_sub(ctx: ctx0, a: k, b: ggml_mul(ctx: ctx0, a: k, b: w));
130 }
131 ggml_tensor * wkv_state = build_rs(inp, s: mctx_cur->get_s_l(il), state_size: hparams.n_embd_s(), n_seqs);
132
133 ggml_tensor * wkv_output;
134 if (is_qrwkv) {
135 wkv_output = ggml_gated_linear_attn(ctx: ctx0, k, v, q: r, g: w, state: wkv_state, scale: pow(x: head_size, y: -0.5f));
136 } else {
137 wkv_output = ggml_rwkv_wkv6(ctx: ctx0, k, v, r, tf: layer.time_mix_first, td: w, state: wkv_state);
138 }
139 cur = ggml_view_1d(ctx: ctx0, a: wkv_output, ne0: n_embd * n_tokens, offset: 0);
140 wkv_state = ggml_view_1d(ctx: ctx0, a: wkv_output, ne0: n_embd * head_size * n_seqs, offset: n_embd * n_tokens * sizeof(float));
141
142 ggml_build_forward_expand(
143 cgraph: gf, tensor: ggml_cpy(ctx: ctx0, a: wkv_state,
144 b: ggml_view_1d(ctx: ctx0, a: mctx_cur->get_s_l(il), ne0: hparams.n_embd_s() * n_seqs,
145 offset: hparams.n_embd_s() * kv_head * ggml_element_size(tensor: mctx_cur->get_s_l(il)))));
146
147 if (!is_qrwkv) {
148 // group norm with head_count groups
149 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd / n_head, ne1: n_head, ne2: n_tokens);
150 cur = ggml_norm(ctx: ctx0, a: cur, eps: 64e-5f);
151
152 // Convert back to regular vectors.
153 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens);
154 cur = ggml_add(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: cur, b: layer.time_mix_ln), b: layer.time_mix_ln_b);
155 } else {
156 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens);
157 }
158 cur = ggml_mul(ctx: ctx0, a: cur, b: g);
159 cur = build_lora_mm(w: layer.time_mix_output, cur);
160
161 return ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_seq_tokens, ne2: n_seqs);
162}
163