| 1 | #pragma once |
| 2 | |
| 3 | #include "llama.h" |
| 4 | |
| 5 | #include "common.h" |
| 6 | |
| 7 | #include <string> |
| 8 | #include <vector> |
| 9 | |
| 10 | // common_sampler extends llama_sampler with additional functionality: |
| 11 | // |
| 12 | // - grammar support |
| 13 | // - custom sampler logic based on the parameters |
| 14 | // - history of the last accepted tokens |
| 15 | // - performance metrics |
| 16 | // |
| 17 | // This goal is to have a common implementation of the sampling logic shared across the examples. |
| 18 | // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more |
| 19 | // complex (top-k, top-p, etc). |
| 20 | // |
| 21 | // Another example is related to the grammar. In general, the grammar constraints applied on the full |
| 22 | // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled |
| 23 | // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the |
| 24 | // grammar constraints are applied to the full vocabulary and the token is resampled. |
| 25 | // |
| 26 | // The common_sampler also maintains a container with the last accepted tokens. In the future, this can |
| 27 | // be moved into the core llama library. |
| 28 | // |
| 29 | // For convenience, the common_sampler also maintains a container with the current candidate tokens. |
| 30 | // This can be used to access the probabilities of the rest of the non-sampled tokens. |
| 31 | // |
| 32 | // TODO: measure grammar performance |
| 33 | // |
| 34 | |
| 35 | struct common_sampler; |
| 36 | |
| 37 | // llama_sampler API overloads |
| 38 | |
| 39 | struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); |
| 40 | |
| 41 | void common_sampler_free(struct common_sampler * gsmpl); |
| 42 | |
| 43 | // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar |
| 44 | void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar); |
| 45 | void common_sampler_reset (struct common_sampler * gsmpl); |
| 46 | struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); |
| 47 | |
| 48 | // arguments can be nullptr to skip printing |
| 49 | void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); |
| 50 | |
| 51 | // extended sampling implementation: |
| 52 | // |
| 53 | // - set logits |
| 54 | // - apply the configured sampler chain |
| 55 | // - check if the token fits the grammar (if any) |
| 56 | // - if not: resample by first applying the grammar constraints and then sampling again (slower path) |
| 57 | // |
| 58 | // if grammar_first is true, the grammar is applied before the samplers (slower) |
| 59 | // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar |
| 60 | // |
| 61 | llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); |
| 62 | |
| 63 | // generalized version of common_sampler_sample |
| 64 | // |
| 65 | // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match |
| 66 | // if the sampler disagrees at some point, we stop and return the accepted tokens up to now |
| 67 | // |
| 68 | // common_sampler_sample_n(gsmpl, ctx, { idx }, {}); |
| 69 | // |
| 70 | // is equivalent to |
| 71 | // |
| 72 | // common_sampler_sample(gsmpl, ctx, idx); |
| 73 | // common_sampler_accept(gsmpl, token, true); |
| 74 | // |
| 75 | // requires: idxs.size() == draft.size() + 1 |
| 76 | // |
| 77 | // returns at least 1 token, up to idxs.size() |
| 78 | // |
| 79 | std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false); |
| 80 | |
| 81 | // assume idxs == [ 0, 1, 2, ..., draft.size() ] |
| 82 | std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); |
| 83 | |
| 84 | uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); |
| 85 | |
| 86 | // helpers |
| 87 | |
| 88 | // access the internal list of current candidate tokens |
| 89 | // if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability) |
| 90 | // the .sorted flag of the result indicates whether the returned candidates are sorted |
| 91 | llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort); |
| 92 | |
| 93 | // get the last accepted token |
| 94 | llama_token common_sampler_last(const struct common_sampler * gsmpl); |
| 95 | |
| 96 | // print the sampler chain into a string |
| 97 | std::string common_sampler_print(const struct common_sampler * gsmpl); |
| 98 | |
| 99 | // get a string representation of the last accepted tokens |
| 100 | std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n); |
| 101 | |
| 102 | char common_sampler_type_to_chr(enum common_sampler_type cnstr); |
| 103 | std::string common_sampler_type_to_str(enum common_sampler_type cnstr); |
| 104 | |
| 105 | std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names); |
| 106 | std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars); |
| 107 | |
| 108 | llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, |
| 109 | const char * grammar_kind, const char * grammar_data); |
| 110 | |