1#define _USE_MATH_DEFINES // For M_PI on MSVC
2
3#include "arg.h"
4#include "common.h"
5#include "sampling.h"
6#include "log.h"
7#include "llama.h"
8
9#define JSON_ASSERT GGML_ASSERT
10#include <nlohmann/json.hpp>
11
12#include <algorithm>
13#include <cmath>
14#include <cstdio>
15#include <fstream>
16#include <map>
17#include <regex>
18#include <string>
19#include <thread>
20#include <vector>
21
22using json = nlohmann::ordered_json;
23
24enum outetts_version {
25 OUTETTS_V0_2,
26 OUTETTS_V0_3,
27};
28
29//
30// Terminal utils
31//
32
33#define SQR(X) ((X) * (X))
34#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40
35
36/**
37 * Quantizes 24-bit RGB to xterm256 code range [16,256).
38 */
39static int rgb2xterm256(int r, int g, int b) {
40 unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377};
41 int av, ir, ig, ib, il, qr, qg, qb, ql;
42 av = r * .299 + g * .587 + b * .114 + .5;
43 ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8;
44 qr = cube[(ir = UNCUBE(r))];
45 qg = cube[(ig = UNCUBE(g))];
46 qb = cube[(ib = UNCUBE(b))];
47 if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <=
48 SQR(ql - r) + SQR(ql - g) + SQR(ql - b))
49 return ir * 36 + ig * 6 + ib + 020;
50 return il + 0350;
51}
52
53static std::string set_xterm256_foreground(int r, int g, int b) {
54 int x = rgb2xterm256(r, g, b);
55 std::ostringstream oss;
56 oss << "\033[38;5;" << x << "m";
57 return oss.str();
58}
59
60const std::vector<std::string> k_colors = {
61 set_xterm256_foreground(r: 220, g: 5, b: 12),
62 set_xterm256_foreground(r: 232, g: 96, b: 28),
63 set_xterm256_foreground(r: 241, g: 147, b: 45),
64 set_xterm256_foreground(r: 246, g: 193, b: 65),
65 set_xterm256_foreground(r: 247, g: 240, b: 86),
66 set_xterm256_foreground(r: 144, g: 201, b: 135),
67 set_xterm256_foreground( r: 78, g: 178, b: 101),
68};
69
70static void print_usage(int, char ** argv) {
71 LOG("\nexample usage:\n");
72 LOG("\n %s -m model.gguf -p \"Hello!\"\n", argv[0]);
73 LOG("\n");
74}
75
76struct wav_header {
77 char riff[4] = {'R', 'I', 'F', 'F'};
78 uint32_t chunk_size;
79 char wave[4] = {'W', 'A', 'V', 'E'};
80 char fmt[4] = {'f', 'm', 't', ' '};
81 uint32_t fmt_chunk_size = 16;
82 uint16_t audio_format = 1; // PCM
83 uint16_t num_channels = 1; // Mono
84 uint32_t sample_rate;
85 uint32_t byte_rate;
86 uint16_t block_align;
87 uint16_t bits_per_sample = 16;
88 char data[4] = {'d', 'a', 't', 'a'};
89 uint32_t data_size;
90};
91
92static bool save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) {
93 std::ofstream file(fname, std::ios::binary);
94 if (!file) {
95 LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str());
96 return false;
97 }
98
99 wav_header header;
100 header.sample_rate = sample_rate;
101 header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8);
102 header.block_align = header.num_channels * (header.bits_per_sample / 8);
103 header.data_size = data.size() * (header.bits_per_sample / 8);
104 header.chunk_size = 36 + header.data_size;
105
106 file.write(s: reinterpret_cast<const char*>(&header), n: sizeof(header));
107
108 for (const auto & sample : data) {
109 int16_t pcm_sample = static_cast<int16_t>(std::clamp(val: sample * 32767.0, lo: -32768.0, hi: 32767.0));
110 file.write(s: reinterpret_cast<const char*>(&pcm_sample), n: sizeof(pcm_sample));
111 }
112
113 return file.good();
114}
115
116static void fill_hann_window(int length, bool periodic, float * output) {
117 int offset = -1;
118 if (periodic) {
119 offset = 0;
120 }
121 for (int i = 0; i < length; i++) {
122 output[i] = 0.5 * (1.0 - cosf(x: (2.0 * M_PI * i) / (length + offset)));
123 }
124}
125
126// very poor-man fft
127static void twiddle(float * real, float * imag, int k, int N) {
128 float angle = 2 * M_PI * k / N;
129 *real = cos(x: angle);
130 *imag = sin(x: angle);
131}
132
133static void irfft(int n, const float * inp_cplx, float * out_real) {
134 int N = n / 2 + 1;
135
136 std::vector<float> real_input(N);
137 std::vector<float> imag_input(N);
138 for (int i = 0; i < N; ++i) {
139 real_input[i] = inp_cplx[2 * i];
140 imag_input[i] = inp_cplx[2 * i + 1];
141 }
142
143 std::vector<float> real_output(n);
144 std::vector<float> imag_output(n);
145
146 for (int k = 0; k < n; ++k) {
147 real_output[k] = 0.0f;
148 imag_output[k] = 0.0f;
149 for (int m = 0; m < N; ++m) {
150 float twiddle_real;
151 float twiddle_imag;
152
153 twiddle(real: &twiddle_real, imag: &twiddle_imag, k: k * m, N: n);
154
155 real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag;
156 imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real;
157 }
158 }
159
160 for (int i = 0; i < n; ++i) {
161 out_real[i] = real_output[i] / N;
162 }
163}
164
165//
166// y = torch.nn.functional.fold(
167// data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
168// )[:, 0, 0, pad:-pad]
169//
170// data.shape = torch.Size([1, 1280, 261])
171// output_size = 84480
172// win_length = 1280
173// hop_length = 320
174// pad = 480
175//
176static void fold(const std::vector<float> & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float> & output) {
177 int64_t output_height = n_out;
178 int64_t kernel_w = n_win;
179 int64_t stride_w = n_hop;
180 int64_t width = n_out;
181
182 output.resize(new_size: width, x: 0.0f);
183
184 int64_t col_idx = 0;
185 for (int64_t w_col = 0; w_col < width; ++w_col) {
186 int64_t start = w_col * stride_w - n_pad;
187 int64_t end = start + kernel_w;
188
189 for (int64_t w_im = start; w_im < end; ++w_im) {
190 if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) {
191 output[w_im] += data[col_idx];
192 }
193 col_idx++;
194 }
195 }
196
197 output.resize(new_size: n_out - 2 * n_pad);
198}
199
200// TODO: not optimized at all
201static std::vector<float> embd_to_audio(
202 const float * embd,
203 const int n_codes,
204 const int n_embd,
205 const int n_thread) {
206 const int n_fft = 1280;
207 const int n_hop = 320;
208 const int n_win = 1280;
209 const int n_pad = (n_win - n_hop)/2;
210 const int n_out = (n_codes - 1)*n_hop + n_win;
211
212 std::vector<float> hann(n_fft);
213
214 fill_hann_window(length: hann.size(), periodic: true, output: hann.data());
215
216 int n_spec = n_embd*n_codes;
217
218 std::vector<float> E (n_spec);
219 std::vector<float> S (n_spec);
220 std::vector<float> ST(n_spec);
221
222 for (int l = 0; l < n_codes; ++l) {
223 for (int k = 0; k < n_embd; ++k) {
224 E[k*n_codes + l] = embd[l*n_embd + k];
225 }
226 }
227
228 for (int k = 0; k < n_embd/2; ++k) {
229 for (int l = 0; l < n_codes; ++l) {
230 float mag = E[(k )*n_codes + l];
231 float phi = E[(k + n_embd/2)*n_codes + l];
232
233 mag = exp(x: mag);
234
235 if (mag > 1e2) {
236 mag = 1e2;
237 }
238 S[2*(k*n_codes + l) + 0] = mag*cosf(x: phi);
239 S[2*(k*n_codes + l) + 1] = mag*sinf(x: phi);
240 }
241 }
242
243 for (int l = 0; l < n_codes; ++l) {
244 for (int k = 0; k < n_embd/2; ++k) {
245 ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0];
246 ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1];
247 }
248 }
249
250 std::vector<float> res (n_codes*n_fft);
251 std::vector<float> hann2(n_codes*n_fft);
252
253 std::vector<std::thread> workers(n_thread);
254 for (int i = 0; i < n_thread; ++i) {
255 workers[i] = std::thread([&, i]() {
256 for (int l = i; l < n_codes; l += n_thread) {
257 irfft(n: n_fft, inp_cplx: ST.data() + l*n_embd, out_real: res.data() + l*n_fft);
258 for (int j = 0; j < n_fft; ++j) {
259 res [l*n_fft + j] *= hann[j];
260 hann2[l*n_fft + j] = hann[j] * hann[j];
261 }
262 }
263 });
264 }
265 for (int i = 0; i < n_thread; ++i) {
266 workers[i].join();
267 }
268
269 std::vector<float> audio;
270 std::vector<float> env;
271
272 fold(data: res, n_out, n_win, n_hop, n_pad, output&: audio);
273 fold(data: hann2, n_out, n_win, n_hop, n_pad, output&: env); // TODO: can be done once
274
275 for (size_t i = 0; i < audio.size(); ++i) {
276 audio[i] /= env[i];
277 }
278
279 return audio;
280}
281
282static const std::map<int, std::string> ones = {
283 {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"},
284 {5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"},
285 {10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"},
286 {15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}
287};
288
289static const std::map<int, std::string> tens = {
290 {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"},
291 {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}
292};
293
294// Convert a number less than 1000 to words
295static std::string convert_less_than_thousand(int num) {
296 std::string result;
297
298 if (num >= 100) {
299 result += ones.at(k: num / 100) + " hundred ";
300 num %= 100;
301 }
302
303 if (num >= 20) {
304 result += tens.at(k: num / 10);
305 if (num % 10 > 0) {
306 result += "-" + ones.at(k: num % 10);
307 }
308 } else if (num > 0) {
309 result += ones.at(k: num);
310 }
311
312 return result;
313}
314
315static std::string number_to_words(const std::string & number_str) {
316 try {
317 size_t decimal_pos = number_str.find(c: '.');
318 std::string integer_part = number_str.substr(pos: 0, n: decimal_pos);
319
320 int int_number = std::stoi(str: integer_part);
321 std::string result;
322
323 if (int_number == 0) {
324 result = "zero";
325 } else {
326 if (int_number >= 1000000000) {
327 int billions = int_number / 1000000000;
328 result += convert_less_than_thousand(num: billions) + " billion ";
329 int_number %= 1000000000;
330 }
331
332 if (int_number >= 1000000) {
333 int millions = int_number / 1000000;
334 result += convert_less_than_thousand(num: millions) + " million ";
335 int_number %= 1000000;
336 }
337
338 if (int_number >= 1000) {
339 int thousands = int_number / 1000;
340 result += convert_less_than_thousand(num: thousands) + " thousand ";
341 int_number %= 1000;
342 }
343
344 if (int_number > 0) {
345 result += convert_less_than_thousand(num: int_number);
346 }
347 }
348
349 // Handle decimal part
350 if (decimal_pos != std::string::npos) {
351 result += " point";
352 std::string decimal_part = number_str.substr(pos: decimal_pos + 1);
353 for (char digit : decimal_part) {
354 result += " " + ones.at(k: digit - '0');
355 }
356 }
357
358 return result;
359 } catch (const std::exception& e) {
360 // Skip if fails
361 return " ";
362 }
363}
364
365static std::string replace_numbers_with_words(const std::string & input_text) {
366 std::regex number_pattern(R"(\d+(\.\d+)?)");
367 std::string result;
368 auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern);
369 auto end = std::sregex_iterator();
370
371 size_t last_pos = 0;
372 for (std::sregex_iterator i = it; i != end; ++i) {
373 const std::smatch& match = *i;
374 result.append(str: input_text, pos: last_pos, n: match.position() - last_pos);
375 result.append(str: number_to_words(number_str: match.str()));
376 last_pos = match.position() + match.length();
377 }
378 result.append(str: input_text, pos: last_pos);
379
380 return result;
381}
382
383// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
384static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
385
386 // For now I skipped text romanization as I am unsure how to handle
387 // uroman and MeCab implementations in C++
388 // maybe something like https://github.com/anyascii/anyascii/ could work.
389 // currently only English would be supported in this function
390
391 std::string processed_text = replace_numbers_with_words(input_text: text);
392
393 std::transform(first: processed_text.begin(), last: processed_text.end(),
394 result: processed_text.begin(), unary_op: ::tolower);
395
396 std::regex special_chars(R"([-_/,\.\\])");
397 processed_text = std::regex_replace(s: processed_text, e: special_chars, fmt: " ");
398
399 std::regex non_alpha(R"([^a-z\s])");
400 processed_text = std::regex_replace(s: processed_text, e: non_alpha, fmt: "");
401
402 std::regex multiple_spaces(R"(\s+)");
403 processed_text = std::regex_replace(s: processed_text, e: multiple_spaces, fmt: " ");
404
405 processed_text = std::regex_replace(s: processed_text, e: std::regex(R"(^\s+|\s+$)"), fmt: "");
406
407 /*
408 Replace spaces with the separator token same as in line 365
409
410 for (auto & c : prompt_user) {
411 if (c == ' ') {
412 prompt_clean += "<|text_sep|>";
413 */
414 std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
415 processed_text = std::regex_replace(s: processed_text, e: std::regex(R"(\s)"), fmt: separator);
416
417 return processed_text;
418}
419
420static void prompt_add(llama_tokens & prompt, llama_token token) {
421 prompt.push_back(x: token);
422}
423
424static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) {
425 prompt.insert(position: prompt.end(), first: tokens.begin(), last: tokens.end());
426}
427
428static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) {
429 auto tmp = common_tokenize(vocab, text: txt, add_special, parse_special);
430 prompt_add(prompt, tokens: tmp);
431}
432
433static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
434 prompt.clear();
435
436 prompt_add(prompt, vocab, txt: "<|im_start|>\n", add_special: true, parse_special: true);
437}
438
439static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
440 const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
441
442 std::vector<llama_token> result;
443 size_t start = 0;
444 size_t end = str.find(str: delimiter);
445
446 //first token is always a newline, as it was not previously added
447 result.push_back(x: common_tokenize(vocab, text: "\n", add_special: false, parse_special: true)[0]);
448
449 while (end != std::string::npos) {
450 std::string current_word = str.substr(pos: start, n: end - start);
451 auto tmp = common_tokenize(vocab, text: current_word, add_special: false, parse_special: true);
452 result.push_back(x: tmp[0]);
453 start = end + delimiter.length();
454 end = str.find(str: delimiter, pos: start);
455 }
456
457 // Add the last part
458 std::string current_word = str.substr(pos: start);
459 auto tmp = common_tokenize(vocab, text: current_word, add_special: false, parse_special: true);
460 if (tmp.size() > 0) {
461 result.push_back(x: tmp[0]);
462 }
463 return result;
464}
465
466static json speaker_from_file(const std::string & speaker_file) {
467 std::ifstream file(speaker_file);
468 if (!file) {
469 LOG_ERR("%s: Failed to open file '%s' for reading\n", __func__, speaker_file.c_str());
470 return json();
471 }
472
473 json speaker = json::parse(i&: file);
474 return speaker;
475}
476
477static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) {
478 if (speaker.contains(key: "version")) {
479 std::string version = speaker["version"].get<std::string>();
480 if (version == "0.2") {
481 return OUTETTS_V0_2;
482 } else if (version == "0.3") {
483 return OUTETTS_V0_3;
484 } else {
485 LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str());
486 }
487 }
488
489 // Also could get version from model itself
490 const char *chat_template = llama_model_chat_template(model, name: nullptr);
491 if (chat_template && std::string(chat_template) == "outetts-0.3") {
492 return OUTETTS_V0_3;
493 }
494
495 // Use 0.2 as the default version
496 return OUTETTS_V0_2;
497}
498
499static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
500 std::string audio_text = "<|text_start|>";
501
502 if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
503 std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
504 for (const auto &word : speaker["words"]) {
505 audio_text += word["word"].get<std::string>() + separator;
506 }
507 }
508
509 return audio_text;
510}
511
512static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
513 std::string audio_data = "<|audio_start|>\n";
514
515 if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
516 std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>";
517 std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>";
518 for (const auto &word : speaker["words"]) {
519 std::string word_text = word["word"].get<std::string>();
520 double duration = word["duration"].get<double>();
521 std::vector<int> codes = word["codes"].get<std::vector<int>>();
522
523 // Create the audio output entry
524 std::ostringstream word_entry;
525 word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
526 << duration << "|>" + code_start;
527 for (const auto &Code : codes) {
528 word_entry << "<|" << Code << "|>";
529 }
530 word_entry << code_end << "\n";
531 audio_data += word_entry.str();
532 }
533 }
534
535 return audio_data;
536}
537
538int main(int argc, char ** argv) {
539 common_params params;
540
541 params.out_file = "output.wav";
542 params.prompt = "";
543
544 params.n_predict = 4096;
545 params.n_batch = 8192;
546 params.n_ctx = 8192;
547
548 params.sampling.top_k = 4;
549 params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
550
551 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_TTS, print_usage)) {
552 return 1;
553 }
554
555 const int n_parallel = params.n_parallel;
556 const int n_predict = params.n_predict;
557
558 common_init();
559
560 // init LLM
561
562 llama_backend_init();
563 llama_numa_init(numa: params.numa);
564
565 llama_model * model_ttc = NULL; // text-to-codes
566 llama_model * model_cts = NULL; // codes-to-speech
567
568 llama_context * ctx_ttc = NULL;
569 llama_context * ctx_cts = NULL;
570
571 common_init_result llama_init_ttc = common_init_from_params(params);
572
573 model_ttc = llama_init_ttc.model.get();
574 ctx_ttc = llama_init_ttc.context.get();
575
576 if (model_ttc == nullptr || ctx_ttc == nullptr) {
577 return ENOENT;
578 }
579
580 const llama_vocab * vocab = llama_model_get_vocab(model: model_ttc);
581
582 params.model = params.vocoder.model;
583 params.embedding = true;
584 params.n_ubatch = params.n_batch;
585
586 common_init_result llama_init_cts = common_init_from_params(params);
587
588 model_cts = llama_init_cts.model.get();
589 ctx_cts = llama_init_cts.context.get();
590
591 if (model_cts == nullptr || ctx_cts == nullptr) {
592 return ENOENT;
593 }
594
595 std::vector<common_sampler *> smpl(n_parallel);
596 for (int i = 0; i < n_parallel; ++i) {
597 params.sampling.no_perf = (i != 0);
598 params.sampling.seed = params.sampling.seed + 1;
599
600 smpl[i] = common_sampler_init(model: model_ttc, params: params.sampling);
601 }
602
603 LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl[0]));
604 LOG_INF("sampler params: \n%s\n", params.sampling.print().c_str());
605 LOG_INF("sampler chain: %s\n", common_sampler_print(smpl[0]).c_str());
606
607 LOG_INF("%s: loading done\n", __func__);
608
609 const auto t_main_start = ggml_time_us();
610
611 std::vector<llama_token> codes;
612 std::vector<llama_token> guide_tokens;
613
614 // the default speaker profile is from: https://github.com/edwko/OuteTTS/blob/main/outetts/version/v1/default_speakers/en_male_1.json
615 std::string audio_text = "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>";
616 std::string audio_data = R"(<|audio_start|>
617the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
618overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
619package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
620from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|>
621just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|>
622two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|>
623people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|>
624is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|>
625pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|>
626remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|>
627sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|>
628i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|>
629have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|>
630some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|>
631critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|>
632about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|>
633some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|>
634of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|>
635the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|>
636gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|>
637aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|>
638but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|>
639its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|>
640still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|>
641really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|>
642enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|>
643and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|>
644it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|>
645looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
646lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
647
648 // audio data for 0.3 version
649 outetts_version tts_version = get_tts_version(model: model_ttc);
650 if (tts_version == OUTETTS_V0_3) {
651 audio_text = std::regex_replace(s: audio_text, e: std::regex(R"(<\|text_sep\|>)"), fmt: "<|space|>");
652 audio_data = std::regex_replace(s: audio_data, e: std::regex(R"(<\|code_start\|>)"), fmt: "");
653 audio_data = std::regex_replace(s: audio_data, e: std::regex(R"(<\|code_end\|>)"), fmt: "<|space|>");
654 }
655
656 // load speaker if given
657 if (!params.vocoder.speaker_file.empty()) {
658 LOG_INF("%s: loading speaker ..\n", __func__);
659 json speaker = speaker_from_file(speaker_file: params.vocoder.speaker_file);
660 if (speaker.empty()) {
661 LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
662 return 1;
663 }
664 audio_text = audio_text_from_speaker(speaker, tts_version);
665 audio_data = audio_data_from_speaker(speaker, tts_version);
666 }
667
668 // process prompt and generate voice codes
669 {
670 LOG_INF("%s: constructing prompt ..\n", __func__);
671
672 std::vector<llama_token> prompt_inp;
673
674 prompt_init(prompt&: prompt_inp, vocab);
675
676 prompt_add(prompt&: prompt_inp, vocab, txt: audio_text, add_special: false, parse_special: true);
677
678 // convert the input text into the necessary format expected by OuteTTS
679 {
680 std::string prompt_clean = process_text(text: params.prompt, tts_version);
681 if (params.vocoder.use_guide_tokens) {
682 guide_tokens = prepare_guide_tokens(vocab, str: prompt_clean, tts_version);
683 }
684
685 LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
686
687 prompt_add(prompt&: prompt_inp, vocab, txt: prompt_clean, add_special: false, parse_special: true);
688 }
689
690 prompt_add(prompt&: prompt_inp, vocab, txt: "<|text_end|>\n", add_special: false, parse_special: true);
691
692 if (!params.vocoder.speaker_file.empty()) {
693 prompt_add(prompt&: prompt_inp, vocab, txt: audio_data, add_special: false, parse_special: true);
694 } else {
695 // disabled to save time on tokenizing each time
696#if 1
697 const std::string voice_data = audio_data;
698
699 auto tmp = common_tokenize(vocab, text: voice_data, add_special: false, parse_special: true);
700
701 std::ostringstream tokens_oss;
702 for (size_t i = 0; i < tmp.size(); ++i) {
703 tokens_oss << tmp[i] << ", ";
704 }
705 LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str());
706
707 prompt_add(prompt&: prompt_inp, tokens: tmp);
708#else
709 prompt_add(prompt_inp, llama_tokens {
710 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
711 152460, 153375, 151670, 198, 74455, 155808, 151669, 151799,
712 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470,
713 151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040,
714 153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691,
715 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607,
716 152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198,
717 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207,
718 152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267,
719 153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179,
720 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904,
721 152311, 151670, 198, 1499, 155791, 151669, 152276, 152454,
722 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226,
723 153043, 152325, 153267, 152622, 151670, 198, 4250, 155797,
724 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
725 152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213,
726 152112, 153204, 151722, 152542, 151670, 198, 19789, 155796,
727 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002,
728 152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224,
729 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669,
730 152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733,
731 152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589,
732 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504,
733 153376, 152272, 152433, 152325, 151941, 151670, 198, 285,
734 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381,
735 152474, 152680, 152157, 153255, 152324, 151682, 151670, 198,
736 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
737 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488,
738 153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685,
739 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669,
740 151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459,
741 153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609,
742 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470,
743 152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018,
744 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736,
745 153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457,
746 152393, 153112, 152595, 151670, 198, 19098, 155808, 151669,
747 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239,
748 153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673,
749 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482,
750 152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795,
751 152111, 152746, 152377, 153471, 152309, 151670, 198, 19016,
752 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701,
753 152939, 152536, 152091, 151815, 152733, 151672, 151670, 198,
754 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042,
755 153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670,
756 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835,
757 152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600,
758 151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847,
759 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458,
760 152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250,
761 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
762 152228, 152733, 151670, 198, 9096, 155801, 151669, 151698,
763 153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106,
764 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851,
765 152901, 152885, 152594, 153446, 153080, 151670, 198, 14689,
766 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191,
767 151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384,
768 153364, 153188, 153246, 151670, 198, 1055, 155779, 151669,
769 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782,
770 155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046,
771 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605,
772 153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133,
773 152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581,
774 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032,
775 152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409,
776 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469,
777 152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230,
778 151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435,
779 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540,
780 151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715,
781 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
782 151670, 198, 8088, 155792, 151669, 152452, 153497, 153353,
783 152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285,
784 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781,
785 151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271,
786 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889,
787 152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180,
788 151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287,
789 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092,
790 152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384,
791 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691,
792 152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676,
793 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
794 152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234,
795 153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678,
796 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825,
797 152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957,
798 151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
799 151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174,
800 151792, 153409, 153327, 152990, 151670, 198, 275, 155781,
801 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974,
802 151670, 198, 94273, 155799, 151669, 152953, 152938, 153427,
803 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331,
804 152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326,
805 152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268,
806 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110,
807 152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444,
808 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
809 152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499,
810 152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
811 151670,});
812#endif
813 }
814
815 // print the prompt token-by-token
816
817 LOG("\n");
818
819 for (auto id : prompt_inp) {
820 LOG("%s", common_token_to_piece(ctx_ttc, id).c_str());
821 }
822
823 LOG_INF("%s: prompt size: %d\n", __func__, (int) prompt_inp.size());
824
825 LOG("\n");
826
827 // create a llama_batch
828 // we use this object to submit token data for decoding
829 llama_batch batch = llama_batch_init(n_tokens: std::max(a: prompt_inp.size(), b: (size_t) n_parallel), embd: 0, n_seq_max: n_parallel);
830
831 std::vector<llama_seq_id> seq_ids(n_parallel, 0);
832 for (int32_t i = 0; i < n_parallel; ++i) {
833 seq_ids[i] = i;
834 }
835
836 // evaluate the initial prompt
837 for (size_t i = 0; i < prompt_inp.size(); ++i) {
838 common_batch_add(batch, id: prompt_inp[i], pos: i, seq_ids, logits: false);
839 }
840 GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
841
842 // llama_decode will output logits only for the last token of the prompt
843 batch.logits[batch.n_tokens - 1] = true;
844
845 if (llama_decode(ctx: ctx_ttc, batch) != 0) {
846 LOG_ERR("%s: llama_decode() failed\n", __func__);
847 return 1;
848 }
849
850 if (n_parallel > 1) {
851 LOG_INF("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
852 }
853
854 llama_synchronize(ctx: ctx_ttc);
855
856 LOG_INF("%s: time for prompt: %.3f ms\n\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
857
858 const auto t_dec_start = ggml_time_us();
859
860 // main loop
861
862 // remember the batch index of the last token for each parallel sequence
863 // we need this to determine which logits to sample from
864 std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
865
866 int n_past = batch.n_tokens;
867 int n_decode = 0;
868
869 bool next_token_uses_guide_token = true;
870
871 while (n_decode <= n_predict) {
872 // prepare the next batch
873 common_batch_clear(batch);
874
875 // sample the next token for each parallel sequence / stream
876 for (int32_t i = 0; i < n_parallel; ++i) {
877 if (i_batch[i] < 0) {
878 // the stream has already finished
879 continue;
880 }
881
882 llama_token new_token_id = common_sampler_sample(gsmpl: smpl[i], ctx: ctx_ttc, idx: i_batch[i]);
883
884 //guide tokens help prevent hallucinations by forcing the TTS to use the correct word
885 if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, token: new_token_id) && !llama_vocab_is_eog(vocab, token: new_token_id)) {
886 llama_token guide_token = guide_tokens[0];
887 guide_tokens.erase(position: guide_tokens.begin());
888 new_token_id = guide_token; //ensure correct word fragment is used
889 }
890
891 //this is the token id that always precedes a new word
892 next_token_uses_guide_token = (new_token_id == 198);
893
894 common_sampler_accept(gsmpl: smpl[i], token: new_token_id, accept_grammar: true);
895
896 codes.push_back(x: new_token_id);
897
898 const auto * cands = common_sampler_get_candidates(gsmpl: smpl[i], do_sort: false);
899
900 // is it an end of generation? -> mark the stream as finished
901 if (llama_vocab_is_eog(vocab, token: new_token_id) || n_decode == n_predict) {
902 std::string reason;
903 if (llama_vocab_is_eog(vocab, token: new_token_id)) {
904 reason = "eos";
905 } else {
906 reason = "n_predict";
907 }
908
909 i_batch[i] = -1;
910
911 LOG("\n");
912 if (n_parallel > 1) {
913 LOG_CNT("\n");
914 LOG_INF("%s: stream %d finished at n_past = %d, reason = '%s'\n", __func__, i, n_past, reason.c_str());
915 }
916
917 continue;
918 }
919
920 {
921 const float p = cands->data[cands->selected].p;
922
923 const int col = std::max(a: 0, b: std::min(a: (int) k_colors.size() - 1, b: (int) ((3*p)*float(k_colors.size()))));
924
925 LOG_CNT("%s%d%s", k_colors[col].c_str(), i, "\033[0m");
926 //LOG_CNT("%d", i);
927 }
928
929 i_batch[i] = batch.n_tokens;
930
931 // push this new token for next evaluation
932 common_batch_add(batch, id: new_token_id, pos: n_past, seq_ids: { i }, logits: true);
933 }
934
935 // all streams are finished
936 if (batch.n_tokens == 0) {
937 break;
938 }
939
940 n_decode += 1;
941 n_past += 1;
942
943 // evaluate the current batch with the transformer model
944 if (llama_decode(ctx: ctx_ttc, batch)) {
945 LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
946 return 1;
947 }
948 }
949
950 llama_batch_free(batch);
951
952 LOG("\n");
953 LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
954 }
955
956 common_perf_print(ctx: ctx_ttc, gsmpl: smpl[0]);
957
958 //std::vector<llama_token> codes = {198, 88225, 155856, 151669, 152205,
959 // 153064, 152537, 153421, 153209, 152524, 151689, 152993, 152438, 152695,
960 // 153091, 152945, 152829, 152534, 152934, 153020, 151997, 152263, 153010,
961 // 153146, 152399, 153208, 152496, 151793, 152848, 152263, 152571, 153286,
962 // 152227, 153300, 152934, 152263, 153208, 152263, 152965, 152430, 152296,
963 // 153146, 152920, 152376, 152556, 153363, 151775, 152044, 152972, 152690,
964 // 153379, 152368, 152233, 153422, 152490, 151996, 152022, 151694, 152061,
965 // 153238, 152539, 153356, 152640, 153021, 153123, 151962, 153094, 151670,
966 // 198, 20339, 13189, 155824, 151669, 152070, 152007, 152910, 151683,
967 // 152000, 152373, 152760, 152046, 151735, 152334, 152394, 153073, 152908,
968 // 151856, 151953, 153247, 153293, 151903, 153480, 153168, 152478, 153359,
969 // 153429, 151905, 151678, 152567, 152411, 152165, 152556, 153075, 153424,
970 // 151993, 152999, 153078, 152151, 152088, 153389, 152484, 151874, 151670,
971 // 198, 285, 155784, 151669, 152226, 152126, 152638, 153215, 151729,
972 // 152959, 153479, 153059, 151838, 151670, 198, 1782, 155783, 151669,
973 // 153288, 153055, 153314, 152497, 152962, 152741, 152076, 153253, 151670,
974 // 198, 471, 16488, 155825, 151669, 152060, 152916, 151893, 153469, 152501,
975 // 152080, 152743, 151932, 153161, 152096, 152761, 152698, 153401, 153242,
976 // 153336, 152441, 152838, 153467, 152706, 153496, 153310, 152422, 153360,
977 // 153115, 152763, 151998, 152373, 153450, 152554, 151968, 153323, 152055,
978 // 152468, 153111, 153358, 152813, 152010, 151770, 152823, 152960, 151670,
979 // 198, 22627, 155823, 151669, 152814, 152366, 153484, 152931, 153441,
980 // 152164, 152877, 152915, 153463, 151692, 152911, 152747, 152776, 151831,
981 // 153449, 151882, 152975, 152031, 152513, 153150, 152448, 152667, 153133,
982 // 153189, 152619, 153466, 152054, 152106, 153119, 152277, 152439, 153109,
983 // 152997, 152141, 153154, 153256, 153311, 151922, 151670, 198, 1055,
984 // 155781, 151669, 152633, 151850, 153060, 153270, 152560, 153348, 152729,
985 // 151670, 198, 25312, 155803, 151669, 152521, 153403, 152561, 153337,
986 // 153383, 152199, 153493, 153326, 151830, 152254, 152248, 152349, 152153,
987 // 153007, 151823, 153037, 152575, 152457, 152406, 152592, 153116, 153365,
988 // 153456, 151670, 198, 88225, 155817, 151669, 153271, 151925, 152218,
989 // 152418, 152253, 153140, 151903, 153151, 152626, 152338, 152647, 153464,
990 // 152785, 152768, 151711, 152037, 152033, 151804, 152216, 151701, 151855,
991 // 152348, 152995, 152955, 152905, 152342, 152340, 153391, 153453, 152418,
992 // 153415, 151990, 153083, 152884, 151670, 198, 151668, 198, 151645};
993
994 {
995 const std::string inp_txt = common_detokenize(ctx: ctx_ttc, tokens: codes, special: true);
996
997 LOG("\n");
998 LOG_INF("codes: '%s'\n", inp_txt.c_str());
999 LOG_INF("%s: codes size: %d\n", __func__, (int) codes.size());
1000 }
1001
1002 // remove all non-audio tokens (i.e. < 151672 || > 155772)
1003 codes.erase(first: std::remove_if(first: codes.begin(), last: codes.end(), pred: [](llama_token t) { return t < 151672 || t > 155772; }), last: codes.end());
1004
1005 {
1006 const std::string inp_txt = common_detokenize(ctx: ctx_ttc, tokens: codes, special: true);
1007 LOG_INF("codes audio: '%s'\n", inp_txt.c_str());
1008 LOG_INF("%s: codes audio size: %d\n", __func__, (int) codes.size());
1009 }
1010
1011 for (auto & token : codes) {
1012 token -= 151672;
1013 }
1014
1015 const auto t_voc_start = ggml_time_us();
1016
1017 const int n_codes = codes.size();
1018
1019 llama_batch batch = llama_batch_init(n_tokens: n_codes, embd: 0, n_seq_max: 1);
1020
1021 for (size_t i = 0; i < codes.size(); ++i) {
1022 common_batch_add(batch, id: codes[i], pos: i, seq_ids: { 0 }, logits: true); // TODO: all logits?
1023 }
1024 GGML_ASSERT(batch.n_tokens == n_codes);
1025
1026 if (llama_encode(ctx: ctx_cts, batch) != 0) {
1027 LOG_ERR("%s: llama_encode() failed\n", __func__);
1028 return 1;
1029 }
1030
1031 llama_synchronize(ctx: ctx_cts);
1032
1033 LOG_INF("%s: time for vocoder: %.3f ms\n", __func__, (ggml_time_us() - t_voc_start) / 1000.0f);
1034
1035 const auto t_spec_start = ggml_time_us();
1036
1037#if 1
1038 // spectral operations
1039 const int n_embd = llama_model_n_embd(model: model_cts);
1040 const float * embd = llama_get_embeddings(ctx: ctx_cts);
1041
1042 auto audio = embd_to_audio(embd, n_codes, n_embd, n_thread: params.cpuparams.n_threads);
1043
1044#else
1045 // read the spectrogram from a file for debugging purposes
1046 std::vector<float> audio;
1047 {
1048 std::ifstream fin("out.bin", std::ios::binary);
1049 if (!fin) {
1050 LOG_ERR("%s: failed to open file '%s'\n", __func__, "out.bin");
1051 return 1;
1052 }
1053
1054 std::vector<float> embd;
1055
1056 int n_codes;
1057 int n_embd;
1058
1059 fin.read(reinterpret_cast<char *>(&n_codes), sizeof(int));
1060 fin.read(reinterpret_cast<char *>(&n_embd), sizeof(int));
1061
1062 embd.resize(n_codes * n_embd);
1063 fin.read(reinterpret_cast<char *>(embd.data()), n_codes * n_embd * sizeof(float));
1064 fin.close();
1065
1066 LOG_INF("%s: n_codes: %d, n_embd: %d\n", __func__, n_codes, n_embd);
1067
1068 audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
1069 }
1070#endif
1071
1072 const int n_sr = 24000; // sampling rate
1073
1074 // zero out first 0.25 seconds
1075 for (int i = 0; i < 24000/4; ++i) {
1076 audio[i] = 0.0f;
1077 }
1078
1079 LOG_INF("%s: time for spectral ops: %.3f ms\n", __func__, (ggml_time_us() - t_spec_start) / 1000.0f);
1080 LOG_INF("%s: total time: %.3f ms\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
1081
1082 int retval = 0;
1083
1084 if (save_wav16(fname: params.out_file, data: audio, sample_rate: n_sr)) {
1085 LOG_INF("%s: audio written to file '%s'\n", __func__, params.out_file.c_str());
1086 } else {
1087 retval = ENOENT;
1088 }
1089
1090 llama_backend_free();
1091
1092 return retval;
1093}
1094