1#include "models.h"
2
3
4
5llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
6 llm_graph_context(params),
7 model(model),
8 n_embd_head(model.hparams.n_embd_head_k),
9 n_embd_altup(model.hparams.n_embd_altup),
10 n_altup(model.hparams.n_altup),
11 i_altup_act(model.hparams.i_altup_act) {
12 ggml_tensor * cur;
13 ggml_tensor * inpL;
14
15 inpL = build_inp_embd(tok_embd: model.tok_embd);
16
17 // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
18 if (ubatch.token) {
19 inpL = ggml_scale(ctx: ctx0, a: inpL, s: sqrtf(x: n_embd));
20 cb(cur: inpL, name: "inp_scaled", il: -1);
21 }
22 // inp_pos - contains the positions
23 ggml_tensor * inp_pos = build_inp_pos();
24
25 // TODO: is causal == true correct? might need some changes
26 auto * inp_attn = build_attn_inp_kv_iswa();
27
28 // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
29 ggml_tensor * inp_per_layer = project_per_layer_inputs(inputs_embeds: inpL, inp_per_layer: get_per_layer_inputs());
30
31 // inpL now has only 1 altup, project it to the rest of the altups
32 // these "added" altups will be concat to the last dim of inpL
33 {
34 ggml_tensor * target_magnitude = calc_magnitude(x: inpL);
35 ggml_tensor * inp_repeated = ggml_repeat_4d(ctx: ctx0, a: inpL, ne0: n_embd, ne1: n_tokens, ne2: n_altup - 1, ne3: 1);
36 ggml_tensor * altup_added =
37 ggml_mul_mat(ctx: ctx0, a: model.altup_proj, b: inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
38 ggml_tensor * new_magnitude = calc_magnitude(x: altup_added);
39 altup_added = ggml_div(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: altup_added, b: target_magnitude), b: new_magnitude);
40 inpL = ggml_concat(ctx: ctx0, a: inpL, b: altup_added, dim: 2); // shape: [n_embd, n_tokens, n_altup]
41 cb(cur: inpL, name: "inp_stacked", il: -1);
42 }
43 // inpL now has shape: [n_embd, n_tokens, n_altup]
44 // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
45
46 for (int il = 0; il < n_layer; ++il) {
47 // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
48 const float freq_base_l = model.get_rope_freq_base(cparams, il);
49 const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
50
51 ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
52 ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
53
54 // predicted value will go through self-attention and laurel
55 ggml_tensor * active_prediction = view_2d_slice(x: predictions, idx: i_altup_act); // [n_embd, n_tokens]
56 cur = active_prediction;
57 cb(cur, name: "active_prediction", il);
58
59 // norm
60 cur = build_norm(cur, mw: model.layers[il].attn_norm, NULL, type: LLM_NORM_RMS, il);
61 cb(cur, name: "attn_norm", il);
62
63 // laurel
64 ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
65
66 // self-attention
67 if (hparams.has_kv(il)) {
68 // compute Q and K and RoPE them
69 ggml_tensor * Qcur = build_lora_mm(w: model.layers[il].wq, cur);
70 cb(cur: Qcur, name: "Qcur", il);
71
72 ggml_tensor * Kcur = build_lora_mm(w: model.layers[il].wk, cur);
73 cb(cur: Kcur, name: "Kcur", il);
74
75 ggml_tensor * Vcur = build_lora_mm(w: model.layers[il].wv, cur);
76 cb(cur: Vcur, name: "Vcur", il);
77
78 Qcur = ggml_reshape_3d(ctx: ctx0, a: Qcur, ne0: n_embd_head, ne1: n_head, ne2: n_tokens);
79 Kcur = ggml_reshape_3d(ctx: ctx0, a: Kcur, ne0: n_embd_head, ne1: n_head_kv, ne2: n_tokens);
80 Vcur = ggml_reshape_3d(ctx: ctx0, a: Vcur, ne0: n_embd_head, ne1: n_head_kv, ne2: n_tokens);
81
82 Qcur = build_norm(cur: Qcur, mw: model.layers[il].attn_q_norm, NULL, type: LLM_NORM_RMS, il);
83 Kcur = build_norm(cur: Kcur, mw: model.layers[il].attn_k_norm, NULL, type: LLM_NORM_RMS, il);
84 Vcur = ggml_rms_norm(ctx: ctx0, a: Vcur, eps: hparams.f_norm_rms_eps);
85
86 cb(cur: Qcur, name: "Qcur_normed", il);
87 cb(cur: Kcur, name: "Kcur_normed", il);
88 cb(cur: Vcur, name: "Vcur_normed", il);
89
90 Qcur = ggml_rope_ext(ctx: ctx0, a: Qcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base: freq_base_l, freq_scale: freq_scale_l,
91 ext_factor, attn_factor, beta_fast, beta_slow);
92
93 Kcur = ggml_rope_ext(ctx: ctx0, a: Kcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base: freq_base_l, freq_scale: freq_scale_l,
94 ext_factor, attn_factor, beta_fast, beta_slow);
95
96 cb(cur: Qcur, name: "Qcur_pos", il);
97 cb(cur: Kcur, name: "Kcur_pos", il);
98
99 cur = build_attn(inp: inp_attn, wo: model.layers[il].wo,
100 NULL, q_cur: Qcur, k_cur: Kcur, v_cur: Vcur, kq_b: nullptr, sinks: nullptr, v_mla: nullptr,
101 kq_scale: hparams.f_attention_scale, il);
102 } else {
103 // reuse KV cache of earlier layers
104 ggml_tensor * Qcur = build_lora_mm(w: model.layers[il].wq, cur);
105 cb(cur: Qcur, name: "Qcur", il);
106 Qcur = ggml_reshape_3d(ctx: ctx0, a: Qcur, ne0: n_embd_head, ne1: n_head, ne2: n_tokens);
107
108 Qcur = build_norm(cur: Qcur, mw: model.layers[il].attn_q_norm, NULL, type: LLM_NORM_RMS, il);
109 cb(cur: Qcur, name: "Qcur_normed", il);
110
111 Qcur = ggml_rope_ext(ctx: ctx0, a: Qcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base: freq_base_l, freq_scale: freq_scale_l,
112 ext_factor, attn_factor, beta_fast, beta_slow);
113 cb(cur: Qcur, name: "Qcur_pos", il);
114
115 cur = build_attn(inp: inp_attn,
116 wo: model.layers[il].wo, NULL,
117 q_cur: Qcur, k_cur: nullptr, v_cur: nullptr, kq_b: nullptr, sinks: nullptr, v_mla: nullptr, kq_scale: hparams.f_attention_scale, il);
118 }
119 cur = build_norm(cur, mw: model.layers[il].attn_post_norm, NULL, type: LLM_NORM_RMS, il);
120 cb(cur, name: "attn_post_norm", il);
121
122 cur = ggml_add(ctx: ctx0, a: cur, b: active_prediction); // [n_embd, n_tokens]
123 cb(cur, name: "attn_gated", il);
124
125 ggml_tensor * attn_laurel = ggml_scale(ctx: ctx0, a: ggml_add(ctx: ctx0, a: cur, b: laurel_out),
126 s: 1.0f / sqrtf(x: 2.0f)); // [n_embd, n_tokens]
127 cb(cur: attn_laurel, name: "attn_laurel", il);
128
129 cur = build_norm(cur: attn_laurel, mw: model.layers[il].ffn_norm, NULL, type: LLM_NORM_RMS, il);
130 cb(cur, name: "ffn_norm", il);
131
132 // feed-forward network
133 {
134 ggml_tensor * up_proj = build_lora_mm(w: model.layers[il].ffn_up, cur);
135 ggml_tensor * gate_proj = build_lora_mm(w: model.layers[il].ffn_gate, cur);
136
137 if (il < n_layer_sparsity) {
138 // apply activation sparsity
139 gate_proj = gaussian_topk(x: gate_proj);
140 }
141 gate_proj = ggml_gelu(ctx: ctx0, a: gate_proj);
142
143 cur = ggml_mul(ctx: ctx0, a: up_proj, b: gate_proj);
144 cur = build_lora_mm(w: model.layers[il].ffn_down, cur);
145 cb(cur, name: "ffn_out", il);
146 }
147 cur = build_norm(cur, mw: model.layers[il].ffn_post_norm, NULL, type: LLM_NORM_RMS, il: -1);
148 cb(cur, name: "ffn_post_norm", il);
149
150 ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx: ctx0, a: cur, b: attn_laurel); // [n_embd, n_tokens]
151 cb(cur: attn_ffw_laurel_gated, name: "attn_ffw_laurel_gated", il);
152
153 ggml_tensor * corrected = altup_correct(predictions, activated: attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
154
155 ggml_tensor * first_prediction; // [n_embd, n_tokens]
156 {
157 first_prediction = view_2d_slice(x: corrected, idx: i_altup_act); // [n_embd, n_tokens]
158 first_prediction = ggml_mul(ctx: ctx0, a: first_prediction, b: model.layers[il].altup_correct_scale);
159 first_prediction = build_lora_mm(w: model.layers[il].per_layer_inp_gate, cur: first_prediction);
160 first_prediction = ggml_gelu(ctx: ctx0, a: first_prediction); // [n_embd_altup, n_tokens]
161 cb(cur: first_prediction, name: "first_prediction_gated", il);
162 ggml_tensor * inp_this_layer = view_2d_slice(x: inp_per_layer, idx: il); // [n_embd_altup, n_tokens]
163 first_prediction = ggml_mul(ctx: ctx0, a: first_prediction, b: inp_this_layer); // [n_embd_altup, n_tokens]
164 cb(cur: first_prediction, name: "first_prediction_scaled", il);
165
166 first_prediction = build_lora_mm(w: model.layers[il].per_layer_proj, cur: first_prediction); // [n_embd, n_tokens]
167 first_prediction =
168 build_norm(cur: first_prediction, mw: model.layers[il].per_layer_post_norm, NULL, type: LLM_NORM_RMS, il);
169 cb(cur: first_prediction, name: "first_prediction_out", il);
170 }
171 // equivalent to python code: corrected_predictions[1:] += first_prediction
172 {
173 ggml_tensor * slice_first = view_2d_slice(x: corrected, idx: 0);
174 ggml_tensor * slice_rest = ggml_view_3d(
175 ctx: ctx0, a: corrected, ne0: n_embd, ne1: n_tokens, ne2: n_altup - 1, nb1: ggml_row_size(type: corrected->type, ne: n_embd),
176 nb2: ggml_row_size(type: corrected->type, ne: n_embd * n_tokens), offset: n_embd * n_tokens * ggml_element_size(tensor: corrected));
177 ggml_tensor * tmp = ggml_add(ctx: ctx0, a: slice_rest, b: first_prediction); // [n_embd, n_tokens, n_altup - 1]
178 corrected = ggml_concat(ctx: ctx0, a: slice_first, b: tmp, dim: 2); // [n_embd, n_tokens, n_altup]
179 }
180 cur = corrected; // [n_embd, n_tokens, n_altup]
181 cur = build_cvec(cur, il);
182 cb(cur, name: "l_out", il);
183
184 // input for next layer
185 inpL = cur;
186 }
187 cur = inpL; // [n_embd, n_tokens, n_altup]
188
189 // cur now has multiple altup(s), we want to merge them back to 1 altup
190 {
191 ggml_tensor * target_magnitude = calc_magnitude(x: view_2d_slice(x: cur, idx: i_altup_act)); // [n_embd, n_tokens]
192 // do a view to skip the first slice (active altup)
193 ggml_tensor * alt_slice =
194 ggml_view_3d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens, ne2: n_altup - 1, nb1: ggml_row_size(type: cur->type, ne: n_embd),
195 nb2: ggml_row_size(type: cur->type, ne: n_embd * n_tokens), offset: n_embd * n_tokens * ggml_element_size(tensor: cur));
196 ggml_tensor * altup_unembd =
197 ggml_mul_mat(ctx: ctx0, a: model.altup_unembd_proj, b: alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
198 ggml_tensor * new_magnitude = calc_magnitude(x: altup_unembd);
199 altup_unembd = ggml_div(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: altup_unembd, b: target_magnitude), b: new_magnitude);
200 cb(cur: altup_unembd, name: "altup_unembd", il: -1);
201
202 // equivalent to torch.mean(hidden_states, dim=0)
203 cur = view_2d_slice(x: cur, idx: 0); // [n_embd, n_tokens]
204 for (int i = 0; i < n_altup - 1; ++i) {
205 cur = ggml_add(ctx: ctx0, a: cur, b: view_2d_slice(x: altup_unembd, idx: i));
206 }
207 cur = ggml_scale(ctx: ctx0, a: cur, s: 1.0f / float(n_altup)); // [n_embd, n_tokens]
208 cb(cur, name: "unembd_merged", il: -1);
209 }
210 // cur now has shape: [n_embd, n_tokens]
211
212 // TODO: move this to right after the last KV layer
213 {
214 // skip computing output for unused tokens
215 ggml_tensor * inp_out_ids = build_inp_out_ids();
216 cur = ggml_get_rows(ctx: ctx0, a: cur, b: inp_out_ids);
217 }
218 cur = build_norm(cur, mw: model.output_norm, NULL, type: LLM_NORM_RMS, il: -1);
219
220 cb(cur, name: "result_norm", il: -1);
221 res->t_embd = cur;
222
223 cur = build_lora_mm(w: model.output, cur);
224
225 {
226 // final logit soft-capping
227 cur = ggml_scale(ctx: ctx0, a: cur, s: 1.0f / hparams.f_final_logit_softcapping);
228 cur = ggml_tanh(ctx: ctx0, a: cur);
229 cur = ggml_scale(ctx: ctx0, a: cur, s: hparams.f_final_logit_softcapping);
230 }
231 cb(cur, name: "result_output", il: -1);
232 res->t_logits = cur;
233
234 ggml_build_forward_expand(cgraph: gf, tensor: cur);
235}
236
237ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
238 return ggml_sqrt(ctx: ctx0, a: ggml_sum_rows(ctx: ctx0, a: ggml_sqr(ctx: ctx0, a: x)));
239}
240
241// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
242ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
243 GGML_ASSERT(idx < (int) x->ne[2]);
244 return ggml_view_2d(ctx: ctx0, a: x, ne0: x->ne[0], ne1: x->ne[1], nb1: ggml_row_size(type: x->type, ne: x->ne[0]),
245 offset: idx * x->ne[0] * x->ne[1] * ggml_element_size(tensor: x));
246}
247
248// equivalent to get_per_layer_inputs() in python code
249// output shape: [n_embd_altup, n_layer, n_tokens]
250ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
251 auto inp = std::make_unique<llm_graph_input_embd>();
252 ggml_tensor * inp_per_layer;
253 if (ubatch.token) {
254 inp->tokens = ggml_new_tensor_1d(ctx: ctx0, type: GGML_TYPE_I32, ne0: ubatch.n_tokens);
255 ggml_set_input(tensor: inp->tokens);
256 res->t_tokens = inp->tokens;
257 inp_per_layer = ggml_get_rows(ctx: ctx0, a: model.tok_embd_per_layer, b: inp->tokens);
258 inp_per_layer = ggml_reshape_3d(ctx: ctx0, a: inp_per_layer, ne0: n_embd_altup, ne1: n_layer, ne2: n_tokens);
259 inp_per_layer = ggml_scale(ctx: ctx0, a: inp_per_layer, s: sqrtf(x: (float) n_embd_altup));
260 cb(cur: inp_per_layer, name: "inp_per_layer_selected", il: -1);
261 } else {
262 GGML_ABORT("TODO: support embd input");
263 }
264 res->add_input(input: std::move(inp));
265 return inp_per_layer;
266}
267
268// equivalent to project_per_layer_inputs() in python code
269// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
270// output shape: [n_embd_altup, n_tokens, n_layer]
271ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
272 const float per_layer_projection_scale = 1.0f / sqrtf(x: (float) n_embd);
273 const float per_layer_input_scale = 1.0f / sqrtf(x: 2.0f);
274
275 ggml_tensor * per_layer_proj = ggml_mul_mat(ctx: ctx0, a: model.per_layer_model_proj, b: inputs_embeds);
276 per_layer_proj = ggml_scale(ctx: ctx0, a: per_layer_proj, s: per_layer_projection_scale);
277 per_layer_proj = ggml_reshape_3d(ctx: ctx0, a: per_layer_proj, ne0: n_embd_altup, ne1: n_layer, ne2: n_tokens);
278 per_layer_proj = build_norm(cur: per_layer_proj, mw: model.per_layer_proj_norm, NULL, type: LLM_NORM_RMS,
279 il: -1); // [n_embd_altup, n_layer, n_tokens]
280 cb(cur: per_layer_proj, name: "per_layer_proj", il: -1);
281
282 inp_per_layer = ggml_add(ctx: ctx0, a: inp_per_layer, b: per_layer_proj);
283 inp_per_layer = ggml_scale(ctx: ctx0, a: inp_per_layer, s: per_layer_input_scale);
284 cb(cur: inp_per_layer, name: "inp_per_layer", il: -1);
285
286 // permute to shape: [n_embd_altup, n_tokens, n_layer]
287 inp_per_layer = ggml_cont(ctx: ctx0, a: ggml_permute(ctx: ctx0, a: inp_per_layer, axis0: 0, axis1: 2, axis2: 1, axis3: 3));
288 return inp_per_layer;
289}
290
291// input cur shape: [n_altup, n_tokens]
292// output shape: [n_altup, n_tokens]
293ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) {
294 ggml_tensor * tmp = cur;
295 tmp = build_lora_mm(w: model.layers[il].laurel_l, cur: tmp);
296 tmp = build_lora_mm(w: model.layers[il].laurel_r, cur: tmp);
297 tmp = build_norm(cur: tmp, mw: model.layers[il].laurel_post_norm, NULL, type: LLM_NORM_RMS, il);
298 tmp = ggml_add(ctx: ctx0, a: tmp, b: cur);
299 cb(cur: tmp, name: "laurel_out", il);
300 return tmp;
301}
302
303// input x shape: [n_embd, n_tokens]
304// output shape: [n_embd, n_tokens]
305ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) {
306 ggml_tensor * mean = ggml_mean(ctx: ctx0, a: x);
307 ggml_tensor * std = ggml_sqrt(ctx: ctx0, a: ggml_scale(ctx: ctx0, a: ggml_sum_rows(ctx: ctx0, a: ggml_sqr(ctx: ctx0, a: ggml_sub(ctx: ctx0, a: x, b: mean))),
308 s: 1.0f / (float) (x->ne[0] - 1)));
309 ggml_tensor * cutoff_x = ggml_add(ctx: ctx0, a: mean, b: ggml_scale(ctx: ctx0, a: std, s: f_sparsity_std_mul));
310 return ggml_relu(ctx: ctx0, a: ggml_sub(ctx: ctx0, a: x, b: cutoff_x));
311}
312
313//
314// altup functions
315//
316
317// equivalent to compute_router_modalities() in python code
318// input x shape: [n_embd, n_tokens]
319// output shape: [n_altup, n_tokens]
320ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) {
321 ggml_tensor * router_inputs = build_norm(cur: x, mw: model.layers[il].altup_router_norm, NULL, type: LLM_NORM_RMS, il);
322
323 // router_input_scale
324 router_inputs = ggml_scale(ctx: ctx0, a: router_inputs, s: 1.0f / (float) n_embd);
325
326 ggml_tensor * output = ggml_mul_mat(ctx: ctx0, a: model.layers[il].altup_router, b: router_inputs);
327 return ggml_tanh(ctx: ctx0, a: output); // [n_altup, n_tokens]
328}
329
330// input cur shape: [n_embd, n_tokens, n_altup]
331// output shape: [n_embd, n_tokens, n_altup]
332ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
333 ggml_tensor * activated = view_2d_slice(x: cur, idx: i_altup_act); // [n_embd, n_tokens]
334 ggml_tensor * modalities = altup_compute_router_modalities(x: activated, il); // [n_altup, n_tokens]
335 cb(cur: modalities, name: "modalities", il);
336
337 ggml_tensor * all_coefs = build_lora_mm(w: model.layers[il].altup_predict_coef, cur: modalities);
338 cb(cur: all_coefs, name: "all_coefs", il);
339 // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
340 all_coefs = ggml_reshape_3d(ctx: ctx0, a: all_coefs, ne0: n_altup, ne1: n_altup, ne2: n_tokens);
341
342 // permute to [n_altup, n_embd, n_tokens]
343 ggml_tensor * cur_permuted = ggml_cont(ctx: ctx0, a: ggml_permute(ctx: ctx0, a: cur, axis0: 1, axis1: 2, axis2: 0, axis3: 3));
344 ggml_tensor * predictions = ggml_mul_mat(ctx: ctx0, a: cur_permuted, b: all_coefs); // [n_altup, n_embd, n_tokens]
345
346 // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
347 predictions = ggml_cont(ctx: ctx0, a: ggml_permute(ctx: ctx0, a: predictions, axis0: 0, axis1: 2, axis2: 1, axis3: 3));
348 predictions = ggml_add(ctx: ctx0, a: predictions, b: cur);
349 cb(cur: predictions, name: "predictions", il);
350
351 return predictions;
352}
353
354// input predictions shape: [n_embd, n_tokens, n_altup]
355// input activated shape: [n_embd, n_tokens]
356// output shape: [n_embd, n_tokens, n_altup]
357ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
358 ggml_tensor * modalities = altup_compute_router_modalities(x: activated, il); // [n_altup, n_tokens]
359 cb(cur: modalities, name: "modalities", il);
360
361 ggml_tensor * active_prediction = view_2d_slice(x: predictions, idx: i_altup_act);
362 ggml_tensor * innovation = ggml_sub(ctx: ctx0, a: activated, b: active_prediction); // [n_embd, n_tokens]
363 cb(cur: innovation, name: "innovation", il);
364
365 ggml_tensor * all_coefs = build_lora_mm(w: model.layers[il].altup_correct_coef, cur: modalities); // [n_altup, n_tokens]
366 all_coefs = ggml_scale_bias(ctx: ctx0, a: all_coefs, s: 1.0f, b: 1.0f); // + 1.0
367 cb(cur: all_coefs, name: "all_coefs", il);
368 all_coefs = ggml_transpose(ctx: ctx0, a: all_coefs); // [n_tokens, n_altup]
369 all_coefs = ggml_cont_3d(ctx: ctx0, a: all_coefs, ne0: 1, ne1: n_tokens, ne2: n_altup); // [1, n_tokens, n_altup]
370
371 innovation = ggml_repeat_4d(ctx: ctx0, a: innovation, ne0: n_embd, ne1: n_tokens, ne2: n_altup, ne3: 1);
372 ggml_tensor * corrected = ggml_mul(ctx: ctx0, a: innovation, b: all_coefs); // [n_embd, n_tokens, n_altup]
373 corrected = ggml_add(ctx: ctx0, a: corrected, b: predictions); // [n_embd, n_tokens, n_altup]
374 cb(cur: corrected, name: "corrected", il);
375
376 return corrected;
377}
378