1#include "sampling.h"
2#include "log.h"
3
4#ifdef LLAMA_USE_LLGUIDANCE
5
6# include "llguidance.h"
7# include <cmath>
8
9struct llama_sampler_llg {
10 const llama_vocab * vocab;
11 std::string grammar_kind;
12 std::string grammar_data;
13 LlgTokenizer * tokenizer;
14 LlgMatcher * grammar;
15};
16
17static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
18 const char * grammar_data) {
19 LlgConstraintInit cinit;
20 llg_constraint_init_set_defaults(&cinit, tokenizer);
21 const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
22 if (log_level && *log_level) {
23 cinit.log_stderr_level = atoi(log_level);
24 }
25 auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
26 if (llg_matcher_get_error(c)) {
27 LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
28 llg_free_matcher(c);
29 return nullptr;
30 }
31
32 return c;
33}
34
35static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
36 return "llguidance";
37}
38
39static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
40 auto * ctx = (llama_sampler_llg *) smpl->ctx;
41 if (ctx->grammar) {
42 llg_matcher_consume_token(ctx->grammar, token);
43 }
44}
45
46static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
47 auto * ctx = (llama_sampler_llg *) smpl->ctx;
48 if (ctx->grammar) {
49 const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
50 if (mask == nullptr) {
51 if (llg_matcher_compute_mask(ctx->grammar) == 0) {
52 mask = llg_matcher_get_mask(ctx->grammar);
53 } else {
54 LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
55 llg_free_matcher(ctx->grammar);
56 ctx->grammar = nullptr;
57 return;
58 }
59 }
60
61 for (size_t i = 0; i < cur_p->size; ++i) {
62 auto token = cur_p->data[i].id;
63 if ((mask[token / 32] & (1 << (token % 32))) == 0) {
64 cur_p->data[i].logit = -INFINITY;
65 }
66 }
67 }
68}
69
70static void llama_sampler_llg_reset(llama_sampler * smpl) {
71 auto * ctx = (llama_sampler_llg *) smpl->ctx;
72 if (ctx->grammar) {
73 llg_matcher_reset(ctx->grammar);
74 }
75}
76
77static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
78 const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
79
80 auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
81
82 // copy the state
83 {
84 auto * result_ctx = (llama_sampler_llg *) result->ctx;
85
86 if (ctx->grammar) {
87 result_ctx->grammar_kind = ctx->grammar_kind;
88 result_ctx->grammar_data = ctx->grammar_data;
89 result_ctx->grammar = llg_clone_matcher(ctx->grammar);
90 result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
91 }
92 }
93
94 return result;
95}
96
97static void llama_sampler_llg_free(llama_sampler * smpl) {
98 const auto * ctx = (llama_sampler_llg *) smpl->ctx;
99
100 if (ctx->grammar) {
101 llg_free_matcher(ctx->grammar);
102 llg_free_tokenizer(ctx->tokenizer);
103 }
104
105 delete ctx;
106}
107
108static llama_sampler_i llama_sampler_llg_i = {
109 /* .name = */ llama_sampler_llg_name,
110 /* .accept = */ llama_sampler_llg_accept_impl,
111 /* .apply = */ llama_sampler_llg_apply,
112 /* .reset = */ llama_sampler_llg_reset,
113 /* .clone = */ llama_sampler_llg_clone,
114 /* .free = */ llama_sampler_llg_free,
115};
116
117static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
118 uint32_t * output_tokens, size_t output_tokens_len) {
119 const llama_vocab * vocab = (const llama_vocab *) user_data;
120 int r = 0;
121 try {
122 r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
123 true);
124 } catch (const std::exception & e) {
125 GGML_ABORT("llama_tokenize failed: %s\n", e.what());
126 }
127 if (r < 0) {
128 return -r;
129 }
130 return r;
131}
132
133static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
134 // TODO store the tokenizer in the vocab somehow
135 static const llama_vocab * vocab_cache;
136 static LlgTokenizer * tokenizer_cache;
137
138 if (vocab_cache == vocab) {
139 return llg_clone_tokenizer(tokenizer_cache);
140 }
141
142 auto tok_eos = llama_vocab_eot(vocab);
143 if (tok_eos == LLAMA_TOKEN_NULL) {
144 tok_eos = llama_vocab_eos(vocab);
145 }
146
147 size_t vocab_size = llama_vocab_n_tokens(vocab);
148
149 auto token_lens = new uint32_t[vocab_size];
150 // we typically have ~7 bytes per token; let's go on the safe side here
151 auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
152 auto token_bytes = new uint8_t[token_bytes_size];
153
154 size_t offset = 0;
155 for (size_t i = 0; i < vocab_size; i++) {
156 size_t max_token = 1024;
157 if (token_bytes_size - offset < max_token) {
158 GGML_ABORT("token_bytes buffer too small\n");
159 }
160
161 llama_token token = i;
162 auto dp = (char *) token_bytes + offset;
163 auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
164 if (size < 0) {
165 GGML_ABORT("llama_detokenize failed\n");
166 }
167 if (size == 0) {
168 size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
169 if (size < 0) {
170 GGML_ABORT("llama_detokenize failed\n");
171 }
172 if (size != 0) {
173 *dp = '\xff'; // special token prefix marker
174 size += 1;
175 }
176 }
177
178 token_lens[i] = size;
179 offset += size;
180 }
181
182 LlgTokenizerInit tinit = {
183 /* .vocab_size = */ (uint32_t) vocab_size,
184 /* .tok_eos = */ (uint32_t) tok_eos,
185 /* .token_lens = */ token_lens,
186 /* .token_bytes = */ token_bytes,
187 /* .tokenizer_json = */ nullptr,
188 /* .tokenize_assumes_string = */ true,
189 /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
190 /* .use_approximate_greedy_tokenize_fn = */ false,
191 /* .tokenize_user_data = */ vocab,
192 /* .slices = */ nullptr,
193 };
194
195 char error_buffer[1024];
196 LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
197
198 delete[] token_bytes;
199 delete[] token_lens;
200
201 if (tokenizer == nullptr) {
202 LOG_ERR("llg tokenizer error: %s\n", error_buffer);
203 return tokenizer;
204 }
205
206 if (tokenizer_cache) {
207 llg_free_tokenizer(tokenizer_cache);
208 }
209 vocab_cache = vocab;
210 tokenizer_cache = tokenizer;
211
212 return llg_clone_tokenizer(tokenizer_cache);
213}
214
215llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
216 const char * grammar_data) {
217 auto * ctx = new llama_sampler_llg;
218
219 if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
220 auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
221 *ctx = {
222 /* .vocab = */ vocab,
223 /* .grammar_kind = */ grammar_kind,
224 /* .grammar_data = */ grammar_data,
225 /* .tokenizer = */ tokenizer,
226 /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
227 };
228 if (ctx->grammar) {
229 GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
230 llg_matcher_get_mask_byte_size(ctx->grammar));
231 }
232 } else {
233 *ctx = {
234 /* .vocab = */ vocab,
235 /* .grammar_kind = */ {},
236 /* .grammar_data = */ {},
237 /* .tokenizer = */ nullptr,
238 /* .grammar = */ nullptr,
239 };
240 }
241
242 return llama_sampler_init(
243 /* .iface = */ &llama_sampler_llg_i,
244 /* .ctx = */ ctx);
245}
246
247#else
248
249llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
250 LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
251 return nullptr;
252}
253
254#endif // LLAMA_USE_LLGUIDANCE
255