| 1 | #pragma once |
| 2 | |
| 3 | #include "llama.h" |
| 4 | #include "llama-cparams.h" |
| 5 | #include "llama-graph.h" |
| 6 | #include "llama-adapter.h" |
| 7 | |
| 8 | #include "ggml-cpp.h" |
| 9 | #include "ggml-opt.h" |
| 10 | |
| 11 | #include <map> |
| 12 | #include <vector> |
| 13 | |
| 14 | struct llama_model; |
| 15 | class llama_batch_allocr; |
| 16 | |
| 17 | class llama_io_read_i; |
| 18 | class llama_io_write_i; |
| 19 | |
| 20 | // "memory" as in abstract memory for the context |
| 21 | struct llama_memory_i; |
| 22 | struct llama_memory_context_i; |
| 23 | |
| 24 | // "memory" as in physical memory for a buffer type, in bytes |
| 25 | struct llama_memory_breakdown_data { |
| 26 | size_t model = 0; // memory allocated for the model |
| 27 | size_t context = 0; // memory allocated for the context |
| 28 | size_t compute = 0; // memory allocated for temporary compute buffers |
| 29 | }; |
| 30 | |
| 31 | struct llama_context { |
| 32 | // init scheduler and compute buffers, reserve worst-case graphs |
| 33 | llama_context( |
| 34 | const llama_model & model, |
| 35 | llama_context_params params); |
| 36 | |
| 37 | ~llama_context(); |
| 38 | |
| 39 | void synchronize(); |
| 40 | |
| 41 | const llama_model & get_model() const; |
| 42 | const llama_cparams & get_cparams() const; |
| 43 | |
| 44 | ggml_backend_sched_t get_sched() const; |
| 45 | |
| 46 | uint32_t n_ctx() const; |
| 47 | uint32_t n_ctx_seq() const; |
| 48 | uint32_t n_batch() const; |
| 49 | uint32_t n_ubatch() const; |
| 50 | uint32_t n_seq_max() const; |
| 51 | |
| 52 | uint32_t n_threads() const; |
| 53 | uint32_t n_threads_batch() const; |
| 54 | |
| 55 | llama_memory_t get_memory() const; |
| 56 | |
| 57 | // return true if the memory was updated |
| 58 | bool memory_update(bool optimize); |
| 59 | |
| 60 | enum llama_pooling_type pooling_type() const; |
| 61 | |
| 62 | float * get_logits(); |
| 63 | float * get_logits_ith(int32_t i); |
| 64 | |
| 65 | float * get_embeddings(); |
| 66 | float * get_embeddings_ith(int32_t i); |
| 67 | float * get_embeddings_seq(llama_seq_id seq_id); |
| 68 | |
| 69 | void attach_threadpool( |
| 70 | ggml_threadpool_t threadpool, |
| 71 | ggml_threadpool_t threadpool_batch); |
| 72 | |
| 73 | void detach_threadpool(); |
| 74 | |
| 75 | void set_n_threads(int32_t n_threads, int32_t n_threads_batch); |
| 76 | |
| 77 | void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); |
| 78 | |
| 79 | void set_embeddings (bool value); |
| 80 | void set_causal_attn(bool value); |
| 81 | void set_warmup(bool value); |
| 82 | |
| 83 | void set_adapter_lora( |
| 84 | llama_adapter_lora * adapter, |
| 85 | float scale); |
| 86 | |
| 87 | bool rm_adapter_lora( |
| 88 | llama_adapter_lora * adapter); |
| 89 | |
| 90 | void clear_adapter_lora(); |
| 91 | |
| 92 | bool apply_adapter_cvec( |
| 93 | const float * data, |
| 94 | size_t len, |
| 95 | int32_t n_embd, |
| 96 | int32_t il_start, |
| 97 | int32_t il_end); |
| 98 | |
| 99 | // process a single ubatch with a specific graph type |
| 100 | // if memory_context is provided, it will be applied first to the context's memory |
| 101 | // ret contains the status of the graph computation |
| 102 | // returns nullptr only if ret != GGML_STATUS_SUCCESS |
| 103 | llm_graph_result * process_ubatch( |
| 104 | const llama_ubatch & ubatch, |
| 105 | llm_graph_type gtype, |
| 106 | llama_memory_context_i * mctx, |
| 107 | ggml_status & ret); |
| 108 | |
| 109 | int encode(const llama_batch & batch_inp); |
| 110 | int decode(const llama_batch & batch_inp); |
| 111 | |
| 112 | // |
| 113 | // state save/load |
| 114 | // |
| 115 | |
| 116 | size_t state_get_size(); |
| 117 | size_t state_get_data( uint8_t * dst, size_t size); |
| 118 | size_t state_set_data(const uint8_t * src, size_t size); |
| 119 | |
| 120 | size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); |
| 121 | size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); |
| 122 | size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); |
| 123 | |
| 124 | bool state_load_file( |
| 125 | const char * filepath, |
| 126 | llama_token * tokens_out, |
| 127 | size_t n_token_capacity, |
| 128 | size_t * n_token_count_out); |
| 129 | |
| 130 | bool state_save_file( |
| 131 | const char * filepath, |
| 132 | const llama_token * tokens, |
| 133 | size_t n_token_count); |
| 134 | |
| 135 | size_t state_seq_load_file( |
| 136 | llama_seq_id seq_id, |
| 137 | const char * filepath, |
| 138 | llama_token * tokens_out, |
| 139 | size_t n_token_capacity, |
| 140 | size_t * n_token_count_out); |
| 141 | |
| 142 | size_t state_seq_save_file( |
| 143 | llama_seq_id seq_id, |
| 144 | const char * filepath, |
| 145 | const llama_token * tokens, |
| 146 | size_t n_token_count); |
| 147 | |
| 148 | // |
| 149 | // perf |
| 150 | // |
| 151 | |
| 152 | llama_perf_context_data perf_get_data() const; |
| 153 | void perf_reset(); |
| 154 | |
| 155 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const; |
| 156 | |
| 157 | // |
| 158 | // training |
| 159 | // |
| 160 | |
| 161 | void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); |
| 162 | |
| 163 | // TODO: more flexible combinations of logical/physical batch size and context size |
| 164 | void opt_epoch( |
| 165 | ggml_opt_dataset_t dataset, |
| 166 | ggml_opt_result_t result_train, |
| 167 | ggml_opt_result_t result_eval, |
| 168 | int64_t idata_split, |
| 169 | ggml_opt_epoch_callback callback_train, |
| 170 | ggml_opt_epoch_callback callback_eval); |
| 171 | |
| 172 | void opt_epoch_iter( |
| 173 | ggml_opt_dataset_t dataset, |
| 174 | ggml_opt_result_t result, |
| 175 | const std::vector<llama_token> & tokens, |
| 176 | const std::vector<llama_token> & labels_sparse, |
| 177 | llama_batch & batch, |
| 178 | ggml_opt_epoch_callback callback, |
| 179 | bool train, |
| 180 | int64_t idata_in_loop, |
| 181 | int64_t ndata_in_loop, |
| 182 | int64_t t_loop_start); |
| 183 | |
| 184 | private: |
| 185 | // |
| 186 | // output |
| 187 | // |
| 188 | |
| 189 | // Make sure enough space is available for outputs. |
| 190 | // Returns max number of outputs for which space was reserved. |
| 191 | uint32_t output_reserve(int32_t n_outputs); |
| 192 | |
| 193 | void output_reorder(); |
| 194 | |
| 195 | // |
| 196 | // graph |
| 197 | // |
| 198 | |
| 199 | public: |
| 200 | uint32_t graph_max_nodes() const; |
| 201 | |
| 202 | // can reuse the llm_graph_result instance of the context (for example to update a memory module) |
| 203 | llm_graph_result * get_gf_res_reserve() const; |
| 204 | |
| 205 | // returns the result of ggml_backend_sched_graph_compute_async execution |
| 206 | ggml_status graph_compute(ggml_cgraph * gf, bool batched); |
| 207 | |
| 208 | // reserve a graph with a dummy ubatch of the specified size |
| 209 | ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); |
| 210 | |
| 211 | private: |
| 212 | llm_graph_params graph_params( |
| 213 | llm_graph_result * res, |
| 214 | const llama_ubatch & ubatch, |
| 215 | const llama_memory_context_i * mctx, |
| 216 | llm_graph_type gtype) const; |
| 217 | |
| 218 | llm_graph_cb graph_get_cb() const; |
| 219 | |
| 220 | // TODO: read/write lora adapters and cvec |
| 221 | size_t state_write_data(llama_io_write_i & io); |
| 222 | size_t state_read_data (llama_io_read_i & io); |
| 223 | |
| 224 | size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
| 225 | size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
| 226 | |
| 227 | // |
| 228 | // members |
| 229 | // |
| 230 | |
| 231 | const llama_model & model; |
| 232 | |
| 233 | llama_cparams cparams; |
| 234 | llama_adapter_cvec cvec; |
| 235 | llama_adapter_loras loras; |
| 236 | |
| 237 | llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably |
| 238 | |
| 239 | std::unique_ptr<llama_memory_i> memory; |
| 240 | |
| 241 | // decode output (2-dimensional array: [n_outputs][n_vocab]) |
| 242 | size_t logits_size = 0; // capacity (of floats) for logits |
| 243 | float * logits = nullptr; |
| 244 | |
| 245 | // embeddings output (2-dimensional array: [n_outputs][n_embd]) |
| 246 | // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE |
| 247 | size_t embd_size = 0; // capacity (of floats) for embeddings |
| 248 | float * embd = nullptr; |
| 249 | |
| 250 | // sequence embeddings output (map of [n_embd] vectors) |
| 251 | // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE |
| 252 | std::map<llama_seq_id, std::vector<float>> embd_seq; |
| 253 | |
| 254 | // reuse the batch_allocr to avoid unnecessary memory allocations |
| 255 | std::unique_ptr<llama_batch_allocr> balloc; |
| 256 | |
| 257 | uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch |
| 258 | |
| 259 | std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers |
| 260 | |
| 261 | struct swap_info { |
| 262 | uint32_t i0; |
| 263 | uint32_t i1; |
| 264 | }; |
| 265 | |
| 266 | std::vector<swap_info> output_swaps; |
| 267 | |
| 268 | ggml_backend_sched_ptr sched; |
| 269 | |
| 270 | ggml_backend_t backend_cpu = nullptr; |
| 271 | std::vector<ggml_backend_ptr> backends; |
| 272 | |
| 273 | // training |
| 274 | ggml_opt_context_t opt_ctx = nullptr; |
| 275 | |
| 276 | ggml_threadpool_t threadpool = nullptr; |
| 277 | ggml_threadpool_t threadpool_batch = nullptr; |
| 278 | |
| 279 | ggml_abort_callback abort_callback = nullptr; |
| 280 | void * abort_callback_data = nullptr; |
| 281 | |
| 282 | std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns; |
| 283 | |
| 284 | // buffer types used for the compute buffer of each backend |
| 285 | std::vector<ggml_backend_t> backend_ptrs; |
| 286 | std::vector<ggml_backend_buffer_type_t> backend_buft; |
| 287 | |
| 288 | llm_graph_result_ptr gf_res_prev; |
| 289 | llm_graph_result_ptr gf_res_reserve; |
| 290 | |
| 291 | // host buffer for the model output (logits and embeddings) |
| 292 | ggml_backend_buffer_ptr buf_output; |
| 293 | |
| 294 | bool has_evaluated_once = false; |
| 295 | |
| 296 | // env: LLAMA_GRAPH_REUSE_DISABLE |
| 297 | bool graph_reuse_disable = false; |
| 298 | |
| 299 | // perf |
| 300 | mutable int64_t t_start_us = 0; |
| 301 | mutable int64_t t_load_us = 0; |
| 302 | mutable int64_t t_p_eval_us = 0; |
| 303 | mutable int64_t t_eval_us = 0; |
| 304 | |
| 305 | mutable int64_t t_compute_start_us = 0; |
| 306 | mutable int64_t n_queued_tokens = 0; |
| 307 | |
| 308 | mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) |
| 309 | mutable int32_t n_eval = 0; // number of eval calls |
| 310 | |
| 311 | mutable int32_t n_reused = 0; // number of times the previous graph was reused |
| 312 | }; |
| 313 | |