1#include "llama-sampling.h"
2
3#include "llama-impl.h"
4#include "llama-vocab.h"
5#include "llama-grammar.h"
6
7#include <algorithm>
8#include <cassert>
9#include <cfloat>
10#include <chrono>
11#include <cmath>
12#include <cstdlib>
13#include <cstring>
14#include <ctime>
15#include <numeric>
16#include <random>
17#include <unordered_map>
18#include <stdexcept>
19
20// the ring buffer works similarly to std::deque, but with a fixed capacity
21template<typename T>
22struct ring_buffer {
23 ring_buffer(size_t cap) : capacity(cap), data(cap) {}
24
25 T & front() {
26 if (sz == 0) {
27 throw std::runtime_error("ring buffer is empty");
28 }
29 return data[first];
30 }
31
32 const T & front() const {
33 if (sz == 0) {
34 throw std::runtime_error("ring buffer is empty");
35 }
36 return data[first];
37 }
38
39 T & back() {
40 if (sz == 0) {
41 throw std::runtime_error("ring buffer is empty");
42 }
43 return data[pos];
44 }
45
46 const T & back() const {
47 if (sz == 0) {
48 throw std::runtime_error("ring buffer is empty");
49 }
50 return data[pos];
51 }
52
53 void push_back(const T & value) {
54 if (capacity == 0) {
55 throw std::runtime_error("ring buffer: capacity is zero");
56 }
57
58 if (sz == capacity) {
59 // advance the start when buffer is full
60 first = (first + 1) % capacity;
61 } else {
62 sz++;
63 }
64 data[pos] = value;
65 pos = (pos + 1) % capacity;
66 }
67
68 T pop_front() {
69 if (sz == 0) {
70 throw std::runtime_error("ring buffer is empty");
71 }
72 T value = data[first];
73 first = (first + 1) % capacity;
74 sz--;
75 return value;
76 }
77
78 //T & operator[](size_t i) {
79 // if (i >= sz) {
80 // throw std::runtime_error("ring buffer: index out of bounds");
81 // }
82 // return data[(first + i) % capacity];
83 //}
84
85 //const T & at(size_t i) const {
86 // if (i >= sz) {
87 // throw std::runtime_error("ring buffer: index out of bounds");
88 // }
89 // return data[(first + i) % capacity];
90 //}
91
92 const T & rat(size_t i) const {
93 if (i >= sz) {
94 throw std::runtime_error("ring buffer: index out of bounds");
95 }
96 return data[(first + sz - i - 1) % capacity];
97 }
98
99 std::vector<T> to_vector() const {
100 std::vector<T> result;
101 result.reserve(sz);
102 for (size_t i = 0; i < sz; i++) {
103 result.push_back(data[(first + i) % capacity]);
104 }
105 return result;
106 }
107
108 void clear() {
109 // here only reset the status of the buffer
110 sz = 0;
111 first = 0;
112 pos = 0;
113 }
114
115 bool empty() const {
116 return sz == 0;
117 }
118
119 size_t size() const {
120 return sz;
121 }
122
123 size_t capacity = 0;
124 size_t sz = 0;
125 size_t first = 0;
126 size_t pos = 0;
127
128 std::vector<T> data;
129};
130
131// writes result in res, does not mutate cur
132static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
133 static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
134 return a.logit > b.logit;
135 };
136
137 constexpr int nbuckets = 128;
138 constexpr float bucket_low = -10.0f;
139 constexpr float bucket_high = 10.0f;
140 constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
141 constexpr float bucket_inter = -bucket_low * bucket_scale;
142
143 std::vector<int> bucket_idx;
144 std::vector<int> histo(nbuckets, 0);
145
146 std::vector<llama_token_data*> bucket_ptrs;
147
148 bucket_idx.reserve(n: cur.size);
149
150 for (int i = 0; i < (int)cur.size; ++i) {
151 const float val = cur.data[i].logit;
152 int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
153 ib = std::max(a: 0, b: std::min(a: nbuckets - 1, b: ib));
154 bucket_idx.push_back(x: ib);
155 ++histo[ib];
156 }
157 int nhave = 0;
158 int ib = nbuckets - 1;
159 for ( ; ib >= 0; --ib) {
160 nhave += histo[ib];
161 if (nhave >= npartial) {
162 break;
163 }
164 }
165 res.resize(new_size: nhave);
166 auto * ptr = res.data();
167 bucket_ptrs.reserve(n: nbuckets - ib);
168 for (int j = nbuckets - 1; j >= ib; --j) {
169 bucket_ptrs.push_back(x: ptr);
170 ptr += histo[j];
171 }
172 for (int i = 0; i < (int)cur.size; ++i) {
173 int j = bucket_idx[i];
174 if (j >= ib) {
175 *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
176 }
177 }
178
179 ptr = res.data();
180 int ndone = 0;
181 for (int j = nbuckets - 1; j > ib; --j) {
182 std::sort(first: ptr, last: ptr + histo[j], comp: comp);
183 ptr += histo[j];
184 ndone += histo[j];
185 }
186 std::partial_sort(first: ptr, middle: ptr + npartial - ndone, last: ptr + histo[ib], comp: comp);
187}
188
189// reduces the size of cur_p to npartial, keeping only the top npartial elements
190static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
191 static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
192 return a.logit > b.logit;
193 };
194
195 if (npartial <= 128) {
196 std::partial_sort(first: cur_p->data, middle: cur_p->data + npartial, last: cur_p->data + cur_p->size, comp: comp);
197
198 cur_p->size = npartial;
199 cur_p->sorted = true;
200
201 return;
202 }
203
204 std::vector<llama_token_data> tmp;
205
206 llama_token_data_array_partial_sort(cur: *cur_p, npartial, res&: tmp);
207
208 std::copy(first: tmp.data(), last: tmp.data() + npartial, result: cur_p->data);
209
210 cur_p->size = npartial;
211 cur_p->sorted = true;
212}
213
214static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
215 // iterator for the probabilities
216#ifdef __GNUC__
217 #pragma GCC diagnostic push
218 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
219#endif
220
221 struct probs_iterator {
222 typedef std::input_iterator_tag iterator_category;
223 typedef float value_type;
224 typedef float * pointer;
225 typedef float & reference;
226 typedef ptrdiff_t difference_type;
227
228 const llama_token_data * data;
229
230 bool operator==(const probs_iterator & other) const { return data == other.data; }
231 bool operator!=(const probs_iterator & other) const { return data != other.data; }
232 const float & operator*() const { return data->p; }
233 probs_iterator & operator++() { ++data; return *this; }
234 probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
235 };
236
237#ifdef __GNUC__
238 #pragma GCC diagnostic pop
239#endif
240
241 std::discrete_distribution<int> dist(probs_iterator{.data: cur_p->data}, probs_iterator{.data: cur_p->data + cur_p->size});
242
243 return dist(rng);
244}
245
246/*
247static void llama_log_softmax(float * array, size_t size) {
248 float max_l = *std::max_element(array, array + size);
249 float sum = 0.f;
250 for (size_t i = 0; i < size; ++i) {
251 float p = expf(array[i] - max_l);
252 sum += p;
253 array[i] = p;
254 }
255
256 for (size_t i = 0; i < size; ++i) {
257 array[i] = logf(array[i] / sum);
258 }
259}
260*/
261
262static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
263 if (temp <= 0.0f) {
264 // find the token with the highest logit and set the rest to -inf
265 size_t max_i = 0;
266 float max_l = cur_p->data[0].logit;
267
268 for (size_t i = 1; i < cur_p->size; ++i) {
269 if (cur_p->data[i ].logit > max_l) {
270 cur_p->data[max_i].logit = -INFINITY;
271 max_i = i;
272 max_l = cur_p->data[i].logit;
273 } else {
274 cur_p->data[i].logit = -INFINITY;
275 }
276 }
277
278 return;
279 }
280
281 for (size_t i = 0; i < cur_p->size; ++i) {
282 cur_p->data[i].logit /= temp;
283 }
284}
285
286static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
287 GGML_ASSERT(cur_p->size > 0);
288
289 // Sort the logits in descending order if requested
290 if (do_sort && !cur_p->sorted) {
291 llama_token_data_array_partial_sort_inplace(cur_p, npartial: cur_p->size);
292 }
293
294 float max_l = cur_p->data[0].logit;
295 if (!cur_p->sorted) {
296 for (size_t i = 1; i < cur_p->size; ++i) {
297 max_l = std::max(a: max_l, b: cur_p->data[i].logit);
298 }
299 }
300
301 float cum_sum = 0.0f;
302
303 for (size_t i = 0; i < cur_p->size; ++i) {
304 float p = expf(x: cur_p->data[i].logit - max_l);
305 cur_p->data[i].p = p;
306 cum_sum += p;
307 }
308
309 for (size_t i = 0; i < cur_p->size; ++i) {
310 cur_p->data[i].p /= cum_sum;
311 }
312}
313
314static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
315 // if (k >= (int32_t)cur_p->size) {
316 // return;
317 // }
318
319 if (k <= 0) {
320 return;
321 }
322
323 k = std::min(a: k, b: (int) cur_p->size);
324
325 // Sort scores in descending order
326 if (!cur_p->sorted) {
327 llama_token_data_array_partial_sort_inplace(cur_p, npartial: k);
328 }
329
330 cur_p->size = k;
331}
332
333static uint32_t get_rng_seed(uint32_t seed) {
334 if (seed == LLAMA_DEFAULT_SEED) {
335 // use system clock if std::random_device is not a true RNG
336 static bool is_rd_prng = std::random_device().entropy() == 0;
337 if (is_rd_prng) {
338 return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
339 }
340 std::random_device rd;
341 return rd();
342 }
343 return seed;
344}
345
346// llama_sampler API
347
348struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
349 return new llama_sampler {
350 /* .iface = */ iface,
351 /* .ctx = */ ctx,
352 };
353}
354
355const char * llama_sampler_name(const struct llama_sampler * smpl) {
356 if (!smpl->iface) {
357 return "(null)";
358 }
359
360 return smpl->iface->name(smpl);
361}
362
363void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
364 if (smpl->iface->accept) {
365 smpl->iface->accept(smpl, token);
366 }
367}
368
369void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
370 GGML_ASSERT(smpl->iface->apply);
371 smpl->iface->apply(smpl, cur_p);
372}
373
374void llama_sampler_reset(struct llama_sampler * smpl) {
375 if (smpl->iface->reset) {
376 smpl->iface->reset(smpl);
377 }
378}
379
380struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
381 if (smpl->iface->clone) {
382 return smpl->iface->clone(smpl);
383 }
384
385 if (smpl->ctx == nullptr) {
386 return llama_sampler_init(
387 /* .iface = */ smpl->iface,
388 /* .ctx = */ nullptr
389 );
390 }
391
392 GGML_ABORT("the sampler does not support cloning");
393}
394
395void llama_sampler_free(struct llama_sampler * smpl) {
396 if (smpl == nullptr) {
397 return;
398 }
399
400 if (smpl->iface->free) {
401 smpl->iface->free(smpl);
402 }
403
404 delete smpl;
405}
406
407llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
408 const auto * logits = llama_get_logits_ith(ctx, i: idx);
409
410 const llama_model * model = llama_get_model(ctx);
411 const llama_vocab * vocab = llama_model_get_vocab(model);
412
413 const int n_vocab = llama_vocab_n_tokens(vocab);
414
415 // TODO: do not allocate each time
416 std::vector<llama_token_data> cur;
417 cur.reserve(n: n_vocab);
418 for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
419 cur.emplace_back(args: llama_token_data{.id: token_id, .logit: logits[token_id], .p: 0.0f});
420 }
421
422 llama_token_data_array cur_p = {
423 /* .data = */ cur.data(),
424 /* .size = */ cur.size(),
425 /* .selected = */ -1,
426 /* .sorted = */ false,
427 };
428
429 llama_sampler_apply(smpl, cur_p: &cur_p);
430
431 GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
432
433 auto token = cur_p.data[cur_p.selected].id;
434
435 llama_sampler_accept(smpl, token);
436
437 return token;
438}
439
440// sampler chain
441
442static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
443 return "chain";
444}
445
446static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
447 auto * chain = (llama_sampler_chain *) smpl->ctx;
448
449 time_meas tm(chain->t_sample_us, chain->params.no_perf);
450
451 for (auto * smpl : chain->samplers) {
452 llama_sampler_accept(smpl, token);
453 }
454
455 chain->n_sample++;
456}
457
458static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
459 auto * chain = (llama_sampler_chain *) smpl->ctx;
460
461 time_meas tm(chain->t_sample_us, chain->params.no_perf);
462
463 for (auto * smpl : chain->samplers) {
464 llama_sampler_apply(smpl, cur_p);
465 }
466}
467
468static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
469 auto * chain = (llama_sampler_chain *) smpl->ctx;
470
471 for (auto * smpl : chain->samplers) {
472 llama_sampler_reset(smpl);
473 }
474
475 chain->t_sample_us = 0;
476 chain->n_sample = 0;
477}
478
479static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
480 const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
481
482 auto * result = llama_sampler_chain_init(params: chain_src->params);
483
484 for (auto * smpl : chain_src->samplers) {
485 llama_sampler_chain_add(chain: result, smpl: llama_sampler_clone(smpl));
486 }
487
488 return result;
489}
490
491static void llama_sampler_chain_free(struct llama_sampler * smpl) {
492 auto * chain = (llama_sampler_chain *) smpl->ctx;
493
494 for (auto * smpl : chain->samplers) {
495 llama_sampler_free(smpl);
496 }
497
498 delete chain;
499}
500
501static struct llama_sampler_i llama_sampler_chain_i = {
502 /* .name = */ llama_sampler_chain_name,
503 /* .accept = */ llama_sampler_chain_accept,
504 /* .apply = */ llama_sampler_chain_apply,
505 /* .reset = */ llama_sampler_chain_reset,
506 /* .clone = */ llama_sampler_chain_clone,
507 /* .free = */ llama_sampler_chain_free,
508};
509
510struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
511 return llama_sampler_init(
512 /* .iface = */ &llama_sampler_chain_i,
513 /* .ctx = */ new llama_sampler_chain {
514 /* .params = */ params,
515 /* .samplers = */ {},
516 /* .t_sample_us = */ 0,
517 /* .n_sample = */ 0,
518 }
519 );
520}
521
522void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
523 auto * p = (llama_sampler_chain *) chain->ctx;
524 p->samplers.push_back(x: smpl);
525}
526
527struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
528 const auto * p = (const llama_sampler_chain *) chain->ctx;
529
530 if (i < 0 || (size_t) i >= p->samplers.size()) {
531 return nullptr;
532 }
533
534 return p->samplers[i];
535}
536
537struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
538 auto * p = (llama_sampler_chain *) chain->ctx;
539
540 if (i < 0 || (size_t) i >= p->samplers.size()) {
541 return nullptr;
542 }
543
544 auto * result = p->samplers[i];
545 p->samplers.erase(position: p->samplers.begin() + i);
546
547 return result;
548}
549
550int llama_sampler_chain_n(const struct llama_sampler * chain) {
551 const auto * p = (const llama_sampler_chain *) chain->ctx;
552
553 return p->samplers.size();
554}
555
556//
557// samplers
558//
559
560// greedy
561
562static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
563 return "greedy";
564}
565
566static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
567 cur_p->selected = 0;
568 for (size_t i = 1; i < cur_p->size; ++i) {
569 if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
570 cur_p->selected = i;
571 }
572 }
573}
574
575static struct llama_sampler_i llama_sampler_greedy_i = {
576 /* .name = */ llama_sampler_greedy_name,
577 /* .accept = */ nullptr,
578 /* .apply = */ llama_sampler_greedy_apply,
579 /* .reset = */ nullptr,
580 /* .clone = */ nullptr,
581 /* .free = */ nullptr,
582};
583
584struct llama_sampler * llama_sampler_init_greedy() {
585 return llama_sampler_init(
586 /* .iface = */ &llama_sampler_greedy_i,
587 /* .ctx = */ nullptr
588 );
589}
590
591// dist
592
593struct llama_sampler_dist {
594 const uint32_t seed;
595 uint32_t seed_cur;
596
597 std::mt19937 rng;
598};
599
600static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
601 return "dist";
602}
603
604static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
605 auto * ctx = (llama_sampler_dist *) smpl->ctx;
606
607 // edge cases
608 if (cur_p->size == 0) {
609 cur_p->selected = -1;
610 return;
611 }
612
613 cur_p->selected = 0;
614
615 if (cur_p->size == 1) {
616 cur_p->data[0].p = 1.0f;
617 return;
618 }
619
620 // max logit for numerical stability
621 float max_l = cur_p->data[0].logit;
622 if (!cur_p->sorted) {
623 for (size_t i = 1; i < cur_p->size; ++i) {
624 max_l = std::max(a: max_l, b: cur_p->data[i].logit);
625 }
626 }
627
628 // apply softmax to obtain the probabilities
629 double sum_cum = 0.0f;
630 for (size_t i = 0; i < cur_p->size; ++i) {
631 float p = expf(x: cur_p->data[i].logit - max_l);
632 cur_p->data[i].p = p;
633 sum_cum += p;
634 }
635
636#if 1
637 // sample from the obtained probabilities and normalize the probs in a single pass
638 // this is ~3x faster on Mac with full gpt-oss vocab than the version below
639 //
640 std::uniform_real_distribution<double> dist(0.0f, 1.0f);
641 const double rnd = dist(ctx->rng);
642
643 double sum_run = 0.0f;
644 const double sum_tgt = sum_cum*rnd;
645
646 bool found = false;
647 for (size_t i = 0; i < cur_p->size; ++i) {
648 if (!found) {
649 // accumulate probs until we reach the target sum
650 sum_run += cur_p->data[i].p;
651 if (sum_run >= sum_tgt) {
652 cur_p->selected = i;
653 found = true;
654 }
655 }
656
657 // normalize probs
658 cur_p->data[i].p /= sum_cum;
659 }
660
661 // fallback to the last token (don't think this can happen)
662 assert(found);
663 if (!found) {
664 cur_p->selected = cur_p->size - 1;
665 }
666#else
667 // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
668 for (size_t i = 0; i < cur_p->size; ++i) {
669 cur_p->data[i].p /= sum_cum;
670 }
671
672 cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
673#endif
674}
675
676static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
677 const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
678 auto * result = llama_sampler_init_dist(seed: ctx->seed);
679
680 // copy the state
681 {
682 auto * result_ctx = (llama_sampler_dist *) result->ctx;
683
684 result_ctx->rng = ctx->rng;
685 }
686
687 return result;
688}
689
690static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
691 auto * ctx = (llama_sampler_dist *) smpl->ctx;
692 ctx->seed_cur = get_rng_seed(seed: ctx->seed);
693 ctx->rng.seed(sd: ctx->seed_cur);
694}
695
696static void llama_sampler_dist_free(struct llama_sampler * smpl) {
697 delete (llama_sampler_dist *) smpl->ctx;
698}
699
700static struct llama_sampler_i llama_sampler_dist_i = {
701 /* .name = */ llama_sampler_dist_name,
702 /* .accept = */ nullptr,
703 /* .apply = */ llama_sampler_dist_apply,
704 /* .reset = */ llama_sampler_dist_reset,
705 /* .clone = */ llama_sampler_dist_clone,
706 /* .free = */ llama_sampler_dist_free,
707};
708
709struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
710 auto seed_cur = get_rng_seed(seed);
711 return llama_sampler_init(
712 /* .iface = */ &llama_sampler_dist_i,
713 /* .ctx = */ new llama_sampler_dist {
714 /* .seed = */ seed,
715 /* .seed_cur = */ seed_cur,
716 /* .rng = */ std::mt19937(seed_cur),
717 }
718 );
719}
720
721// top-k
722
723struct llama_sampler_top_k {
724 const int32_t k;
725};
726
727static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
728 return "top-k";
729}
730
731static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
732 auto * ctx = (llama_sampler_top_k *) smpl->ctx;
733 llama_sampler_top_k_impl(cur_p, k: ctx->k);
734}
735
736static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
737 const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
738 return llama_sampler_init_top_k(k: ctx->k);
739}
740
741static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
742 delete (llama_sampler_top_k *) smpl->ctx;
743}
744
745static struct llama_sampler_i llama_sampler_top_k_i = {
746 /* .name = */ llama_sampler_top_k_name,
747 /* .accept = */ nullptr,
748 /* .apply = */ llama_sampler_top_k_apply,
749 /* .reset = */ nullptr,
750 /* .clone = */ llama_sampler_top_k_clone,
751 /* .free = */ llama_sampler_top_k_free,
752};
753
754struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
755 return llama_sampler_init(
756 /* .iface = */ &llama_sampler_top_k_i,
757 /* .ctx = */ new llama_sampler_top_k {
758 /* .k = */ k,
759 }
760 );
761}
762
763// top-p
764
765struct llama_sampler_top_p {
766 const float p;
767 const size_t min_keep;
768
769 std::vector<llama_token_data> buf_sort;
770};
771
772static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
773 return "top-p";
774}
775
776static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
777 auto * ctx = (llama_sampler_top_p *) smpl->ctx;
778
779 if (ctx->p >= 1.0f) {
780 return;
781 }
782
783 llama_sampler_softmax_impl(cur_p, do_sort: false);
784
785 size_t k = cur_p->size;
786 auto * pdata = cur_p->data;
787
788 auto & buf_sort = ctx->buf_sort;
789
790 // if not sorted, try adaptive top-k sorting
791 if (!cur_p->sorted && cur_p->size > 1024) {
792 k = std::min<size_t>(a: 256, b: cur_p->size);
793 llama_token_data_array_partial_sort(cur: *cur_p, npartial: k, res&: buf_sort);
794 pdata = buf_sort.data();
795 } else if (!cur_p->sorted) {
796 // small candidates -> sort inplace
797 llama_token_data_array_partial_sort_inplace(cur_p, npartial: k);
798 }
799
800 // Compute the cumulative probabilities
801 float cum_sum = 0.0f;
802 size_t last_idx = cur_p->size;
803
804 for (size_t i = 0; i < cur_p->size; ++i) {
805 cum_sum += pdata[i].p;
806
807 // Check if the running sum is at least p or if we have kept at least min_keep tokens
808 // we set the last index to i+1 to indicate that the current iterate should be included in the set
809 if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
810 last_idx = i + 1;
811 break;
812 }
813
814 // we exceeded the current top-k heuristic -> increase k and continue
815 if (!cur_p->sorted && i == k - 1) {
816 k = cur_p->size;
817 llama_token_data_array_partial_sort(cur: *cur_p, npartial: k, res&: buf_sort);
818 pdata = buf_sort.data();
819 }
820 }
821
822 // Resize the output vector to keep only the top-p tokens
823 if (!cur_p->sorted) {
824 std::copy(first: buf_sort.data(), last: buf_sort.data() + last_idx, result: cur_p->data);
825 cur_p->sorted = true;
826 }
827
828 cur_p->size = last_idx;
829}
830
831static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
832 const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
833 return llama_sampler_init_top_p(p: ctx->p, min_keep: ctx->min_keep);
834}
835
836static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
837 delete (llama_sampler_top_p *) smpl->ctx;
838}
839
840static struct llama_sampler_i llama_sampler_top_p_i = {
841 /* .name = */ llama_sampler_top_p_name,
842 /* .accept = */ nullptr,
843 /* .apply = */ llama_sampler_top_p_apply,
844 /* .reset = */ nullptr,
845 /* .clone = */ llama_sampler_top_p_clone,
846 /* .free = */ llama_sampler_top_p_free,
847};
848
849struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
850 return llama_sampler_init(
851 /* .iface = */ &llama_sampler_top_p_i,
852 /* .ctx = */ new llama_sampler_top_p {
853 /* .p = */ p,
854 /* .min_keep = */ min_keep,
855 /* .buf_sort = */ {},
856 }
857 );
858}
859
860// min-p
861
862struct llama_sampler_min_p {
863 const float p;
864 const size_t min_keep;
865};
866
867static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
868 return "min-p";
869}
870
871static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
872 auto * ctx = (llama_sampler_min_p *) smpl->ctx;
873
874 if (ctx->p <= 0.0f || !cur_p->size) {
875 return;
876 }
877
878 bool min_p_applied = false;
879
880 // if the cur_p aren't sorted, try the unsorted implementation first
881 if (!cur_p->sorted) {
882 std::vector<llama_token_data> filtered_tokens;
883
884 float max_logit = -FLT_MAX;
885 for (size_t i = 0; i < cur_p->size; ++i) {
886 max_logit = std::max(a: max_logit, b: cur_p->data[i].logit);
887 }
888 const float min_logit = max_logit + logf(x: ctx->p); // min logit for p_i >= p * p_max
889
890 for (size_t i = 0; i < cur_p->size; ++i) {
891 if (cur_p->data[i].logit >= min_logit) {
892 filtered_tokens.push_back(x: cur_p->data[i]);
893 }
894 }
895
896 // if we have enough values the operation was a success
897 if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
898 std::copy(first: filtered_tokens.begin(), last: filtered_tokens.end(), result: cur_p->data);
899 cur_p->size = filtered_tokens.size();
900 min_p_applied = true;
901 }
902 }
903
904 // if the cur_p are sorted or the unsorted implementation failed, use this implementation
905 if (!min_p_applied) {
906 // Sort the logits in descending order
907 if (!cur_p->sorted) {
908 llama_token_data_array_partial_sort_inplace(cur_p, npartial: cur_p->size);
909 }
910
911 const float min_logit = cur_p->data[0].logit + logf(x: ctx->p); // min logit for p_i >= p * p_max
912 size_t i = 1; // first token always matches
913
914 for (; i < cur_p->size; ++i) {
915 if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
916 break; // prob too small
917 }
918 }
919
920 // Resize the output vector to keep only the matching tokens
921 cur_p->size = i;
922 }
923}
924
925static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
926 const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
927 return llama_sampler_init_min_p(p: ctx->p, min_keep: ctx->min_keep);
928}
929
930static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
931 delete (llama_sampler_min_p *) smpl->ctx;
932}
933
934static struct llama_sampler_i llama_sampler_min_p_i = {
935 /* .name = */ llama_sampler_min_p_name,
936 /* .accept = */ nullptr,
937 /* .apply = */ llama_sampler_min_p_apply,
938 /* .reset = */ nullptr,
939 /* .clone = */ llama_sampler_min_p_clone,
940 /* .free = */ llama_sampler_min_p_free,
941};
942
943struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
944 return llama_sampler_init(
945 /* .iface = */ &llama_sampler_min_p_i,
946 /* .ctx = */ new llama_sampler_min_p {
947 /* .p = */ p,
948 /* .min_keep = */ min_keep,
949 }
950 );
951}
952
953// typical
954
955struct llama_sampler_typical {
956 const float p;
957 const size_t min_keep;
958};
959
960static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
961 return "typical";
962}
963
964static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
965 auto * ctx = (llama_sampler_typical *) smpl->ctx;
966
967 // Reference implementation:
968 // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
969 if (ctx->p >= 1.0f) {
970 return;
971 }
972
973 // Compute the softmax of logits and calculate entropy
974 llama_sampler_softmax_impl(cur_p, do_sort: true);
975
976 float entropy = 0.0f;
977 for (size_t i = 0; i < cur_p->size; ++i) {
978 entropy += -cur_p->data[i].p * logf(x: cur_p->data[i].p);
979 }
980
981 // Compute the absolute difference between negative log probability and entropy for each candidate
982 std::vector<float> shifted_scores;
983 for (size_t i = 0; i < cur_p->size; ++i) {
984 float shifted_score = fabsf(x: -logf(x: cur_p->data[i].p) - entropy);
985 shifted_scores.push_back(x: shifted_score);
986 }
987
988 // Sort tokens based on the shifted_scores and their corresponding indices
989 std::vector<size_t> indices(cur_p->size);
990 std::iota(first: indices.begin(), last: indices.end(), value: 0);
991
992 std::sort(first: indices.begin(), last: indices.end(), comp: [&](size_t a, size_t b) {
993 return shifted_scores[a] < shifted_scores[b];
994 });
995
996 // Compute the cumulative probabilities
997 float cum_sum = 0.0f;
998 size_t last_idx = indices.size();
999
1000 for (size_t i = 0; i < indices.size(); ++i) {
1001 size_t idx = indices[i];
1002 cum_sum += cur_p->data[idx].p;
1003
1004 // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
1005 if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
1006 last_idx = i + 1;
1007 break;
1008 }
1009 }
1010
1011 // Resize the output vector to keep only the locally typical tokens
1012 std::vector<llama_token_data> cur_p_new;
1013 for (size_t i = 0; i < last_idx; ++i) {
1014 size_t idx = indices[i];
1015 cur_p_new.push_back(x: cur_p->data[idx]);
1016 }
1017
1018 // Replace the data in cur_p with the cur_p_new data
1019 std::copy(first: cur_p_new.begin(), last: cur_p_new.end(), result: cur_p->data);
1020 cur_p->size = cur_p_new.size();
1021 cur_p->sorted = false;
1022}
1023
1024static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
1025 const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
1026 return llama_sampler_init_typical(p: ctx->p, min_keep: ctx->min_keep);
1027}
1028
1029static void llama_sampler_typical_free(struct llama_sampler * smpl) {
1030 delete (llama_sampler_typical *) smpl->ctx;
1031}
1032
1033static struct llama_sampler_i llama_sampler_typical_i = {
1034 /* .name = */ llama_sampler_typical_name,
1035 /* .accept = */ nullptr,
1036 /* .apply = */ llama_sampler_typical_apply,
1037 /* .reset = */ nullptr,
1038 /* .clone = */ llama_sampler_typical_clone,
1039 /* .free = */ llama_sampler_typical_free,
1040};
1041
1042struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1043 return llama_sampler_init(
1044 /* .iface = */ &llama_sampler_typical_i,
1045 /* .ctx = */ new llama_sampler_typical {
1046 /* .p = */ p,
1047 /* .min_keep = */ min_keep,
1048 }
1049 );
1050}
1051
1052// temp
1053
1054struct llama_sampler_temp {
1055 const float temp;
1056};
1057
1058static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
1059 return "temp";
1060}
1061
1062static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1063 const auto * ctx = (llama_sampler_temp *) smpl->ctx;
1064
1065 llama_sampler_temp_impl(cur_p, temp: ctx->temp);
1066}
1067
1068static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
1069 const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
1070 return llama_sampler_init_temp(t: ctx->temp);
1071}
1072
1073static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1074 delete (llama_sampler_temp *) smpl->ctx;
1075}
1076
1077static struct llama_sampler_i llama_sampler_temp_i = {
1078 /* .name = */ llama_sampler_temp_name,
1079 /* .accept = */ nullptr,
1080 /* .apply = */ llama_sampler_temp_apply,
1081 /* .reset = */ nullptr,
1082 /* .clone = */ llama_sampler_temp_clone,
1083 /* .free = */ llama_sampler_temp_free,
1084};
1085
1086struct llama_sampler * llama_sampler_init_temp(float temp) {
1087 return llama_sampler_init(
1088 /* .iface = */ &llama_sampler_temp_i,
1089 /* .ctx = */ new llama_sampler_temp {
1090 /*.temp = */ temp,
1091 }
1092 );
1093}
1094
1095// temp-ext
1096
1097struct llama_sampler_temp_ext {
1098 const float temp;
1099 const float delta;
1100 const float exponent;
1101};
1102
1103static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
1104 return "temp-ext";
1105}
1106
1107static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1108 auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1109 if (ctx->delta > 0) {
1110 const float min_temp = std::max(a: 0.0f, b: ctx->temp - ctx->delta);
1111 const float max_temp = ctx->temp + ctx->delta;
1112
1113 float exponent_val = ctx->exponent;
1114
1115 // no need to do anything if there is only one (or zero) candidates
1116 if (cur_p->size <= 1) {
1117 return;
1118 }
1119
1120 // Calculate maximum possible entropy
1121 float max_entropy = -logf(x: 1.0f / cur_p->size);
1122
1123 llama_sampler_softmax_impl(cur_p, do_sort: true);
1124
1125 // Calculate entropy of the softmax probabilities
1126 float entropy = 0.0f;
1127 for (size_t i = 0; i < cur_p->size; ++i) {
1128 float prob = cur_p->data[i].p;
1129 if (prob > 0.0f) { // Ensure no log(0)
1130 entropy -= prob * logf(x: prob);
1131 }
1132 }
1133
1134 // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1135 float normalized_entropy = entropy / max_entropy;
1136
1137 // Map the normalized entropy to the desired temperature range using the power function
1138 float dyn_temp = min_temp + (max_temp - min_temp) * powf(x: normalized_entropy, y: exponent_val);
1139
1140 #ifdef DEBUG
1141 LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1142 LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1143 LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1144 LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1145 LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1146 LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1147 #endif
1148
1149 // Apply the dynamically calculated temperature scaling
1150 llama_sampler_temp_impl(cur_p, temp: dyn_temp);
1151
1152 // Re-compute softmax probabilities after scaling logits with dynamic temperature
1153 const double max_l_double = cur_p->data[0].logit;
1154
1155 double cum_sum_double = 0.0;
1156 for (size_t i = 0; i < cur_p->size; ++i) {
1157 double p = exp(x: cur_p->data[i].logit - max_l_double);
1158 cur_p->data[i].p = p; // Store the scaled probability
1159 cum_sum_double += p;
1160 }
1161
1162 for (size_t i = 0; i < cur_p->size; ++i) {
1163 cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
1164 }
1165
1166 #ifdef DEBUG
1167 // Print the updated top 25 probabilities after temperature scaling
1168 LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
1169 for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
1170 LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
1171 }
1172 #endif
1173 } else {
1174 llama_sampler_temp_impl(cur_p, temp: ctx->temp);
1175 }
1176}
1177
1178static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
1179 const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
1180 return llama_sampler_init_temp_ext(t: ctx->temp, delta: ctx->delta, exponent: ctx->exponent);
1181}
1182
1183static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
1184 delete (llama_sampler_temp_ext *) smpl->ctx;
1185}
1186
1187static struct llama_sampler_i llama_sampler_temp_ext_i = {
1188 /* .name = */ llama_sampler_temp_ext_name,
1189 /* .accept = */ nullptr,
1190 /* .apply = */ llama_sampler_temp_ext_apply,
1191 /* .reset = */ nullptr,
1192 /* .clone = */ llama_sampler_temp_ext_clone,
1193 /* .free = */ llama_sampler_temp_ext_free,
1194};
1195
1196struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1197 return llama_sampler_init(
1198 /* .iface = */ &llama_sampler_temp_ext_i,
1199 /* .ctx = */ new llama_sampler_temp_ext {
1200 /* .temp = */ temp,
1201 /* .delta = */ delta,
1202 /* .exponent = */ exponent,
1203 }
1204 );
1205}
1206
1207// xtc
1208
1209struct llama_sampler_xtc {
1210 const float probability;
1211 const float threshold;
1212 const size_t min_keep;
1213
1214 const uint32_t seed;
1215 uint32_t seed_cur;
1216
1217 std::mt19937 rng;
1218};
1219
1220static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1221 return "xtc";
1222}
1223
1224static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1225 auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1226
1227 if (ctx->probability <= 0.0f
1228 || ctx->threshold > 0.5f
1229 || cur_p->size < 2) {
1230 return;
1231 }
1232
1233 std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1234 float chance = distribution(ctx->rng);
1235 if (chance > ctx->probability) {
1236 return;
1237 }
1238
1239 llama_sampler_softmax_impl(cur_p, do_sort: true);
1240
1241 int pos_last = 0;
1242
1243 for (size_t i = 0; i < cur_p->size; ++i) {
1244 if (cur_p->data[i].p >= ctx->threshold) {
1245 pos_last = i;
1246 } else {
1247 break;
1248 }
1249 }
1250
1251 if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1252 cur_p->data += pos_last;
1253 cur_p->size -= pos_last;
1254 }
1255}
1256
1257static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1258 const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1259 auto * result = llama_sampler_init_xtc(p: ctx->probability, t: ctx->threshold, min_keep: ctx->min_keep, seed: ctx->seed);
1260
1261 // copy the state
1262 {
1263 auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1264
1265 result_ctx->rng = ctx->rng;
1266 }
1267
1268 return result;
1269}
1270
1271static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1272 delete (llama_sampler_xtc *) smpl->ctx;
1273}
1274
1275static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1276 auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1277 ctx->seed_cur = get_rng_seed(seed: ctx->seed);
1278 ctx->rng.seed(sd: ctx->seed_cur);
1279}
1280
1281static struct llama_sampler_i llama_sampler_xtc_i = {
1282 /* .name = */ llama_sampler_xtc_name,
1283 /* .accept = */ nullptr,
1284 /* .apply = */ llama_sample_xtc_apply,
1285 /* .reset = */ llama_sampler_xtc_reset,
1286 /* .clone = */ llama_sampler_xtc_clone,
1287 /* .free = */ llama_sampler_xtc_free,
1288};
1289
1290struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1291 auto seed_cur = get_rng_seed(seed);
1292 return llama_sampler_init(
1293 /* .iface = */ &llama_sampler_xtc_i,
1294 /* .ctx = */ new llama_sampler_xtc {
1295 /* .probability = */ p,
1296 /* .threshold = */ t,
1297 /* .min_keep = */ min_keep,
1298 /* .seed = */ seed,
1299 /* .seed_cur = */ seed_cur,
1300 /* .rng = */ std::mt19937(seed_cur),
1301 }
1302 );
1303}
1304
1305// mirostat
1306
1307struct llama_sampler_mirostat {
1308 const int32_t n_vocab;
1309
1310 const uint32_t seed;
1311 uint32_t seed_cur;
1312
1313 const float tau;
1314 const float eta;
1315
1316 const int32_t m;
1317
1318 float mu;
1319
1320 std::mt19937 rng;
1321};
1322
1323static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
1324 return "mirostat";
1325}
1326
1327static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1328 auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1329
1330 llama_sampler_softmax_impl(cur_p, do_sort: true);
1331
1332 // Estimate s_hat using the most probable m tokens
1333 float s_hat = 0.0;
1334 float sum_ti_bi = 0.0;
1335 float sum_ti_sq = 0.0;
1336 for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
1337 float t_i = logf(x: float(i + 2) / float(i + 1));
1338 float b_i = logf(x: cur_p->data[i].p / cur_p->data[i + 1].p);
1339 sum_ti_bi += t_i * b_i;
1340 sum_ti_sq += t_i * t_i;
1341 }
1342 s_hat = sum_ti_bi / sum_ti_sq;
1343
1344 // Compute k from the estimated s_hat and target surprise value
1345 float epsilon_hat = s_hat - 1;
1346 float k = powf(x: (epsilon_hat * powf(x: 2, y: ctx->mu)) / (1 - powf(x: ctx->n_vocab, y: -epsilon_hat)), y: 1 / s_hat);
1347
1348 llama_sampler_top_k_impl(cur_p, k: std::max(a: int(k), b: 1));
1349
1350 llama_sampler_softmax_impl(cur_p, do_sort: true);
1351
1352 const int idx = llama_sample_dist(cur_p, rng&: ctx->rng);
1353
1354 cur_p->selected = idx;
1355
1356 float observed_surprise = -log2f(x: cur_p->data[idx].p);
1357 float e = observed_surprise - ctx->tau;
1358
1359 // Update mu using the learning rate and error
1360 ctx->mu = ctx->mu - ctx->eta * e;
1361}
1362
1363static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
1364 const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
1365 auto * result = llama_sampler_init_mirostat(n_vocab: ctx->n_vocab, seed: ctx->seed, tau: ctx->tau, eta: ctx->eta, m: ctx->m);
1366
1367 // copy the state
1368 {
1369 auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
1370
1371 result_ctx->mu = ctx->mu;
1372 result_ctx->rng = ctx->rng;
1373 }
1374
1375 return result;
1376}
1377
1378static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
1379 auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
1380 ctx->mu = 2.0f*ctx->tau;
1381 ctx->seed_cur = get_rng_seed(seed: ctx->seed);
1382 ctx->rng.seed(sd: ctx->seed_cur);
1383}
1384
1385static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
1386 delete (llama_sampler_mirostat *) smpl->ctx;
1387}
1388
1389static struct llama_sampler_i llama_sampler_mirostat_i = {
1390 /* .name = */ llama_sampler_mirostat_name,
1391 /* .accept = */ nullptr,
1392 /* .apply = */ llama_sampler_mirostat_apply,
1393 /* .reset = */ llama_sampler_mirostat_reset,
1394 /* .clone = */ llama_sampler_mirostat_clone,
1395 /* .free = */ llama_sampler_mirostat_free,
1396};
1397
1398struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1399 auto seed_cur = get_rng_seed(seed);
1400 return llama_sampler_init(
1401 /* .iface = */ &llama_sampler_mirostat_i,
1402 /* .ctx = */ new llama_sampler_mirostat {
1403 /* .n_vocab = */ n_vocab,
1404 /* .seed = */ seed,
1405 /* .seed_cur = */ seed_cur,
1406 /* .tau = */ tau,
1407 /* .eta = */ eta,
1408 /* .m = */ m,
1409 /* .mu = */ 2.0f*tau,
1410 /* .rng = */ std::mt19937(seed_cur),
1411 }
1412 );
1413}
1414
1415// mirostat v2
1416
1417struct llama_sampler_mirostat_v2 {
1418 const uint32_t seed;
1419 uint32_t seed_cur;
1420
1421 const float tau;
1422 const float eta;
1423
1424 float mu;
1425
1426 std::mt19937 rng;
1427};
1428
1429static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
1430 return "mirostat-v2";
1431}
1432
1433static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1434 auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1435
1436 llama_sampler_softmax_impl(cur_p, do_sort: true);
1437
1438 // Truncate the words with surprise values greater than mu
1439 cur_p->size = std::distance(first: cur_p->data, last: std::find_if(first: cur_p->data, last: cur_p->data + cur_p->size, pred: [&](const llama_token_data & candidate) {
1440 return -log2f(x: candidate.p) > ctx->mu;
1441 }));
1442
1443 if (cur_p->size == 0) {
1444 cur_p->size = 1;
1445 }
1446
1447 // Normalize the probabilities of the remaining words
1448 llama_sampler_softmax_impl(cur_p, do_sort: true);
1449
1450 const int idx = llama_sample_dist(cur_p, rng&: ctx->rng);
1451
1452 cur_p->selected = idx;
1453
1454 float observed_surprise = -log2f(x: cur_p->data[idx].p);
1455 float e = observed_surprise - ctx->tau;
1456
1457 // Update mu using the learning rate and error
1458 ctx->mu = ctx->mu - ctx->eta * e;
1459}
1460
1461static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
1462 auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
1463 ctx->mu = 2.0f*ctx->tau;
1464 ctx->seed_cur = get_rng_seed(seed: ctx->seed);
1465 ctx->rng.seed(sd: ctx->seed_cur);
1466}
1467
1468static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
1469 const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
1470
1471 auto * result = llama_sampler_init_mirostat_v2(seed: ctx->seed, tau: ctx->tau, eta: ctx->eta);
1472
1473 // copy the state
1474 {
1475 auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
1476
1477 result_ctx->mu = ctx->mu;
1478 result_ctx->rng = ctx->rng;
1479 }
1480
1481 return result;
1482}
1483
1484static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
1485 delete (llama_sampler_mirostat_v2 *) smpl->ctx;
1486}
1487
1488static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1489 /* .name = */ llama_sampler_mirostat_v2_name,
1490 /* .accept = */ nullptr,
1491 /* .apply = */ llama_sampler_mirostat_v2_apply,
1492 /* .reset = */ llama_sampler_mirostat_v2_reset,
1493 /* .clone = */ llama_sampler_mirostat_v2_clone,
1494 /* .free = */ llama_sampler_mirostat_v2_free,
1495};
1496
1497struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
1498 auto seed_cur = get_rng_seed(seed);
1499 return llama_sampler_init(
1500 /* .iface = */ &llama_sampler_mirostat_v2_i,
1501 /* .ctx = */ new llama_sampler_mirostat_v2 {
1502 /* .seed = */ seed,
1503 /* .seed_cur = */ seed_cur,
1504 /* .tau = */ tau,
1505 /* .eta = */ eta,
1506 /* .mu = */ 2.0f*tau,
1507 /* .rng = */ std::mt19937(seed_cur),
1508 }
1509 );
1510}
1511
1512// grammar
1513
1514struct llama_sampler_grammar {
1515 const struct llama_vocab * vocab;
1516
1517 std::string grammar_str;
1518 std::string grammar_root;
1519
1520 struct llama_grammar * grammar;
1521};
1522
1523static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
1524 return "grammar";
1525}
1526
1527static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
1528 auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1529 if (ctx->grammar) {
1530 llama_grammar_accept_impl(grammar&: *ctx->grammar, token);
1531 }
1532}
1533
1534static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1535 auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1536 if (ctx->grammar) {
1537 llama_grammar_apply_impl(grammar: *ctx->grammar, cur_p);
1538 }
1539}
1540
1541// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
1542static struct llama_sampler * llama_sampler_init_grammar_impl(
1543 const struct llama_vocab * vocab,
1544 const char * grammar_str,
1545 const char * grammar_root,
1546 bool lazy,
1547 const char ** trigger_words,
1548 size_t num_trigger_words,
1549 const llama_token * trigger_tokens,
1550 size_t num_trigger_tokens,
1551 const char ** trigger_patterns,
1552 size_t num_trigger_patterns);
1553
1554static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1555 auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1556 if (!ctx->grammar) {
1557 return;
1558 }
1559
1560 std::vector<const char *> trigger_patterns_c;
1561 trigger_patterns_c.reserve(n: ctx->grammar->trigger_patterns.size());
1562 for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
1563 trigger_patterns_c.push_back(x: trigger_pattern.pattern.c_str());
1564 }
1565
1566 auto * grammar_new = llama_grammar_init_impl(vocab: ctx->grammar->vocab, grammar_str: ctx->grammar_str.c_str(), grammar_root: ctx->grammar_root.c_str(),
1567 lazy: ctx->grammar->lazy, trigger_patterns: trigger_patterns_c.data(), num_trigger_patterns: trigger_patterns_c.size(),
1568 trigger_tokens: ctx->grammar->trigger_tokens.data(), num_trigger_tokens: ctx->grammar->trigger_tokens.size());
1569
1570 llama_grammar_free_impl(grammar: ctx->grammar);
1571 ctx->grammar = grammar_new;
1572}
1573
1574static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1575 const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1576
1577 auto * result = llama_sampler_init_grammar_impl(vocab: ctx->vocab, grammar_str: nullptr, grammar_root: nullptr, lazy: false, trigger_words: nullptr, num_trigger_words: 0, trigger_tokens: nullptr, num_trigger_tokens: 0, trigger_patterns: nullptr, num_trigger_patterns: 0);
1578 GGML_ASSERT(result);
1579
1580 // copy the state
1581 {
1582 auto * result_ctx = (llama_sampler_grammar *) result->ctx;
1583
1584 if (ctx->grammar) {
1585 result_ctx->grammar_str = ctx->grammar_str;
1586 result_ctx->grammar_root = ctx->grammar_root;
1587
1588 result_ctx->grammar = llama_grammar_clone_impl(grammar: *ctx->grammar);
1589 }
1590 }
1591
1592 return result;
1593}
1594
1595static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
1596 const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
1597
1598 if (ctx->grammar) {
1599 llama_grammar_free_impl(grammar: ctx->grammar);
1600 }
1601
1602 delete ctx;
1603}
1604
1605static struct llama_sampler_i llama_sampler_grammar_i = {
1606 /* .name = */ llama_sampler_grammar_name,
1607 /* .accept = */ llama_sampler_grammar_accept_impl,
1608 /* .apply = */ llama_sampler_grammar_apply,
1609 /* .reset = */ llama_sampler_grammar_reset,
1610 /* .clone = */ llama_sampler_grammar_clone,
1611 /* .free = */ llama_sampler_grammar_free,
1612};
1613
1614static struct llama_sampler * llama_sampler_init_grammar_impl(
1615 const struct llama_vocab * vocab,
1616 const char * grammar_str,
1617 const char * grammar_root,
1618 bool lazy,
1619 const char ** trigger_words,
1620 size_t num_trigger_words,
1621 const llama_token * trigger_tokens,
1622 size_t num_trigger_tokens,
1623 const char ** trigger_patterns,
1624 size_t num_trigger_patterns) {
1625 auto * ctx = new llama_sampler_grammar;
1626
1627 if (grammar_str != nullptr && grammar_str[0] != '\0') {
1628 // TODO: remove trigger_words support.
1629 if (trigger_words != nullptr && num_trigger_words > 0) {
1630 GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
1631 std::string trigger_pattern("[\\s\\S]*?(");
1632 for (size_t i = 0; i < num_trigger_words; ++i) {
1633 static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
1634 if (i > 0) {
1635 trigger_pattern += "|";
1636 }
1637 trigger_pattern += std::regex_replace(s: trigger_words[i], e: special_chars, fmt: "\\$0");
1638 }
1639 trigger_pattern += ")[\\s\\S]*";
1640 const auto * trigger_pattern_c = trigger_pattern.c_str();
1641 trigger_patterns = &trigger_pattern_c;
1642 num_trigger_patterns = 1;
1643 }
1644 *ctx = {
1645 /* .vocab = */ vocab,
1646 /* .grammar_str = */ grammar_str,
1647 /* .grammar_root = */ grammar_root,
1648 /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
1649 };
1650 if (!ctx->grammar) {
1651 delete ctx;
1652 return nullptr;
1653 }
1654 } else {
1655 *ctx = {
1656 /* .vocab = */ vocab,
1657 /* .grammar_str = */ {},
1658 /* .grammar_root = */ {},
1659 /* .grammar = */ nullptr,
1660 };
1661 }
1662
1663 return llama_sampler_init(
1664 /* .iface = */ &llama_sampler_grammar_i,
1665 /* .ctx = */ ctx
1666 );
1667}
1668
1669struct llama_sampler * llama_sampler_init_grammar(
1670 const struct llama_vocab * vocab,
1671 const char * grammar_str,
1672 const char * grammar_root) {
1673 return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, trigger_words: nullptr, num_trigger_words: 0, trigger_tokens: nullptr, num_trigger_tokens: 0, trigger_patterns: nullptr, num_trigger_patterns: 0);
1674}
1675
1676struct llama_sampler * llama_sampler_init_grammar_lazy(
1677 const struct llama_vocab * vocab,
1678 const char * grammar_str,
1679 const char * grammar_root,
1680 const char ** trigger_words,
1681 size_t num_trigger_words,
1682 const llama_token * trigger_tokens,
1683 size_t num_trigger_tokens) {
1684 return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, trigger_patterns: nullptr, num_trigger_patterns: 0);
1685}
1686
1687struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
1688 const struct llama_vocab * vocab,
1689 const char * grammar_str,
1690 const char * grammar_root,
1691 const char ** trigger_patterns,
1692 size_t num_trigger_patterns,
1693 const llama_token * trigger_tokens,
1694 size_t num_trigger_tokens) {
1695 return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words: nullptr, num_trigger_words: 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
1696}
1697
1698// penalties
1699
1700struct llama_sampler_penalties {
1701 const int32_t penalty_last_n;
1702 const float penalty_repeat;
1703 const float penalty_freq;
1704 const float penalty_present;
1705
1706 ring_buffer<llama_token> prev;
1707
1708 // a frequency map to count token occurrences
1709 std::unordered_map<llama_token, int> token_count;
1710};
1711
1712static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
1713 return "penalties";
1714}
1715
1716static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
1717 auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1718 if (ctx->penalty_last_n == 0) {
1719 return;
1720 }
1721
1722 ctx->token_count[token]++;
1723
1724 // if the ring buffer is full, remove the oldest token
1725 if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
1726 const auto old = ctx->prev.front();
1727
1728 ctx->token_count[old]--;
1729 if (ctx->token_count[old] == 0) {
1730 ctx->token_count.erase(x: old);
1731 }
1732 }
1733
1734 ctx->prev.push_back(value: token);
1735
1736#if 0
1737 // sanity check
1738 std::unordered_map<llama_token, int> tmp;
1739 for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
1740 tmp[ctx->prev.rat(i)]++;
1741 }
1742
1743 assert(ctx->token_count == tmp);
1744#endif
1745}
1746
1747static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1748 auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1749
1750 if ((ctx->penalty_last_n == 0) ||
1751 (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
1752 return;
1753 }
1754
1755 // Apply frequency and presence penalties to the cur_p
1756 for (size_t i = 0; i < cur_p->size; ++i) {
1757 const auto token_iter = ctx->token_count.find(x: cur_p->data[i].id);
1758 if (token_iter == ctx->token_count.end()) {
1759 continue;
1760 }
1761
1762 const int count = token_iter->second;
1763
1764 assert(count > 0 && count <= ctx->penalty_last_n);
1765
1766 // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
1767 // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
1768 if (cur_p->data[i].logit <= 0) {
1769 cur_p->data[i].logit *= ctx->penalty_repeat;
1770 } else {
1771 cur_p->data[i].logit /= ctx->penalty_repeat;
1772 }
1773
1774 cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
1775 }
1776
1777 cur_p->sorted = false;
1778}
1779
1780static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
1781 auto * ctx = (llama_sampler_penalties *) smpl->ctx;
1782 ctx->prev.clear();
1783 ctx->token_count.clear();
1784}
1785
1786static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
1787 const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
1788 auto * result = llama_sampler_init_penalties(
1789 penalty_last_n: ctx->penalty_last_n,
1790 penalty_repeat: ctx->penalty_repeat,
1791 penalty_freq: ctx->penalty_freq,
1792 penalty_present: ctx->penalty_present);
1793
1794 // copy the state
1795 {
1796 auto * result_ctx = (llama_sampler_penalties *) result->ctx;
1797
1798 result_ctx->prev = ctx->prev;
1799 }
1800
1801 return result;
1802}
1803
1804static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
1805 delete (llama_sampler_penalties *) smpl->ctx;
1806}
1807
1808static struct llama_sampler_i llama_sampler_penalties_i = {
1809 /* .name = */ llama_sampler_penalties_name,
1810 /* .accept = */ llama_sampler_penalties_accept,
1811 /* .apply = */ llama_sampler_penalties_apply,
1812 /* .reset = */ llama_sampler_penalties_reset,
1813 /* .clone = */ llama_sampler_penalties_clone,
1814 /* .free = */ llama_sampler_penalties_free,
1815};
1816
1817struct llama_sampler * llama_sampler_init_penalties(
1818 int32_t penalty_last_n,
1819 float penalty_repeat,
1820 float penalty_freq,
1821 float penalty_present) {
1822 penalty_last_n = std::max(a: penalty_last_n, b: 0);
1823
1824 return llama_sampler_init(
1825 /* .iface = */ &llama_sampler_penalties_i,
1826 /* .ctx = */ new llama_sampler_penalties {
1827 /* .penalty_last_n = */ penalty_last_n,
1828 /* .penalty_repeat = */ penalty_repeat,
1829 /* .penalty_freq = */ penalty_freq,
1830 /* .penalty_present = */ penalty_present,
1831 /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1832 /* .token_count = */ {},
1833 }
1834 );
1835}
1836
1837// top-n-sigma
1838
1839struct llama_sampler_top_n_sigma {
1840 const float n;
1841};
1842
1843static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
1844 return "top-n-sigma";
1845}
1846
1847static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1848 auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1849
1850 if (ctx->n <= 0.0f || cur_p->size <= 1) {
1851 return;
1852 }
1853
1854 // find max logit and calculate mean
1855 float max = cur_p->data[0].logit;
1856 float logits_sum = 0;
1857 size_t valid_count = 0;
1858 for (size_t i = 0; i < cur_p->size; ++i) {
1859 // Only count non-negative infinity values
1860 if (cur_p->data[i].logit != -INFINITY) {
1861 if (cur_p->data[i].logit > max) {
1862 max = cur_p->data[i].logit;
1863 }
1864 logits_sum += cur_p->data[i].logit;
1865 valid_count++;
1866 }
1867 }
1868 float mean = valid_count > 0 ? logits_sum/valid_count : 0;
1869
1870 // calculate standard deviation
1871 float acc = 0;
1872 for (size_t i = 0; i < cur_p->size; ++i) {
1873 // Skip -infinity in std calculation
1874 if (cur_p->data[i].logit != -INFINITY) {
1875 acc += pow(x: cur_p->data[i].logit - mean, y: 2);
1876 }
1877 }
1878 float std = valid_count > 0 ? sqrt(x: acc/valid_count) : 0;
1879
1880 // apply mask
1881 for (size_t i = 0; i < cur_p->size; ++i) {
1882 if (cur_p->data[i].logit < max - (ctx->n * std)) {
1883 cur_p->data[i].logit = -INFINITY;
1884 }
1885 }
1886
1887 llama_sampler_softmax_impl(cur_p, do_sort: true);
1888}
1889
1890static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
1891 const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
1892 return llama_sampler_init_top_n_sigma(n: ctx->n);
1893}
1894
1895static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1896 delete (llama_sampler_top_n_sigma *) smpl->ctx;
1897}
1898
1899static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1900 /* .name = */ llama_sampler_top_n_sigma_name,
1901 /* .accept = */ nullptr,
1902 /* .apply = */ llama_sampler_top_n_sigma_apply,
1903 /* .reset = */ nullptr,
1904 /* .clone = */ llama_sampler_top_n_sigma_clone,
1905 /* .free = */ llama_sampler_top_n_sigma_free,
1906};
1907
1908struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
1909 return llama_sampler_init(
1910 /* .iface = */ &llama_sampler_top_n_sigma_i,
1911 /* .ctx = */ new llama_sampler_top_n_sigma {
1912 /* .n = */ n,
1913 }
1914 );
1915}
1916
1917// DRY
1918
1919struct llama_sampler_dry {
1920 int32_t total_context_size;
1921
1922 const float dry_multiplier;
1923 const float dry_base;
1924 const int32_t dry_allowed_length;
1925 const int32_t dry_penalty_last_n;
1926
1927 std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
1928 std::vector<int> dry_repeat_count;
1929 std::unordered_map<llama_token, int> dry_max_token_repeat;
1930 ring_buffer<llama_token> last_tokens;
1931};
1932
1933// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1934static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1935 for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
1936 std::string word = vocab.detokenize(tokens: {token_id}, special: true);
1937 if (word.find(str: str) != std::string::npos) {
1938 token_sequences.emplace(args&: token_id, args: std::vector<llama_token>());
1939 } else {
1940 size_t word_len = word.size();
1941 size_t str_len = str.size();
1942 size_t pos = -1;
1943 while ((pos = word.find(c: str[0], pos: pos + 1)) != std::string::npos) {
1944 bool match = true;
1945 size_t i;
1946 for (i = 1; i < str_len && i + pos < word_len; ++i) {
1947 if (word[pos + i] != str[i]) {
1948 match = false;
1949 break;
1950 }
1951 }
1952 if (match) {
1953 std::vector<llama_token> tokenization = vocab.tokenize(raw_text: str.substr(pos: i), add_special: false, parse_special: false);
1954 if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1955 tokenization.resize(new_size: max_tail_len);
1956 }
1957
1958 // Ensure we don't already have a duplicate matching tokenization
1959 auto its = token_sequences.equal_range(x: token_id);
1960 bool found = false;
1961 for (auto it = its.first; it != its.second; ++it) {
1962 if (tokenization == it->second) {
1963 found = true;
1964 break;
1965 }
1966 }
1967 if (!found) {
1968 token_sequences.emplace(args&: token_id, args&: tokenization);
1969 }
1970 }
1971 }
1972 }
1973 }
1974}
1975
1976static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
1977 return "dry";
1978}
1979
1980static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
1981 auto * ctx = (llama_sampler_dry *) smpl->ctx;
1982 if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1983 return;
1984 }
1985
1986 ctx->last_tokens.push_back(value: token);
1987}
1988
1989// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1990static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1991 auto * ctx = (llama_sampler_dry *) smpl->ctx;
1992
1993 if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
1994 return;
1995 }
1996
1997 int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(a: ctx->dry_penalty_last_n, b: 0);
1998 int last_n_repeat = std::min(a: std::min(a: (int)ctx->last_tokens.size(), b: effective_dry_penalty_last_n), b: ctx->total_context_size);
1999
2000 if (last_n_repeat <= ctx->dry_allowed_length) {
2001 return;
2002 }
2003
2004 ctx->dry_repeat_count.assign(n: last_n_repeat, val: 0);
2005 ctx->dry_max_token_repeat.clear();
2006
2007 // Step 1: Look for restart sequences to limit the maximum repetition length.
2008 // Work backwards through the context looking for any token that begins a restart sequence.
2009 //
2010 // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
2011 // sequences that together comprise a restart sequence. This allows us to quickly check
2012 // whether each token is the head of a complete sequence. Most restart sequences are actually
2013 // a single token, and for these the "tail" is an empty vector.
2014 //
2015 // If the token is a "head", test all restart sequences that begin with this token
2016 // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
2017 // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
2018 // longest matching sequence (if any) is used to limit the maximum repetition length.
2019 //
2020 // Note that in the case case of a short sequence contained in a longer one, this might fail to
2021 // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
2022 // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
2023 // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
2024 //
2025 // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
2026 // have already clamped the maximum tail sequence length when generating `restart_sequences`.
2027 // With clamping, this scan is O(N) in the context length.
2028
2029 int rep_limit = last_n_repeat;
2030 for (int i = 0; i < last_n_repeat; ++i) {
2031 llama_token token = ctx->last_tokens.rat(i);
2032 auto its = ctx->dry_processed_breakers.equal_range(x: token);
2033 if (its.first == ctx->dry_processed_breakers.end()) {
2034 continue;
2035 }
2036 int longest_match = -1;
2037 for (auto it = its.first; it != its.second; ++it) {
2038 // Note that (*it) does not contain the head character, so seq_len will be
2039 // the restart sequence length minus 1.
2040 // In the common case of a single-token restart sequence, (*it) will be empty
2041 // and we will trivially match.
2042 int seq_len = (int)it->second.size();
2043 if (seq_len > longest_match && seq_len <= (int)i) {
2044 bool match = true;
2045 for (int offset = 0; offset < seq_len; ++offset) {
2046 // The -1 when indexing `last_tokens` is because we already matched the head.
2047 if (it->second[offset] != ctx->last_tokens.rat(i: i - offset - 1)) {
2048 match = false;
2049 break;
2050 }
2051 }
2052 if (match) {
2053 longest_match = seq_len;
2054 }
2055 }
2056 }
2057 if (longest_match >= 0) {
2058 // We found a restart sequence starting `i` tokens from the end and continuing for
2059 // `longest_match` tokens.
2060 rep_limit = i - longest_match;
2061 break;
2062 }
2063 }
2064 if (rep_limit < ctx->dry_allowed_length) {
2065 return;
2066 }
2067
2068 // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
2069 // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
2070 // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
2071 //
2072 // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
2073 // https://ivanyu.me/blog/2014/10/15/z-algorithm/
2074 //
2075 // The code below is adapted from the public domain implementation by the same author here:
2076 // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
2077 //
2078 // Example:
2079 // Last N tokens: a b c c b c y a b c
2080 // Repeat counts: 0 0 3 1 0 2 0 0 0 0
2081 // ^
2082 // This `3` means that the last three tokens of the context (a b c) also appear here.
2083 //
2084 // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
2085 // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
2086 // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
2087 // ensure that the inner while loops only examine each token in the context once as the outer
2088 // for loop iterates over the context.
2089
2090 {
2091 const int last = last_n_repeat - 1;
2092
2093 int rt = 0;
2094 int lt = 0;
2095
2096 for (int k = 1; k < last_n_repeat; ++k) {
2097 if (k > rt) {
2098 // If k is outside the current Z-box, do naive computation.
2099 int n = 0;
2100 while (n + k < last_n_repeat && ctx->last_tokens.rat(i: n) == ctx->last_tokens.rat(i: n+k)) {
2101 ++n;
2102 }
2103 ctx->dry_repeat_count[last - k] = std::min(a: n, b: rep_limit);
2104 if (n > 0) {
2105 lt = k;
2106 rt = k + n - 1;
2107 }
2108 } else {
2109 // If k is inside the current Z-box, consider two cases.
2110
2111 int p = k - lt; // Pair index.
2112 int right_part_len = rt - k + 1;
2113
2114 if (ctx->dry_repeat_count[last - p] < right_part_len) {
2115 int n = std::min(a: ctx->dry_repeat_count[last - p], b: rep_limit);
2116 ctx->dry_repeat_count[last - k] = n;
2117 } else {
2118 int i = rt + 1;
2119 while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i: i - k)) {
2120 i += 1;
2121 }
2122
2123 int n = std::min(a: i - k, b: rep_limit);
2124 ctx->dry_repeat_count[last - k] = n;
2125 lt = k;
2126 rt = i - 1;
2127 }
2128 }
2129 }
2130 }
2131
2132 // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
2133 // that would be generated by emitting each new token that would extend a sequence.
2134 //
2135 // Following the same example as above:
2136 // Last N tokens: a b c c b c y a b c
2137 // Repeat counts: 0 0 3 1 0 2 0 0 0 0
2138 //
2139 // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
2140 // c: 3 -> 4 (from `a b c` to `a b c c`)
2141 // b: 1 -> 2 (from `c` to `c b`)
2142 // y: 2 -> 3 (from `b c` to `b c y`)
2143
2144 for (int i = 0; i < last_n_repeat - 1; ++i) {
2145 int repeat_len = ctx->dry_repeat_count[i];
2146 if (repeat_len >= ctx->dry_allowed_length) {
2147 // This token ends a repeat, so the next token would continue one.
2148 // By convention, the value of `repeat_len` only includes the tokens currently
2149 // in the context, not the new token that would be added.
2150 llama_token token = ctx->last_tokens.rat(i: last_n_repeat - 2 - i);
2151 // Track the maximum sequence ending in this token.
2152 const auto& it = ctx->dry_max_token_repeat.find(x: token);
2153 if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
2154 ctx->dry_max_token_repeat[token] = repeat_len;
2155 }
2156 }
2157 }
2158
2159 // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
2160
2161 // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
2162 // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
2163 const float FLOAT_MAX_LOG = 88.7228391f;
2164 int max_exponent = 0;
2165 if (ctx->dry_base > 1.000001f) {
2166 max_exponent = FLOAT_MAX_LOG / std::log(x: ctx->dry_base);
2167 }
2168
2169 for (size_t i = 0; i < cur_p->size; ++i) {
2170 const auto& af_kvp = ctx->dry_max_token_repeat.find(x: cur_p->data[i].id);
2171 if (af_kvp != ctx->dry_max_token_repeat.end()) {
2172 // Check all sequence breakers starting with this token
2173 auto range = ctx->dry_processed_breakers.equal_range(x: cur_p->data[i].id);
2174 bool is_single_token_breaker = false;
2175
2176 for (auto it = range.first; it != range.second; ++it) {
2177 if (it->second.empty()) {
2178 is_single_token_breaker = true;
2179 break;
2180 }
2181 }
2182
2183 // Apply penalty only if it's not a single-token sequence breaker
2184 if (!is_single_token_breaker) {
2185 int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
2186 if (max_exponent > 0 && repeat_exp > max_exponent) {
2187 repeat_exp = max_exponent;
2188 }
2189 float penalty = ctx->dry_multiplier * std::pow(x: ctx->dry_base, y: repeat_exp);
2190 cur_p->data[i].logit -= penalty;
2191 }
2192 }
2193 }
2194
2195 cur_p->sorted = false;
2196}
2197
2198static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
2199 auto * ctx = (llama_sampler_dry *) smpl->ctx;
2200 ctx->last_tokens.clear();
2201 ctx->dry_repeat_count.clear();
2202 ctx->dry_max_token_repeat.clear();
2203}
2204
2205static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
2206 const auto * ctx = (llama_sampler_dry *) smpl->ctx;
2207
2208 llama_vocab dummy_vocab;
2209
2210 // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
2211 auto * result = llama_sampler_init_dry(vocab: &dummy_vocab, n_ctx_train: ctx->total_context_size, dry_multiplier: ctx->dry_multiplier, dry_base: ctx->dry_base, dry_allowed_length: ctx->dry_allowed_length, dry_penalty_last_n: ctx->dry_penalty_last_n, NULL, num_breakers: 0);
2212
2213 // Copy the state, including the processed breakers
2214 {
2215 auto * result_ctx = (llama_sampler_dry *) result->ctx;
2216 result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
2217 result_ctx->dry_repeat_count = ctx->dry_repeat_count;
2218 result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
2219 result_ctx->last_tokens = ctx->last_tokens;
2220 }
2221
2222 return result;
2223}
2224
2225static void llama_sampler_dry_free(struct llama_sampler * smpl) {
2226 delete (llama_sampler_dry *) smpl->ctx;
2227}
2228
2229static struct llama_sampler_i llama_sampler_dry_i = {
2230 /* .name = */ llama_sampler_dry_name,
2231 /* .accept = */ llama_sampler_dry_accept,
2232 /* .apply = */ llama_sampler_dry_apply,
2233 /* .reset = */ llama_sampler_dry_reset,
2234 /* .clone = */ llama_sampler_dry_clone,
2235 /* .free = */ llama_sampler_dry_free,
2236};
2237
2238struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
2239 int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(a: dry_penalty_last_n, b: 0);
2240 std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
2241 const int MAX_CHAR_LEN = 40;
2242 const int MAX_SEQ_LEN = 20;
2243
2244 const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
2245
2246 if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
2247 // Process sequence breakers
2248 for (size_t i = 0; i < num_breakers; ++i) {
2249 if (seq_breakers[i] == nullptr || std::strlen(s: seq_breakers[i]) == 0) {
2250 LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
2251 continue;
2252 }
2253
2254 std::string sequence_break(seq_breakers[i]);
2255 if (sequence_break.empty()) {
2256 LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
2257 continue;
2258 }
2259
2260 if (sequence_break.size() > MAX_CHAR_LEN) {
2261 LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
2262 sequence_break.resize(n: MAX_CHAR_LEN);
2263 }
2264
2265 get_overlapping_token_sequences(vocab: *vocab, str: sequence_break, token_sequences&: processed_breakers, max_tail_len: MAX_SEQ_LEN);
2266 }
2267 }
2268
2269 return llama_sampler_init(
2270 /* .iface = */ &llama_sampler_dry_i,
2271 /* .ctx = */ new llama_sampler_dry {
2272 /* .total_context_size = */ n_ctx_train,
2273 /* .dry_multiplier = */ dry_multiplier,
2274 /* .dry_base = */ dry_base,
2275 /* .dry_allowed_length = */ dry_allowed_length,
2276 /* .dry_penalty_last_n = */ dry_penalty_last_n,
2277 /* .dry_processed_breakers = */ std::move(processed_breakers),
2278 /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
2279 /* .dry_max_token_repeat = */ {},
2280 /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2281 }
2282 );
2283}
2284
2285// wrapper for test-sampling.cpp
2286struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
2287 llama_vocab dummy_vocab;
2288 auto * result = llama_sampler_init_dry(vocab: &dummy_vocab, n_ctx_train: context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, num_breakers: 0);
2289 auto * ctx = (llama_sampler_dry *) result->ctx;
2290
2291 // Process the token-based sequence breakers
2292 ctx->dry_processed_breakers.clear();
2293 if (seq_breakers.empty()) {
2294 LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
2295 } else {
2296 for (const auto& breaker : seq_breakers) {
2297 if (breaker.empty()) {
2298 LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
2299 continue;
2300 }
2301 llama_token head_token = breaker[0];
2302 std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
2303 ctx->dry_processed_breakers.emplace(args&: head_token, args: std::move(tail_tokens));
2304 }
2305
2306 if (ctx->dry_processed_breakers.empty()) {
2307 LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
2308 }
2309 }
2310
2311 return result;
2312}
2313
2314// logit-bias
2315
2316struct llama_sampler_logit_bias {
2317 const int32_t n_vocab;
2318
2319 const std::vector<llama_logit_bias> logit_bias;
2320
2321 std::vector<llama_logit_bias> to_search;
2322};
2323
2324static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
2325 return "logit-bias";
2326}
2327
2328static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2329 auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
2330
2331 if (ctx->logit_bias.empty()) {
2332 return;
2333 }
2334
2335 ctx->to_search.clear();
2336
2337 // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
2338 for (const auto & lb : ctx->logit_bias) {
2339 if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
2340 cur_p->data[lb.token].logit += lb.bias;
2341 } else {
2342 ctx->to_search.push_back(x: lb);
2343 }
2344 }
2345
2346 if (ctx->to_search.empty()) {
2347 return;
2348 }
2349
2350 // search for the remaining candidates that were not found in the previous step
2351 for (size_t i = 0; i < cur_p->size; ++i) {
2352 for (const auto & lb : ctx->to_search) {
2353 if (cur_p->data[i].id == lb.token) {
2354 cur_p->data[i].logit += lb.bias;
2355 break;
2356 }
2357 }
2358 }
2359}
2360
2361static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
2362 const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
2363 return llama_sampler_init_logit_bias(n_vocab: ctx->n_vocab, n_logit_bias: ctx->logit_bias.size(), logit_bias: ctx->logit_bias.data());
2364}
2365
2366static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
2367 delete (llama_sampler_logit_bias *) smpl->ctx;
2368}
2369
2370static struct llama_sampler_i llama_sampler_logit_bias_i = {
2371 /* .name = */ llama_sampler_logit_bias_name,
2372 /* .accept = */ nullptr,
2373 /* .apply = */ llama_sampler_logit_bias_apply,
2374 /* .reset = */ nullptr,
2375 /* .clone = */ llama_sampler_logit_bias_clone,
2376 /* .free = */ llama_sampler_logit_bias_free,
2377};
2378
2379struct llama_sampler * llama_sampler_init_logit_bias(
2380 int32_t n_vocab,
2381 int32_t n_logit_bias,
2382 const llama_logit_bias * logit_bias) {
2383 return llama_sampler_init(
2384 /* .iface = */ &llama_sampler_logit_bias_i,
2385 /* .ctx = */ new llama_sampler_logit_bias {
2386 /* .n_vocab = */ n_vocab,
2387 /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
2388 /* .to_search = */ {},
2389 }
2390 );
2391}
2392
2393// infill
2394
2395//#define GGML_DEBUG_SAMPLER_INFILL
2396
2397struct llama_sampler_infill {
2398 const struct llama_vocab * vocab;
2399
2400 std::vector<char> buf0;
2401 std::vector<char> buf1;
2402};
2403
2404static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
2405 return "infill";
2406}
2407
2408static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2409 auto * ctx = (llama_sampler_infill *) smpl->ctx;
2410
2411 llama_sampler_softmax_impl(cur_p, do_sort: true);
2412
2413#if defined(GGML_DEBUG_SAMPLER_INFILL)
2414#define LOG_DBG_CUR LLAMA_LOG_DEBUG
2415#else
2416#define LOG_DBG_CUR(...)
2417#endif
2418
2419 for (size_t i = 0; i < cur_p->size; ++i) {
2420 LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2421 }
2422
2423 float p_txt_sum = 0.0f;
2424 float p_eog_sum = 0.0f;
2425
2426 for (size_t i = 0; i < cur_p->size; ++i) {
2427 if (ctx->vocab->is_eog(id: cur_p->data[i].id)) {
2428 p_eog_sum += cur_p->data[i].p;
2429 } else {
2430 p_txt_sum += cur_p->data[i].p;
2431 }
2432 }
2433
2434 const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
2435
2436 LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
2437
2438 if (3*p_eog_sum*cur_p->size > p_txt_sum) {
2439 LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
2440
2441 // keep just the EOG tokens
2442 const auto size_org = cur_p->size;
2443
2444 cur_p->size = 0;
2445
2446 float p_sum = 0.0f;
2447
2448 for (size_t i = 0; i < size_org; ++i) {
2449 if (ctx->vocab->is_eog(id: cur_p->data[i].id)) {
2450 p_sum += cur_p->data[i].p;
2451
2452 cur_p->data[cur_p->size++] = cur_p->data[i];
2453 }
2454 }
2455
2456 // normalize probs
2457 for (size_t i = 0; i < cur_p->size; ++i) {
2458 cur_p->data[i].p /= p_sum;
2459 }
2460
2461 return;
2462 }
2463
2464 size_t n_combined = 0; GGML_UNUSED(n_combined);
2465
2466 // combine tokens with common prefix
2467 for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
2468 for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
2469 if (cur_p->data[i0].logit == -INFINITY) {
2470 break;
2471 }
2472
2473 if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
2474 continue;
2475 }
2476
2477 int len0 = ctx->vocab->token_to_piece(token: cur_p->data[i0].id, buf: ctx->buf0.data(), length: ctx->buf0.size(), lstrip: 0, special: false);
2478 if (len0 < 0) {
2479 ctx->buf0.resize(new_size: len0);
2480 len0 = ctx->vocab->token_to_piece(token: cur_p->data[i0].id, buf: ctx->buf0.data(), length: ctx->buf0.size(), lstrip: 0, special: false);
2481 assert(len0 > 0);
2482 }
2483
2484 int len1 = ctx->vocab->token_to_piece(token: cur_p->data[i1].id, buf: ctx->buf1.data(), length: ctx->buf1.size(), lstrip: 0, special: false);
2485 if (len1 < 0) {
2486 ctx->buf1.resize(new_size: len1);
2487 len1 = ctx->vocab->token_to_piece(token: cur_p->data[i1].id, buf: ctx->buf1.data(), length: ctx->buf1.size(), lstrip: 0, special: false);
2488 assert(len1 > 0);
2489 }
2490
2491 // token i0 is a prefix of token i1
2492 if (len0 > 0 && len0 <= len1 && memcmp(s1: ctx->buf0.data(), s2: ctx->buf1.data(), n: len0) == 0) {
2493 int dst = i0;
2494 int src = i1;
2495
2496 // merge into the token with higher probability
2497 if (cur_p->data[i1].p > cur_p->data[i0].p) {
2498 std::swap(a&: dst, b&: src);
2499 }
2500
2501 cur_p->data[dst].p += cur_p->data[src].p;
2502 cur_p->data[src].logit = -INFINITY;
2503 cur_p->data[src].p = 0.0f;
2504
2505 n_combined++;
2506 }
2507 }
2508 }
2509
2510 size_t n_non_eog = 0;
2511
2512 size_t size_org = cur_p->size;
2513
2514 float p_sum = 0.0f;
2515 float thold = 0.2f;
2516
2517 cur_p->size = 0;
2518
2519 LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2520
2521 for (size_t i = 0; i < size_org; ++i) {
2522 const bool is_eog = ctx->vocab->is_eog(id: cur_p->data[i].id);
2523
2524 if (cur_p->data[i].p < thold && !is_eog) {
2525 continue;
2526 }
2527
2528 if (!is_eog) {
2529 ++n_non_eog;
2530 }
2531
2532 p_sum += cur_p->data[i].p;
2533
2534 // keep this token
2535 cur_p->data[cur_p->size++] = cur_p->data[i];
2536 }
2537
2538 LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2539
2540 // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2541 if (n_non_eog == 0) {
2542 cur_p->size = 1;
2543 cur_p->data[0].id = ctx->vocab->token_eot();
2544 if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
2545 cur_p->data[0].id = ctx->vocab->token_eos();
2546 }
2547 cur_p->data[0].logit = 1.0f;
2548
2549 GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
2550
2551 return;
2552 }
2553
2554 // normalize probs
2555 for (size_t i = 0; i < cur_p->size; ++i) {
2556 cur_p->data[i].p /= p_sum;
2557
2558 LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2559 }
2560
2561 size_org = cur_p->size;
2562 p_sum = 0.0f;
2563 thold = 1.0/(n_non_eog + 1);
2564
2565 cur_p->size = 0;
2566
2567 LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2568
2569 for (size_t i = 0; i < size_org; ++i) {
2570 const bool is_eog = ctx->vocab->is_eog(id: cur_p->data[i].id);
2571
2572 if (cur_p->data[i].p < thold && !is_eog) {
2573 continue;
2574 }
2575
2576 p_sum += cur_p->data[i].p;
2577
2578 cur_p->data[cur_p->size++] = cur_p->data[i];
2579 }
2580
2581 // normalize probs
2582 for (size_t i = 0; i < cur_p->size; ++i) {
2583 cur_p->data[i].p /= p_sum;
2584
2585 LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2586 }
2587
2588#undef LOG_DBG_CUR
2589}
2590
2591static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2592 const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2593 return llama_sampler_init_infill(vocab: ctx->vocab);
2594}
2595
2596static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2597 delete (llama_sampler_infill *) smpl->ctx;
2598}
2599
2600static struct llama_sampler_i llama_sampler_infill_i = {
2601 /* .name = */ llama_sampler_infill_name,
2602 /* .accept = */ nullptr,
2603 /* .apply = */ llama_sampler_infill_apply,
2604 /* .reset = */ nullptr,
2605 /* .clone = */ llama_sampler_infill_clone,
2606 /* .free = */ llama_sampler_infill_free,
2607};
2608
2609struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2610 return llama_sampler_init(
2611 /* .iface = */ &llama_sampler_infill_i,
2612 /* .ctx = */ new llama_sampler_infill {
2613 /* .vocab = */ vocab,
2614 /* .buf0 = */ std::vector<char>(512),
2615 /* .buf1 = */ std::vector<char>(512),
2616 }
2617 );
2618}
2619
2620// utils
2621
2622uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
2623 if (smpl->iface == &llama_sampler_dist_i) {
2624 return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
2625 }
2626
2627 if (smpl->iface == &llama_sampler_mirostat_i) {
2628 return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
2629 }
2630
2631 if (smpl->iface == &llama_sampler_mirostat_v2_i) {
2632 return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
2633 }
2634
2635 if (smpl->iface == &llama_sampler_chain_i) {
2636 const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
2637 for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
2638 const uint32_t seed = llama_sampler_get_seed(smpl: *it);
2639 if (seed != LLAMA_DEFAULT_SEED) {
2640 return seed;
2641 }
2642 }
2643 }
2644
2645 return LLAMA_DEFAULT_SEED;
2646}
2647
2648// perf
2649
2650struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
2651 struct llama_perf_sampler_data data = {};
2652
2653 if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2654 GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2655 }
2656
2657 const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
2658
2659 data.t_sample_ms = 1e-3 * ctx->t_sample_us;
2660 data.n_sample = std::max(a: 0, b: ctx->n_sample);
2661
2662 return data;
2663}
2664
2665void llama_perf_sampler_print(const struct llama_sampler * chain) {
2666 const auto data = llama_perf_sampler(chain);
2667
2668 LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2669 __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
2670}
2671
2672void llama_perf_sampler_reset(struct llama_sampler * chain) {
2673 if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
2674 GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
2675 }
2676
2677 auto * ctx = (struct llama_sampler_chain *) chain->ctx;
2678
2679 ctx->t_sample_us = ctx->n_sample = 0;
2680}
2681