| 1 | #pragma once |
| 2 | |
| 3 | #include "llama-arch.h" |
| 4 | #include "llama-batch.h" |
| 5 | #include "llama-hparams.h" |
| 6 | #include "llama-adapter.h" |
| 7 | |
| 8 | #include <cstdint> |
| 9 | #include <vector> |
| 10 | #include <memory> |
| 11 | #include <set> |
| 12 | #include <functional> |
| 13 | |
| 14 | struct ggml_cgraph; |
| 15 | struct ggml_context; |
| 16 | struct ggml_tensor; |
| 17 | |
| 18 | struct llama_cparams; |
| 19 | |
| 20 | struct llama_memory_context_i; |
| 21 | |
| 22 | class llama_kv_cache_context; |
| 23 | class llama_kv_cache_iswa_context; |
| 24 | class llama_memory_recurrent_context; |
| 25 | class llama_memory_hybrid_context; |
| 26 | |
| 27 | // certain models (typically multi-modal) can produce different types of graphs |
| 28 | enum llm_graph_type { |
| 29 | LLM_GRAPH_TYPE_DEFAULT, |
| 30 | LLM_GRAPH_TYPE_ENCODER, |
| 31 | LLM_GRAPH_TYPE_DECODER, |
| 32 | }; |
| 33 | |
| 34 | enum llm_ffn_op_type { |
| 35 | LLM_FFN_SILU, |
| 36 | LLM_FFN_GELU, |
| 37 | LLM_FFN_RELU, |
| 38 | LLM_FFN_RELU_SQR, |
| 39 | LLM_FFN_SWIGLU, |
| 40 | LLM_FFN_GEGLU, |
| 41 | LLM_FFN_REGLU, |
| 42 | LLM_FFN_SWIGLU_OAI_MOE, |
| 43 | }; |
| 44 | |
| 45 | enum llm_ffn_gate_type { |
| 46 | LLM_FFN_SEQ, |
| 47 | LLM_FFN_PAR, // ffn_gate is parallel to ffn_up |
| 48 | }; |
| 49 | |
| 50 | enum llm_norm_type { |
| 51 | LLM_NORM, |
| 52 | LLM_NORM_RMS, |
| 53 | LLM_NORM_GROUP, |
| 54 | }; |
| 55 | |
| 56 | // TODO: tmp - need something better to pass the data from the encoder to the decoder |
| 57 | struct llama_cross { |
| 58 | // the output embeddings from the encoder as a ggml tensor |
| 59 | // TODO: this needs more work to be correct, for now copy the embeddings data to host memory |
| 60 | // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524 |
| 61 | //ggml_tensor * t_embd = nullptr; |
| 62 | |
| 63 | int64_t n_embd = 0; |
| 64 | int64_t n_enc = 0; |
| 65 | |
| 66 | // embeddings data copied to host memory (tmp) |
| 67 | std::vector<float> v_embd; |
| 68 | |
| 69 | // needed to construct the cross-attention mask in the decoder |
| 70 | std::vector<std::set<llama_seq_id>> seq_ids_enc; |
| 71 | }; |
| 72 | |
| 73 | struct llm_graph_params; |
| 74 | |
| 75 | // |
| 76 | // llm_graph_input |
| 77 | // |
| 78 | |
| 79 | class llm_graph_input_i { |
| 80 | public: |
| 81 | llm_graph_input_i() { |
| 82 | const char * LLAMA_GRAPH_INPUT_DEBUG = getenv(name: "LLAMA_GRAPH_INPUT_DEBUG" ); |
| 83 | debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(nptr: LLAMA_GRAPH_INPUT_DEBUG) : 0; |
| 84 | } |
| 85 | |
| 86 | virtual ~llm_graph_input_i() = default; |
| 87 | |
| 88 | virtual void set_input(const llama_ubatch * ubatch) = 0; |
| 89 | |
| 90 | // return true if the resulting input tensors using the provided graph parameters would be |
| 91 | // the same as the previous input tensors that we have currently stored in the object |
| 92 | virtual bool can_reuse(const llm_graph_params & params) { |
| 93 | // returning false here by default will prevent from reusing the graph if the check |
| 94 | // for the input type has not been implemented yet |
| 95 | GGML_UNUSED(params); |
| 96 | return false; |
| 97 | } |
| 98 | protected: |
| 99 | // env: LLAMA_GRAPH_INPUT_DEBUG |
| 100 | int debug = 0; |
| 101 | }; |
| 102 | |
| 103 | using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>; |
| 104 | |
| 105 | class llm_graph_input_embd : public llm_graph_input_i { |
| 106 | public: |
| 107 | llm_graph_input_embd() = default; |
| 108 | virtual ~llm_graph_input_embd() = default; |
| 109 | |
| 110 | void set_input(const llama_ubatch * ubatch) override; |
| 111 | |
| 112 | bool can_reuse(const llm_graph_params & params) override; |
| 113 | |
| 114 | ggml_tensor * tokens = nullptr; // I32 [n_batch] |
| 115 | ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] |
| 116 | }; |
| 117 | |
| 118 | class llm_graph_input_pos : public llm_graph_input_i { |
| 119 | public: |
| 120 | llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} |
| 121 | virtual ~llm_graph_input_pos() = default; |
| 122 | |
| 123 | void set_input(const llama_ubatch * ubatch) override; |
| 124 | |
| 125 | bool can_reuse(const llm_graph_params & params) override; |
| 126 | |
| 127 | ggml_tensor * pos = nullptr; // I32 [n_batch] |
| 128 | |
| 129 | const uint32_t n_pos_per_embd = 1; |
| 130 | }; |
| 131 | |
| 132 | // temperature tuning, used by llama4 |
| 133 | class llm_graph_input_attn_temp : public llm_graph_input_i { |
| 134 | public: |
| 135 | llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) |
| 136 | : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} |
| 137 | virtual ~llm_graph_input_attn_temp() = default; |
| 138 | |
| 139 | void set_input(const llama_ubatch * ubatch) override; |
| 140 | |
| 141 | ggml_tensor * attn_scale = nullptr; // F32 [n_batch] |
| 142 | |
| 143 | const uint32_t n_attn_temp_floor_scale; |
| 144 | const float f_attn_temp_scale; |
| 145 | }; |
| 146 | |
| 147 | class llm_graph_input_pos_bucket : public llm_graph_input_i { |
| 148 | public: |
| 149 | llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {} |
| 150 | virtual ~llm_graph_input_pos_bucket() = default; |
| 151 | |
| 152 | void set_input(const llama_ubatch * ubatch) override; |
| 153 | |
| 154 | ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch] |
| 155 | |
| 156 | const llama_hparams hparams; |
| 157 | }; |
| 158 | |
| 159 | class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { |
| 160 | public: |
| 161 | llm_graph_input_pos_bucket_kv( |
| 162 | const llama_hparams & hparams, |
| 163 | const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {} |
| 164 | virtual ~llm_graph_input_pos_bucket_kv() = default; |
| 165 | |
| 166 | void set_input(const llama_ubatch * ubatch) override; |
| 167 | |
| 168 | ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] |
| 169 | |
| 170 | const llama_hparams hparams; |
| 171 | |
| 172 | const llama_kv_cache_context * mctx; |
| 173 | }; |
| 174 | |
| 175 | class llm_graph_input_out_ids : public llm_graph_input_i { |
| 176 | public: |
| 177 | llm_graph_input_out_ids( |
| 178 | const llama_hparams & hparams, |
| 179 | const llama_cparams & cparams, |
| 180 | uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {} |
| 181 | virtual ~llm_graph_input_out_ids() = default; |
| 182 | |
| 183 | void set_input(const llama_ubatch * ubatch) override; |
| 184 | |
| 185 | bool can_reuse(const llm_graph_params & params) override; |
| 186 | |
| 187 | ggml_tensor * out_ids; // I32 [n_outputs] |
| 188 | |
| 189 | const llama_hparams hparams; |
| 190 | const llama_cparams cparams; |
| 191 | |
| 192 | const uint32_t n_outputs; |
| 193 | }; |
| 194 | |
| 195 | class llm_graph_input_mean : public llm_graph_input_i { |
| 196 | public: |
| 197 | llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {} |
| 198 | virtual ~llm_graph_input_mean() = default; |
| 199 | |
| 200 | void set_input(const llama_ubatch * ubatch) override; |
| 201 | |
| 202 | ggml_tensor * mean; // F32 [n_batch, n_batch] |
| 203 | |
| 204 | const llama_cparams cparams; |
| 205 | }; |
| 206 | |
| 207 | class llm_graph_input_cls : public llm_graph_input_i { |
| 208 | public: |
| 209 | llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {} |
| 210 | virtual ~llm_graph_input_cls() = default; |
| 211 | |
| 212 | void set_input(const llama_ubatch * ubatch) override; |
| 213 | |
| 214 | ggml_tensor * cls; // I32 [n_batch] |
| 215 | |
| 216 | const llama_cparams cparams; |
| 217 | const llm_arch arch; |
| 218 | }; |
| 219 | |
| 220 | class llm_graph_input_rs : public llm_graph_input_i { |
| 221 | public: |
| 222 | llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} |
| 223 | virtual ~llm_graph_input_rs() = default; |
| 224 | |
| 225 | void set_input(const llama_ubatch * ubatch) override; |
| 226 | |
| 227 | ggml_tensor * s_copy; // I32 [n_rs] |
| 228 | |
| 229 | // views of s_copy, computed once per graph |
| 230 | // and shared across layers which use build_rs |
| 231 | ggml_tensor * s_copy_main; // I32 [n_seqs] |
| 232 | ggml_tensor * ; // I32 [n_rs - n_seqs] |
| 233 | |
| 234 | const llama_memory_recurrent_context * mctx; |
| 235 | }; |
| 236 | |
| 237 | class llm_graph_input_cross_embd : public llm_graph_input_i { |
| 238 | public: |
| 239 | llm_graph_input_cross_embd( |
| 240 | const llama_cross * cross) : cross(cross) {} |
| 241 | virtual ~llm_graph_input_cross_embd() = default; |
| 242 | |
| 243 | void set_input(const llama_ubatch * ubatch) override; |
| 244 | |
| 245 | ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc] |
| 246 | |
| 247 | const llama_cross * cross; |
| 248 | }; |
| 249 | |
| 250 | class llm_graph_input_attn_no_cache : public llm_graph_input_i { |
| 251 | public: |
| 252 | llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) : |
| 253 | hparams(hparams), |
| 254 | cparams(cparams) { |
| 255 | } |
| 256 | ~llm_graph_input_attn_no_cache() = default; |
| 257 | |
| 258 | void set_input(const llama_ubatch * ubatch) override; |
| 259 | |
| 260 | ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } |
| 261 | ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } |
| 262 | |
| 263 | // n_tokens == n_batch |
| 264 | ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] |
| 265 | ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] |
| 266 | ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] |
| 267 | ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] |
| 268 | |
| 269 | const llama_hparams hparams; |
| 270 | const llama_cparams cparams; |
| 271 | }; |
| 272 | |
| 273 | class llm_graph_input_attn_kv : public llm_graph_input_i { |
| 274 | public: |
| 275 | llm_graph_input_attn_kv( |
| 276 | const llama_hparams & hparams, |
| 277 | const llama_cparams & cparams, |
| 278 | const llama_kv_cache_context * mctx) : |
| 279 | hparams(hparams), |
| 280 | cparams(cparams), |
| 281 | mctx(mctx) { |
| 282 | } |
| 283 | ~llm_graph_input_attn_kv() = default; |
| 284 | |
| 285 | void set_input(const llama_ubatch * ubatch) override; |
| 286 | |
| 287 | bool can_reuse(const llm_graph_params & params) override; |
| 288 | |
| 289 | ggml_tensor * get_k_idxs() const { return self_k_idxs; } |
| 290 | ggml_tensor * get_v_idxs() const { return self_v_idxs; } |
| 291 | |
| 292 | ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } |
| 293 | |
| 294 | ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] |
| 295 | ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] |
| 296 | |
| 297 | ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] |
| 298 | ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] |
| 299 | |
| 300 | // note: these have to be copies because in order to be able to reuse a graph, its inputs |
| 301 | // need to carry these parameters with them. otherwise, they can point to freed |
| 302 | // llm_graph_params from a previous batch, causing stack-use-after-return |
| 303 | const llama_hparams hparams; |
| 304 | const llama_cparams cparams; |
| 305 | |
| 306 | const llama_kv_cache_context * mctx; |
| 307 | }; |
| 308 | |
| 309 | class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { |
| 310 | public: |
| 311 | llm_graph_input_attn_kv_iswa( |
| 312 | const llama_hparams & hparams, |
| 313 | const llama_cparams & cparams, |
| 314 | const llama_kv_cache_iswa_context * mctx) : |
| 315 | hparams(hparams), |
| 316 | cparams(cparams), |
| 317 | mctx(mctx) { |
| 318 | } |
| 319 | ~llm_graph_input_attn_kv_iswa() = default; |
| 320 | |
| 321 | void set_input(const llama_ubatch * ubatch) override; |
| 322 | |
| 323 | bool can_reuse(const llm_graph_params & params) override; |
| 324 | |
| 325 | ggml_tensor * get_k_idxs() const { return self_k_idxs; } |
| 326 | ggml_tensor * get_v_idxs() const { return self_v_idxs; } |
| 327 | ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; } |
| 328 | ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; } |
| 329 | |
| 330 | ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } |
| 331 | ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } |
| 332 | |
| 333 | ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] |
| 334 | ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] |
| 335 | ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] |
| 336 | ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] |
| 337 | |
| 338 | ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] |
| 339 | ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] |
| 340 | ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] |
| 341 | ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] |
| 342 | |
| 343 | const llama_hparams hparams; |
| 344 | const llama_cparams cparams; |
| 345 | |
| 346 | const llama_kv_cache_iswa_context * mctx; |
| 347 | }; |
| 348 | |
| 349 | class llm_graph_input_attn_cross : public llm_graph_input_i { |
| 350 | public: |
| 351 | llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} |
| 352 | ~llm_graph_input_attn_cross() = default; |
| 353 | |
| 354 | void set_input(const llama_ubatch * ubatch) override; |
| 355 | |
| 356 | ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } |
| 357 | |
| 358 | ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] |
| 359 | ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] |
| 360 | |
| 361 | const llama_cross * cross = nullptr; |
| 362 | }; |
| 363 | |
| 364 | class llm_graph_input_mem_hybrid : public llm_graph_input_i { |
| 365 | public: |
| 366 | ( |
| 367 | std::unique_ptr<llm_graph_input_attn_kv> inp_attn, |
| 368 | std::unique_ptr<llm_graph_input_rs> inp_rs, |
| 369 | const llama_memory_hybrid_context * mctx) : |
| 370 | inp_attn(std::move(inp_attn)), |
| 371 | inp_rs(std::move(inp_rs)), |
| 372 | mctx(mctx) { } |
| 373 | virtual ~llm_graph_input_mem_hybrid() = default; |
| 374 | |
| 375 | void set_input(const llama_ubatch * ubatch) override; |
| 376 | |
| 377 | std::unique_ptr<llm_graph_input_attn_kv> inp_attn; |
| 378 | std::unique_ptr<llm_graph_input_rs> inp_rs; |
| 379 | |
| 380 | llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } |
| 381 | llm_graph_input_rs * get_recr() const { return inp_rs.get(); } |
| 382 | |
| 383 | const llama_memory_hybrid_context * mctx; |
| 384 | }; |
| 385 | |
| 386 | // |
| 387 | // llm_graph_result |
| 388 | // |
| 389 | |
| 390 | // these objects deliver the result from the graph build process back to the llama_context |
| 391 | // note that the input tensors created for the graph are referenced here - the goal is to be able to populate their |
| 392 | // specific data, by calling the set_inputs() method |
| 393 | // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc. |
| 394 | // these are used by the llama_context to extact the relevant data, based on the compute parameters |
| 395 | |
| 396 | // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) |
| 397 | using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>; |
| 398 | |
| 399 | class llm_graph_result; |
| 400 | |
| 401 | struct llm_graph_params { |
| 402 | llm_arch arch = LLM_ARCH_UNKNOWN; |
| 403 | |
| 404 | llama_hparams hparams; |
| 405 | llama_cparams cparams; |
| 406 | |
| 407 | llama_ubatch ubatch; // note: intentionally make a copy |
| 408 | |
| 409 | llm_graph_type gtype; |
| 410 | |
| 411 | ggml_backend_sched_t sched; |
| 412 | ggml_backend_t backend_cpu; |
| 413 | |
| 414 | const llama_adapter_cvec * cvec; |
| 415 | const llama_adapter_loras * loras; |
| 416 | const llama_memory_context_i * mctx; |
| 417 | const llama_cross * cross; |
| 418 | |
| 419 | uint32_t n_outputs; |
| 420 | |
| 421 | llm_graph_cb cb; |
| 422 | |
| 423 | llm_graph_result * res; |
| 424 | |
| 425 | // return true if the "other" params would result in a graph with the same topology as with the current params |
| 426 | // having the same topology allows us to reuse the graph in some cases |
| 427 | bool allow_reuse(const llm_graph_params & other) const { |
| 428 | // first check the ubatch |
| 429 | bool can_reuse_ubatch = |
| 430 | ubatch.equal_seqs() == other.ubatch.equal_seqs() && |
| 431 | ubatch.n_tokens == other.ubatch.n_tokens && |
| 432 | ubatch.n_seq_tokens == other.ubatch.n_seq_tokens && |
| 433 | ubatch.n_seqs == other.ubatch.n_seqs && |
| 434 | ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && |
| 435 | ( |
| 436 | (!ubatch.token && !other.ubatch.token) || |
| 437 | (!ubatch.embd && !other.ubatch.embd) |
| 438 | ); |
| 439 | |
| 440 | // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same |
| 441 | // the reason is because the set of attention streams would be different for different sequences |
| 442 | if (can_reuse_ubatch && ubatch.equal_seqs()) { |
| 443 | if (!ubatch.data) { |
| 444 | // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and |
| 445 | // therefore we cannot perform the sequence id check. normally should never happen |
| 446 | can_reuse_ubatch = false; |
| 447 | } else { |
| 448 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
| 449 | can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s]; |
| 450 | } |
| 451 | } |
| 452 | } |
| 453 | |
| 454 | if (!can_reuse_ubatch) { |
| 455 | return false; |
| 456 | } |
| 457 | |
| 458 | return |
| 459 | cparams.embeddings == other.cparams.embeddings && |
| 460 | cparams.causal_attn == other.cparams.causal_attn && |
| 461 | arch == other.arch && |
| 462 | gtype == other.gtype && |
| 463 | cvec == other.cvec && |
| 464 | loras == other.loras && |
| 465 | cross == other.cross && |
| 466 | n_outputs == other.n_outputs; |
| 467 | } |
| 468 | }; |
| 469 | |
| 470 | class llm_graph_result { |
| 471 | public: |
| 472 | llm_graph_result(int64_t max_nodes); |
| 473 | |
| 474 | virtual ~llm_graph_result() = default; |
| 475 | |
| 476 | ggml_tensor * get_tokens() const { return t_tokens; } |
| 477 | ggml_tensor * get_logits() const { return t_logits; } |
| 478 | ggml_tensor * get_embd() const { return t_embd; } |
| 479 | ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } |
| 480 | |
| 481 | ggml_cgraph * get_gf() const { return gf; } |
| 482 | ggml_context * get_ctx() const { return ctx_compute.get(); } |
| 483 | |
| 484 | int64_t get_max_nodes() const; |
| 485 | |
| 486 | void reset(); |
| 487 | |
| 488 | void set_inputs(const llama_ubatch * ubatch); |
| 489 | |
| 490 | // try to update the existing graph result using the new graph parameters in order to reuse it |
| 491 | // this can only be done if we determine that the resulting graph using the new graph parameters |
| 492 | // would be identical to the existing graph. in that case, we simply have to update the memory |
| 493 | // contexts of the input tensors of the graph and we can reuse it for another computation |
| 494 | // return true if the graph was updated and can be reused |
| 495 | bool can_reuse(const llm_graph_params & params); |
| 496 | |
| 497 | llm_graph_input_i * add_input(llm_graph_input_ptr input); |
| 498 | |
| 499 | void set_params(const llm_graph_params & params); |
| 500 | |
| 501 | // important graph nodes |
| 502 | ggml_tensor * t_tokens = nullptr; |
| 503 | ggml_tensor * t_logits = nullptr; |
| 504 | ggml_tensor * t_embd = nullptr; |
| 505 | ggml_tensor * t_embd_pooled = nullptr; |
| 506 | |
| 507 | std::vector<llm_graph_input_ptr> inputs; |
| 508 | |
| 509 | ggml_context_ptr ctx_compute; |
| 510 | |
| 511 | // memory buffers used to evaluate the model |
| 512 | std::vector<uint8_t> buf_compute_meta; |
| 513 | |
| 514 | ggml_cgraph * gf; |
| 515 | |
| 516 | int64_t max_nodes; |
| 517 | |
| 518 | private: |
| 519 | // keep a copy of the previous graph parameters |
| 520 | // we will use this to determine whether the graph can be reused by comparing them with the new parameters |
| 521 | // note: these are updated after constructing the new graph |
| 522 | llm_graph_params params; |
| 523 | |
| 524 | // env: LLAMA_GRAPH_RESULT_DEBUG |
| 525 | int debug = 0; |
| 526 | }; |
| 527 | |
| 528 | using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>; |
| 529 | |
| 530 | // |
| 531 | // llm_graph_context |
| 532 | // |
| 533 | |
| 534 | // used in build_rs to properly order writes and avoid unnecessary copies |
| 535 | using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>; |
| 536 | |
| 537 | struct llm_graph_context { |
| 538 | const llm_arch arch; |
| 539 | |
| 540 | const llama_hparams & hparams; |
| 541 | const llama_cparams & cparams; |
| 542 | const llama_ubatch & ubatch; |
| 543 | |
| 544 | const int64_t n_embd; |
| 545 | const int64_t n_layer; |
| 546 | const int64_t n_rot; |
| 547 | const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) |
| 548 | const int64_t n_head; |
| 549 | const int64_t n_head_kv; |
| 550 | const int64_t n_embd_head_k; |
| 551 | const int64_t n_embd_k_gqa; |
| 552 | const int64_t n_embd_head_v; |
| 553 | const int64_t n_embd_v_gqa; |
| 554 | const int64_t n_expert; |
| 555 | const int64_t n_expert_used; |
| 556 | |
| 557 | const float freq_base; |
| 558 | const float freq_scale; |
| 559 | const float ext_factor; |
| 560 | const float attn_factor; |
| 561 | const float beta_fast; |
| 562 | const float beta_slow; |
| 563 | const float norm_eps; |
| 564 | const float norm_rms_eps; |
| 565 | |
| 566 | const int64_t n_tokens; |
| 567 | const int64_t n_outputs; |
| 568 | const int32_t n_ctx_orig; // yarn |
| 569 | |
| 570 | const enum llama_pooling_type pooling_type; |
| 571 | const enum llama_rope_type rope_type; |
| 572 | |
| 573 | ggml_backend_sched_t sched; |
| 574 | |
| 575 | ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? |
| 576 | |
| 577 | const llama_adapter_cvec * cvec; |
| 578 | const llama_adapter_loras * loras; |
| 579 | const llama_memory_context_i * mctx; |
| 580 | const llama_cross * cross; |
| 581 | |
| 582 | const llm_graph_cb & cb_func; |
| 583 | |
| 584 | llm_graph_result * res; |
| 585 | |
| 586 | ggml_context * ctx0 = nullptr; |
| 587 | ggml_cgraph * gf = nullptr; |
| 588 | |
| 589 | llm_graph_context(const llm_graph_params & params); |
| 590 | virtual ~llm_graph_context() = default; |
| 591 | |
| 592 | void cb(ggml_tensor * cur, const char * name, int il) const; |
| 593 | |
| 594 | // |
| 595 | // common |
| 596 | // |
| 597 | |
| 598 | ggml_tensor * build_cvec( |
| 599 | ggml_tensor * cur, |
| 600 | int il) const; |
| 601 | |
| 602 | // do mat_mul, while optionally apply lora |
| 603 | ggml_tensor * build_lora_mm( |
| 604 | ggml_tensor * w, |
| 605 | ggml_tensor * cur) const; |
| 606 | |
| 607 | // do mat_mul_id, while optionally apply lora |
| 608 | ggml_tensor * build_lora_mm_id( |
| 609 | ggml_tensor * w, // ggml_tensor * as |
| 610 | ggml_tensor * cur, // ggml_tensor * b |
| 611 | ggml_tensor * ids) const; |
| 612 | |
| 613 | ggml_tensor * build_norm( |
| 614 | ggml_tensor * cur, |
| 615 | ggml_tensor * mw, |
| 616 | ggml_tensor * mb, |
| 617 | llm_norm_type type, |
| 618 | int il) const; |
| 619 | |
| 620 | ggml_tensor * build_ffn( |
| 621 | ggml_tensor * cur, |
| 622 | ggml_tensor * up, |
| 623 | ggml_tensor * up_b, |
| 624 | ggml_tensor * up_s, |
| 625 | ggml_tensor * gate, |
| 626 | ggml_tensor * gate_b, |
| 627 | ggml_tensor * gate_s, |
| 628 | ggml_tensor * down, |
| 629 | ggml_tensor * down_b, |
| 630 | ggml_tensor * down_s, |
| 631 | ggml_tensor * act_scales, |
| 632 | llm_ffn_op_type type_op, |
| 633 | llm_ffn_gate_type type_gate, |
| 634 | int il) const; |
| 635 | |
| 636 | // build MoE FFN without bias tensors |
| 637 | ggml_tensor * build_moe_ffn( |
| 638 | ggml_tensor * cur, |
| 639 | ggml_tensor * gate_inp, |
| 640 | ggml_tensor * up_exps, |
| 641 | ggml_tensor * gate_exps, |
| 642 | ggml_tensor * down_exps, |
| 643 | ggml_tensor * exp_probs_b, |
| 644 | int64_t n_expert, |
| 645 | int64_t n_expert_used, |
| 646 | llm_ffn_op_type type_op, |
| 647 | bool norm_w, |
| 648 | bool scale_w, |
| 649 | float w_scale, |
| 650 | llama_expert_gating_func_type gating_op, |
| 651 | int il, |
| 652 | ggml_tensor * probs_in = nullptr) const; |
| 653 | |
| 654 | ggml_tensor * build_moe_ffn( |
| 655 | ggml_tensor * cur, |
| 656 | ggml_tensor * gate_inp, |
| 657 | ggml_tensor * gate_inp_b, |
| 658 | ggml_tensor * up_exps, |
| 659 | ggml_tensor * up_exps_b, |
| 660 | ggml_tensor * gate_exps, |
| 661 | ggml_tensor * gate_exps_b, |
| 662 | ggml_tensor * down_exps, |
| 663 | ggml_tensor * down_exps_b, |
| 664 | ggml_tensor * exp_probs_b, |
| 665 | int64_t n_expert, |
| 666 | int64_t n_expert_used, |
| 667 | llm_ffn_op_type type_op, |
| 668 | bool norm_w, |
| 669 | bool scale_w, |
| 670 | float w_scale, |
| 671 | llama_expert_gating_func_type gating_op, |
| 672 | int il, |
| 673 | ggml_tensor * probs_in = nullptr) const; |
| 674 | |
| 675 | // |
| 676 | // inputs |
| 677 | // |
| 678 | |
| 679 | ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; |
| 680 | ggml_tensor * build_inp_pos() const; |
| 681 | ggml_tensor * build_inp_attn_scale() const; |
| 682 | ggml_tensor * build_inp_out_ids() const; |
| 683 | ggml_tensor * build_inp_mean() const; |
| 684 | ggml_tensor * build_inp_cls() const; |
| 685 | |
| 686 | ggml_tensor * build_inp_cross_embd() const; |
| 687 | ggml_tensor * build_inp_pos_bucket_enc() const; |
| 688 | ggml_tensor * build_inp_pos_bucket_dec() const; |
| 689 | ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; |
| 690 | |
| 691 | // |
| 692 | // attention |
| 693 | // |
| 694 | |
| 695 | ggml_tensor * build_attn_mha( |
| 696 | ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] |
| 697 | ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] |
| 698 | ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) |
| 699 | ggml_tensor * kq_b, |
| 700 | ggml_tensor * kq_mask, |
| 701 | ggml_tensor * sinks, // [n_head_q] |
| 702 | ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] |
| 703 | float kq_scale, |
| 704 | int il) const; |
| 705 | |
| 706 | llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; |
| 707 | |
| 708 | ggml_tensor * build_attn( |
| 709 | llm_graph_input_attn_no_cache * inp, |
| 710 | ggml_tensor * wo, |
| 711 | ggml_tensor * wo_b, |
| 712 | ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] |
| 713 | ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] |
| 714 | ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] |
| 715 | ggml_tensor * kq_b, |
| 716 | ggml_tensor * sinks, // [n_head_q] |
| 717 | ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] |
| 718 | float kq_scale, |
| 719 | int il) const; |
| 720 | |
| 721 | llm_graph_input_attn_kv * build_attn_inp_kv() const; |
| 722 | |
| 723 | ggml_tensor * build_attn( |
| 724 | llm_graph_input_attn_kv * inp, |
| 725 | ggml_tensor * wo, |
| 726 | ggml_tensor * wo_b, |
| 727 | ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] |
| 728 | ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] |
| 729 | ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] |
| 730 | ggml_tensor * kq_b, |
| 731 | ggml_tensor * sinks, // [n_head_q] |
| 732 | ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] |
| 733 | float kq_scale, |
| 734 | int il) const; |
| 735 | |
| 736 | llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; |
| 737 | |
| 738 | // note: if k_cur or v_cur are not provided, they will not be stored in the memory |
| 739 | ggml_tensor * build_attn( |
| 740 | llm_graph_input_attn_kv_iswa * inp, |
| 741 | ggml_tensor * wo, |
| 742 | ggml_tensor * wo_b, |
| 743 | ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] |
| 744 | ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional |
| 745 | ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional |
| 746 | ggml_tensor * kq_b, |
| 747 | ggml_tensor * sinks, // [n_head_q] |
| 748 | ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] |
| 749 | float kq_scale, |
| 750 | int il) const; |
| 751 | |
| 752 | llm_graph_input_attn_cross * build_attn_inp_cross() const; |
| 753 | |
| 754 | ggml_tensor * build_attn( |
| 755 | llm_graph_input_attn_cross * inp, |
| 756 | ggml_tensor * wo, |
| 757 | ggml_tensor * wo_b, |
| 758 | ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] |
| 759 | ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] |
| 760 | ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] |
| 761 | ggml_tensor * kq_b, |
| 762 | ggml_tensor * sinks, // [n_head_q] |
| 763 | ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] |
| 764 | float kq_scale, |
| 765 | int il) const; |
| 766 | |
| 767 | // |
| 768 | // recurrent |
| 769 | // |
| 770 | |
| 771 | // TODO: move this implementation to llama_memory_recurrent. |
| 772 | // this is analogous to llama_kv_cache::cpy_k / cpy_v |
| 773 | // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the |
| 774 | // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in |
| 775 | // `llama_memory_recurrent` |
| 776 | ggml_tensor * build_rs( |
| 777 | ggml_tensor * s, |
| 778 | ggml_tensor * state_copy_main, |
| 779 | ggml_tensor * , |
| 780 | int32_t state_size, |
| 781 | int32_t n_seqs, |
| 782 | uint32_t n_rs, |
| 783 | uint32_t rs_head, |
| 784 | uint32_t rs_size, |
| 785 | int32_t rs_zero, |
| 786 | const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; |
| 787 | |
| 788 | llm_graph_input_rs * build_rs_inp() const; |
| 789 | |
| 790 | ggml_tensor * build_rs( |
| 791 | llm_graph_input_rs * inp, |
| 792 | ggml_tensor * s, |
| 793 | int32_t state_size, |
| 794 | int32_t n_seqs, |
| 795 | const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; |
| 796 | |
| 797 | ggml_tensor * build_rwkv_token_shift_load( |
| 798 | llm_graph_input_rs * inp, |
| 799 | const llama_ubatch & ubatch, |
| 800 | int il) const; |
| 801 | |
| 802 | ggml_tensor * build_rwkv_token_shift_store( |
| 803 | ggml_tensor * token_shift, |
| 804 | const llama_ubatch & ubatch, |
| 805 | int il) const; |
| 806 | // |
| 807 | // hybrid |
| 808 | // |
| 809 | |
| 810 | llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; |
| 811 | |
| 812 | // |
| 813 | // pooling |
| 814 | // |
| 815 | |
| 816 | void build_pooling( |
| 817 | ggml_tensor * cls, |
| 818 | ggml_tensor * cls_b, |
| 819 | ggml_tensor * cls_out, |
| 820 | ggml_tensor * cls_out_b) const; |
| 821 | |
| 822 | // |
| 823 | // dense (out) |
| 824 | // |
| 825 | |
| 826 | void build_dense_out( |
| 827 | ggml_tensor * dense_2, |
| 828 | ggml_tensor * dense_3) const; |
| 829 | }; |
| 830 | |
| 831 | // TODO: better name |
| 832 | int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); |
| 833 | |