| 1 | #pragma once |
|---|---|
| 2 | |
| 3 | #include "llama.h" |
| 4 | #include "common.h" |
| 5 | |
| 6 | struct common_speculative; |
| 7 | |
| 8 | struct common_speculative_params { |
| 9 | int n_draft = 16; // max drafted tokens |
| 10 | int n_reuse = 256; |
| 11 | |
| 12 | float p_min = 0.75f; // min probability required to accept a token in the draft |
| 13 | }; |
| 14 | |
| 15 | struct common_speculative * common_speculative_init( |
| 16 | struct llama_context * ctx_tgt, |
| 17 | struct llama_context * ctx_dft |
| 18 | ); |
| 19 | |
| 20 | void common_speculative_free(struct common_speculative * spec); |
| 21 | |
| 22 | bool common_speculative_are_compatible( |
| 23 | const struct llama_context * ctx_tgt, |
| 24 | const struct llama_context * ctx_dft); |
| 25 | |
| 26 | void common_speculative_add_replacement_tgt_dft( |
| 27 | struct common_speculative * spec, |
| 28 | const char *source, const char *dest); |
| 29 | |
| 30 | // sample up to n_draft tokens and add them to the batch using the draft model |
| 31 | llama_tokens common_speculative_gen_draft( |
| 32 | struct common_speculative * spec, |
| 33 | struct common_speculative_params params, |
| 34 | const llama_tokens & prompt, |
| 35 | llama_token id_last); |
| 36 |