1#include "sampling.h"
2
3#include "common.h"
4#include "log.h"
5
6#include <cmath>
7#include <unordered_map>
8#include <algorithm>
9
10// the ring buffer works similarly to std::deque, but with a fixed capacity
11// TODO: deduplicate with llama-impl.h
12template<typename T>
13struct ring_buffer {
14 ring_buffer(size_t cap) : capacity(cap), data(cap) {}
15
16 T & front() {
17 if (sz == 0) {
18 throw std::runtime_error("ring buffer is empty");
19 }
20 return data[first];
21 }
22
23 const T & front() const {
24 if (sz == 0) {
25 throw std::runtime_error("ring buffer is empty");
26 }
27 return data[first];
28 }
29
30 T & back() {
31 if (sz == 0) {
32 throw std::runtime_error("ring buffer is empty");
33 }
34 return data[pos];
35 }
36
37 const T & back() const {
38 if (sz == 0) {
39 throw std::runtime_error("ring buffer is empty");
40 }
41 return data[pos];
42 }
43
44 void push_back(const T & value) {
45 if (sz == capacity) {
46 // advance the start when buffer is full
47 first = (first + 1) % capacity;
48 } else {
49 sz++;
50 }
51 data[pos] = value;
52 pos = (pos + 1) % capacity;
53 }
54
55 T pop_front() {
56 if (sz == 0) {
57 throw std::runtime_error("ring buffer is empty");
58 }
59 T value = data[first];
60 first = (first + 1) % capacity;
61 sz--;
62 return value;
63 }
64
65 const T & rat(size_t i) const {
66 if (i >= sz) {
67 throw std::runtime_error("ring buffer: index out of bounds");
68 }
69 return data[(first + sz - i - 1) % capacity];
70 }
71
72 std::vector<T> to_vector() const {
73 std::vector<T> result;
74 result.reserve(sz);
75 for (size_t i = 0; i < sz; i++) {
76 result.push_back(data[(first + i) % capacity]);
77 }
78 return result;
79 }
80
81 void clear() {
82 // here only reset the status of the buffer
83 sz = 0;
84 first = 0;
85 pos = 0;
86 }
87
88 bool empty() const {
89 return sz == 0;
90 }
91
92 size_t size() const {
93 return sz;
94 }
95
96 size_t capacity = 0;
97 size_t sz = 0;
98 size_t first = 0;
99 size_t pos = 0;
100 std::vector<T> data;
101};
102
103struct common_sampler {
104 common_params_sampling params;
105
106 struct llama_sampler * grmr;
107 struct llama_sampler * chain;
108
109 ring_buffer<llama_token> prev;
110
111 std::vector<llama_token_data> cur;
112
113 llama_token_data_array cur_p;
114
115 void set_logits(struct llama_context * ctx, int idx) {
116 const auto * logits = llama_get_logits_ith(ctx, i: idx);
117
118 const llama_model * model = llama_get_model(ctx);
119 const llama_vocab * vocab = llama_model_get_vocab(model);
120
121 const int n_vocab = llama_vocab_n_tokens(vocab);
122
123 cur.resize(new_size: n_vocab);
124
125 for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
126 cur[token_id] = llama_token_data{.id: token_id, .logit: logits[token_id], .p: 0.0f};
127 }
128
129 cur_p = { .data: cur.data(), .size: cur.size(), .selected: -1, .sorted: false };
130 }
131};
132
133std::string common_params_sampling::print() const {
134 char result[1024];
135
136 snprintf(s: result, maxlen: sizeof(result),
137 format: "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
138 "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
139 "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
140 "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
141 penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
142 dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
143 top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
144 mirostat, mirostat_eta, mirostat_tau);
145
146 return std::string(result);
147}
148
149struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
150 const llama_vocab * vocab = llama_model_get_vocab(model);
151
152 llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
153
154 lparams.no_perf = params.no_perf;
155
156 struct llama_sampler * grmr;
157 if (params.grammar.compare(pos: 0, n1: 11, s: "%llguidance") == 0) {
158#ifdef LLAMA_USE_LLGUIDANCE
159 grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
160#else
161 GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
162#endif // LLAMA_USE_LLGUIDANCE
163 } else {
164 std::vector<std::string> trigger_patterns;
165 std::vector<std::string> patterns_anywhere;
166 std::vector<llama_token> trigger_tokens;
167 for (const auto & trigger : params.grammar_triggers) {
168 switch (trigger.type) {
169 case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
170 {
171 const auto & word = trigger.value;
172 patterns_anywhere.push_back(x: regex_escape(s: word));
173 break;
174 }
175 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
176 {
177 patterns_anywhere.push_back(x: trigger.value);
178 break;
179 }
180 case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
181 {
182 trigger_patterns.push_back(x: trigger.value);
183 break;
184 }
185 case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
186 {
187 const auto token = trigger.token;
188 trigger_tokens.push_back(x: token);
189 break;
190 }
191 default:
192 GGML_ASSERT(false && "unknown trigger type");
193 }
194 }
195
196 if (!patterns_anywhere.empty()) {
197 trigger_patterns.push_back(x: "^[\\s\\S]*?(" + string_join(values: patterns_anywhere, separator: "|") + ")[\\s\\S]*");
198 }
199
200 std::vector<const char *> trigger_patterns_c;
201 trigger_patterns_c.reserve(n: trigger_patterns.size());
202 for (const auto & regex : trigger_patterns) {
203 trigger_patterns_c.push_back(x: regex.c_str());
204 }
205
206 grmr = params.grammar_lazy
207 ? llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str: params.grammar.c_str(), grammar_root: "root",
208 trigger_patterns: trigger_patterns_c.data(), num_trigger_patterns: trigger_patterns_c.size(),
209 trigger_tokens: trigger_tokens.data(), num_trigger_tokens: trigger_tokens.size())
210 : llama_sampler_init_grammar(vocab, grammar_str: params.grammar.c_str(), grammar_root: "root");
211 if (!grmr) {
212 return nullptr;
213 }
214 }
215
216 auto * result = new common_sampler {
217 /* .params = */ params,
218 /* .grmr = */ grmr,
219 /* .chain = */ llama_sampler_chain_init(params: lparams),
220 /* .prev = */ ring_buffer<llama_token>(std::max(a: 32, b: params.n_prev)),
221 /* .cur = */ {},
222 /* .cur_p = */ {},
223 };
224
225 llama_sampler_chain_add(chain: result->chain,
226 smpl: llama_sampler_init_logit_bias(
227 n_vocab: llama_vocab_n_tokens(vocab),
228 n_logit_bias: params.logit_bias.size(),
229 logit_bias: params.logit_bias.data()));
230
231 if (params.mirostat == 0) {
232 for (const auto & cnstr : params.samplers) {
233 switch (cnstr) {
234 case COMMON_SAMPLER_TYPE_DRY:
235 {
236 std::vector<const char *> c_breakers;
237 c_breakers.reserve(n: params.dry_sequence_breakers.size());
238 for (const auto & str : params.dry_sequence_breakers) {
239 c_breakers.push_back(x: str.c_str());
240 }
241
242 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_dry (vocab, n_ctx_train: llama_model_n_ctx_train(model), dry_multiplier: params.dry_multiplier, dry_base: params.dry_base, dry_allowed_length: params.dry_allowed_length, dry_penalty_last_n: params.dry_penalty_last_n, seq_breakers: c_breakers.data(), num_breakers: c_breakers.size()));
243 }
244 break;
245 case COMMON_SAMPLER_TYPE_TOP_K:
246 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_top_k (k: params.top_k));
247 break;
248 case COMMON_SAMPLER_TYPE_TOP_P:
249 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_top_p (p: params.top_p, min_keep: params.min_keep));
250 break;
251 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
252 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_top_n_sigma (n: params.top_n_sigma));
253 break;
254 case COMMON_SAMPLER_TYPE_MIN_P:
255 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_min_p (p: params.min_p, min_keep: params.min_keep));
256 break;
257 case COMMON_SAMPLER_TYPE_XTC:
258 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_xtc (p: params.xtc_probability, t: params.xtc_threshold, min_keep: params.min_keep, seed: params.seed));
259 break;
260 case COMMON_SAMPLER_TYPE_TYPICAL_P:
261 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_typical (p: params.typ_p, min_keep: params.min_keep));
262 break;
263 case COMMON_SAMPLER_TYPE_TEMPERATURE:
264 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_temp_ext (t: params.temp, delta: params.dynatemp_range, exponent: params.dynatemp_exponent));
265 break;
266 case COMMON_SAMPLER_TYPE_INFILL:
267 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_infill (vocab));
268 break;
269 case COMMON_SAMPLER_TYPE_PENALTIES:
270 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_penalties (penalty_last_n: params.penalty_last_n, penalty_repeat: params.penalty_repeat, penalty_freq: params.penalty_freq, penalty_present: params.penalty_present));
271 break;
272 default:
273 GGML_ASSERT(false && "unknown sampler type");
274 }
275 }
276 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_dist(seed: params.seed));
277 } else if (params.mirostat == 1) {
278 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_temp(t: params.temp));
279 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_mirostat(n_vocab: llama_vocab_n_tokens(vocab), seed: params.seed, tau: params.mirostat_tau, eta: params.mirostat_eta, m: 100));
280 } else if (params.mirostat == 2) {
281 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_temp(t: params.temp));
282 llama_sampler_chain_add(chain: result->chain, smpl: llama_sampler_init_mirostat_v2(seed: params.seed, tau: params.mirostat_tau, eta: params.mirostat_eta));
283 } else {
284 GGML_ASSERT(false && "unknown mirostat version");
285 }
286
287 return result;
288}
289
290void common_sampler_free(struct common_sampler * gsmpl) {
291 if (gsmpl) {
292 llama_sampler_free(smpl: gsmpl->grmr);
293
294 llama_sampler_free(smpl: gsmpl->chain);
295
296 delete gsmpl;
297 }
298}
299
300void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
301 if (accept_grammar) {
302 llama_sampler_accept(smpl: gsmpl->grmr, token);
303 }
304
305 llama_sampler_accept(smpl: gsmpl->chain, token);
306
307 gsmpl->prev.push_back(value: token);
308}
309
310void common_sampler_reset(struct common_sampler * gsmpl) {
311 llama_sampler_reset(smpl: gsmpl->grmr);
312
313 llama_sampler_reset(smpl: gsmpl->chain);
314}
315
316struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
317 return new common_sampler {
318 /* .params = */ gsmpl->params,
319 /* .grmr = */ llama_sampler_clone(smpl: gsmpl->grmr),
320 /* .chain = */ llama_sampler_clone(smpl: gsmpl->chain),
321 /* .prev = */ gsmpl->prev,
322 /* .cur = */ gsmpl->cur,
323 /* .cur_p = */ gsmpl->cur_p,
324 };
325}
326
327void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
328 // TODO: measure grammar performance
329
330 if (gsmpl) {
331 llama_perf_sampler_print(chain: gsmpl->chain);
332 }
333 if (ctx) {
334 llama_perf_context_print(ctx);
335 llama_memory_breakdown_print(ctx);
336 }
337}
338
339llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
340 gsmpl->set_logits(ctx, idx);
341
342 auto & grmr = gsmpl->grmr;
343 auto & chain = gsmpl->chain;
344 auto & cur_p = gsmpl->cur_p; // initialized by set_logits
345
346 if (grammar_first) {
347 llama_sampler_apply(smpl: grmr, cur_p: &cur_p);
348 }
349
350 llama_sampler_apply(smpl: chain, cur_p: &cur_p);
351
352 GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
353
354 const llama_token id = cur_p.data[cur_p.selected].id;
355
356 if (grammar_first) {
357 return id;
358 }
359
360 // check if it the sampled token fits the grammar
361 {
362 llama_token_data single_token_data = { .id: id, .logit: 1.0f, .p: 0.0f };
363 llama_token_data_array single_token_data_array = { .data: &single_token_data, .size: 1, .selected: -1, .sorted: false };
364
365 llama_sampler_apply(smpl: grmr, cur_p: &single_token_data_array);
366
367 const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
368 if (is_valid) {
369 return id;
370 }
371 }
372
373 // resampling:
374 // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
375 gsmpl->set_logits(ctx, idx);
376
377 llama_sampler_apply(smpl: grmr, cur_p: &cur_p);
378 llama_sampler_apply(smpl: chain, cur_p: &cur_p);
379
380 GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
381
382 return cur_p.data[cur_p.selected].id;
383}
384
385std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
386 GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
387
388 std::vector<llama_token> result;
389 result.reserve(n: idxs.size());
390
391 size_t i = 0;
392 for (; i < draft.size(); i++) {
393 const llama_token id = common_sampler_sample(gsmpl, ctx, idx: idxs[i], grammar_first);
394
395 common_sampler_accept(gsmpl, token: id, accept_grammar: true);
396
397 result.push_back(x: id);
398
399 if (draft[i] != id) {
400 break;
401 }
402 }
403
404 if (i == draft.size()) {
405 const llama_token id = common_sampler_sample(gsmpl, ctx, idx: idxs[i], grammar_first);
406
407 common_sampler_accept(gsmpl, token: id, accept_grammar: true);
408
409 result.push_back(x: id);
410 }
411
412 return result;
413}
414
415std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
416 std::vector<int> idxs(draft.size() + 1);
417 for (size_t i = 0; i < idxs.size(); ++i) {
418 idxs[i] = i;
419 }
420
421 return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
422}
423
424uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
425 return llama_sampler_get_seed(smpl: gsmpl->chain);
426}
427
428// helpers
429
430llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
431 auto * res = &gsmpl->cur_p;
432
433 if (do_sort && !res->sorted) {
434 // remember the selected token before sorting
435 const llama_token id = res->data[res->selected].id;
436
437 std::sort(first: res->data, last: res->data + res->size, comp: [](const llama_token_data & a, const llama_token_data & b) {
438 return a.p > b.p;
439 });
440
441 // restore the selected token after sorting
442 for (size_t i = 0; i < res->size; ++i) {
443 if (res->data[i].id == id) {
444 res->selected = i;
445 break;
446 }
447 }
448
449 res->sorted = true;
450 }
451
452 return res;
453}
454
455llama_token common_sampler_last(const struct common_sampler * gsmpl) {
456 return gsmpl->prev.rat(i: 0);
457}
458
459std::string common_sampler_print(const struct common_sampler * gsmpl) {
460 std::string result = "logits ";
461
462 for (int i = 0; i < llama_sampler_chain_n(chain: gsmpl->chain); i++) {
463 const auto * smpl = llama_sampler_chain_get(chain: gsmpl->chain, i);
464 result += std::string("-> ") + llama_sampler_name(smpl) + " ";
465 }
466
467 return result;
468}
469
470std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
471 n = std::min(a: n, b: (int) gsmpl->prev.size());
472
473 if (n <= 0) {
474 return "";
475 }
476
477 std::string result;
478 result.reserve(res_arg: 8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
479
480 for (int i = n - 1; i >= 0; i--) {
481 const llama_token id = gsmpl->prev.rat(i);
482
483 GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
484
485 result += common_token_to_piece(ctx: ctx_main, token: id);
486 }
487
488 return result;
489}
490
491char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
492 switch (cnstr) {
493 case COMMON_SAMPLER_TYPE_DRY: return 'd';
494 case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
495 case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
496 case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
497 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
498 case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
499 case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
500 case COMMON_SAMPLER_TYPE_XTC: return 'x';
501 case COMMON_SAMPLER_TYPE_INFILL: return 'i';
502 case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
503 default : return '?';
504 }
505}
506
507std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
508 switch (cnstr) {
509 case COMMON_SAMPLER_TYPE_DRY: return "dry";
510 case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
511 case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
512 case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
513 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
514 case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
515 case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
516 case COMMON_SAMPLER_TYPE_XTC: return "xtc";
517 case COMMON_SAMPLER_TYPE_INFILL: return "infill";
518 case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
519 default : return "";
520 }
521}
522
523std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
524 std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
525 { "dry", COMMON_SAMPLER_TYPE_DRY },
526 { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
527 { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
528 { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
529 { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
530 { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
531 { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
532 { "xtc", COMMON_SAMPLER_TYPE_XTC },
533 { "infill", COMMON_SAMPLER_TYPE_INFILL },
534 { "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
535 };
536
537 // since samplers names are written multiple ways
538 // make it ready for both system names and input names
539 std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
540 { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
541 { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
542 { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
543 { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
544 { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
545 { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
546 { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
547 { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
548 { "min-p", COMMON_SAMPLER_TYPE_MIN_P },
549 { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
550 };
551
552 std::vector<common_sampler_type> samplers;
553 samplers.reserve(n: names.size());
554
555 for (const auto & name : names) {
556 auto sampler = sampler_canonical_name_map.find(x: name);
557 if (sampler != sampler_canonical_name_map.end()) {
558 samplers.push_back(x: sampler->second);
559 continue;
560 }
561 if (allow_alt_names) {
562 sampler = sampler_alt_name_map.find(x: name);
563 if (sampler != sampler_alt_name_map.end()) {
564 samplers.push_back(x: sampler->second);
565 continue;
566 }
567 }
568 LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
569 }
570
571 return samplers;
572}
573
574std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
575 std::unordered_map<char, common_sampler_type> sampler_name_map = {
576 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
577 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
578 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
579 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
580 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
581 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
582 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
583 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
584 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
585 { common_sampler_type_to_chr(cnstr: COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
586 };
587
588 std::vector<common_sampler_type> samplers;
589 samplers.reserve(n: chars.size());
590
591 for (const auto & c : chars) {
592 const auto sampler = sampler_name_map.find(x: c);
593 if (sampler != sampler_name_map.end()) {
594 samplers.push_back(x: sampler->second);
595 } else {
596 LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
597 }
598 }
599
600 return samplers;
601}
602