1#pragma once
2
3#include "../llama-model.h"
4#include "../llama-graph.h"
5#include "../llama-memory-recurrent.h"
6
7#include <cmath>
8
9struct llm_graph_context_mamba : public llm_graph_context {
10 llm_graph_context_mamba(const llm_graph_params & params);
11
12 virtual ~llm_graph_context_mamba() = default;
13
14 ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
15 ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const;
16
17};
18
19// Base class for RWKV-related models
20struct llm_build_rwkv6_base : public llm_graph_context {
21 const llama_model & model;
22
23 llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params);
24
25 virtual ~llm_build_rwkv6_base() = default;
26
27 ggml_tensor * build_rwkv6_channel_mix(const llama_layer * layer,
28 ggml_tensor * cur,
29 ggml_tensor * x_prev,
30 llm_arch arch) const;
31
32 ggml_tensor * build_rwkv6_time_mix(llm_graph_input_rs * inp,
33 ggml_tensor * cur,
34 ggml_tensor * x_prev,
35 const llama_ubatch & ubatch,
36 int il) const;
37};
38
39// Base class for RWKV7-related models
40struct llm_build_rwkv7_base : public llm_graph_context {
41 const llama_model & model;
42
43 llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params);
44
45 virtual ~llm_build_rwkv7_base() = default;
46
47 // RWKV7-specific graph building methods
48 ggml_tensor * build_rwkv7_channel_mix(const llama_layer * layer,
49 ggml_tensor * cur,
50 ggml_tensor * x_prev,
51 llm_arch arch) const;
52 ggml_tensor * build_rwkv7_time_mix(llm_graph_input_rs * inp,
53 ggml_tensor * cur,
54 ggml_tensor * x_prev,
55 ggml_tensor *& first_layer_value,
56 const llama_ubatch & ubatch,
57 int il) const;
58};
59
60struct llm_build_apertus : public llm_graph_context {
61 llm_build_apertus(const llama_model & model, const llm_graph_params & params);
62};
63
64struct llm_build_arcee : public llm_graph_context {
65 llm_build_arcee(const llama_model & model, const llm_graph_params & params);
66};
67
68struct llm_build_arctic : public llm_graph_context {
69 llm_build_arctic(const llama_model & model, const llm_graph_params & params);
70};
71
72struct llm_build_arwkv7 : public llm_build_rwkv7_base {
73 llm_build_arwkv7(const llama_model & model, const llm_graph_params & params);
74};
75
76struct llm_build_baichuan : public llm_graph_context {
77 llm_build_baichuan(const llama_model & model, const llm_graph_params & params);
78};
79
80struct llm_build_bailingmoe2 : public llm_graph_context {
81 llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params);
82};
83
84struct llm_build_bailingmoe : public llm_graph_context {
85 llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params);
86};
87
88struct llm_build_bert : public llm_graph_context {
89 llm_build_bert(const llama_model & model, const llm_graph_params & params);
90};
91
92struct llm_build_bitnet : public llm_graph_context {
93 llm_build_bitnet(const llama_model & model, const llm_graph_params & params);
94};
95
96struct llm_build_bloom : public llm_graph_context {
97 llm_build_bloom(const llama_model & model, const llm_graph_params & params);
98};
99
100struct llm_build_chameleon : public llm_graph_context {
101 llm_build_chameleon(const llama_model & model, const llm_graph_params & params);
102};
103
104struct llm_build_chatglm : public llm_graph_context {
105 llm_build_chatglm(const llama_model & model, const llm_graph_params & params);
106};
107
108struct llm_build_codeshell : public llm_graph_context {
109 llm_build_codeshell(const llama_model & model, const llm_graph_params & params);
110};
111
112struct llm_build_cogvlm : public llm_graph_context {
113 llm_build_cogvlm(const llama_model & model, const llm_graph_params & params);
114};
115
116struct llm_build_cohere2_iswa : public llm_graph_context {
117 llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params);
118};
119
120struct llm_build_command_r : public llm_graph_context {
121 llm_build_command_r(const llama_model & model, const llm_graph_params & params);
122};
123
124struct llm_build_dbrx : public llm_graph_context {
125 llm_build_dbrx(const llama_model & model, const llm_graph_params & params);
126};
127
128struct llm_build_deci : public llm_graph_context {
129 llm_build_deci(const llama_model & model, const llm_graph_params & params);
130};
131
132struct llm_build_deepseek2 : public llm_graph_context {
133 llm_build_deepseek2(const llama_model & model, const llm_graph_params & params);
134};
135
136struct llm_build_deepseek : public llm_graph_context {
137 llm_build_deepseek(const llama_model & model, const llm_graph_params & params);
138};
139
140struct llm_build_dots1 : public llm_graph_context {
141 llm_build_dots1(const llama_model & model, const llm_graph_params & params);
142};
143
144struct llm_build_dream : public llm_graph_context {
145 llm_build_dream(const llama_model & model, const llm_graph_params & params);
146};
147
148struct llm_build_ernie4_5 : public llm_graph_context {
149 llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params);
150};
151
152struct llm_build_ernie4_5_moe : public llm_graph_context {
153 llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params);
154};
155
156template <bool iswa>
157struct llm_build_exaone4 : public llm_graph_context {
158 llm_build_exaone4(const llama_model & model, const llm_graph_params & params);
159};
160
161struct llm_build_exaone : public llm_graph_context {
162 llm_build_exaone(const llama_model & model, const llm_graph_params & params);
163};
164
165struct llm_build_falcon : public llm_graph_context {
166 llm_build_falcon(const llama_model & model, const llm_graph_params & params);
167};
168
169struct llm_build_falcon_h1 : public llm_graph_context_mamba {
170 llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params);
171};
172
173struct llm_build_gemma2_iswa : public llm_graph_context {
174 llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
175};
176
177struct llm_build_gemma3_iswa : public llm_graph_context {
178 llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params);
179};
180
181struct llm_build_gemma3n_iswa : public llm_graph_context {
182 const llama_model & model;
183
184 const int64_t n_embd_head;
185 const int64_t n_embd_altup;
186 const int64_t n_altup;
187 const int i_altup_act;
188 const int n_layer_sparsity = 10; // number of layers using activation sparsity
189 const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
190
191 llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params);
192 ggml_tensor * calc_magnitude(ggml_tensor * x);
193 ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
194 ggml_tensor * get_per_layer_inputs();
195 ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
196 ggml_tensor * gaussian_topk(ggml_tensor * x);
197 ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il);
198 ggml_tensor * altup_predict(ggml_tensor * cur, int il);
199 ggml_tensor * laurel(ggml_tensor * cur, int il);
200 ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il);
201};
202
203struct llm_build_gemma_embedding : public llm_graph_context {
204 llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params);
205};
206
207struct llm_build_gemma : public llm_graph_context {
208 llm_build_gemma(const llama_model & model, const llm_graph_params & params);
209};
210
211struct llm_build_glm4 : public llm_graph_context {
212 llm_build_glm4(const llama_model & model, const llm_graph_params & params);
213};
214
215struct llm_build_glm4_moe : public llm_graph_context {
216 llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
217};
218
219struct llm_build_gpt2 : public llm_graph_context {
220 llm_build_gpt2(const llama_model & model, const llm_graph_params & params);
221};
222
223struct llm_build_gptneox : public llm_graph_context {
224 llm_build_gptneox(const llama_model & model, const llm_graph_params & params);
225};
226
227struct llm_build_granite : public llm_graph_context {
228 llm_build_granite(const llama_model & model, const llm_graph_params & params);
229
230private:
231 ggml_tensor * build_attention_layer(
232 ggml_tensor * cur,
233 ggml_tensor * inp_pos,
234 llm_graph_input_attn_kv * inp_attn,
235 const llama_model & model,
236 const int64_t n_embd_head,
237 const int il);
238
239 ggml_tensor * build_layer_ffn(
240 ggml_tensor * cur,
241 ggml_tensor * inpSA,
242 const llama_model & model,
243 const int il);
244};
245
246struct llm_build_granite_hybrid : public llm_graph_context_mamba {
247 llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params);
248 ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il);
249 ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn,
250 const llama_model & model,const int64_t n_embd_head, const int il);
251};
252
253struct llm_build_grok : public llm_graph_context {
254 llm_build_grok(const llama_model & model, const llm_graph_params & params);
255};
256
257struct llm_build_grovemoe : public llm_graph_context {
258 llm_build_grovemoe(const llama_model & model, const llm_graph_params & params);
259};
260
261struct llm_build_hunyuan_dense : public llm_graph_context {
262 llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params);
263};
264
265struct llm_build_hunyuan_moe : public llm_graph_context {
266 llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params);
267};
268
269struct llm_build_internlm2 : public llm_graph_context {
270 llm_build_internlm2(const llama_model & model, const llm_graph_params & params);
271};
272
273struct llm_build_jais : public llm_graph_context {
274 llm_build_jais(const llama_model & model, const llm_graph_params & params);
275};
276
277struct llm_build_jamba : public llm_graph_context_mamba {
278 llm_build_jamba(const llama_model & model, const llm_graph_params & params);
279};
280
281struct llm_build_lfm2 : public llm_graph_context {
282 const llama_model & model;
283
284 llm_build_lfm2(const llama_model & model, const llm_graph_params & params);
285 ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const;
286 ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const;
287 ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const;
288 ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il);
289
290};
291
292struct llm_build_llada : public llm_graph_context {
293 llm_build_llada(const llama_model & model, const llm_graph_params & params);
294};
295
296struct llm_build_llada_moe : public llm_graph_context {
297 llm_build_llada_moe(const llama_model & model, const llm_graph_params & params);
298};
299
300struct llm_build_llama : public llm_graph_context {
301 llm_build_llama(const llama_model & model, const llm_graph_params & params);
302};
303
304struct llm_build_llama_iswa : public llm_graph_context {
305 llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params);
306};
307
308struct llm_build_mamba : public llm_graph_context_mamba {
309 llm_build_mamba(const llama_model & model, const llm_graph_params & params);
310};
311
312struct llm_build_minicpm3 : public llm_graph_context {
313 llm_build_minicpm3(const llama_model & model, const llm_graph_params & params);
314};
315
316struct llm_build_minimax_m2 : public llm_graph_context {
317 llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
318};
319
320struct llm_build_mpt : public llm_graph_context {
321 llm_build_mpt(const llama_model & model, const llm_graph_params & params);
322};
323
324struct llm_build_nemotron : public llm_graph_context {
325 llm_build_nemotron(const llama_model & model, const llm_graph_params & params);
326};
327
328struct llm_build_nemotron_h : public llm_graph_context_mamba {
329 llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params);
330 ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il);
331 ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn,
332 const llama_model & model, const int64_t n_embd_head, const int il);
333};
334
335struct llm_build_neo_bert : public llm_graph_context {
336 llm_build_neo_bert(const llama_model & model, const llm_graph_params & params);
337};
338
339template <bool iswa>
340struct llm_build_olmo2 : public llm_graph_context {
341 llm_build_olmo2(const llama_model & model, const llm_graph_params & params);
342};
343
344struct llm_build_olmoe : public llm_graph_context {
345 llm_build_olmoe(const llama_model & model, const llm_graph_params & params);
346};
347
348struct llm_build_olmo : public llm_graph_context {
349 llm_build_olmo(const llama_model & model, const llm_graph_params & params);
350};
351
352struct llm_build_openai_moe_iswa : public llm_graph_context {
353 llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params);
354};
355
356struct llm_build_openelm : public llm_graph_context {
357 llm_build_openelm(const llama_model & model, const llm_graph_params & params);
358};
359
360struct llm_build_orion : public llm_graph_context {
361 llm_build_orion(const llama_model & model, const llm_graph_params & params);
362};
363
364struct llm_build_pangu_embedded : public llm_graph_context {
365 llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params);
366};
367
368struct llm_build_phi2 : public llm_graph_context {
369 llm_build_phi2(const llama_model & model, const llm_graph_params & params);
370};
371
372template<bool iswa>
373struct llm_build_phi3 : public llm_graph_context {
374 llm_build_phi3(const llama_model & model, const llm_graph_params & params);
375};
376
377struct llm_build_plamo2 : public llm_graph_context_mamba {
378 llm_build_plamo2(const llama_model & model, const llm_graph_params & params);
379 private:
380 ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
381 ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur,
382 const llama_model & model, int il);
383};
384
385struct llm_build_plamo : public llm_graph_context {
386 llm_build_plamo(const llama_model & model, const llm_graph_params & params);
387};
388
389struct llm_build_plm : public llm_graph_context {
390 llm_build_plm(const llama_model & model, const llm_graph_params & params);
391};
392
393struct llm_build_qwen2 : public llm_graph_context {
394 llm_build_qwen2(const llama_model & model, const llm_graph_params & params);
395};
396
397struct llm_build_qwen2moe : public llm_graph_context {
398 llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params);
399};
400
401struct llm_build_qwen2vl : public llm_graph_context {
402 llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params);
403};
404
405struct llm_build_qwen3 : public llm_graph_context {
406 llm_build_qwen3(const llama_model & model, const llm_graph_params & params);
407};
408
409struct llm_build_qwen3moe : public llm_graph_context {
410 llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params);
411};
412
413struct llm_build_qwen3vl : public llm_graph_context {
414 llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params);
415};
416
417struct llm_build_qwen3vlmoe : public llm_graph_context {
418 llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
419};
420
421
422struct llm_build_qwen : public llm_graph_context {
423 llm_build_qwen(const llama_model & model, const llm_graph_params & params);
424};
425
426struct llm_build_refact : public llm_graph_context {
427 llm_build_refact(const llama_model & model, const llm_graph_params & params);
428};
429
430struct llm_build_rwkv6 : public llm_build_rwkv6_base {
431 llm_build_rwkv6(const llama_model & model, const llm_graph_params & params);
432};
433
434struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
435 llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params);
436};
437
438struct llm_build_rwkv7 : public llm_build_rwkv7_base {
439 llm_build_rwkv7(const llama_model & model, const llm_graph_params & params);
440};
441
442struct llm_build_seed_oss : public llm_graph_context {
443 llm_build_seed_oss(const llama_model & model, const llm_graph_params & params);
444};
445
446template <bool iswa>
447struct llm_build_smallthinker : public llm_graph_context {
448 llm_build_smallthinker(const llama_model & model, const llm_graph_params & params);
449};
450
451struct llm_build_smollm3 : public llm_graph_context {
452 llm_build_smollm3(const llama_model & model, const llm_graph_params & params);
453};
454
455struct llm_build_stablelm : public llm_graph_context {
456 llm_build_stablelm(const llama_model & model, const llm_graph_params & params);
457};
458
459struct llm_build_starcoder2 : public llm_graph_context {
460 llm_build_starcoder2(const llama_model & model, const llm_graph_params & params);
461};
462
463struct llm_build_starcoder : public llm_graph_context {
464 llm_build_starcoder(const llama_model & model, const llm_graph_params & params);
465};
466
467struct llm_build_t5_dec : public llm_graph_context {
468 llm_build_t5_dec(const llama_model & model, const llm_graph_params & params);
469};
470
471struct llm_build_t5_enc : public llm_graph_context {
472 llm_build_t5_enc(const llama_model & model, const llm_graph_params & params);
473};
474
475struct llm_build_wavtokenizer_dec : public llm_graph_context {
476 llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params);
477};
478
479struct llm_build_xverse : public llm_graph_context {
480 llm_build_xverse(const llama_model & model, const llm_graph_params & params);
481};
482