| 1 | #include "models.h" |
| 2 | |
| 3 | |
| 4 | |
| 5 | llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : |
| 6 | llm_graph_context(params), |
| 7 | model(model), |
| 8 | n_embd_head(model.hparams.n_embd_head_k), |
| 9 | n_embd_altup(model.hparams.n_embd_altup), |
| 10 | n_altup(model.hparams.n_altup), |
| 11 | i_altup_act(model.hparams.i_altup_act) { |
| 12 | ggml_tensor * cur; |
| 13 | ggml_tensor * inpL; |
| 14 | |
| 15 | inpL = build_inp_embd(tok_embd: model.tok_embd); |
| 16 | |
| 17 | // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) |
| 18 | if (ubatch.token) { |
| 19 | inpL = ggml_scale(ctx: ctx0, a: inpL, s: sqrtf(x: n_embd)); |
| 20 | cb(cur: inpL, name: "inp_scaled" , il: -1); |
| 21 | } |
| 22 | // inp_pos - contains the positions |
| 23 | ggml_tensor * inp_pos = build_inp_pos(); |
| 24 | |
| 25 | // TODO: is causal == true correct? might need some changes |
| 26 | auto * inp_attn = build_attn_inp_kv_iswa(); |
| 27 | |
| 28 | // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] |
| 29 | ggml_tensor * inp_per_layer = project_per_layer_inputs(inputs_embeds: inpL, inp_per_layer: get_per_layer_inputs()); |
| 30 | |
| 31 | // inpL now has only 1 altup, project it to the rest of the altups |
| 32 | // these "added" altups will be concat to the last dim of inpL |
| 33 | { |
| 34 | ggml_tensor * target_magnitude = calc_magnitude(x: inpL); |
| 35 | ggml_tensor * inp_repeated = ggml_repeat_4d(ctx: ctx0, a: inpL, ne0: n_embd, ne1: n_tokens, ne2: n_altup - 1, ne3: 1); |
| 36 | ggml_tensor * altup_added = |
| 37 | ggml_mul_mat(ctx: ctx0, a: model.altup_proj, b: inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] |
| 38 | ggml_tensor * new_magnitude = calc_magnitude(x: altup_added); |
| 39 | altup_added = ggml_div(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: altup_added, b: target_magnitude), b: new_magnitude); |
| 40 | inpL = ggml_concat(ctx: ctx0, a: inpL, b: altup_added, dim: 2); // shape: [n_embd, n_tokens, n_altup] |
| 41 | cb(cur: inpL, name: "inp_stacked" , il: -1); |
| 42 | } |
| 43 | // inpL now has shape: [n_embd, n_tokens, n_altup] |
| 44 | // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] |
| 45 | |
| 46 | for (int il = 0; il < n_layer; ++il) { |
| 47 | // this block is made to be closely resemble Gemma3p5DecoderLayer on python code |
| 48 | const float freq_base_l = model.get_rope_freq_base(cparams, il); |
| 49 | const float freq_scale_l = model.get_rope_freq_scale(cparams, il); |
| 50 | |
| 51 | ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] |
| 52 | ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] |
| 53 | |
| 54 | // predicted value will go through self-attention and laurel |
| 55 | ggml_tensor * active_prediction = view_2d_slice(x: predictions, idx: i_altup_act); // [n_embd, n_tokens] |
| 56 | cur = active_prediction; |
| 57 | cb(cur, name: "active_prediction" , il); |
| 58 | |
| 59 | // norm |
| 60 | cur = build_norm(cur, mw: model.layers[il].attn_norm, NULL, type: LLM_NORM_RMS, il); |
| 61 | cb(cur, name: "attn_norm" , il); |
| 62 | |
| 63 | // laurel |
| 64 | ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] |
| 65 | |
| 66 | // self-attention |
| 67 | if (hparams.has_kv(il)) { |
| 68 | // compute Q and K and RoPE them |
| 69 | ggml_tensor * Qcur = build_lora_mm(w: model.layers[il].wq, cur); |
| 70 | cb(cur: Qcur, name: "Qcur" , il); |
| 71 | |
| 72 | ggml_tensor * Kcur = build_lora_mm(w: model.layers[il].wk, cur); |
| 73 | cb(cur: Kcur, name: "Kcur" , il); |
| 74 | |
| 75 | ggml_tensor * Vcur = build_lora_mm(w: model.layers[il].wv, cur); |
| 76 | cb(cur: Vcur, name: "Vcur" , il); |
| 77 | |
| 78 | Qcur = ggml_reshape_3d(ctx: ctx0, a: Qcur, ne0: n_embd_head, ne1: n_head, ne2: n_tokens); |
| 79 | Kcur = ggml_reshape_3d(ctx: ctx0, a: Kcur, ne0: n_embd_head, ne1: n_head_kv, ne2: n_tokens); |
| 80 | Vcur = ggml_reshape_3d(ctx: ctx0, a: Vcur, ne0: n_embd_head, ne1: n_head_kv, ne2: n_tokens); |
| 81 | |
| 82 | Qcur = build_norm(cur: Qcur, mw: model.layers[il].attn_q_norm, NULL, type: LLM_NORM_RMS, il); |
| 83 | Kcur = build_norm(cur: Kcur, mw: model.layers[il].attn_k_norm, NULL, type: LLM_NORM_RMS, il); |
| 84 | Vcur = ggml_rms_norm(ctx: ctx0, a: Vcur, eps: hparams.f_norm_rms_eps); |
| 85 | |
| 86 | cb(cur: Qcur, name: "Qcur_normed" , il); |
| 87 | cb(cur: Kcur, name: "Kcur_normed" , il); |
| 88 | cb(cur: Vcur, name: "Vcur_normed" , il); |
| 89 | |
| 90 | Qcur = ggml_rope_ext(ctx: ctx0, a: Qcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base: freq_base_l, freq_scale: freq_scale_l, |
| 91 | ext_factor, attn_factor, beta_fast, beta_slow); |
| 92 | |
| 93 | Kcur = ggml_rope_ext(ctx: ctx0, a: Kcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base: freq_base_l, freq_scale: freq_scale_l, |
| 94 | ext_factor, attn_factor, beta_fast, beta_slow); |
| 95 | |
| 96 | cb(cur: Qcur, name: "Qcur_pos" , il); |
| 97 | cb(cur: Kcur, name: "Kcur_pos" , il); |
| 98 | |
| 99 | cur = build_attn(inp: inp_attn, wo: model.layers[il].wo, |
| 100 | NULL, q_cur: Qcur, k_cur: Kcur, v_cur: Vcur, kq_b: nullptr, sinks: nullptr, v_mla: nullptr, |
| 101 | kq_scale: hparams.f_attention_scale, il); |
| 102 | } else { |
| 103 | // reuse KV cache of earlier layers |
| 104 | ggml_tensor * Qcur = build_lora_mm(w: model.layers[il].wq, cur); |
| 105 | cb(cur: Qcur, name: "Qcur" , il); |
| 106 | Qcur = ggml_reshape_3d(ctx: ctx0, a: Qcur, ne0: n_embd_head, ne1: n_head, ne2: n_tokens); |
| 107 | |
| 108 | Qcur = build_norm(cur: Qcur, mw: model.layers[il].attn_q_norm, NULL, type: LLM_NORM_RMS, il); |
| 109 | cb(cur: Qcur, name: "Qcur_normed" , il); |
| 110 | |
| 111 | Qcur = ggml_rope_ext(ctx: ctx0, a: Qcur, b: inp_pos, c: nullptr, n_dims: n_rot, mode: rope_type, n_ctx_orig, freq_base: freq_base_l, freq_scale: freq_scale_l, |
| 112 | ext_factor, attn_factor, beta_fast, beta_slow); |
| 113 | cb(cur: Qcur, name: "Qcur_pos" , il); |
| 114 | |
| 115 | cur = build_attn(inp: inp_attn, |
| 116 | wo: model.layers[il].wo, NULL, |
| 117 | q_cur: Qcur, k_cur: nullptr, v_cur: nullptr, kq_b: nullptr, sinks: nullptr, v_mla: nullptr, kq_scale: hparams.f_attention_scale, il); |
| 118 | } |
| 119 | cur = build_norm(cur, mw: model.layers[il].attn_post_norm, NULL, type: LLM_NORM_RMS, il); |
| 120 | cb(cur, name: "attn_post_norm" , il); |
| 121 | |
| 122 | cur = ggml_add(ctx: ctx0, a: cur, b: active_prediction); // [n_embd, n_tokens] |
| 123 | cb(cur, name: "attn_gated" , il); |
| 124 | |
| 125 | ggml_tensor * attn_laurel = ggml_scale(ctx: ctx0, a: ggml_add(ctx: ctx0, a: cur, b: laurel_out), |
| 126 | s: 1.0f / sqrtf(x: 2.0f)); // [n_embd, n_tokens] |
| 127 | cb(cur: attn_laurel, name: "attn_laurel" , il); |
| 128 | |
| 129 | cur = build_norm(cur: attn_laurel, mw: model.layers[il].ffn_norm, NULL, type: LLM_NORM_RMS, il); |
| 130 | cb(cur, name: "ffn_norm" , il); |
| 131 | |
| 132 | // feed-forward network |
| 133 | { |
| 134 | ggml_tensor * up_proj = build_lora_mm(w: model.layers[il].ffn_up, cur); |
| 135 | ggml_tensor * gate_proj = build_lora_mm(w: model.layers[il].ffn_gate, cur); |
| 136 | |
| 137 | if (il < n_layer_sparsity) { |
| 138 | // apply activation sparsity |
| 139 | gate_proj = gaussian_topk(x: gate_proj); |
| 140 | } |
| 141 | gate_proj = ggml_gelu(ctx: ctx0, a: gate_proj); |
| 142 | |
| 143 | cur = ggml_mul(ctx: ctx0, a: up_proj, b: gate_proj); |
| 144 | cur = build_lora_mm(w: model.layers[il].ffn_down, cur); |
| 145 | cb(cur, name: "ffn_out" , il); |
| 146 | } |
| 147 | cur = build_norm(cur, mw: model.layers[il].ffn_post_norm, NULL, type: LLM_NORM_RMS, il: -1); |
| 148 | cb(cur, name: "ffn_post_norm" , il); |
| 149 | |
| 150 | ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx: ctx0, a: cur, b: attn_laurel); // [n_embd, n_tokens] |
| 151 | cb(cur: attn_ffw_laurel_gated, name: "attn_ffw_laurel_gated" , il); |
| 152 | |
| 153 | ggml_tensor * corrected = altup_correct(predictions, activated: attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup] |
| 154 | |
| 155 | ggml_tensor * first_prediction; // [n_embd, n_tokens] |
| 156 | { |
| 157 | first_prediction = view_2d_slice(x: corrected, idx: i_altup_act); // [n_embd, n_tokens] |
| 158 | first_prediction = ggml_mul(ctx: ctx0, a: first_prediction, b: model.layers[il].altup_correct_scale); |
| 159 | first_prediction = build_lora_mm(w: model.layers[il].per_layer_inp_gate, cur: first_prediction); |
| 160 | first_prediction = ggml_gelu(ctx: ctx0, a: first_prediction); // [n_embd_altup, n_tokens] |
| 161 | cb(cur: first_prediction, name: "first_prediction_gated" , il); |
| 162 | ggml_tensor * inp_this_layer = view_2d_slice(x: inp_per_layer, idx: il); // [n_embd_altup, n_tokens] |
| 163 | first_prediction = ggml_mul(ctx: ctx0, a: first_prediction, b: inp_this_layer); // [n_embd_altup, n_tokens] |
| 164 | cb(cur: first_prediction, name: "first_prediction_scaled" , il); |
| 165 | |
| 166 | first_prediction = build_lora_mm(w: model.layers[il].per_layer_proj, cur: first_prediction); // [n_embd, n_tokens] |
| 167 | first_prediction = |
| 168 | build_norm(cur: first_prediction, mw: model.layers[il].per_layer_post_norm, NULL, type: LLM_NORM_RMS, il); |
| 169 | cb(cur: first_prediction, name: "first_prediction_out" , il); |
| 170 | } |
| 171 | // equivalent to python code: corrected_predictions[1:] += first_prediction |
| 172 | { |
| 173 | ggml_tensor * slice_first = view_2d_slice(x: corrected, idx: 0); |
| 174 | ggml_tensor * slice_rest = ggml_view_3d( |
| 175 | ctx: ctx0, a: corrected, ne0: n_embd, ne1: n_tokens, ne2: n_altup - 1, nb1: ggml_row_size(type: corrected->type, ne: n_embd), |
| 176 | nb2: ggml_row_size(type: corrected->type, ne: n_embd * n_tokens), offset: n_embd * n_tokens * ggml_element_size(tensor: corrected)); |
| 177 | ggml_tensor * tmp = ggml_add(ctx: ctx0, a: slice_rest, b: first_prediction); // [n_embd, n_tokens, n_altup - 1] |
| 178 | corrected = ggml_concat(ctx: ctx0, a: slice_first, b: tmp, dim: 2); // [n_embd, n_tokens, n_altup] |
| 179 | } |
| 180 | cur = corrected; // [n_embd, n_tokens, n_altup] |
| 181 | cur = build_cvec(cur, il); |
| 182 | cb(cur, name: "l_out" , il); |
| 183 | |
| 184 | // input for next layer |
| 185 | inpL = cur; |
| 186 | } |
| 187 | cur = inpL; // [n_embd, n_tokens, n_altup] |
| 188 | |
| 189 | // cur now has multiple altup(s), we want to merge them back to 1 altup |
| 190 | { |
| 191 | ggml_tensor * target_magnitude = calc_magnitude(x: view_2d_slice(x: cur, idx: i_altup_act)); // [n_embd, n_tokens] |
| 192 | // do a view to skip the first slice (active altup) |
| 193 | ggml_tensor * alt_slice = |
| 194 | ggml_view_3d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_tokens, ne2: n_altup - 1, nb1: ggml_row_size(type: cur->type, ne: n_embd), |
| 195 | nb2: ggml_row_size(type: cur->type, ne: n_embd * n_tokens), offset: n_embd * n_tokens * ggml_element_size(tensor: cur)); |
| 196 | ggml_tensor * altup_unembd = |
| 197 | ggml_mul_mat(ctx: ctx0, a: model.altup_unembd_proj, b: alt_slice); // shape: [n_embd, n_tokens, n_altup - 1] |
| 198 | ggml_tensor * new_magnitude = calc_magnitude(x: altup_unembd); |
| 199 | altup_unembd = ggml_div(ctx: ctx0, a: ggml_mul(ctx: ctx0, a: altup_unembd, b: target_magnitude), b: new_magnitude); |
| 200 | cb(cur: altup_unembd, name: "altup_unembd" , il: -1); |
| 201 | |
| 202 | // equivalent to torch.mean(hidden_states, dim=0) |
| 203 | cur = view_2d_slice(x: cur, idx: 0); // [n_embd, n_tokens] |
| 204 | for (int i = 0; i < n_altup - 1; ++i) { |
| 205 | cur = ggml_add(ctx: ctx0, a: cur, b: view_2d_slice(x: altup_unembd, idx: i)); |
| 206 | } |
| 207 | cur = ggml_scale(ctx: ctx0, a: cur, s: 1.0f / float(n_altup)); // [n_embd, n_tokens] |
| 208 | cb(cur, name: "unembd_merged" , il: -1); |
| 209 | } |
| 210 | // cur now has shape: [n_embd, n_tokens] |
| 211 | |
| 212 | // TODO: move this to right after the last KV layer |
| 213 | { |
| 214 | // skip computing output for unused tokens |
| 215 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
| 216 | cur = ggml_get_rows(ctx: ctx0, a: cur, b: inp_out_ids); |
| 217 | } |
| 218 | cur = build_norm(cur, mw: model.output_norm, NULL, type: LLM_NORM_RMS, il: -1); |
| 219 | |
| 220 | cb(cur, name: "result_norm" , il: -1); |
| 221 | res->t_embd = cur; |
| 222 | |
| 223 | cur = build_lora_mm(w: model.output, cur); |
| 224 | |
| 225 | { |
| 226 | // final logit soft-capping |
| 227 | cur = ggml_scale(ctx: ctx0, a: cur, s: 1.0f / hparams.f_final_logit_softcapping); |
| 228 | cur = ggml_tanh(ctx: ctx0, a: cur); |
| 229 | cur = ggml_scale(ctx: ctx0, a: cur, s: hparams.f_final_logit_softcapping); |
| 230 | } |
| 231 | cb(cur, name: "result_output" , il: -1); |
| 232 | res->t_logits = cur; |
| 233 | |
| 234 | ggml_build_forward_expand(cgraph: gf, tensor: cur); |
| 235 | } |
| 236 | |
| 237 | ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { |
| 238 | return ggml_sqrt(ctx: ctx0, a: ggml_sum_rows(ctx: ctx0, a: ggml_sqr(ctx: ctx0, a: x))); |
| 239 | } |
| 240 | |
| 241 | // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim |
| 242 | ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { |
| 243 | GGML_ASSERT(idx < (int) x->ne[2]); |
| 244 | return ggml_view_2d(ctx: ctx0, a: x, ne0: x->ne[0], ne1: x->ne[1], nb1: ggml_row_size(type: x->type, ne: x->ne[0]), |
| 245 | offset: idx * x->ne[0] * x->ne[1] * ggml_element_size(tensor: x)); |
| 246 | } |
| 247 | |
| 248 | // equivalent to get_per_layer_inputs() in python code |
| 249 | // output shape: [n_embd_altup, n_layer, n_tokens] |
| 250 | ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { |
| 251 | auto inp = std::make_unique<llm_graph_input_embd>(); |
| 252 | ggml_tensor * inp_per_layer; |
| 253 | if (ubatch.token) { |
| 254 | inp->tokens = ggml_new_tensor_1d(ctx: ctx0, type: GGML_TYPE_I32, ne0: ubatch.n_tokens); |
| 255 | ggml_set_input(tensor: inp->tokens); |
| 256 | res->t_tokens = inp->tokens; |
| 257 | inp_per_layer = ggml_get_rows(ctx: ctx0, a: model.tok_embd_per_layer, b: inp->tokens); |
| 258 | inp_per_layer = ggml_reshape_3d(ctx: ctx0, a: inp_per_layer, ne0: n_embd_altup, ne1: n_layer, ne2: n_tokens); |
| 259 | inp_per_layer = ggml_scale(ctx: ctx0, a: inp_per_layer, s: sqrtf(x: (float) n_embd_altup)); |
| 260 | cb(cur: inp_per_layer, name: "inp_per_layer_selected" , il: -1); |
| 261 | } else { |
| 262 | GGML_ABORT("TODO: support embd input" ); |
| 263 | } |
| 264 | res->add_input(input: std::move(inp)); |
| 265 | return inp_per_layer; |
| 266 | } |
| 267 | |
| 268 | // equivalent to project_per_layer_inputs() in python code |
| 269 | // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim |
| 270 | // output shape: [n_embd_altup, n_tokens, n_layer] |
| 271 | ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { |
| 272 | const float per_layer_projection_scale = 1.0f / sqrtf(x: (float) n_embd); |
| 273 | const float per_layer_input_scale = 1.0f / sqrtf(x: 2.0f); |
| 274 | |
| 275 | ggml_tensor * per_layer_proj = ggml_mul_mat(ctx: ctx0, a: model.per_layer_model_proj, b: inputs_embeds); |
| 276 | per_layer_proj = ggml_scale(ctx: ctx0, a: per_layer_proj, s: per_layer_projection_scale); |
| 277 | per_layer_proj = ggml_reshape_3d(ctx: ctx0, a: per_layer_proj, ne0: n_embd_altup, ne1: n_layer, ne2: n_tokens); |
| 278 | per_layer_proj = build_norm(cur: per_layer_proj, mw: model.per_layer_proj_norm, NULL, type: LLM_NORM_RMS, |
| 279 | il: -1); // [n_embd_altup, n_layer, n_tokens] |
| 280 | cb(cur: per_layer_proj, name: "per_layer_proj" , il: -1); |
| 281 | |
| 282 | inp_per_layer = ggml_add(ctx: ctx0, a: inp_per_layer, b: per_layer_proj); |
| 283 | inp_per_layer = ggml_scale(ctx: ctx0, a: inp_per_layer, s: per_layer_input_scale); |
| 284 | cb(cur: inp_per_layer, name: "inp_per_layer" , il: -1); |
| 285 | |
| 286 | // permute to shape: [n_embd_altup, n_tokens, n_layer] |
| 287 | inp_per_layer = ggml_cont(ctx: ctx0, a: ggml_permute(ctx: ctx0, a: inp_per_layer, axis0: 0, axis1: 2, axis2: 1, axis3: 3)); |
| 288 | return inp_per_layer; |
| 289 | } |
| 290 | |
| 291 | // input cur shape: [n_altup, n_tokens] |
| 292 | // output shape: [n_altup, n_tokens] |
| 293 | ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { |
| 294 | ggml_tensor * tmp = cur; |
| 295 | tmp = build_lora_mm(w: model.layers[il].laurel_l, cur: tmp); |
| 296 | tmp = build_lora_mm(w: model.layers[il].laurel_r, cur: tmp); |
| 297 | tmp = build_norm(cur: tmp, mw: model.layers[il].laurel_post_norm, NULL, type: LLM_NORM_RMS, il); |
| 298 | tmp = ggml_add(ctx: ctx0, a: tmp, b: cur); |
| 299 | cb(cur: tmp, name: "laurel_out" , il); |
| 300 | return tmp; |
| 301 | } |
| 302 | |
| 303 | // input x shape: [n_embd, n_tokens] |
| 304 | // output shape: [n_embd, n_tokens] |
| 305 | ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { |
| 306 | ggml_tensor * mean = ggml_mean(ctx: ctx0, a: x); |
| 307 | ggml_tensor * std = ggml_sqrt(ctx: ctx0, a: ggml_scale(ctx: ctx0, a: ggml_sum_rows(ctx: ctx0, a: ggml_sqr(ctx: ctx0, a: ggml_sub(ctx: ctx0, a: x, b: mean))), |
| 308 | s: 1.0f / (float) (x->ne[0] - 1))); |
| 309 | ggml_tensor * cutoff_x = ggml_add(ctx: ctx0, a: mean, b: ggml_scale(ctx: ctx0, a: std, s: f_sparsity_std_mul)); |
| 310 | return ggml_relu(ctx: ctx0, a: ggml_sub(ctx: ctx0, a: x, b: cutoff_x)); |
| 311 | } |
| 312 | |
| 313 | // |
| 314 | // altup functions |
| 315 | // |
| 316 | |
| 317 | // equivalent to compute_router_modalities() in python code |
| 318 | // input x shape: [n_embd, n_tokens] |
| 319 | // output shape: [n_altup, n_tokens] |
| 320 | ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) { |
| 321 | ggml_tensor * router_inputs = build_norm(cur: x, mw: model.layers[il].altup_router_norm, NULL, type: LLM_NORM_RMS, il); |
| 322 | |
| 323 | // router_input_scale |
| 324 | router_inputs = ggml_scale(ctx: ctx0, a: router_inputs, s: 1.0f / (float) n_embd); |
| 325 | |
| 326 | ggml_tensor * output = ggml_mul_mat(ctx: ctx0, a: model.layers[il].altup_router, b: router_inputs); |
| 327 | return ggml_tanh(ctx: ctx0, a: output); // [n_altup, n_tokens] |
| 328 | } |
| 329 | |
| 330 | // input cur shape: [n_embd, n_tokens, n_altup] |
| 331 | // output shape: [n_embd, n_tokens, n_altup] |
| 332 | ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { |
| 333 | ggml_tensor * activated = view_2d_slice(x: cur, idx: i_altup_act); // [n_embd, n_tokens] |
| 334 | ggml_tensor * modalities = altup_compute_router_modalities(x: activated, il); // [n_altup, n_tokens] |
| 335 | cb(cur: modalities, name: "modalities" , il); |
| 336 | |
| 337 | ggml_tensor * all_coefs = build_lora_mm(w: model.layers[il].altup_predict_coef, cur: modalities); |
| 338 | cb(cur: all_coefs, name: "all_coefs" , il); |
| 339 | // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) |
| 340 | all_coefs = ggml_reshape_3d(ctx: ctx0, a: all_coefs, ne0: n_altup, ne1: n_altup, ne2: n_tokens); |
| 341 | |
| 342 | // permute to [n_altup, n_embd, n_tokens] |
| 343 | ggml_tensor * cur_permuted = ggml_cont(ctx: ctx0, a: ggml_permute(ctx: ctx0, a: cur, axis0: 1, axis1: 2, axis2: 0, axis3: 3)); |
| 344 | ggml_tensor * predictions = ggml_mul_mat(ctx: ctx0, a: cur_permuted, b: all_coefs); // [n_altup, n_embd, n_tokens] |
| 345 | |
| 346 | // final shape must be the same as cur: [n_embd, n_tokens, n_altup] |
| 347 | predictions = ggml_cont(ctx: ctx0, a: ggml_permute(ctx: ctx0, a: predictions, axis0: 0, axis1: 2, axis2: 1, axis3: 3)); |
| 348 | predictions = ggml_add(ctx: ctx0, a: predictions, b: cur); |
| 349 | cb(cur: predictions, name: "predictions" , il); |
| 350 | |
| 351 | return predictions; |
| 352 | } |
| 353 | |
| 354 | // input predictions shape: [n_embd, n_tokens, n_altup] |
| 355 | // input activated shape: [n_embd, n_tokens] |
| 356 | // output shape: [n_embd, n_tokens, n_altup] |
| 357 | ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { |
| 358 | ggml_tensor * modalities = altup_compute_router_modalities(x: activated, il); // [n_altup, n_tokens] |
| 359 | cb(cur: modalities, name: "modalities" , il); |
| 360 | |
| 361 | ggml_tensor * active_prediction = view_2d_slice(x: predictions, idx: i_altup_act); |
| 362 | ggml_tensor * innovation = ggml_sub(ctx: ctx0, a: activated, b: active_prediction); // [n_embd, n_tokens] |
| 363 | cb(cur: innovation, name: "innovation" , il); |
| 364 | |
| 365 | ggml_tensor * all_coefs = build_lora_mm(w: model.layers[il].altup_correct_coef, cur: modalities); // [n_altup, n_tokens] |
| 366 | all_coefs = ggml_scale_bias(ctx: ctx0, a: all_coefs, s: 1.0f, b: 1.0f); // + 1.0 |
| 367 | cb(cur: all_coefs, name: "all_coefs" , il); |
| 368 | all_coefs = ggml_transpose(ctx: ctx0, a: all_coefs); // [n_tokens, n_altup] |
| 369 | all_coefs = ggml_cont_3d(ctx: ctx0, a: all_coefs, ne0: 1, ne1: n_tokens, ne2: n_altup); // [1, n_tokens, n_altup] |
| 370 | |
| 371 | innovation = ggml_repeat_4d(ctx: ctx0, a: innovation, ne0: n_embd, ne1: n_tokens, ne2: n_altup, ne3: 1); |
| 372 | ggml_tensor * corrected = ggml_mul(ctx: ctx0, a: innovation, b: all_coefs); // [n_embd, n_tokens, n_altup] |
| 373 | corrected = ggml_add(ctx: ctx0, a: corrected, b: predictions); // [n_embd, n_tokens, n_altup] |
| 374 | cb(cur: corrected, name: "corrected" , il); |
| 375 | |
| 376 | return corrected; |
| 377 | } |
| 378 | |