1#pragma once
2
3#include "llama-arch.h"
4#include "llama-batch.h"
5#include "llama-hparams.h"
6#include "llama-adapter.h"
7
8#include <cstdint>
9#include <vector>
10#include <memory>
11#include <set>
12#include <functional>
13
14struct ggml_cgraph;
15struct ggml_context;
16struct ggml_tensor;
17
18struct llama_cparams;
19
20struct llama_memory_context_i;
21
22class llama_kv_cache_context;
23class llama_kv_cache_iswa_context;
24class llama_memory_recurrent_context;
25class llama_memory_hybrid_context;
26
27// certain models (typically multi-modal) can produce different types of graphs
28enum llm_graph_type {
29 LLM_GRAPH_TYPE_DEFAULT,
30 LLM_GRAPH_TYPE_ENCODER,
31 LLM_GRAPH_TYPE_DECODER,
32};
33
34enum llm_ffn_op_type {
35 LLM_FFN_SILU,
36 LLM_FFN_GELU,
37 LLM_FFN_RELU,
38 LLM_FFN_RELU_SQR,
39 LLM_FFN_SWIGLU,
40 LLM_FFN_GEGLU,
41 LLM_FFN_REGLU,
42 LLM_FFN_SWIGLU_OAI_MOE,
43};
44
45enum llm_ffn_gate_type {
46 LLM_FFN_SEQ,
47 LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
48};
49
50enum llm_norm_type {
51 LLM_NORM,
52 LLM_NORM_RMS,
53 LLM_NORM_GROUP,
54};
55
56// TODO: tmp - need something better to pass the data from the encoder to the decoder
57struct llama_cross {
58 // the output embeddings from the encoder as a ggml tensor
59 // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
60 // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
61 //ggml_tensor * t_embd = nullptr;
62
63 int64_t n_embd = 0;
64 int64_t n_enc = 0;
65
66 // embeddings data copied to host memory (tmp)
67 std::vector<float> v_embd;
68
69 // needed to construct the cross-attention mask in the decoder
70 std::vector<std::set<llama_seq_id>> seq_ids_enc;
71};
72
73struct llm_graph_params;
74
75//
76// llm_graph_input
77//
78
79class llm_graph_input_i {
80public:
81 llm_graph_input_i() {
82 const char * LLAMA_GRAPH_INPUT_DEBUG = getenv(name: "LLAMA_GRAPH_INPUT_DEBUG");
83 debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(nptr: LLAMA_GRAPH_INPUT_DEBUG) : 0;
84 }
85
86 virtual ~llm_graph_input_i() = default;
87
88 virtual void set_input(const llama_ubatch * ubatch) = 0;
89
90 // return true if the resulting input tensors using the provided graph parameters would be
91 // the same as the previous input tensors that we have currently stored in the object
92 virtual bool can_reuse(const llm_graph_params & params) {
93 // returning false here by default will prevent from reusing the graph if the check
94 // for the input type has not been implemented yet
95 GGML_UNUSED(params);
96 return false;
97 }
98protected:
99 // env: LLAMA_GRAPH_INPUT_DEBUG
100 int debug = 0;
101};
102
103using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
104
105class llm_graph_input_embd : public llm_graph_input_i {
106public:
107 llm_graph_input_embd() = default;
108 virtual ~llm_graph_input_embd() = default;
109
110 void set_input(const llama_ubatch * ubatch) override;
111
112 bool can_reuse(const llm_graph_params & params) override;
113
114 ggml_tensor * tokens = nullptr; // I32 [n_batch]
115 ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
116};
117
118class llm_graph_input_pos : public llm_graph_input_i {
119public:
120 llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
121 virtual ~llm_graph_input_pos() = default;
122
123 void set_input(const llama_ubatch * ubatch) override;
124
125 bool can_reuse(const llm_graph_params & params) override;
126
127 ggml_tensor * pos = nullptr; // I32 [n_batch]
128
129 const uint32_t n_pos_per_embd = 1;
130};
131
132// temperature tuning, used by llama4
133class llm_graph_input_attn_temp : public llm_graph_input_i {
134public:
135 llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
136 : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
137 virtual ~llm_graph_input_attn_temp() = default;
138
139 void set_input(const llama_ubatch * ubatch) override;
140
141 ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
142
143 const uint32_t n_attn_temp_floor_scale;
144 const float f_attn_temp_scale;
145};
146
147class llm_graph_input_pos_bucket : public llm_graph_input_i {
148public:
149 llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
150 virtual ~llm_graph_input_pos_bucket() = default;
151
152 void set_input(const llama_ubatch * ubatch) override;
153
154 ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
155
156 const llama_hparams hparams;
157};
158
159class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
160public:
161 llm_graph_input_pos_bucket_kv(
162 const llama_hparams & hparams,
163 const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
164 virtual ~llm_graph_input_pos_bucket_kv() = default;
165
166 void set_input(const llama_ubatch * ubatch) override;
167
168 ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
169
170 const llama_hparams hparams;
171
172 const llama_kv_cache_context * mctx;
173};
174
175class llm_graph_input_out_ids : public llm_graph_input_i {
176public:
177 llm_graph_input_out_ids(
178 const llama_hparams & hparams,
179 const llama_cparams & cparams,
180 uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
181 virtual ~llm_graph_input_out_ids() = default;
182
183 void set_input(const llama_ubatch * ubatch) override;
184
185 bool can_reuse(const llm_graph_params & params) override;
186
187 ggml_tensor * out_ids; // I32 [n_outputs]
188
189 const llama_hparams hparams;
190 const llama_cparams cparams;
191
192 const uint32_t n_outputs;
193};
194
195class llm_graph_input_mean : public llm_graph_input_i {
196public:
197 llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
198 virtual ~llm_graph_input_mean() = default;
199
200 void set_input(const llama_ubatch * ubatch) override;
201
202 ggml_tensor * mean; // F32 [n_batch, n_batch]
203
204 const llama_cparams cparams;
205};
206
207class llm_graph_input_cls : public llm_graph_input_i {
208public:
209 llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
210 virtual ~llm_graph_input_cls() = default;
211
212 void set_input(const llama_ubatch * ubatch) override;
213
214 ggml_tensor * cls; // I32 [n_batch]
215
216 const llama_cparams cparams;
217 const llm_arch arch;
218};
219
220class llm_graph_input_rs : public llm_graph_input_i {
221public:
222 llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
223 virtual ~llm_graph_input_rs() = default;
224
225 void set_input(const llama_ubatch * ubatch) override;
226
227 ggml_tensor * s_copy; // I32 [n_rs]
228
229 // views of s_copy, computed once per graph
230 // and shared across layers which use build_rs
231 ggml_tensor * s_copy_main; // I32 [n_seqs]
232 ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
233
234 const llama_memory_recurrent_context * mctx;
235};
236
237class llm_graph_input_cross_embd : public llm_graph_input_i {
238public:
239 llm_graph_input_cross_embd(
240 const llama_cross * cross) : cross(cross) {}
241 virtual ~llm_graph_input_cross_embd() = default;
242
243 void set_input(const llama_ubatch * ubatch) override;
244
245 ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
246
247 const llama_cross * cross;
248};
249
250class llm_graph_input_attn_no_cache : public llm_graph_input_i {
251public:
252 llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
253 hparams(hparams),
254 cparams(cparams) {
255 }
256 ~llm_graph_input_attn_no_cache() = default;
257
258 void set_input(const llama_ubatch * ubatch) override;
259
260 ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
261 ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
262
263 // n_tokens == n_batch
264 ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
265 ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
266 ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
267 ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
268
269 const llama_hparams hparams;
270 const llama_cparams cparams;
271};
272
273class llm_graph_input_attn_kv : public llm_graph_input_i {
274public:
275 llm_graph_input_attn_kv(
276 const llama_hparams & hparams,
277 const llama_cparams & cparams,
278 const llama_kv_cache_context * mctx) :
279 hparams(hparams),
280 cparams(cparams),
281 mctx(mctx) {
282 }
283 ~llm_graph_input_attn_kv() = default;
284
285 void set_input(const llama_ubatch * ubatch) override;
286
287 bool can_reuse(const llm_graph_params & params) override;
288
289 ggml_tensor * get_k_idxs() const { return self_k_idxs; }
290 ggml_tensor * get_v_idxs() const { return self_v_idxs; }
291
292 ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
293
294 ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
295 ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
296
297 ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
298 ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
299
300 // note: these have to be copies because in order to be able to reuse a graph, its inputs
301 // need to carry these parameters with them. otherwise, they can point to freed
302 // llm_graph_params from a previous batch, causing stack-use-after-return
303 const llama_hparams hparams;
304 const llama_cparams cparams;
305
306 const llama_kv_cache_context * mctx;
307};
308
309class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
310public:
311 llm_graph_input_attn_kv_iswa(
312 const llama_hparams & hparams,
313 const llama_cparams & cparams,
314 const llama_kv_cache_iswa_context * mctx) :
315 hparams(hparams),
316 cparams(cparams),
317 mctx(mctx) {
318 }
319 ~llm_graph_input_attn_kv_iswa() = default;
320
321 void set_input(const llama_ubatch * ubatch) override;
322
323 bool can_reuse(const llm_graph_params & params) override;
324
325 ggml_tensor * get_k_idxs() const { return self_k_idxs; }
326 ggml_tensor * get_v_idxs() const { return self_v_idxs; }
327 ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
328 ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
329
330 ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
331 ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
332
333 ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
334 ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
335 ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
336 ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
337
338 ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
339 ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
340 ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
341 ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
342
343 const llama_hparams hparams;
344 const llama_cparams cparams;
345
346 const llama_kv_cache_iswa_context * mctx;
347};
348
349class llm_graph_input_attn_cross : public llm_graph_input_i {
350public:
351 llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
352 ~llm_graph_input_attn_cross() = default;
353
354 void set_input(const llama_ubatch * ubatch) override;
355
356 ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
357
358 ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
359 ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
360
361 const llama_cross * cross = nullptr;
362};
363
364class llm_graph_input_mem_hybrid : public llm_graph_input_i {
365public:
366 llm_graph_input_mem_hybrid(
367 std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
368 std::unique_ptr<llm_graph_input_rs> inp_rs,
369 const llama_memory_hybrid_context * mctx) :
370 inp_attn(std::move(inp_attn)),
371 inp_rs(std::move(inp_rs)),
372 mctx(mctx) { }
373 virtual ~llm_graph_input_mem_hybrid() = default;
374
375 void set_input(const llama_ubatch * ubatch) override;
376
377 std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
378 std::unique_ptr<llm_graph_input_rs> inp_rs;
379
380 llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
381 llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
382
383 const llama_memory_hybrid_context * mctx;
384};
385
386//
387// llm_graph_result
388//
389
390// these objects deliver the result from the graph build process back to the llama_context
391// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
392// specific data, by calling the set_inputs() method
393// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
394// these are used by the llama_context to extact the relevant data, based on the compute parameters
395
396// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
397using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
398
399class llm_graph_result;
400
401struct llm_graph_params {
402 llm_arch arch = LLM_ARCH_UNKNOWN;
403
404 llama_hparams hparams;
405 llama_cparams cparams;
406
407 llama_ubatch ubatch; // note: intentionally make a copy
408
409 llm_graph_type gtype;
410
411 ggml_backend_sched_t sched;
412 ggml_backend_t backend_cpu;
413
414 const llama_adapter_cvec * cvec;
415 const llama_adapter_loras * loras;
416 const llama_memory_context_i * mctx;
417 const llama_cross * cross;
418
419 uint32_t n_outputs;
420
421 llm_graph_cb cb;
422
423 llm_graph_result * res;
424
425 // return true if the "other" params would result in a graph with the same topology as with the current params
426 // having the same topology allows us to reuse the graph in some cases
427 bool allow_reuse(const llm_graph_params & other) const {
428 // first check the ubatch
429 bool can_reuse_ubatch =
430 ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
431 ubatch.n_tokens == other.ubatch.n_tokens &&
432 ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
433 ubatch.n_seqs == other.ubatch.n_seqs &&
434 ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
435 (
436 (!ubatch.token && !other.ubatch.token) ||
437 (!ubatch.embd && !other.ubatch.embd)
438 );
439
440 // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
441 // the reason is because the set of attention streams would be different for different sequences
442 if (can_reuse_ubatch && ubatch.equal_seqs()) {
443 if (!ubatch.data) {
444 // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
445 // therefore we cannot perform the sequence id check. normally should never happen
446 can_reuse_ubatch = false;
447 } else {
448 for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
449 can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
450 }
451 }
452 }
453
454 if (!can_reuse_ubatch) {
455 return false;
456 }
457
458 return
459 cparams.embeddings == other.cparams.embeddings &&
460 cparams.causal_attn == other.cparams.causal_attn &&
461 arch == other.arch &&
462 gtype == other.gtype &&
463 cvec == other.cvec &&
464 loras == other.loras &&
465 cross == other.cross &&
466 n_outputs == other.n_outputs;
467 }
468};
469
470class llm_graph_result {
471public:
472 llm_graph_result(int64_t max_nodes);
473
474 virtual ~llm_graph_result() = default;
475
476 ggml_tensor * get_tokens() const { return t_tokens; }
477 ggml_tensor * get_logits() const { return t_logits; }
478 ggml_tensor * get_embd() const { return t_embd; }
479 ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
480
481 ggml_cgraph * get_gf() const { return gf; }
482 ggml_context * get_ctx() const { return ctx_compute.get(); }
483
484 int64_t get_max_nodes() const;
485
486 void reset();
487
488 void set_inputs(const llama_ubatch * ubatch);
489
490 // try to update the existing graph result using the new graph parameters in order to reuse it
491 // this can only be done if we determine that the resulting graph using the new graph parameters
492 // would be identical to the existing graph. in that case, we simply have to update the memory
493 // contexts of the input tensors of the graph and we can reuse it for another computation
494 // return true if the graph was updated and can be reused
495 bool can_reuse(const llm_graph_params & params);
496
497 llm_graph_input_i * add_input(llm_graph_input_ptr input);
498
499 void set_params(const llm_graph_params & params);
500
501 // important graph nodes
502 ggml_tensor * t_tokens = nullptr;
503 ggml_tensor * t_logits = nullptr;
504 ggml_tensor * t_embd = nullptr;
505 ggml_tensor * t_embd_pooled = nullptr;
506
507 std::vector<llm_graph_input_ptr> inputs;
508
509 ggml_context_ptr ctx_compute;
510
511 // memory buffers used to evaluate the model
512 std::vector<uint8_t> buf_compute_meta;
513
514 ggml_cgraph * gf;
515
516 int64_t max_nodes;
517
518private:
519 // keep a copy of the previous graph parameters
520 // we will use this to determine whether the graph can be reused by comparing them with the new parameters
521 // note: these are updated after constructing the new graph
522 llm_graph_params params;
523
524 // env: LLAMA_GRAPH_RESULT_DEBUG
525 int debug = 0;
526};
527
528using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
529
530//
531// llm_graph_context
532//
533
534// used in build_rs to properly order writes and avoid unnecessary copies
535using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
536
537struct llm_graph_context {
538 const llm_arch arch;
539
540 const llama_hparams & hparams;
541 const llama_cparams & cparams;
542 const llama_ubatch & ubatch;
543
544 const int64_t n_embd;
545 const int64_t n_layer;
546 const int64_t n_rot;
547 const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
548 const int64_t n_head;
549 const int64_t n_head_kv;
550 const int64_t n_embd_head_k;
551 const int64_t n_embd_k_gqa;
552 const int64_t n_embd_head_v;
553 const int64_t n_embd_v_gqa;
554 const int64_t n_expert;
555 const int64_t n_expert_used;
556
557 const float freq_base;
558 const float freq_scale;
559 const float ext_factor;
560 const float attn_factor;
561 const float beta_fast;
562 const float beta_slow;
563 const float norm_eps;
564 const float norm_rms_eps;
565
566 const int64_t n_tokens;
567 const int64_t n_outputs;
568 const int32_t n_ctx_orig; // yarn
569
570 const enum llama_pooling_type pooling_type;
571 const enum llama_rope_type rope_type;
572
573 ggml_backend_sched_t sched;
574
575 ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
576
577 const llama_adapter_cvec * cvec;
578 const llama_adapter_loras * loras;
579 const llama_memory_context_i * mctx;
580 const llama_cross * cross;
581
582 const llm_graph_cb & cb_func;
583
584 llm_graph_result * res;
585
586 ggml_context * ctx0 = nullptr;
587 ggml_cgraph * gf = nullptr;
588
589 llm_graph_context(const llm_graph_params & params);
590 virtual ~llm_graph_context() = default;
591
592 void cb(ggml_tensor * cur, const char * name, int il) const;
593
594 //
595 // common
596 //
597
598 ggml_tensor * build_cvec(
599 ggml_tensor * cur,
600 int il) const;
601
602 // do mat_mul, while optionally apply lora
603 ggml_tensor * build_lora_mm(
604 ggml_tensor * w,
605 ggml_tensor * cur) const;
606
607 // do mat_mul_id, while optionally apply lora
608 ggml_tensor * build_lora_mm_id(
609 ggml_tensor * w, // ggml_tensor * as
610 ggml_tensor * cur, // ggml_tensor * b
611 ggml_tensor * ids) const;
612
613 ggml_tensor * build_norm(
614 ggml_tensor * cur,
615 ggml_tensor * mw,
616 ggml_tensor * mb,
617 llm_norm_type type,
618 int il) const;
619
620 ggml_tensor * build_ffn(
621 ggml_tensor * cur,
622 ggml_tensor * up,
623 ggml_tensor * up_b,
624 ggml_tensor * up_s,
625 ggml_tensor * gate,
626 ggml_tensor * gate_b,
627 ggml_tensor * gate_s,
628 ggml_tensor * down,
629 ggml_tensor * down_b,
630 ggml_tensor * down_s,
631 ggml_tensor * act_scales,
632 llm_ffn_op_type type_op,
633 llm_ffn_gate_type type_gate,
634 int il) const;
635
636 // build MoE FFN without bias tensors
637 ggml_tensor * build_moe_ffn(
638 ggml_tensor * cur,
639 ggml_tensor * gate_inp,
640 ggml_tensor * up_exps,
641 ggml_tensor * gate_exps,
642 ggml_tensor * down_exps,
643 ggml_tensor * exp_probs_b,
644 int64_t n_expert,
645 int64_t n_expert_used,
646 llm_ffn_op_type type_op,
647 bool norm_w,
648 bool scale_w,
649 float w_scale,
650 llama_expert_gating_func_type gating_op,
651 int il,
652 ggml_tensor * probs_in = nullptr) const;
653
654 ggml_tensor * build_moe_ffn(
655 ggml_tensor * cur,
656 ggml_tensor * gate_inp,
657 ggml_tensor * gate_inp_b,
658 ggml_tensor * up_exps,
659 ggml_tensor * up_exps_b,
660 ggml_tensor * gate_exps,
661 ggml_tensor * gate_exps_b,
662 ggml_tensor * down_exps,
663 ggml_tensor * down_exps_b,
664 ggml_tensor * exp_probs_b,
665 int64_t n_expert,
666 int64_t n_expert_used,
667 llm_ffn_op_type type_op,
668 bool norm_w,
669 bool scale_w,
670 float w_scale,
671 llama_expert_gating_func_type gating_op,
672 int il,
673 ggml_tensor * probs_in = nullptr) const;
674
675 //
676 // inputs
677 //
678
679 ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
680 ggml_tensor * build_inp_pos() const;
681 ggml_tensor * build_inp_attn_scale() const;
682 ggml_tensor * build_inp_out_ids() const;
683 ggml_tensor * build_inp_mean() const;
684 ggml_tensor * build_inp_cls() const;
685
686 ggml_tensor * build_inp_cross_embd() const;
687 ggml_tensor * build_inp_pos_bucket_enc() const;
688 ggml_tensor * build_inp_pos_bucket_dec() const;
689 ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
690
691 //
692 // attention
693 //
694
695 ggml_tensor * build_attn_mha(
696 ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
697 ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
698 ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
699 ggml_tensor * kq_b,
700 ggml_tensor * kq_mask,
701 ggml_tensor * sinks, // [n_head_q]
702 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
703 float kq_scale,
704 int il) const;
705
706 llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
707
708 ggml_tensor * build_attn(
709 llm_graph_input_attn_no_cache * inp,
710 ggml_tensor * wo,
711 ggml_tensor * wo_b,
712 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
713 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
714 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
715 ggml_tensor * kq_b,
716 ggml_tensor * sinks, // [n_head_q]
717 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
718 float kq_scale,
719 int il) const;
720
721 llm_graph_input_attn_kv * build_attn_inp_kv() const;
722
723 ggml_tensor * build_attn(
724 llm_graph_input_attn_kv * inp,
725 ggml_tensor * wo,
726 ggml_tensor * wo_b,
727 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
728 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
729 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
730 ggml_tensor * kq_b,
731 ggml_tensor * sinks, // [n_head_q]
732 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
733 float kq_scale,
734 int il) const;
735
736 llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
737
738 // note: if k_cur or v_cur are not provided, they will not be stored in the memory
739 ggml_tensor * build_attn(
740 llm_graph_input_attn_kv_iswa * inp,
741 ggml_tensor * wo,
742 ggml_tensor * wo_b,
743 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
744 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
745 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
746 ggml_tensor * kq_b,
747 ggml_tensor * sinks, // [n_head_q]
748 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
749 float kq_scale,
750 int il) const;
751
752 llm_graph_input_attn_cross * build_attn_inp_cross() const;
753
754 ggml_tensor * build_attn(
755 llm_graph_input_attn_cross * inp,
756 ggml_tensor * wo,
757 ggml_tensor * wo_b,
758 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
759 ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
760 ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
761 ggml_tensor * kq_b,
762 ggml_tensor * sinks, // [n_head_q]
763 ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
764 float kq_scale,
765 int il) const;
766
767 //
768 // recurrent
769 //
770
771 // TODO: move this implementation to llama_memory_recurrent.
772 // this is analogous to llama_kv_cache::cpy_k / cpy_v
773 // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
774 // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
775 // `llama_memory_recurrent`
776 ggml_tensor * build_rs(
777 ggml_tensor * s,
778 ggml_tensor * state_copy_main,
779 ggml_tensor * state_copy_extra,
780 int32_t state_size,
781 int32_t n_seqs,
782 uint32_t n_rs,
783 uint32_t rs_head,
784 uint32_t rs_size,
785 int32_t rs_zero,
786 const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
787
788 llm_graph_input_rs * build_rs_inp() const;
789
790 ggml_tensor * build_rs(
791 llm_graph_input_rs * inp,
792 ggml_tensor * s,
793 int32_t state_size,
794 int32_t n_seqs,
795 const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
796
797 ggml_tensor * build_rwkv_token_shift_load(
798 llm_graph_input_rs * inp,
799 const llama_ubatch & ubatch,
800 int il) const;
801
802 ggml_tensor * build_rwkv_token_shift_store(
803 ggml_tensor * token_shift,
804 const llama_ubatch & ubatch,
805 int il) const;
806 //
807 // hybrid
808 //
809
810 llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
811
812 //
813 // pooling
814 //
815
816 void build_pooling(
817 ggml_tensor * cls,
818 ggml_tensor * cls_b,
819 ggml_tensor * cls_out,
820 ggml_tensor * cls_out_b) const;
821
822 //
823 // dense (out)
824 //
825
826 void build_dense_out(
827 ggml_tensor * dense_2,
828 ggml_tensor * dense_3) const;
829};
830
831// TODO: better name
832int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
833