1#include "llama-graph.h"
2
3#include "llama-impl.h"
4#include "llama-batch.h"
5#include "llama-cparams.h"
6
7#include "llama-kv-cache.h"
8#include "llama-kv-cache-iswa.h"
9#include "llama-memory-hybrid.h"
10#include "llama-memory-recurrent.h"
11
12#include <cassert>
13#include <cmath>
14#include <cstring>
15
16void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
17 if (ubatch->token) {
18 const int64_t n_tokens = ubatch->n_tokens;
19
20 ggml_backend_tensor_set(tensor: tokens, data: ubatch->token, offset: 0, size: n_tokens*ggml_element_size(tensor: tokens));
21 }
22
23 if (ubatch->embd) {
24 const int64_t n_embd = embd->ne[0];
25 const int64_t n_tokens = ubatch->n_tokens;
26
27 ggml_backend_tensor_set(tensor: embd, data: ubatch->embd, offset: 0, size: n_tokens*n_embd*ggml_element_size(tensor: embd));
28 }
29}
30
31bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32 bool res = true;
33
34 res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35 res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
37 return res;
38}
39
40void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
41 if (ubatch->pos && pos) {
42 const int64_t n_tokens = ubatch->n_tokens;
43
44 if (ubatch->token && n_pos_per_embd == 4) {
45 // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
46 // the 3 first dims are the same, and 4th dim is all 0
47 std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
48 // copy the first dimension
49 for (int i = 0; i < n_tokens; ++i) {
50 pos_data[ i] = ubatch->pos[i];
51 pos_data[ n_tokens + i] = ubatch->pos[i];
52 pos_data[2 * n_tokens + i] = ubatch->pos[i];
53 pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
54 }
55 ggml_backend_tensor_set(tensor: pos, data: pos_data.data(), offset: 0, size: pos_data.size()*ggml_element_size(tensor: pos));
56 } else {
57 ggml_backend_tensor_set(tensor: pos, data: ubatch->pos, offset: 0, size: n_tokens*n_pos_per_embd*ggml_element_size(tensor: pos));
58 }
59 }
60}
61
62bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63 bool res = true;
64
65 res &= pos->ne[0] == params.ubatch.n_tokens;
66
67 return res;
68}
69
70void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
71 if (ubatch->pos && attn_scale) {
72 const int64_t n_tokens = ubatch->n_tokens;
73
74 std::vector<float> attn_scale_data(n_tokens, 0.0f);
75 for (int i = 0; i < n_tokens; ++i) {
76 const float pos = ubatch->pos[i];
77 attn_scale_data[i] = std::log(
78 x: std::floor(x: (pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
79 ) * f_attn_temp_scale + 1.0;
80 }
81
82 ggml_backend_tensor_set(tensor: attn_scale, data: attn_scale_data.data(), offset: 0, size: n_tokens*ggml_element_size(tensor: attn_scale));
83 }
84}
85
86void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
87 if (pos_bucket) {
88 const int64_t n_tokens = ubatch->n_tokens;
89
90 GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
91 GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
92
93 int32_t * data = (int32_t *) pos_bucket->data;
94
95 for (int h = 0; h < 1; ++h) {
96 for (int j = 0; j < n_tokens; ++j) {
97 for (int i = 0; i < n_tokens; ++i) {
98 data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(x: ubatch->pos[i], y: ubatch->pos[j], n_buckets: hparams.n_rel_attn_bkts, bidirectional: true);
99 }
100 }
101 }
102 }
103}
104
105void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
106 if (pos_bucket) {
107 mctx->set_input_pos_bucket(dst: pos_bucket, ubatch);
108 }
109}
110
111void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
112 GGML_ASSERT(out_ids);
113
114 const int64_t n_tokens = ubatch->n_tokens;
115
116 GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
117 int32_t * data = (int32_t *) out_ids->data;
118
119 if (n_outputs == n_tokens) {
120 for (int i = 0; i < n_tokens; ++i) {
121 data[i] = i;
122 }
123
124 return;
125 }
126
127 GGML_ASSERT(ubatch->output);
128
129 int n_outputs = 0;
130
131 for (int i = 0; i < n_tokens; ++i) {
132 if (ubatch->output[i]) {
133 data[n_outputs++] = i;
134 }
135 }
136}
137
138bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
139 bool res = true;
140
141 res &= n_outputs == params.n_outputs;
142
143 return res;
144}
145
146void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
147 if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
148 const int64_t n_tokens = ubatch->n_tokens;
149 const int64_t n_seq_tokens = ubatch->n_seq_tokens;
150 const int64_t n_seqs_unq = ubatch->n_seqs_unq;
151
152 GGML_ASSERT(mean);
153 GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
154
155 float * data = (float *) mean->data;
156 memset(s: mean->data, c: 0, n: n_tokens*n_seqs_unq*ggml_element_size(tensor: mean));
157
158 std::vector<uint64_t> sums(n_seqs_unq, 0);
159 for (int i = 0; i < n_tokens; i += n_seq_tokens) {
160 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
161 const llama_seq_id seq_id = ubatch->seq_id[i][s];
162 const int32_t seq_idx = ubatch->seq_idx[seq_id];
163
164 sums[seq_idx] += ubatch->n_seq_tokens;
165 }
166 }
167
168 std::vector<float> div(n_seqs_unq, 0.0f);
169 for (int s = 0; s < n_seqs_unq; ++s) {
170 const uint64_t sum = sums[s];
171 if (sum > 0) {
172 div[s] = 1.0f/float(sum);
173 }
174 }
175
176 for (int i = 0; i < n_tokens; i += n_seq_tokens) {
177 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
178 const llama_seq_id seq_id = ubatch->seq_id[i][s];
179 const int32_t seq_idx = ubatch->seq_idx[seq_id];
180
181 for (int j = 0; j < n_seq_tokens; ++j) {
182 data[seq_idx*n_tokens + i + j] = div[seq_idx];
183 }
184 }
185 }
186 }
187}
188
189void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
190 const int64_t n_tokens = ubatch->n_tokens;
191 const int64_t n_seqs_unq = ubatch->n_seqs_unq;
192
193 if (cparams.embeddings && (
194 cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195 cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196 cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197 )) {
198 GGML_ASSERT(cls);
199 GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
200
201 uint32_t * data = (uint32_t *) cls->data;
202 memset(s: cls->data, c: 0, n: n_seqs_unq*ggml_element_size(tensor: cls));
203
204 std::vector<int> target_pos(n_seqs_unq, -1);
205 std::vector<int> target_row(n_seqs_unq, -1);
206
207 const bool last = (
208 cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
209 (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
210 );
211
212 for (int i = 0; i < n_tokens; ++i) {
213 const llama_pos pos = ubatch->pos[i];
214
215 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
216 const llama_seq_id seq_id = ubatch->seq_id[i][s];
217 const int32_t seq_idx = ubatch->seq_idx[seq_id];
218
219 if (
220 (target_pos[seq_idx] == -1) ||
221 ( last && pos >= target_pos[seq_idx]) ||
222 (!last && pos < target_pos[seq_idx])
223 ) {
224 target_pos[seq_idx] = pos;
225 target_row[seq_idx] = i;
226 }
227 }
228 }
229
230 for (int s = 0; s < n_seqs_unq; ++s) {
231 if (target_row[s] >= 0) {
232 data[s] = target_row[s];
233 }
234 }
235 }
236}
237
238void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
239 GGML_UNUSED(ubatch);
240
241 const int64_t n_rs = mctx->get_n_rs();
242
243 if (s_copy) {
244 GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
245 int32_t * data = (int32_t *) s_copy->data;
246
247 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
248 for (uint32_t i = 0; i < n_rs; ++i) {
249 data[i] = mctx->s_copy(i);
250 }
251 }
252}
253
254void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
255 GGML_UNUSED(ubatch);
256
257 if (cross_embd && !cross->v_embd.empty()) {
258 assert(cross_embd->type == GGML_TYPE_F32);
259
260 ggml_backend_tensor_set(tensor: cross_embd, data: cross->v_embd.data(), offset: 0, size: ggml_nbytes(tensor: cross_embd));
261 }
262}
263
264static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265 LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266 const char * swa_type_str = "unknown";
267
268 switch (swa_type) {
269 case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
270 case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
271 case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
272 case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
273 };
274
275 LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
276 LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
277 LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
278
279 LLAMA_LOG_DEBUG(" ");
280 for (int j = 0; j < std::min(a: (int64_t)20, b: n_kv); ++j) {
281 LLAMA_LOG_DEBUG("%2d", j);
282 }
283 LLAMA_LOG_DEBUG("\n");
284
285 for (int i = 0; i < std::min(a: (int64_t)20, b: n_tokens); ++i) {
286 LLAMA_LOG_DEBUG(" %2d ", i);
287 for (int j = 0; j < std::min(a: (int64_t)20, b: n_kv); ++j) {
288 float val = data[i * n_kv + j];
289 if (val == -INFINITY) {
290 LLAMA_LOG_DEBUG(" ∞");
291 } else {
292 LLAMA_LOG_DEBUG(" 0");
293 }
294 }
295 LLAMA_LOG_DEBUG("\n");
296 }
297}
298
299void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
300 const int64_t n_kv = ubatch->n_tokens;
301 const int64_t n_tokens = ubatch->n_tokens;
302
303 const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
304 for (int h = 0; h < 1; ++h) {
305 for (int i1 = 0; i1 < n_tokens; ++i1) {
306 const llama_seq_id s1 = ubatch->seq_id[i1][0];
307 const llama_pos p1 = ubatch->pos[i1];
308
309 const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
310
311 for (int i0 = 0; i0 < n_tokens; ++i0) {
312 const llama_seq_id s0 = ubatch->seq_id[i0][0];
313 const llama_pos p0 = ubatch->pos[i0];
314
315 // mask different sequences
316 if (s0 != s1) {
317 continue;
318 }
319
320 // mask future tokens
321 if (cparams.causal_attn && p0 > p1) {
322 continue;
323 }
324
325 // apply SWA if any
326 if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
327 continue;
328 }
329
330 data[idst + i0] = hparams.use_alibi ? -std::abs(x: p0 - p1) : 0.0f;
331 }
332 }
333 }
334 };
335
336 {
337 GGML_ASSERT(self_kq_mask);
338 GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
339
340 float * data = (float *) self_kq_mask->data;
341
342 std::fill(first: data, last: data + ggml_nelements(tensor: self_kq_mask), value: -INFINITY);
343
344 fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
345
346 if (debug) {
347 print_mask(data, n_tokens, n_kv, n_swa: 0, swa_type: LLAMA_SWA_TYPE_NONE);
348 }
349 }
350
351 if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
352 GGML_ASSERT(self_kq_mask_swa);
353 GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
354
355 float * data = (float *) self_kq_mask_swa->data;
356
357 std::fill(first: data, last: data + ggml_nelements(tensor: self_kq_mask_swa), value: -INFINITY);
358
359 fill_mask(data, hparams.n_swa, hparams.swa_type);
360
361 if (debug) {
362 print_mask(data, n_tokens, n_kv, n_swa: hparams.n_swa, swa_type: hparams.swa_type);
363 }
364 }
365}
366
367void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
368 mctx->set_input_k_idxs(dst: self_k_idxs, ubatch);
369 mctx->set_input_v_idxs(dst: self_v_idxs, ubatch);
370
371 mctx->set_input_kq_mask(dst: self_kq_mask, ubatch, causal_attn: cparams.causal_attn);
372}
373
374bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
375 const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
376
377 this->mctx = mctx;
378
379 bool res = true;
380
381 res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
382 //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
383
384 res &= self_kq_mask->ne[0] == mctx->get_n_kv();
385 res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
386
387 return res;
388}
389
390void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
391 mctx->get_base()->set_input_k_idxs(dst: self_k_idxs, ubatch);
392 mctx->get_base()->set_input_v_idxs(dst: self_v_idxs, ubatch);
393
394 mctx->get_base()->set_input_kq_mask(dst: self_kq_mask, ubatch, causal_attn: cparams.causal_attn);
395
396 mctx->get_swa()->set_input_k_idxs(dst: self_k_idxs_swa, ubatch);
397 mctx->get_swa()->set_input_v_idxs(dst: self_v_idxs_swa, ubatch);
398
399 mctx->get_swa()->set_input_kq_mask(dst: self_kq_mask_swa, ubatch, causal_attn: cparams.causal_attn);
400}
401
402bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
403 const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
404
405 this->mctx = mctx;
406
407 bool res = true;
408
409 res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
410 //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
411
412 res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
413 //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
414
415 res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
416 res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
417
418 res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
419 res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
420
421 return res;
422}
423
424void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
425 GGML_ASSERT(cross_kq_mask);
426
427 const int64_t n_enc = cross_kq_mask->ne[0];
428 const int64_t n_tokens = ubatch->n_tokens;
429
430 GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
431 GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
432
433 float * data = (float *) cross_kq_mask->data;
434
435 for (int h = 0; h < 1; ++h) {
436 for (int i = 0; i < n_tokens; ++i) {
437 for (int j = 0; j < n_enc; ++j) {
438 float f = -INFINITY;
439
440 for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
441 const llama_seq_id seq_id = ubatch->seq_id[i][s];
442
443 if (cross->seq_ids_enc[j].find(x: seq_id) != cross->seq_ids_enc[j].end()) {
444 f = 0.0f;
445 }
446 }
447
448 data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
449 }
450 }
451
452 for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
453 for (int j = 0; j < n_enc; ++j) {
454 data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
455 }
456 }
457 }
458}
459
460void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
461 inp_attn->set_input(ubatch);
462 inp_rs->set_input(ubatch);
463}
464
465//
466// llm_graph_result
467//
468
469llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
470 reset();
471
472 const char * LLAMA_GRAPH_RESULT_DEBUG = getenv(name: "LLAMA_GRAPH_RESULT_DEBUG");
473 debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(nptr: LLAMA_GRAPH_RESULT_DEBUG) : 0;
474}
475
476int64_t llm_graph_result::get_max_nodes() const {
477 return max_nodes;
478}
479
480void llm_graph_result::reset() {
481 t_tokens = nullptr;
482 t_logits = nullptr;
483 t_embd = nullptr;
484 t_embd_pooled = nullptr;
485
486 params = {};
487
488 inputs.clear();
489
490 buf_compute_meta.resize(new_size: ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(size: max_nodes, grads: false));
491
492 ggml_init_params params = {
493 /*.mem_size =*/ buf_compute_meta.size(),
494 /*.mem_buffer =*/ buf_compute_meta.data(),
495 /*.no_alloc =*/ true,
496 };
497
498 ctx_compute.reset(p: ggml_init(params));
499
500 gf = ggml_new_graph_custom(ctx: ctx_compute.get(), size: max_nodes, grads: false);
501}
502
503void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
504 for (auto & input : inputs) {
505 input->set_input(ubatch);
506 }
507}
508
509bool llm_graph_result::can_reuse(const llm_graph_params & params) {
510 if (!this->params.allow_reuse(other: params)) {
511 if (debug > 1) {
512 LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
513 }
514
515 return false;
516 }
517
518 if (debug > 1) {
519 LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
520 }
521
522 bool res = true;
523
524 for (auto & input : inputs) {
525 const bool cur = input->can_reuse(params);
526
527 if (debug > 1) {
528 LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
529 }
530
531 res = res && cur;
532 }
533
534 if (debug > 0) {
535 LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
536 }
537
538 return res;
539}
540
541llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
542 inputs.emplace_back(args: std::move(input));
543 return inputs.back().get();
544}
545
546void llm_graph_result::set_params(const llm_graph_params & params) {
547 this->params = params;
548}
549
550//
551// llm_graph_context
552//
553
554llm_graph_context::llm_graph_context(const llm_graph_params & params) :
555 arch (params.arch),
556 hparams (params.hparams),
557 cparams (params.cparams),
558 ubatch (params.ubatch),
559 n_embd (hparams.n_embd),
560 n_layer (hparams.n_layer),
561 n_rot (hparams.n_rot),
562 n_ctx (cparams.n_ctx),
563 n_head (hparams.n_head()),
564 n_head_kv (hparams.n_head_kv()),
565 n_embd_head_k (hparams.n_embd_head_k),
566 n_embd_k_gqa (hparams.n_embd_k_gqa()),
567 n_embd_head_v (hparams.n_embd_head_v),
568 n_embd_v_gqa (hparams.n_embd_v_gqa()),
569 n_expert (hparams.n_expert),
570 n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
571 freq_base (cparams.rope_freq_base),
572 freq_scale (cparams.rope_freq_scale),
573 ext_factor (cparams.yarn_ext_factor),
574 attn_factor (cparams.yarn_attn_factor),
575 beta_fast (cparams.yarn_beta_fast),
576 beta_slow (cparams.yarn_beta_slow),
577 norm_eps (hparams.f_norm_eps),
578 norm_rms_eps (hparams.f_norm_rms_eps),
579 n_tokens (ubatch.n_tokens),
580 n_outputs (params.n_outputs),
581 n_ctx_orig (cparams.n_ctx_orig_yarn),
582 pooling_type (cparams.pooling_type),
583 rope_type (hparams.rope_type),
584 sched (params.sched),
585 backend_cpu (params.backend_cpu),
586 cvec (params.cvec),
587 loras (params.loras),
588 mctx (params.mctx),
589 cross (params.cross),
590 cb_func (params.cb),
591 res (params.res),
592 ctx0 (res->get_ctx()),
593 gf (res->get_gf()) {
594 res->set_params(params);
595 }
596
597void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
598 if (cb_func) {
599 cb_func(ubatch, cur, name, il);
600 }
601}
602
603ggml_tensor * llm_graph_context::build_cvec(
604 ggml_tensor * cur,
605 int il) const {
606 return cvec->apply_to(ctx: ctx0, cur, il);
607}
608
609ggml_tensor * llm_graph_context::build_lora_mm(
610 ggml_tensor * w,
611 ggml_tensor * cur) const {
612 ggml_tensor * res = ggml_mul_mat(ctx: ctx0, a: w, b: cur);
613
614 for (const auto & lora : *loras) {
615 llama_adapter_lora_weight * lw = lora.first->get_weight(w);
616 if (lw == nullptr) {
617 continue;
618 }
619
620 const float adapter_scale = lora.second;
621 const float scale = lw->get_scale(alpha: lora.first->alpha, adapter_scale);
622
623 ggml_tensor * ab_cur = ggml_mul_mat(
624 ctx: ctx0, a: lw->b,
625 b: ggml_mul_mat(ctx: ctx0, a: lw->a, b: cur)
626 );
627
628 ab_cur = ggml_scale(ctx: ctx0, a: ab_cur, s: scale);
629 res = ggml_add(ctx: ctx0, a: res, b: ab_cur);
630 }
631
632 return res;
633}
634
635ggml_tensor * llm_graph_context::build_lora_mm_id(
636 ggml_tensor * w, // ggml_tensor * as
637 ggml_tensor * cur, // ggml_tensor * b
638 ggml_tensor * ids) const {
639 ggml_tensor * res = ggml_mul_mat_id(ctx: ctx0, as: w, b: cur, ids);
640 for (const auto & lora : *loras) {
641 llama_adapter_lora_weight * lw = lora.first->get_weight(w);
642 if (lw == nullptr) {
643 continue;
644 }
645
646 const float alpha = lora.first->alpha;
647 const float rank = (float) lw->b->ne[0];
648 const float scale = alpha ? lora.second * alpha / rank : lora.second;
649
650 ggml_tensor * ab_cur = ggml_mul_mat_id(
651 ctx: ctx0, as: lw->b,
652 b: ggml_mul_mat_id(ctx: ctx0, as: lw->a, b: cur, ids),
653 ids
654 );
655
656 ab_cur = ggml_scale(ctx: ctx0, a: ab_cur, s: scale);
657 res = ggml_add(ctx: ctx0, a: res, b: ab_cur);
658 }
659
660 return res;
661}
662
663ggml_tensor * llm_graph_context::build_norm(
664 ggml_tensor * cur,
665 ggml_tensor * mw,
666 ggml_tensor * mb,
667 llm_norm_type type,
668 int il) const {
669 switch (type) {
670 case LLM_NORM: cur = ggml_norm (ctx: ctx0, a: cur, eps: hparams.f_norm_eps); break;
671 case LLM_NORM_RMS: cur = ggml_rms_norm(ctx: ctx0, a: cur, eps: hparams.f_norm_rms_eps); break;
672 case LLM_NORM_GROUP:
673 {
674 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: 1, ne2: cur->ne[1]);
675 cur = ggml_group_norm(ctx: ctx0, a: cur, n_groups: hparams.n_norm_groups, eps: hparams.f_norm_group_eps);
676 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0], ne1: cur->ne[2]);
677 } break;
678 }
679
680 if (mw || mb) {
681 cb(cur, name: "norm", il);
682 }
683
684 if (mw) {
685 cur = ggml_mul(ctx: ctx0, a: cur, b: mw);
686 if (mb) {
687 cb(cur, name: "norm_w", il);
688 }
689 }
690
691 if (mb) {
692 cur = ggml_add(ctx: ctx0, a: cur, b: mb);
693 }
694
695 return cur;
696}
697
698ggml_tensor * llm_graph_context::build_ffn(
699 ggml_tensor * cur,
700 ggml_tensor * up,
701 ggml_tensor * up_b,
702 ggml_tensor * up_s,
703 ggml_tensor * gate,
704 ggml_tensor * gate_b,
705 ggml_tensor * gate_s,
706 ggml_tensor * down,
707 ggml_tensor * down_b,
708 ggml_tensor * down_s,
709 ggml_tensor * act_scales,
710 llm_ffn_op_type type_op,
711 llm_ffn_gate_type type_gate,
712 int il) const {
713 ggml_tensor * tmp = up ? build_lora_mm(w: up, cur) : cur;
714 cb(cur: tmp, name: "ffn_up", il);
715
716 if (up_b) {
717 tmp = ggml_add(ctx: ctx0, a: tmp, b: up_b);
718 cb(cur: tmp, name: "ffn_up_b", il);
719 }
720
721 if (up_s) {
722 tmp = ggml_mul(ctx: ctx0, a: tmp, b: up_s);
723 cb(cur: tmp, name: "ffn_up_s", il);
724 }
725
726 if (gate) {
727 switch (type_gate) {
728 case LLM_FFN_SEQ:
729 {
730 cur = build_lora_mm(w: gate, cur: tmp);
731 cb(cur, name: "ffn_gate", il);
732 } break;
733 case LLM_FFN_PAR:
734 {
735 cur = build_lora_mm(w: gate, cur);
736 cb(cur, name: "ffn_gate", il);
737 } break;
738 }
739
740 if (gate_b) {
741 cur = ggml_add(ctx: ctx0, a: cur, b: gate_b);
742 cb(cur, name: "ffn_gate_b", il);
743 }
744
745 if (gate_s) {
746 cur = ggml_mul(ctx: ctx0, a: cur, b: gate_s);
747 cb(cur, name: "ffn_gate_s", il);
748 }
749
750 } else {
751 cur = tmp;
752 }
753
754 switch (type_op) {
755 case LLM_FFN_SILU:
756 if (gate && type_gate == LLM_FFN_PAR) {
757 cur = ggml_swiglu_split(ctx: ctx0, a: cur, b: tmp);
758 cb(cur, name: "ffn_swiglu", il);
759 type_gate = LLM_FFN_SEQ;
760 } else {
761 cur = ggml_silu(ctx: ctx0, a: cur);
762 cb(cur, name: "ffn_silu", il);
763 } break;
764 case LLM_FFN_GELU:
765 if (gate && type_gate == LLM_FFN_PAR) {
766 cur = ggml_geglu_split(ctx: ctx0, a: cur, b: tmp);
767 cb(cur, name: "ffn_geglu", il);
768 type_gate = LLM_FFN_SEQ;
769 } else {
770 cur = ggml_gelu(ctx: ctx0, a: cur);
771 cb(cur, name: "ffn_gelu", il);
772 if (act_scales != NULL) {
773 cur = ggml_div(ctx: ctx0, a: cur, b: act_scales);
774 cb(cur, name: "ffn_act", il);
775 }
776 } break;
777 case LLM_FFN_RELU:
778 if (gate && type_gate == LLM_FFN_PAR) {
779 cur = ggml_reglu_split(ctx: ctx0, a: cur, b: tmp);
780 cb(cur, name: "ffn_reglu", il);
781 type_gate = LLM_FFN_SEQ;
782 } else {
783 cur = ggml_relu(ctx: ctx0, a: cur);
784 cb(cur, name: "ffn_relu", il);
785 } break;
786 case LLM_FFN_RELU_SQR:
787 {
788 cur = ggml_relu(ctx: ctx0, a: cur);
789 cb(cur, name: "ffn_relu", il);
790
791 cur = ggml_sqr(ctx: ctx0, a: cur);
792 cb(cur, name: "ffn_sqr(relu)", il);
793 } break;
794 case LLM_FFN_SWIGLU:
795 {
796 cur = ggml_swiglu(ctx: ctx0, a: cur);
797 cb(cur, name: "ffn_swiglu", il);
798 } break;
799 case LLM_FFN_GEGLU:
800 {
801 cur = ggml_geglu(ctx: ctx0, a: cur);
802 cb(cur, name: "ffn_geglu", il);
803 } break;
804 case LLM_FFN_REGLU:
805 {
806 cur = ggml_reglu(ctx: ctx0, a: cur);
807 cb(cur, name: "ffn_reglu", il);
808 } break;
809 default:
810 GGML_ABORT("fatal error");
811 }
812
813 //expand here so that we can fuse ffn gate
814 ggml_build_forward_expand(cgraph: gf, tensor: cur);
815
816 if (gate && type_gate == LLM_FFN_PAR) {
817 cur = ggml_mul(ctx: ctx0, a: cur, b: tmp);
818 cb(cur, name: "ffn_gate_par", il);
819 }
820
821 if (down) {
822 cur = build_lora_mm(w: down, cur);
823 if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
824 // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
825 ggml_mul_mat_set_prec(a: cur, prec: GGML_PREC_F32);
826 }
827 }
828
829 if (down_b) {
830 cb(cur, name: "ffn_down", il);
831 }
832
833 if (down_b) {
834 cur = ggml_add(ctx: ctx0, a: cur, b: down_b);
835 }
836
837 if (down_s) {
838 cur = ggml_mul(ctx: ctx0, a: cur, b: down_s);
839 cb(cur, name: "ffn_down_s", il);
840 }
841
842 return cur;
843}
844
845ggml_tensor * llm_graph_context::build_moe_ffn(
846 ggml_tensor * cur,
847 ggml_tensor * gate_inp,
848 ggml_tensor * up_exps,
849 ggml_tensor * gate_exps,
850 ggml_tensor * down_exps,
851 ggml_tensor * exp_probs_b,
852 int64_t n_expert,
853 int64_t n_expert_used,
854 llm_ffn_op_type type_op,
855 bool norm_w,
856 bool scale_w,
857 float w_scale,
858 llama_expert_gating_func_type gating_op,
859 int il,
860 ggml_tensor * probs_in) const {
861 return build_moe_ffn(
862 cur,
863 gate_inp, /* gate_inp_b */ nullptr,
864 up_exps, /* up_exps_b */ nullptr,
865 gate_exps, /* gate_exps_b */ nullptr,
866 down_exps, /* down_exps_b */ nullptr,
867 exp_probs_b,
868 n_expert,
869 n_expert_used,
870 type_op,
871 norm_w,
872 scale_w,
873 w_scale,
874 gating_op,
875 il,
876 probs_in
877 );
878}
879
880ggml_tensor * llm_graph_context::build_moe_ffn(
881 ggml_tensor * cur,
882 ggml_tensor * gate_inp,
883 ggml_tensor * gate_inp_b,
884 ggml_tensor * up_exps,
885 ggml_tensor * up_exps_b,
886 ggml_tensor * gate_exps,
887 ggml_tensor * gate_exps_b,
888 ggml_tensor * down_exps,
889 ggml_tensor * down_exps_b,
890 ggml_tensor * exp_probs_b,
891 int64_t n_expert,
892 int64_t n_expert_used,
893 llm_ffn_op_type type_op,
894 bool norm_w,
895 bool scale_w,
896 float w_scale,
897 llama_expert_gating_func_type gating_op,
898 int il,
899 ggml_tensor * probs_in) const {
900 const int64_t n_embd = cur->ne[0];
901 const int64_t n_tokens = cur->ne[1];
902 const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
903
904 ggml_tensor * logits = nullptr;
905
906 if (probs_in == nullptr) {
907 logits = build_lora_mm(w: gate_inp, cur); // [n_expert, n_tokens]
908 cb(cur: logits, name: "ffn_moe_logits", il);
909 } else {
910 logits = probs_in;
911 }
912
913 if (gate_inp_b) {
914 logits = ggml_add(ctx: ctx0, a: logits, b: gate_inp_b);
915 cb(cur: logits, name: "ffn_moe_logits_biased", il);
916 }
917
918 ggml_tensor * probs = nullptr;
919 switch (gating_op) {
920 case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
921 {
922 probs = ggml_soft_max(ctx: ctx0, a: logits); // [n_expert, n_tokens]
923 } break;
924 case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
925 {
926 probs = ggml_sigmoid(ctx: ctx0, a: logits); // [n_expert, n_tokens]
927 } break;
928 case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
929 {
930 probs = logits; // [n_expert, n_tokens]
931 } break;
932 default:
933 GGML_ABORT("fatal error");
934 }
935 cb(cur: probs, name: "ffn_moe_probs", il);
936
937 // add experts selection bias - introduced in DeepSeek V3
938 // leave probs unbiased as it's later used to get expert weights
939 ggml_tensor * selection_probs = probs;
940 if (exp_probs_b != nullptr) {
941 selection_probs = ggml_add(ctx: ctx0, a: probs, b: exp_probs_b);
942 cb(cur: selection_probs, name: "ffn_moe_probs_biased", il);
943 }
944
945 // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
946 // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
947 if (arch == LLM_ARCH_LLAMA4) {
948 selection_probs = logits;
949 }
950
951 if (arch == LLM_ARCH_GROVEMOE) {
952 selection_probs = ggml_sigmoid(ctx: ctx0, a: logits); // [n_expert, n_tokens]
953 cb(cur: selection_probs, name: "ffn_moe_probs_biased", il);
954 }
955
956 // select top n_group_used expert groups
957 // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
958 if (hparams.n_expert_groups > 1 && n_tokens > 0) {
959 const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
960
961 // organize experts into n_expert_groups
962 ggml_tensor * selection_groups = ggml_reshape_3d(ctx: ctx0, a: selection_probs, ne0: n_exp_per_group, ne1: hparams.n_expert_groups, ne2: n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
963
964 ggml_tensor * group_scores = ggml_top_k(ctx: ctx0, a: selection_groups, k: 2); // [2, n_expert_groups, n_tokens]
965 group_scores = ggml_get_rows(ctx: ctx0, a: ggml_reshape_4d(ctx: ctx0, a: selection_groups, ne0: 1, ne1: selection_groups->ne[0], ne2: selection_groups->ne[1], ne3: selection_groups->ne[2]), b: group_scores); // [1, 2, n_expert_groups, n_tokens]
966
967 // get top n_group_used expert groups
968 group_scores = ggml_sum_rows(ctx: ctx0, a: ggml_reshape_3d(ctx: ctx0, a: group_scores, ne0: group_scores->ne[1], ne1: group_scores->ne[2], ne2: group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
969 group_scores = ggml_reshape_2d(ctx: ctx0, a: group_scores, ne0: group_scores->ne[1], ne1: group_scores->ne[2]); // [n_expert_groups, n_tokens]
970
971 ggml_tensor * expert_groups = ggml_top_k(ctx: ctx0, a: group_scores, k: hparams.n_group_used); // [n_group_used, n_tokens]
972 cb(cur: expert_groups, name: "ffn_moe_group_topk", il);
973
974 // mask out the other groups
975 selection_probs = ggml_get_rows(ctx: ctx0, a: selection_groups, b: expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
976 selection_probs = ggml_set_rows(ctx: ctx0, a: ggml_scale_bias(ctx: ctx0, a: selection_groups, s: 0.0f, b: -INFINITY), b: selection_probs, c: expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
977 selection_probs = ggml_reshape_2d(ctx: ctx0, a: selection_probs, ne0: n_expert, ne1: n_tokens); // [n_expert, n_tokens]
978 cb(cur: selection_probs, name: "ffn_moe_probs_masked", il);
979 }
980
981 // select experts
982 ggml_tensor * selected_experts = ggml_top_k(ctx: ctx0, a: selection_probs, k: n_expert_used); // [n_expert_used, n_tokens]
983 cb(cur: selected_experts->src[0], name: "ffn_moe_argsort", il);
984 cb(cur: selected_experts, name: "ffn_moe_topk", il);
985
986 if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
987 // TODO: Use scalar div instead when/if implemented
988 ggml_tensor * f_sel = ggml_cast(ctx: ctx0, a: selected_experts, type: GGML_TYPE_F32);
989 selected_experts = ggml_cast(ctx: ctx0, a: ggml_scale(ctx: ctx0, a: f_sel, s: 1.0f / float(hparams.n_group_experts)), type: GGML_TYPE_I32);
990 probs = ggml_reshape_3d(ctx: ctx0, a: probs, ne0: 1, ne1: hparams.n_expert, ne2: n_tokens);
991 } else {
992 probs = ggml_reshape_3d(ctx: ctx0, a: probs, ne0: 1, ne1: n_expert, ne2: n_tokens);
993 }
994
995 ggml_tensor * weights = ggml_get_rows(ctx: ctx0, a: probs, b: selected_experts); // [1, n_expert_used, n_tokens]
996 cb(cur: weights, name: "ffn_moe_weights", il);
997
998
999 if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1000 weights = ggml_reshape_2d(ctx: ctx0, a: weights, ne0: n_expert_used, ne1: n_tokens);
1001 weights = ggml_soft_max(ctx: ctx0, a: weights); // [n_expert_used, n_tokens]
1002 weights = ggml_reshape_3d(ctx: ctx0, a: weights, ne0: 1, ne1: n_expert_used, ne2: n_tokens);
1003 cb(cur: weights, name: "ffn_moe_weights_softmax", il);
1004 }
1005
1006 if (norm_w) {
1007 weights = ggml_reshape_2d(ctx: ctx0, a: weights, ne0: n_expert_used, ne1: n_tokens);
1008
1009 ggml_tensor * weights_sum = ggml_sum_rows(ctx: ctx0, a: weights); // [1, n_tokens]
1010 cb(cur: weights_sum, name: "ffn_moe_weights_sum", il);
1011
1012 // Avoid division by zero, clamp to smallest number representable by F16
1013 weights_sum = ggml_clamp(ctx: ctx0, a: weights_sum, min: 6.103515625e-5, INFINITY);
1014 cb(cur: weights_sum, name: "ffn_moe_weights_sum_clamped", il);
1015
1016 weights = ggml_div(ctx: ctx0, a: weights, b: weights_sum); // [n_expert_used, n_tokens]
1017 cb(cur: weights, name: "ffn_moe_weights_norm", il);
1018
1019 weights = ggml_reshape_3d(ctx: ctx0, a: weights, ne0: 1, ne1: n_expert_used, ne2: n_tokens);
1020 }
1021 if (scale_w) {
1022 weights = ggml_scale(ctx: ctx0, a: weights, s: w_scale);
1023 cb(cur: weights, name: "ffn_moe_weights_scaled", il);
1024 }
1025
1026 //call early so that topk-moe can be used
1027 ggml_build_forward_expand(cgraph: gf, tensor: weights);
1028
1029 cur = ggml_reshape_3d(ctx: ctx0, a: cur, ne0: n_embd, ne1: 1, ne2: n_tokens);
1030
1031 if (weight_before_ffn) {
1032 // repeat cur to [n_embd, n_expert_used, n_tokens]
1033 ggml_tensor * repeated = ggml_repeat_4d(ctx: ctx0, a: cur, ne0: n_embd, ne1: n_expert_used, ne2: n_tokens, ne3: 1);
1034 cur = ggml_mul(ctx: ctx0, a: repeated, b: weights);
1035 cb(cur, name: "ffn_moe_weighted", il);
1036 }
1037
1038 ggml_tensor * up = build_lora_mm_id(w: up_exps, cur, ids: selected_experts); // [n_ff, n_expert_used, n_tokens]
1039 cb(cur: up, name: "ffn_moe_up", il);
1040
1041 if (up_exps_b) {
1042 up = ggml_add_id(ctx: ctx0, a: up, b: up_exps_b, ids: selected_experts);
1043 cb(cur: up, name: "ffn_moe_up_biased", il);
1044 }
1045
1046 ggml_tensor * experts = nullptr;
1047 if (gate_exps) {
1048 cur = build_lora_mm_id(w: gate_exps, cur, ids: selected_experts); // [n_ff, n_expert_used, n_tokens]
1049 cb(cur, name: "ffn_moe_gate", il);
1050 } else {
1051 cur = up;
1052 }
1053
1054 if (gate_exps_b) {
1055 cur = ggml_add_id(ctx: ctx0, a: cur, b: gate_exps_b, ids: selected_experts);
1056 cb(cur, name: "ffn_moe_gate_biased", il);
1057 }
1058
1059 switch (type_op) {
1060 case LLM_FFN_SILU:
1061 if (gate_exps) {
1062 cur = ggml_swiglu_split(ctx: ctx0, a: cur, b: up);
1063 cb(cur, name: "ffn_moe_swiglu", il);
1064 } else {
1065 cur = ggml_silu(ctx: ctx0, a: cur);
1066 cb(cur, name: "ffn_moe_silu", il);
1067 } break;
1068 case LLM_FFN_GELU:
1069 if (gate_exps) {
1070 cur = ggml_geglu_split(ctx: ctx0, a: cur, b: up);
1071 cb(cur, name: "ffn_moe_geglu", il);
1072 } else {
1073 cur = ggml_gelu(ctx: ctx0, a: cur);
1074 cb(cur, name: "ffn_moe_gelu", il);
1075 } break;
1076 case LLM_FFN_SWIGLU_OAI_MOE:
1077 {
1078 // TODO: move to hparams?
1079 constexpr float alpha = 1.702f;
1080 constexpr float limit = 7.0f;
1081 cur = ggml_swiglu_oai(ctx: ctx0, a: cur, b: up, alpha, limit);
1082 cb(cur, name: "ffn_moe_swiglu_oai", il);
1083 } break;
1084 case LLM_FFN_RELU:
1085 if (gate_exps) {
1086 cur = ggml_reglu_split(ctx: ctx0, a: cur, b: up);
1087 cb(cur, name: "ffn_moe_reglu", il);
1088 } else {
1089 cur = ggml_relu(ctx: ctx0, a: cur);
1090 cb(cur, name: "ffn_moe_relu", il);
1091 } break;
1092 default:
1093 GGML_ABORT("fatal error");
1094 }
1095
1096 //expand here so that we can fuse ffn gate
1097 ggml_build_forward_expand(cgraph: gf, tensor: cur);
1098
1099 experts = build_lora_mm_id(w: down_exps, cur, ids: selected_experts); // [n_embd, n_expert_used, n_tokens]
1100 cb(cur: experts, name: "ffn_moe_down", il);
1101
1102 if (down_exps_b) {
1103 experts = ggml_add_id(ctx: ctx0, a: experts, b: down_exps_b, ids: selected_experts);
1104 cb(cur: experts, name: "ffn_moe_down_biased", il);
1105 }
1106
1107 if (!weight_before_ffn) {
1108 experts = ggml_mul(ctx: ctx0, a: experts, b: weights);
1109 cb(cur, name: "ffn_moe_weighted", il);
1110 }
1111
1112 ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1113
1114 assert(n_expert_used > 0);
1115
1116 // order the views before the adds
1117 for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1118 cur_experts[i] = ggml_view_2d(ctx: ctx0, a: experts, ne0: n_embd, ne1: n_tokens, nb1: experts->nb[2], offset: i*experts->nb[1]);
1119
1120 ggml_build_forward_expand(cgraph: gf, tensor: cur_experts[i]);
1121 }
1122
1123 // aggregate experts
1124 // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1125 // to avoid potentially a large number of add nodes during warmup
1126 // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1127 ggml_tensor * moe_out = cur_experts[0];
1128
1129 for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1130 moe_out = ggml_add(ctx: ctx0, a: moe_out, b: cur_experts[i]);
1131 }
1132
1133 if (hparams.n_expert_used == 1) {
1134 // avoid returning a non-contiguous tensor
1135 moe_out = ggml_cont(ctx: ctx0, a: moe_out);
1136 }
1137
1138 cb(cur: moe_out, name: "ffn_moe_out", il);
1139
1140 return moe_out;
1141}
1142
1143// input embeddings with optional lora
1144ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1145 const int64_t n_embd = hparams.n_embd_inp();
1146
1147 auto inp = std::make_unique<llm_graph_input_embd>();
1148
1149 ggml_tensor * cur = nullptr;
1150
1151 if (ubatch.token) {
1152 inp->tokens = ggml_new_tensor_1d(ctx: ctx0, type: GGML_TYPE_I32, ne0: ubatch.n_tokens);
1153 //cb(inp->tokens, "inp_tokens", -1);
1154 ggml_set_input(tensor: inp->tokens);
1155 res->t_tokens = inp->tokens;
1156
1157 cur = ggml_get_rows(ctx: ctx0, a: tok_embd, b: inp->tokens);
1158
1159 // apply lora for embedding tokens if needed
1160 for (const auto & lora : *loras) {
1161 llama_adapter_lora_weight * lw = lora.first->get_weight(w: tok_embd);
1162 if (lw == nullptr) {
1163 continue;
1164 }
1165
1166 const float adapter_scale = lora.second;
1167 const float scale = lw->get_scale(alpha: lora.first->alpha, adapter_scale);
1168
1169 ggml_tensor * inpL_delta = ggml_scale(ctx: ctx0, a: ggml_mul_mat(
1170 ctx: ctx0, a: lw->b, // non-transposed lora_b
1171 b: ggml_get_rows(ctx: ctx0, a: lw->a, b: inp->tokens)
1172 ), s: scale);
1173
1174 cur = ggml_add(ctx: ctx0, a: cur, b: inpL_delta);
1175 }
1176 } else {
1177 inp->embd = ggml_new_tensor_2d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_embd, ne1: ubatch.n_tokens);
1178 ggml_set_input(tensor: inp->embd);
1179
1180 cur = inp->embd;
1181 }
1182
1183 // For Granite architecture
1184 if (hparams.f_embedding_scale != 0.0f) {
1185 cur = ggml_scale(ctx: ctx0, a: cur, s: hparams.f_embedding_scale);
1186 }
1187
1188 cb(cur, name: "inp_embd", il: -1);
1189
1190 res->add_input(input: std::move(inp));
1191
1192 return cur;
1193}
1194
1195ggml_tensor * llm_graph_context::build_inp_pos() const {
1196 auto inp = std::make_unique<llm_graph_input_pos>(args: hparams.n_pos_per_embd());
1197
1198 auto & cur = inp->pos;
1199
1200 cur = ggml_new_tensor_1d(ctx: ctx0, type: GGML_TYPE_I32, ne0: (int64_t)n_tokens*hparams.n_pos_per_embd());
1201 ggml_set_input(tensor: cur);
1202
1203 res->add_input(input: std::move(inp));
1204
1205 return cur;
1206}
1207
1208ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1209 auto inp = std::make_unique<llm_graph_input_attn_temp>(args: hparams.n_attn_temp_floor_scale, args: hparams.f_attn_temp_scale);
1210
1211 auto & cur = inp->attn_scale;
1212
1213 // this need to be 1x1xN for broadcasting
1214 cur = ggml_new_tensor_3d(ctx: ctx0, type: GGML_TYPE_F32, ne0: 1, ne1: 1, ne2: n_tokens);
1215 ggml_set_input(tensor: cur);
1216
1217 res->add_input(input: std::move(inp));
1218
1219 return cur;
1220}
1221
1222ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1223 // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1224 // but this would make the graph topology depend on the number of output tokens, which can interere with
1225 // features that require constant topology such as pipline parallelism
1226 // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1227 //if (n_outputs < n_tokens) {
1228 // return nullptr;
1229 //}
1230
1231 auto inp = std::make_unique<llm_graph_input_out_ids>(args: hparams, args: cparams, args: n_outputs);
1232
1233 auto & cur = inp->out_ids;
1234
1235 cur = ggml_new_tensor_1d(ctx: ctx0, type: GGML_TYPE_I32, ne0: n_outputs);
1236 ggml_set_input(tensor: cur);
1237
1238 res->add_input(input: std::move(inp));
1239
1240 return cur;
1241}
1242
1243ggml_tensor * llm_graph_context::build_inp_mean() const {
1244 auto inp = std::make_unique<llm_graph_input_mean>(args: cparams);
1245
1246 auto & cur = inp->mean;
1247
1248 cur = ggml_new_tensor_2d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_tokens, ne1: ubatch.n_seqs_unq);
1249 ggml_set_input(tensor: cur);
1250
1251 res->add_input(input: std::move(inp));
1252
1253 return cur;
1254}
1255
1256ggml_tensor * llm_graph_context::build_inp_cls() const {
1257 auto inp = std::make_unique<llm_graph_input_cls>(args: cparams, args: arch);
1258
1259 auto & cur = inp->cls;
1260
1261 cur = ggml_new_tensor_1d(ctx: ctx0, type: GGML_TYPE_I32, ne0: ubatch.n_seqs_unq);
1262 ggml_set_input(tensor: cur);
1263
1264 res->add_input(input: std::move(inp));
1265
1266 return cur;
1267}
1268
1269ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1270 auto inp = std::make_unique<llm_graph_input_cross_embd>(args: cross);
1271
1272 auto & cur = inp->cross_embd;
1273
1274 // if we have the output embeddings from the encoder, use them directly
1275 // TODO: needs more work to be correct, for now just use the tensor shape
1276 //if (cross->t_embd) {
1277 // cur = ggml_view_tensor(ctx0, cross->t_embd);
1278
1279 // return cur;
1280 //}
1281
1282 const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1283 const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1284
1285 cur = ggml_new_tensor_2d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_embd, ne1: n_enc);
1286 ggml_set_input(tensor: cur);
1287
1288 res->add_input(input: std::move(inp));
1289
1290 return cur;
1291}
1292
1293ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1294 auto inp = std::make_unique<llm_graph_input_pos_bucket>(args: hparams);
1295
1296 auto & cur = inp->pos_bucket;
1297
1298 cur = ggml_new_tensor_2d(ctx: ctx0, type: GGML_TYPE_I32, ne0: n_tokens, ne1: n_tokens);
1299 ggml_set_input(tensor: cur);
1300
1301 res->add_input(input: std::move(inp));
1302
1303 return cur;
1304}
1305
1306ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1307 const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1308
1309 auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(args: hparams, args&: mctx_cur);
1310
1311 const auto n_kv = mctx_cur->get_n_kv();
1312
1313 auto & cur = inp->pos_bucket;
1314
1315 cur = ggml_new_tensor_2d(ctx: ctx0, type: GGML_TYPE_I32, ne0: n_kv, ne1: n_tokens);
1316 ggml_set_input(tensor: cur);
1317
1318 res->add_input(input: std::move(inp));
1319
1320 return cur;
1321}
1322
1323ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
1324 ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx: ctx0, a: pos_bucket, ne0: pos_bucket->ne[0] * pos_bucket->ne[1]);
1325 cb(cur: pos_bucket_1d, name: "pos_bucket_1d", il: -1);
1326
1327 ggml_tensor * pos_bias = ggml_get_rows(ctx: ctx0, a: attn_rel_b, b: pos_bucket_1d);
1328
1329 pos_bias = ggml_reshape_3d(ctx: ctx0, a: pos_bias, ne0: pos_bias->ne[0], ne1: pos_bucket->ne[0], ne2: pos_bucket->ne[1]);
1330 pos_bias = ggml_permute (ctx: ctx0, a: pos_bias, axis0: 2, axis1: 0, axis2: 1, axis3: 3);
1331 pos_bias = ggml_cont (ctx: ctx0, a: pos_bias);
1332
1333 cb(cur: pos_bias, name: "pos_bias", il: -1);
1334
1335 return pos_bias;
1336}
1337
1338ggml_tensor * llm_graph_context::build_attn_mha(
1339 ggml_tensor * q,
1340 ggml_tensor * k,
1341 ggml_tensor * v,
1342 ggml_tensor * kq_b,
1343 ggml_tensor * kq_mask,
1344 ggml_tensor * sinks,
1345 ggml_tensor * v_mla,
1346 float kq_scale,
1347 int il) const {
1348 const bool v_trans = v->nb[1] > v->nb[2];
1349
1350 // split the batch into streams if needed
1351 const auto n_stream = k->ne[3];
1352
1353 q = ggml_view_4d(ctx: ctx0, a: q, ne0: q->ne[0], ne1: q->ne[1], ne2: q->ne[2]/n_stream, ne3: n_stream, nb1: q->nb[1], nb2: q->nb[2], nb3: q->nb[3]/n_stream, offset: 0);
1354
1355 q = ggml_permute(ctx: ctx0, a: q, axis0: 0, axis1: 2, axis2: 1, axis3: 3);
1356 k = ggml_permute(ctx: ctx0, a: k, axis0: 0, axis1: 2, axis2: 1, axis3: 3);
1357 v = ggml_permute(ctx: ctx0, a: v, axis0: 0, axis1: 2, axis2: 1, axis3: 3);
1358
1359 ggml_tensor * cur;
1360
1361 if (cparams.flash_attn && kq_b == nullptr) {
1362 GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1363
1364 if (v_trans) {
1365 v = ggml_transpose(ctx: ctx0, a: v);
1366 }
1367
1368 // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
1369 if (k->type == GGML_TYPE_F32) {
1370 k = ggml_cast(ctx: ctx0, a: k, type: GGML_TYPE_F16);
1371 }
1372
1373 if (v->type == GGML_TYPE_F32) {
1374 v = ggml_cast(ctx: ctx0, a: v, type: GGML_TYPE_F16);
1375 }
1376
1377 cur = ggml_flash_attn_ext(ctx: ctx0, q, k, v, mask: kq_mask, scale: kq_scale, max_bias: hparams.f_max_alibi_bias,
1378 logit_softcap: hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1379 cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1380
1381 ggml_flash_attn_ext_add_sinks(a: cur, sinks);
1382 ggml_flash_attn_ext_set_prec (a: cur, prec: GGML_PREC_F32);
1383
1384 if (v_mla) {
1385#if 0
1386 // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1387 // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1388 cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1389 cur = ggml_mul_mat(ctx0, v_mla, cur);
1390#else
1391 // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1392 // The permutations are noops and only change how the tensor data is interpreted.
1393 cur = ggml_permute(ctx: ctx0, a: cur, axis0: 0, axis1: 2, axis2: 1, axis3: 3);
1394 cur = ggml_mul_mat(ctx: ctx0, a: v_mla, b: cur);
1395 cb(cur, name: "fattn_mla", il);
1396 cur = ggml_permute(ctx: ctx0, a: cur, axis0: 0, axis1: 2, axis2: 1, axis3: 3);
1397 cur = ggml_cont(ctx: ctx0, a: cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1398#endif
1399 }
1400
1401 cur = ggml_reshape_2d(ctx: ctx0, a: cur, ne0: cur->ne[0]*cur->ne[1], ne1: cur->ne[2]*cur->ne[3]);
1402 } else {
1403 ggml_tensor * kq = ggml_mul_mat(ctx: ctx0, a: k, b: q);
1404 cb(cur: kq, name: "kq", il);
1405
1406 // note: this op tends to require high floating point range
1407 // while for some models F16 is enough, for others it is not, so we default to F32 here
1408 ggml_mul_mat_set_prec(a: kq, prec: GGML_PREC_F32);
1409
1410 if (arch == LLM_ARCH_GROK) {
1411 // need to do the following:
1412 // multiply by attn_output_multiplier
1413 // and then :
1414 // kq = 30 * tanh(kq / 30)
1415 // before the softmax below
1416
1417 kq = ggml_tanh(ctx: ctx0, a: ggml_scale(ctx: ctx0, a: kq, s: hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1418 cb(cur: kq, name: "kq_tanh", il);
1419 kq = ggml_scale(ctx: ctx0, a: kq, s: hparams.f_attn_logit_softcapping);
1420 cb(cur: kq, name: "kq_scaled", il);
1421 }
1422
1423 if (hparams.attn_soft_cap) {
1424 kq = ggml_scale(ctx: ctx0, a: kq, s: 1.0f / hparams.f_attn_logit_softcapping);
1425 cb(cur: kq, name: "kq_scaled_1", il);
1426 kq = ggml_tanh (ctx: ctx0, a: kq);
1427 cb(cur: kq, name: "kq_tanh", il);
1428 kq = ggml_scale(ctx: ctx0, a: kq, s: hparams.f_attn_logit_softcapping);
1429 cb(cur: kq, name: "kq_scaled_2", il);
1430 }
1431
1432 if (kq_b) {
1433 kq = ggml_add(ctx: ctx0, a: kq, b: kq_b);
1434 cb(cur: kq, name: "kq_plus_kq_b", il);
1435 }
1436
1437 kq = ggml_soft_max_ext(ctx: ctx0, a: kq, mask: kq_mask, scale: kq_scale, max_bias: hparams.f_max_alibi_bias);
1438 ggml_soft_max_add_sinks(a: kq, sinks);
1439 cb(cur: kq, name: "kq_soft_max", il);
1440
1441 if (!v_trans) {
1442 // note: avoid this branch
1443 v = ggml_cont(ctx: ctx0, a: ggml_transpose(ctx: ctx0, a: v));
1444 cb(cur: v, name: "v_cont", il);
1445 }
1446
1447 ggml_tensor * kqv = ggml_mul_mat(ctx: ctx0, a: v, b: kq);
1448 cb(cur: kqv, name: "kqv", il);
1449
1450 // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1451 if (v_mla) {
1452 kqv = ggml_mul_mat(ctx: ctx0, a: v_mla, b: kqv);
1453 cb(cur: kqv, name: "kqv_mla", il);
1454 }
1455
1456 cur = ggml_permute(ctx: ctx0, a: kqv, axis0: 0, axis1: 2, axis2: 1, axis3: 3);
1457
1458 // recombine streams
1459 cur = ggml_cont_2d(ctx: ctx0, a: cur, ne0: cur->ne[0]*cur->ne[1], ne1: cur->ne[2]*cur->ne[3]);
1460
1461 if (!cparams.offload_kqv) {
1462 // all nodes between the KV store and the attention output are run on the CPU
1463 ggml_backend_sched_set_tensor_backend(sched, node: cur, backend: backend_cpu);
1464 }
1465 }
1466
1467 ggml_build_forward_expand(cgraph: gf, tensor: cur);
1468
1469 return cur;
1470}
1471
1472llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
1473 auto inp = std::make_unique<llm_graph_input_attn_no_cache>(args: hparams, args: cparams);
1474
1475 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1476 inp->self_kq_mask = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), ne2: 1, ne3: 1);
1477 ggml_set_input(tensor: inp->self_kq_mask);
1478
1479 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx: ctx0, a: inp->self_kq_mask, type: GGML_TYPE_F16) : inp->self_kq_mask;
1480
1481 if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1482 inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), ne2: 1, ne3: 1);
1483 ggml_set_input(tensor: inp->self_kq_mask_swa);
1484
1485 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx: ctx0, a: inp->self_kq_mask_swa, type: GGML_TYPE_F16) : inp->self_kq_mask_swa;
1486 } else {
1487 inp->self_kq_mask_swa = nullptr;
1488 inp->self_kq_mask_swa_cnv = nullptr;
1489 }
1490
1491 return (llm_graph_input_attn_no_cache *) res->add_input(input: std::move(inp));
1492}
1493
1494ggml_tensor * llm_graph_context::build_attn(
1495 llm_graph_input_attn_no_cache * inp,
1496 ggml_tensor * wo,
1497 ggml_tensor * wo_b,
1498 ggml_tensor * q_cur,
1499 ggml_tensor * k_cur,
1500 ggml_tensor * v_cur,
1501 ggml_tensor * kq_b,
1502 ggml_tensor * sinks,
1503 ggml_tensor * v_mla,
1504 float kq_scale,
1505 int il) const {
1506 GGML_UNUSED(n_tokens);
1507
1508 // these nodes are added to the graph together so that they are not reordered
1509 // by doing so, the number of splits in the graph is reduced
1510 ggml_build_forward_expand(cgraph: gf, tensor: q_cur);
1511 ggml_build_forward_expand(cgraph: gf, tensor: k_cur);
1512 ggml_build_forward_expand(cgraph: gf, tensor: v_cur);
1513
1514 const bool is_swa = hparams.is_swa(il);
1515
1516 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1517
1518 // [TAG_NO_CACHE_PAD]
1519 // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1520 // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1521 //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1522
1523 ggml_tensor * q = q_cur;
1524 ggml_tensor * k = k_cur;
1525 ggml_tensor * v = v_cur;
1526
1527 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1528 cb(cur, name: "kqv_out", il);
1529
1530 if (wo) {
1531 cur = build_lora_mm(w: wo, cur);
1532 }
1533
1534 if (wo_b) {
1535 //cb(cur, "kqv_wo", il);
1536 }
1537
1538 if (wo_b) {
1539 cur = ggml_add(ctx: ctx0, a: cur, b: wo_b);
1540 }
1541
1542 return cur;
1543}
1544
1545static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1546 ggml_context * ctx0,
1547 const llama_ubatch & ubatch,
1548 const llama_hparams & hparams,
1549 const llama_cparams & cparams,
1550 const llama_kv_cache_context * mctx_cur) {
1551
1552 auto inp = std::make_unique<llm_graph_input_attn_kv>(args: hparams, args: cparams, args&: mctx_cur);
1553
1554 {
1555 GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1556
1557 const auto n_kv = mctx_cur->get_n_kv();
1558 const auto n_tokens = ubatch.n_tokens;
1559 const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1560
1561 inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx: ctx0, ubatch);
1562 inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx: ctx0, ubatch);
1563
1564 inp->self_kq_mask = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), ne2: 1, ne3: n_stream);
1565 ggml_set_input(tensor: inp->self_kq_mask);
1566
1567 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx: ctx0, a: inp->self_kq_mask, type: GGML_TYPE_F16) : inp->self_kq_mask;
1568 }
1569
1570 return inp;
1571}
1572
1573llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1574 const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1575
1576 auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1577
1578 return (llm_graph_input_attn_kv *) res->add_input(input: std::move(inp));
1579}
1580
1581ggml_tensor * llm_graph_context::build_attn(
1582 llm_graph_input_attn_kv * inp,
1583 ggml_tensor * wo,
1584 ggml_tensor * wo_b,
1585 ggml_tensor * q_cur,
1586 ggml_tensor * k_cur,
1587 ggml_tensor * v_cur,
1588 ggml_tensor * kq_b,
1589 ggml_tensor * sinks,
1590 ggml_tensor * v_mla,
1591 float kq_scale,
1592 int il) const {
1593 // these nodes are added to the graph together so that they are not reordered
1594 // by doing so, the number of splits in the graph is reduced
1595 ggml_build_forward_expand(cgraph: gf, tensor: q_cur);
1596 ggml_build_forward_expand(cgraph: gf, tensor: k_cur);
1597 ggml_build_forward_expand(cgraph: gf, tensor: v_cur);
1598
1599 const auto * mctx_cur = inp->mctx;
1600
1601 // store to KV cache
1602 {
1603 const auto & k_idxs = inp->get_k_idxs();
1604 const auto & v_idxs = inp->get_v_idxs();
1605
1606 ggml_build_forward_expand(cgraph: gf, tensor: mctx_cur->cpy_k(ctx: ctx0, k_cur, k_idxs, il));
1607 ggml_build_forward_expand(cgraph: gf, tensor: mctx_cur->cpy_v(ctx: ctx0, v_cur, v_idxs, il));
1608 }
1609
1610 const auto & kq_mask = inp->get_kq_mask();
1611
1612 ggml_tensor * q = q_cur;
1613 ggml_tensor * k = mctx_cur->get_k(ctx: ctx0, il);
1614 ggml_tensor * v = mctx_cur->get_v(ctx: ctx0, il);
1615
1616 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1617 cb(cur, name: "kqv_out", il);
1618
1619 if (wo) {
1620 cur = build_lora_mm(w: wo, cur);
1621 if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1622 // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1623 ggml_mul_mat_set_prec(a: cur, prec: GGML_PREC_F32);
1624 }
1625 }
1626
1627 if (wo_b) {
1628 cur = ggml_add(ctx: ctx0, a: cur, b: wo_b);
1629 }
1630
1631 return cur;
1632}
1633
1634ggml_tensor * llm_graph_context::build_attn(
1635 llm_graph_input_attn_kv_iswa * inp,
1636 ggml_tensor * wo,
1637 ggml_tensor * wo_b,
1638 ggml_tensor * q_cur,
1639 ggml_tensor * k_cur,
1640 ggml_tensor * v_cur,
1641 ggml_tensor * kq_b,
1642 ggml_tensor * sinks,
1643 ggml_tensor * v_mla,
1644 float kq_scale,
1645 int il) const {
1646 // these nodes are added to the graph together so that they are not reordered
1647 // by doing so, the number of splits in the graph is reduced
1648 ggml_build_forward_expand(cgraph: gf, tensor: q_cur);
1649
1650 if (k_cur) {
1651 ggml_build_forward_expand(cgraph: gf, tensor: k_cur);
1652 }
1653
1654 if (v_cur) {
1655 ggml_build_forward_expand(cgraph: gf, tensor: v_cur);
1656 }
1657
1658 const auto * mctx_iswa = inp->mctx;
1659
1660 const bool is_swa = hparams.is_swa(il);
1661
1662 const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1663
1664 // optionally store to KV cache
1665 if (k_cur) {
1666 const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1667
1668 ggml_build_forward_expand(cgraph: gf, tensor: mctx_cur->cpy_k(ctx: ctx0, k_cur, k_idxs, il));
1669 }
1670
1671 if (v_cur) {
1672 const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1673
1674 ggml_build_forward_expand(cgraph: gf, tensor: mctx_cur->cpy_v(ctx: ctx0, v_cur, v_idxs, il));
1675 }
1676
1677 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1678
1679 ggml_tensor * q = q_cur;
1680 ggml_tensor * k = mctx_cur->get_k(ctx: ctx0, il);
1681 ggml_tensor * v = mctx_cur->get_v(ctx: ctx0, il);
1682
1683 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1684 cb(cur, name: "kqv_out", il);
1685
1686 if (wo) {
1687 cur = build_lora_mm(w: wo, cur);
1688 }
1689
1690 if (wo_b) {
1691 //cb(cur, "kqv_wo", il);
1692 }
1693
1694 if (wo_b) {
1695 cur = ggml_add(ctx: ctx0, a: cur, b: wo_b);
1696 }
1697
1698 return cur;
1699}
1700
1701llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1702 auto inp = std::make_unique<llm_graph_input_attn_cross>(args: cross);
1703
1704 const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1705
1706 inp->cross_kq_mask = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), ne2: 1, ne3: 1);
1707 ggml_set_input(tensor: inp->cross_kq_mask);
1708
1709 inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx: ctx0, a: inp->cross_kq_mask, type: GGML_TYPE_F16) : inp->cross_kq_mask;
1710
1711 return (llm_graph_input_attn_cross *) res->add_input(input: std::move(inp));
1712}
1713
1714ggml_tensor * llm_graph_context::build_attn(
1715 llm_graph_input_attn_cross * inp,
1716 ggml_tensor * wo,
1717 ggml_tensor * wo_b,
1718 ggml_tensor * q_cur,
1719 ggml_tensor * k_cur,
1720 ggml_tensor * v_cur,
1721 ggml_tensor * kq_b,
1722 ggml_tensor * sinks,
1723 ggml_tensor * v_mla,
1724 float kq_scale,
1725 int il) const {
1726 // these nodes are added to the graph together so that they are not reordered
1727 // by doing so, the number of splits in the graph is reduced
1728 ggml_build_forward_expand(cgraph: gf, tensor: q_cur);
1729 ggml_build_forward_expand(cgraph: gf, tensor: k_cur);
1730 ggml_build_forward_expand(cgraph: gf, tensor: v_cur);
1731
1732 const auto & kq_mask = inp->get_kq_mask_cross();
1733
1734 ggml_tensor * q = q_cur;
1735 ggml_tensor * k = k_cur;
1736 ggml_tensor * v = v_cur;
1737
1738 ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1739 cb(cur, name: "kqv_out", il);
1740
1741 if (wo) {
1742 cur = build_lora_mm(w: wo, cur);
1743 }
1744
1745 if (wo_b) {
1746 //cb(cur, "kqv_wo", il);
1747 }
1748
1749 if (wo_b) {
1750 cur = ggml_add(ctx: ctx0, a: cur, b: wo_b);
1751 }
1752
1753 return cur;
1754}
1755
1756// TODO: maybe separate the inner implementation into a separate function
1757// like with the non-sliding window equivalent
1758// once sliding-window hybrid caches are a thing.
1759llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
1760 const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
1761
1762 auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(args: hparams, args: cparams, args&: mctx_cur);
1763
1764 const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1765
1766 {
1767 const auto n_kv = mctx_cur->get_base()->get_n_kv();
1768
1769 inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx: ctx0, ubatch);
1770 inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx: ctx0, ubatch);
1771
1772 inp->self_kq_mask = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), ne2: 1, ne3: n_stream);
1773 ggml_set_input(tensor: inp->self_kq_mask);
1774
1775 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx: ctx0, a: inp->self_kq_mask, type: GGML_TYPE_F16) : inp->self_kq_mask;
1776 }
1777
1778 {
1779 GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1780
1781 const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1782
1783 inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx: ctx0, ubatch);
1784 inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx: ctx0, ubatch);
1785
1786 inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx: ctx0, type: GGML_TYPE_F32, ne0: n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), ne2: 1, ne3: n_stream);
1787 ggml_set_input(tensor: inp->self_kq_mask_swa);
1788
1789 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx: ctx0, a: inp->self_kq_mask_swa, type: GGML_TYPE_F16) : inp->self_kq_mask_swa;
1790 }
1791
1792 return (llm_graph_input_attn_kv_iswa *) res->add_input(input: std::move(inp));
1793}
1794
1795ggml_tensor * llm_graph_context::build_rs(
1796 ggml_tensor * s,
1797 ggml_tensor * state_copy_main,
1798 ggml_tensor * state_copy_extra,
1799 int32_t state_size,
1800 int32_t n_seqs,
1801 uint32_t n_rs,
1802 uint32_t rs_head,
1803 uint32_t rs_size,
1804 int32_t rs_zero,
1805 const llm_graph_get_rows_fn & get_state_rows) const {
1806
1807 ggml_tensor * states = ggml_reshape_2d(ctx: ctx0, a: s, ne0: state_size, ne1: rs_size);
1808
1809 // Clear a single state which will then be copied to the other cleared states.
1810 // Note that this is a no-op when the view is zero-sized.
1811 ggml_tensor * state_zero = ggml_view_1d(ctx: ctx0, a: states, ne0: state_size*(rs_zero >= 0), offset: rs_zero*states->nb[1]*(rs_zero >= 0));
1812 ggml_build_forward_expand(cgraph: gf, tensor: ggml_scale_inplace(ctx: ctx0, a: state_zero, s: 0));
1813
1814 // copy states
1815 // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1816 // {state_size, rs_size} -> {state_size, n_seqs}
1817 ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
1818 ggml_build_forward_expand(cgraph: gf, tensor: output_states);
1819
1820 // copy extra states which won't be changed further (between n_seqs and n_rs)
1821 ggml_tensor * states_extra = ggml_get_rows(ctx: ctx0, a: states, b: state_copy_extra);
1822 ggml_build_forward_expand(cgraph: gf,
1823 tensor: ggml_cpy(ctx: ctx0,
1824 a: states_extra,
1825 b: ggml_view_1d(ctx: ctx0, a: s, ne0: state_size*(n_rs - n_seqs), offset: (rs_head + n_seqs)*state_size*ggml_element_size(tensor: s))));
1826
1827 return output_states;
1828}
1829
1830static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1831 ggml_context * ctx0,
1832 const llama_ubatch & ubatch,
1833 const llama_memory_recurrent_context * mctx_cur) {
1834
1835 auto inp = std::make_unique<llm_graph_input_rs>(args&: mctx_cur);
1836
1837 const int64_t n_rs = mctx_cur->get_n_rs();
1838 const int64_t n_seqs = ubatch.n_seqs;
1839
1840 inp->s_copy = ggml_new_tensor_1d(ctx: ctx0, type: GGML_TYPE_I32, ne0: n_rs);
1841 ggml_set_input(tensor: inp->s_copy);
1842
1843 inp->s_copy_main = ggml_view_1d(ctx: ctx0, a: inp->s_copy, ne0: n_seqs, offset: 0);
1844 inp->s_copy_extra = ggml_view_1d(ctx: ctx0, a: inp->s_copy, ne0: n_rs - n_seqs, offset: n_seqs * inp->s_copy->nb[0]);
1845
1846 return inp;
1847}
1848
1849llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1850 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1851
1852 auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
1853
1854 return (llm_graph_input_rs *) res->add_input(input: std::move(inp));
1855}
1856
1857ggml_tensor * llm_graph_context::build_rs(
1858 llm_graph_input_rs * inp,
1859 ggml_tensor * s,
1860 int32_t state_size,
1861 int32_t n_seqs,
1862 const llm_graph_get_rows_fn & get_state_rows) const {
1863 const auto * kv_state = inp->mctx;
1864
1865 return build_rs(s, state_copy_main: inp->s_copy_main, state_copy_extra: inp->s_copy_extra, state_size, n_seqs,
1866 n_rs: kv_state->get_n_rs(), rs_head: kv_state->get_head(), rs_size: kv_state->get_size(), rs_zero: kv_state->get_rs_z(),
1867 get_state_rows);
1868}
1869
1870ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1871 llm_graph_input_rs * inp,
1872 const llama_ubatch & ubatch,
1873 int il) const {
1874 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1875
1876 const auto token_shift_count = hparams.token_shift_count;
1877
1878 const int64_t n_seqs = ubatch.n_seqs;
1879
1880 ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1881
1882 ggml_tensor * token_shift = build_rs(
1883 inp, s: token_shift_all,
1884 state_size: hparams.n_embd_r(), n_seqs);
1885
1886 token_shift = ggml_reshape_3d(ctx: ctx0, a: token_shift, ne0: hparams.n_embd, ne1: token_shift_count, ne2: n_seqs);
1887
1888 return token_shift;
1889}
1890
1891ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1892 ggml_tensor * token_shift,
1893 const llama_ubatch & ubatch,
1894 int il) const {
1895 const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1896
1897 const auto token_shift_count = hparams.token_shift_count;
1898 const auto n_embd = hparams.n_embd;
1899
1900 const int64_t n_seqs = ubatch.n_seqs;
1901
1902 const auto kv_head = mctx_cur->get_head();
1903
1904 return ggml_cpy(
1905 ctx: ctx0,
1906 a: ggml_view_1d(ctx: ctx0, a: token_shift, ne0: n_embd * n_seqs * token_shift_count, offset: 0),
1907 b: ggml_view_1d(ctx: ctx0, a: mctx_cur->get_r_l(il), ne0: hparams.n_embd_r()*n_seqs, offset: hparams.n_embd_r()*kv_head*ggml_element_size(tensor: mctx_cur->get_r_l(il)))
1908 );
1909}
1910
1911llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1912 const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1913
1914 auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur: mctx_cur->get_recr());
1915 auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur: mctx_cur->get_attn());
1916
1917 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(args: std::move(inp_attn), args: std::move(inp_rs), args&: mctx_cur);
1918
1919 return (llm_graph_input_mem_hybrid *) res->add_input(input: std::move(inp));
1920}
1921
1922void llm_graph_context::build_dense_out(
1923 ggml_tensor * dense_2,
1924 ggml_tensor * dense_3) const {
1925 if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
1926 return;
1927 }
1928 ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
1929 GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
1930
1931 cur = ggml_mul_mat(ctx: ctx0, a: dense_2, b: cur);
1932 cur = ggml_mul_mat(ctx: ctx0, a: dense_3, b: cur);
1933 cb(cur, name: "result_embd_pooled", il: -1);
1934 res->t_embd_pooled = cur;
1935 ggml_build_forward_expand(cgraph: gf, tensor: cur);
1936}
1937
1938
1939void llm_graph_context::build_pooling(
1940 ggml_tensor * cls,
1941 ggml_tensor * cls_b,
1942 ggml_tensor * cls_out,
1943 ggml_tensor * cls_out_b) const {
1944 if (!cparams.embeddings) {
1945 return;
1946 }
1947
1948 ggml_tensor * inp = res->t_embd;
1949
1950 //// find result_norm tensor for input
1951 //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
1952 // inp = ggml_graph_node(gf, i);
1953 // if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
1954 // break;
1955 // }
1956
1957 // inp = nullptr;
1958 //}
1959
1960 GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
1961
1962 ggml_tensor * cur;
1963
1964 switch (pooling_type) {
1965 case LLAMA_POOLING_TYPE_NONE:
1966 {
1967 cur = inp;
1968 } break;
1969 case LLAMA_POOLING_TYPE_MEAN:
1970 {
1971 ggml_tensor * inp_mean = build_inp_mean();
1972 cur = ggml_mul_mat(ctx: ctx0, a: ggml_cont(ctx: ctx0, a: ggml_transpose(ctx: ctx0, a: inp)), b: inp_mean);
1973 } break;
1974 case LLAMA_POOLING_TYPE_CLS:
1975 case LLAMA_POOLING_TYPE_LAST:
1976 {
1977 ggml_tensor * inp_cls = build_inp_cls();
1978 cur = ggml_get_rows(ctx: ctx0, a: inp, b: inp_cls);
1979 } break;
1980 case LLAMA_POOLING_TYPE_RANK:
1981 {
1982 ggml_tensor * inp_cls = build_inp_cls();
1983 cur = ggml_get_rows(ctx: ctx0, a: inp, b: inp_cls);
1984
1985 // classification head
1986 // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1987 if (cls) {
1988 cur = ggml_mul_mat(ctx: ctx0, a: cls, b: cur);
1989 if (cls_b) {
1990 cur = ggml_add(ctx: ctx0, a: cur, b: cls_b);
1991 }
1992 cur = ggml_tanh(ctx: ctx0, a: cur);
1993 }
1994
1995 // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1996 // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1997 // Single layer classification head (direct projection)
1998 // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1999 if (cls_out) {
2000 cur = ggml_mul_mat(ctx: ctx0, a: cls_out, b: cur);
2001 if (cls_out_b) {
2002 cur = ggml_add(ctx: ctx0, a: cur, b: cls_out_b);
2003 }
2004 }
2005
2006 // softmax for qwen3 reranker
2007 if (arch == LLM_ARCH_QWEN3) {
2008 cur = ggml_soft_max(ctx: ctx0, a: cur);
2009 }
2010 } break;
2011 default:
2012 {
2013 GGML_ABORT("unknown pooling type");
2014 }
2015 }
2016
2017 cb(cur, name: "result_embd_pooled", il: -1);
2018 res->t_embd_pooled = cur;
2019
2020 ggml_build_forward_expand(cgraph: gf, tensor: cur);
2021}
2022
2023int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2024 // TODO move to hparams if a T5 variant appears that uses a different value
2025 const int64_t max_distance = 128;
2026
2027 if (bidirectional) {
2028 n_buckets >>= 1;
2029 }
2030
2031 const int64_t max_exact = n_buckets >> 1;
2032
2033 int32_t relative_position = x - y;
2034 int32_t relative_bucket = 0;
2035
2036 if (bidirectional) {
2037 relative_bucket += (relative_position > 0) * n_buckets;
2038 relative_position = std::abs(x: relative_position);
2039 } else {
2040 relative_position = -std::min<int32_t>(a: relative_position, b: 0);
2041 }
2042
2043 int32_t relative_position_if_large = floorf(x: max_exact + logf(x: 1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(x: 1.0 * max_distance / max_exact));
2044 relative_position_if_large = std::min<int32_t>(a: relative_position_if_large, b: n_buckets - 1);
2045 relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2046
2047 return relative_bucket;
2048}
2049