| 1 | #define _USE_MATH_DEFINES // For M_PI on MSVC |
| 2 | |
| 3 | #include "arg.h" |
| 4 | #include "common.h" |
| 5 | #include "sampling.h" |
| 6 | #include "log.h" |
| 7 | #include "llama.h" |
| 8 | |
| 9 | #define JSON_ASSERT GGML_ASSERT |
| 10 | #include <nlohmann/json.hpp> |
| 11 | |
| 12 | #include <algorithm> |
| 13 | #include <cmath> |
| 14 | #include <cstdio> |
| 15 | #include <fstream> |
| 16 | #include <map> |
| 17 | #include <regex> |
| 18 | #include <string> |
| 19 | #include <thread> |
| 20 | #include <vector> |
| 21 | |
| 22 | using json = nlohmann::ordered_json; |
| 23 | |
| 24 | enum outetts_version { |
| 25 | OUTETTS_V0_2, |
| 26 | OUTETTS_V0_3, |
| 27 | }; |
| 28 | |
| 29 | // |
| 30 | // Terminal utils |
| 31 | // |
| 32 | |
| 33 | #define SQR(X) ((X) * (X)) |
| 34 | #define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40 |
| 35 | |
| 36 | /** |
| 37 | * Quantizes 24-bit RGB to xterm256 code range [16,256). |
| 38 | */ |
| 39 | static int rgb2xterm256(int r, int g, int b) { |
| 40 | unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377}; |
| 41 | int av, ir, ig, ib, il, qr, qg, qb, ql; |
| 42 | av = r * .299 + g * .587 + b * .114 + .5; |
| 43 | ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8; |
| 44 | qr = cube[(ir = UNCUBE(r))]; |
| 45 | qg = cube[(ig = UNCUBE(g))]; |
| 46 | qb = cube[(ib = UNCUBE(b))]; |
| 47 | if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <= |
| 48 | SQR(ql - r) + SQR(ql - g) + SQR(ql - b)) |
| 49 | return ir * 36 + ig * 6 + ib + 020; |
| 50 | return il + 0350; |
| 51 | } |
| 52 | |
| 53 | static std::string set_xterm256_foreground(int r, int g, int b) { |
| 54 | int x = rgb2xterm256(r, g, b); |
| 55 | std::ostringstream oss; |
| 56 | oss << "\033[38;5;" << x << "m" ; |
| 57 | return oss.str(); |
| 58 | } |
| 59 | |
| 60 | const std::vector<std::string> k_colors = { |
| 61 | set_xterm256_foreground(r: 220, g: 5, b: 12), |
| 62 | set_xterm256_foreground(r: 232, g: 96, b: 28), |
| 63 | set_xterm256_foreground(r: 241, g: 147, b: 45), |
| 64 | set_xterm256_foreground(r: 246, g: 193, b: 65), |
| 65 | set_xterm256_foreground(r: 247, g: 240, b: 86), |
| 66 | set_xterm256_foreground(r: 144, g: 201, b: 135), |
| 67 | set_xterm256_foreground( r: 78, g: 178, b: 101), |
| 68 | }; |
| 69 | |
| 70 | static void print_usage(int, char ** argv) { |
| 71 | LOG("\nexample usage:\n" ); |
| 72 | LOG("\n %s -m model.gguf -p \"Hello!\"\n" , argv[0]); |
| 73 | LOG("\n" ); |
| 74 | } |
| 75 | |
| 76 | struct { |
| 77 | char [4] = {'R', 'I', 'F', 'F'}; |
| 78 | uint32_t ; |
| 79 | char [4] = {'W', 'A', 'V', 'E'}; |
| 80 | char [4] = {'f', 'm', 't', ' '}; |
| 81 | uint32_t = 16; |
| 82 | uint16_t = 1; // PCM |
| 83 | uint16_t = 1; // Mono |
| 84 | uint32_t ; |
| 85 | uint32_t ; |
| 86 | uint16_t ; |
| 87 | uint16_t = 16; |
| 88 | char [4] = {'d', 'a', 't', 'a'}; |
| 89 | uint32_t ; |
| 90 | }; |
| 91 | |
| 92 | static bool save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) { |
| 93 | std::ofstream file(fname, std::ios::binary); |
| 94 | if (!file) { |
| 95 | LOG_ERR("%s: Failed to open file '%s' for writing.\n" , __func__, fname.c_str()); |
| 96 | return false; |
| 97 | } |
| 98 | |
| 99 | wav_header ; |
| 100 | header.sample_rate = sample_rate; |
| 101 | header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); |
| 102 | header.block_align = header.num_channels * (header.bits_per_sample / 8); |
| 103 | header.data_size = data.size() * (header.bits_per_sample / 8); |
| 104 | header.chunk_size = 36 + header.data_size; |
| 105 | |
| 106 | file.write(s: reinterpret_cast<const char*>(&header), n: sizeof(header)); |
| 107 | |
| 108 | for (const auto & sample : data) { |
| 109 | int16_t pcm_sample = static_cast<int16_t>(std::clamp(val: sample * 32767.0, lo: -32768.0, hi: 32767.0)); |
| 110 | file.write(s: reinterpret_cast<const char*>(&pcm_sample), n: sizeof(pcm_sample)); |
| 111 | } |
| 112 | |
| 113 | return file.good(); |
| 114 | } |
| 115 | |
| 116 | static void fill_hann_window(int length, bool periodic, float * output) { |
| 117 | int offset = -1; |
| 118 | if (periodic) { |
| 119 | offset = 0; |
| 120 | } |
| 121 | for (int i = 0; i < length; i++) { |
| 122 | output[i] = 0.5 * (1.0 - cosf(x: (2.0 * M_PI * i) / (length + offset))); |
| 123 | } |
| 124 | } |
| 125 | |
| 126 | // very poor-man fft |
| 127 | static void twiddle(float * real, float * imag, int k, int N) { |
| 128 | float angle = 2 * M_PI * k / N; |
| 129 | *real = cos(x: angle); |
| 130 | *imag = sin(x: angle); |
| 131 | } |
| 132 | |
| 133 | static void irfft(int n, const float * inp_cplx, float * out_real) { |
| 134 | int N = n / 2 + 1; |
| 135 | |
| 136 | std::vector<float> real_input(N); |
| 137 | std::vector<float> imag_input(N); |
| 138 | for (int i = 0; i < N; ++i) { |
| 139 | real_input[i] = inp_cplx[2 * i]; |
| 140 | imag_input[i] = inp_cplx[2 * i + 1]; |
| 141 | } |
| 142 | |
| 143 | std::vector<float> real_output(n); |
| 144 | std::vector<float> imag_output(n); |
| 145 | |
| 146 | for (int k = 0; k < n; ++k) { |
| 147 | real_output[k] = 0.0f; |
| 148 | imag_output[k] = 0.0f; |
| 149 | for (int m = 0; m < N; ++m) { |
| 150 | float twiddle_real; |
| 151 | float twiddle_imag; |
| 152 | |
| 153 | twiddle(real: &twiddle_real, imag: &twiddle_imag, k: k * m, N: n); |
| 154 | |
| 155 | real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; |
| 156 | imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; |
| 157 | } |
| 158 | } |
| 159 | |
| 160 | for (int i = 0; i < n; ++i) { |
| 161 | out_real[i] = real_output[i] / N; |
| 162 | } |
| 163 | } |
| 164 | |
| 165 | // |
| 166 | // y = torch.nn.functional.fold( |
| 167 | // data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), |
| 168 | // )[:, 0, 0, pad:-pad] |
| 169 | // |
| 170 | // data.shape = torch.Size([1, 1280, 261]) |
| 171 | // output_size = 84480 |
| 172 | // win_length = 1280 |
| 173 | // hop_length = 320 |
| 174 | // pad = 480 |
| 175 | // |
| 176 | static void fold(const std::vector<float> & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float> & output) { |
| 177 | int64_t output_height = n_out; |
| 178 | int64_t kernel_w = n_win; |
| 179 | int64_t stride_w = n_hop; |
| 180 | int64_t width = n_out; |
| 181 | |
| 182 | output.resize(new_size: width, x: 0.0f); |
| 183 | |
| 184 | int64_t col_idx = 0; |
| 185 | for (int64_t w_col = 0; w_col < width; ++w_col) { |
| 186 | int64_t start = w_col * stride_w - n_pad; |
| 187 | int64_t end = start + kernel_w; |
| 188 | |
| 189 | for (int64_t w_im = start; w_im < end; ++w_im) { |
| 190 | if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { |
| 191 | output[w_im] += data[col_idx]; |
| 192 | } |
| 193 | col_idx++; |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | output.resize(new_size: n_out - 2 * n_pad); |
| 198 | } |
| 199 | |
| 200 | // TODO: not optimized at all |
| 201 | static std::vector<float> embd_to_audio( |
| 202 | const float * embd, |
| 203 | const int n_codes, |
| 204 | const int n_embd, |
| 205 | const int n_thread) { |
| 206 | const int n_fft = 1280; |
| 207 | const int n_hop = 320; |
| 208 | const int n_win = 1280; |
| 209 | const int n_pad = (n_win - n_hop)/2; |
| 210 | const int n_out = (n_codes - 1)*n_hop + n_win; |
| 211 | |
| 212 | std::vector<float> hann(n_fft); |
| 213 | |
| 214 | fill_hann_window(length: hann.size(), periodic: true, output: hann.data()); |
| 215 | |
| 216 | int n_spec = n_embd*n_codes; |
| 217 | |
| 218 | std::vector<float> E (n_spec); |
| 219 | std::vector<float> S (n_spec); |
| 220 | std::vector<float> ST(n_spec); |
| 221 | |
| 222 | for (int l = 0; l < n_codes; ++l) { |
| 223 | for (int k = 0; k < n_embd; ++k) { |
| 224 | E[k*n_codes + l] = embd[l*n_embd + k]; |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | for (int k = 0; k < n_embd/2; ++k) { |
| 229 | for (int l = 0; l < n_codes; ++l) { |
| 230 | float mag = E[(k )*n_codes + l]; |
| 231 | float phi = E[(k + n_embd/2)*n_codes + l]; |
| 232 | |
| 233 | mag = exp(x: mag); |
| 234 | |
| 235 | if (mag > 1e2) { |
| 236 | mag = 1e2; |
| 237 | } |
| 238 | S[2*(k*n_codes + l) + 0] = mag*cosf(x: phi); |
| 239 | S[2*(k*n_codes + l) + 1] = mag*sinf(x: phi); |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | for (int l = 0; l < n_codes; ++l) { |
| 244 | for (int k = 0; k < n_embd/2; ++k) { |
| 245 | ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; |
| 246 | ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; |
| 247 | } |
| 248 | } |
| 249 | |
| 250 | std::vector<float> res (n_codes*n_fft); |
| 251 | std::vector<float> hann2(n_codes*n_fft); |
| 252 | |
| 253 | std::vector<std::thread> workers(n_thread); |
| 254 | for (int i = 0; i < n_thread; ++i) { |
| 255 | workers[i] = std::thread([&, i]() { |
| 256 | for (int l = i; l < n_codes; l += n_thread) { |
| 257 | irfft(n: n_fft, inp_cplx: ST.data() + l*n_embd, out_real: res.data() + l*n_fft); |
| 258 | for (int j = 0; j < n_fft; ++j) { |
| 259 | res [l*n_fft + j] *= hann[j]; |
| 260 | hann2[l*n_fft + j] = hann[j] * hann[j]; |
| 261 | } |
| 262 | } |
| 263 | }); |
| 264 | } |
| 265 | for (int i = 0; i < n_thread; ++i) { |
| 266 | workers[i].join(); |
| 267 | } |
| 268 | |
| 269 | std::vector<float> audio; |
| 270 | std::vector<float> env; |
| 271 | |
| 272 | fold(data: res, n_out, n_win, n_hop, n_pad, output&: audio); |
| 273 | fold(data: hann2, n_out, n_win, n_hop, n_pad, output&: env); // TODO: can be done once |
| 274 | |
| 275 | for (size_t i = 0; i < audio.size(); ++i) { |
| 276 | audio[i] /= env[i]; |
| 277 | } |
| 278 | |
| 279 | return audio; |
| 280 | } |
| 281 | |
| 282 | static const std::map<int, std::string> ones = { |
| 283 | {0, "zero" }, {1, "one" }, {2, "two" }, {3, "three" }, {4, "four" }, |
| 284 | {5, "five" }, {6, "six" }, {7, "seven" }, {8, "eight" }, {9, "nine" }, |
| 285 | {10, "ten" }, {11, "eleven" }, {12, "twelve" }, {13, "thirteen" }, {14, "fourteen" }, |
| 286 | {15, "fifteen" }, {16, "sixteen" }, {17, "seventeen" }, {18, "eighteen" }, {19, "nineteen" } |
| 287 | }; |
| 288 | |
| 289 | static const std::map<int, std::string> tens = { |
| 290 | {2, "twenty" }, {3, "thirty" }, {4, "forty" }, {5, "fifty" }, |
| 291 | {6, "sixty" }, {7, "seventy" }, {8, "eighty" }, {9, "ninety" } |
| 292 | }; |
| 293 | |
| 294 | // Convert a number less than 1000 to words |
| 295 | static std::string convert_less_than_thousand(int num) { |
| 296 | std::string result; |
| 297 | |
| 298 | if (num >= 100) { |
| 299 | result += ones.at(k: num / 100) + " hundred " ; |
| 300 | num %= 100; |
| 301 | } |
| 302 | |
| 303 | if (num >= 20) { |
| 304 | result += tens.at(k: num / 10); |
| 305 | if (num % 10 > 0) { |
| 306 | result += "-" + ones.at(k: num % 10); |
| 307 | } |
| 308 | } else if (num > 0) { |
| 309 | result += ones.at(k: num); |
| 310 | } |
| 311 | |
| 312 | return result; |
| 313 | } |
| 314 | |
| 315 | static std::string number_to_words(const std::string & number_str) { |
| 316 | try { |
| 317 | size_t decimal_pos = number_str.find(c: '.'); |
| 318 | std::string integer_part = number_str.substr(pos: 0, n: decimal_pos); |
| 319 | |
| 320 | int int_number = std::stoi(str: integer_part); |
| 321 | std::string result; |
| 322 | |
| 323 | if (int_number == 0) { |
| 324 | result = "zero" ; |
| 325 | } else { |
| 326 | if (int_number >= 1000000000) { |
| 327 | int billions = int_number / 1000000000; |
| 328 | result += convert_less_than_thousand(num: billions) + " billion " ; |
| 329 | int_number %= 1000000000; |
| 330 | } |
| 331 | |
| 332 | if (int_number >= 1000000) { |
| 333 | int millions = int_number / 1000000; |
| 334 | result += convert_less_than_thousand(num: millions) + " million " ; |
| 335 | int_number %= 1000000; |
| 336 | } |
| 337 | |
| 338 | if (int_number >= 1000) { |
| 339 | int thousands = int_number / 1000; |
| 340 | result += convert_less_than_thousand(num: thousands) + " thousand " ; |
| 341 | int_number %= 1000; |
| 342 | } |
| 343 | |
| 344 | if (int_number > 0) { |
| 345 | result += convert_less_than_thousand(num: int_number); |
| 346 | } |
| 347 | } |
| 348 | |
| 349 | // Handle decimal part |
| 350 | if (decimal_pos != std::string::npos) { |
| 351 | result += " point" ; |
| 352 | std::string decimal_part = number_str.substr(pos: decimal_pos + 1); |
| 353 | for (char digit : decimal_part) { |
| 354 | result += " " + ones.at(k: digit - '0'); |
| 355 | } |
| 356 | } |
| 357 | |
| 358 | return result; |
| 359 | } catch (const std::exception& e) { |
| 360 | // Skip if fails |
| 361 | return " " ; |
| 362 | } |
| 363 | } |
| 364 | |
| 365 | static std::string replace_numbers_with_words(const std::string & input_text) { |
| 366 | std::regex number_pattern(R"(\d+(\.\d+)?)" ); |
| 367 | std::string result; |
| 368 | auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern); |
| 369 | auto end = std::sregex_iterator(); |
| 370 | |
| 371 | size_t last_pos = 0; |
| 372 | for (std::sregex_iterator i = it; i != end; ++i) { |
| 373 | const std::smatch& match = *i; |
| 374 | result.append(str: input_text, pos: last_pos, n: match.position() - last_pos); |
| 375 | result.append(str: number_to_words(number_str: match.str())); |
| 376 | last_pos = match.position() + match.length(); |
| 377 | } |
| 378 | result.append(str: input_text, pos: last_pos); |
| 379 | |
| 380 | return result; |
| 381 | } |
| 382 | |
| 383 | // Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 |
| 384 | static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) { |
| 385 | |
| 386 | // For now I skipped text romanization as I am unsure how to handle |
| 387 | // uroman and MeCab implementations in C++ |
| 388 | // maybe something like https://github.com/anyascii/anyascii/ could work. |
| 389 | // currently only English would be supported in this function |
| 390 | |
| 391 | std::string processed_text = replace_numbers_with_words(input_text: text); |
| 392 | |
| 393 | std::transform(first: processed_text.begin(), last: processed_text.end(), |
| 394 | result: processed_text.begin(), unary_op: ::tolower); |
| 395 | |
| 396 | std::regex special_chars(R"([-_/,\.\\])" ); |
| 397 | processed_text = std::regex_replace(s: processed_text, e: special_chars, fmt: " " ); |
| 398 | |
| 399 | std::regex non_alpha(R"([^a-z\s])" ); |
| 400 | processed_text = std::regex_replace(s: processed_text, e: non_alpha, fmt: "" ); |
| 401 | |
| 402 | std::regex multiple_spaces(R"(\s+)" ); |
| 403 | processed_text = std::regex_replace(s: processed_text, e: multiple_spaces, fmt: " " ); |
| 404 | |
| 405 | processed_text = std::regex_replace(s: processed_text, e: std::regex(R"(^\s+|\s+$)" ), fmt: "" ); |
| 406 | |
| 407 | /* |
| 408 | Replace spaces with the separator token same as in line 365 |
| 409 | |
| 410 | for (auto & c : prompt_user) { |
| 411 | if (c == ' ') { |
| 412 | prompt_clean += "<|text_sep|>"; |
| 413 | */ |
| 414 | std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>" ; |
| 415 | processed_text = std::regex_replace(s: processed_text, e: std::regex(R"(\s)" ), fmt: separator); |
| 416 | |
| 417 | return processed_text; |
| 418 | } |
| 419 | |
| 420 | static void prompt_add(llama_tokens & prompt, llama_token token) { |
| 421 | prompt.push_back(x: token); |
| 422 | } |
| 423 | |
| 424 | static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) { |
| 425 | prompt.insert(position: prompt.end(), first: tokens.begin(), last: tokens.end()); |
| 426 | } |
| 427 | |
| 428 | static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) { |
| 429 | auto tmp = common_tokenize(vocab, text: txt, add_special, parse_special); |
| 430 | prompt_add(prompt, tokens: tmp); |
| 431 | } |
| 432 | |
| 433 | static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { |
| 434 | prompt.clear(); |
| 435 | |
| 436 | prompt_add(prompt, vocab, txt: "<|im_start|>\n" , add_special: true, parse_special: true); |
| 437 | } |
| 438 | |
| 439 | static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) { |
| 440 | const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>" ); |
| 441 | |
| 442 | std::vector<llama_token> result; |
| 443 | size_t start = 0; |
| 444 | size_t end = str.find(str: delimiter); |
| 445 | |
| 446 | //first token is always a newline, as it was not previously added |
| 447 | result.push_back(x: common_tokenize(vocab, text: "\n" , add_special: false, parse_special: true)[0]); |
| 448 | |
| 449 | while (end != std::string::npos) { |
| 450 | std::string current_word = str.substr(pos: start, n: end - start); |
| 451 | auto tmp = common_tokenize(vocab, text: current_word, add_special: false, parse_special: true); |
| 452 | result.push_back(x: tmp[0]); |
| 453 | start = end + delimiter.length(); |
| 454 | end = str.find(str: delimiter, pos: start); |
| 455 | } |
| 456 | |
| 457 | // Add the last part |
| 458 | std::string current_word = str.substr(pos: start); |
| 459 | auto tmp = common_tokenize(vocab, text: current_word, add_special: false, parse_special: true); |
| 460 | if (tmp.size() > 0) { |
| 461 | result.push_back(x: tmp[0]); |
| 462 | } |
| 463 | return result; |
| 464 | } |
| 465 | |
| 466 | static json speaker_from_file(const std::string & speaker_file) { |
| 467 | std::ifstream file(speaker_file); |
| 468 | if (!file) { |
| 469 | LOG_ERR("%s: Failed to open file '%s' for reading\n" , __func__, speaker_file.c_str()); |
| 470 | return json(); |
| 471 | } |
| 472 | |
| 473 | json speaker = json::parse(i&: file); |
| 474 | return speaker; |
| 475 | } |
| 476 | |
| 477 | static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) { |
| 478 | if (speaker.contains(key: "version" )) { |
| 479 | std::string version = speaker["version" ].get<std::string>(); |
| 480 | if (version == "0.2" ) { |
| 481 | return OUTETTS_V0_2; |
| 482 | } else if (version == "0.3" ) { |
| 483 | return OUTETTS_V0_3; |
| 484 | } else { |
| 485 | LOG_ERR("%s: Unsupported speaker version '%s'\n" , __func__, version.c_str()); |
| 486 | } |
| 487 | } |
| 488 | |
| 489 | // Also could get version from model itself |
| 490 | const char *chat_template = llama_model_chat_template(model, name: nullptr); |
| 491 | if (chat_template && std::string(chat_template) == "outetts-0.3" ) { |
| 492 | return OUTETTS_V0_3; |
| 493 | } |
| 494 | |
| 495 | // Use 0.2 as the default version |
| 496 | return OUTETTS_V0_2; |
| 497 | } |
| 498 | |
| 499 | static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) { |
| 500 | std::string audio_text = "<|text_start|>" ; |
| 501 | |
| 502 | if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) { |
| 503 | std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>" ; |
| 504 | for (const auto &word : speaker["words" ]) { |
| 505 | audio_text += word["word" ].get<std::string>() + separator; |
| 506 | } |
| 507 | } |
| 508 | |
| 509 | return audio_text; |
| 510 | } |
| 511 | |
| 512 | static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) { |
| 513 | std::string audio_data = "<|audio_start|>\n" ; |
| 514 | |
| 515 | if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) { |
| 516 | std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>" ; |
| 517 | std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>" ; |
| 518 | for (const auto &word : speaker["words" ]) { |
| 519 | std::string word_text = word["word" ].get<std::string>(); |
| 520 | double duration = word["duration" ].get<double>(); |
| 521 | std::vector<int> codes = word["codes" ].get<std::vector<int>>(); |
| 522 | |
| 523 | // Create the audio output entry |
| 524 | std::ostringstream word_entry; |
| 525 | word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2) |
| 526 | << duration << "|>" + code_start; |
| 527 | for (const auto &Code : codes) { |
| 528 | word_entry << "<|" << Code << "|>" ; |
| 529 | } |
| 530 | word_entry << code_end << "\n" ; |
| 531 | audio_data += word_entry.str(); |
| 532 | } |
| 533 | } |
| 534 | |
| 535 | return audio_data; |
| 536 | } |
| 537 | |
| 538 | int main(int argc, char ** argv) { |
| 539 | common_params params; |
| 540 | |
| 541 | params.out_file = "output.wav" ; |
| 542 | params.prompt = "" ; |
| 543 | |
| 544 | params.n_predict = 4096; |
| 545 | params.n_batch = 8192; |
| 546 | params.n_ctx = 8192; |
| 547 | |
| 548 | params.sampling.top_k = 4; |
| 549 | params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; |
| 550 | |
| 551 | if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_TTS, print_usage)) { |
| 552 | return 1; |
| 553 | } |
| 554 | |
| 555 | const int n_parallel = params.n_parallel; |
| 556 | const int n_predict = params.n_predict; |
| 557 | |
| 558 | common_init(); |
| 559 | |
| 560 | // init LLM |
| 561 | |
| 562 | llama_backend_init(); |
| 563 | llama_numa_init(numa: params.numa); |
| 564 | |
| 565 | llama_model * model_ttc = NULL; // text-to-codes |
| 566 | llama_model * model_cts = NULL; // codes-to-speech |
| 567 | |
| 568 | llama_context * ctx_ttc = NULL; |
| 569 | llama_context * ctx_cts = NULL; |
| 570 | |
| 571 | common_init_result llama_init_ttc = common_init_from_params(params); |
| 572 | |
| 573 | model_ttc = llama_init_ttc.model.get(); |
| 574 | ctx_ttc = llama_init_ttc.context.get(); |
| 575 | |
| 576 | if (model_ttc == nullptr || ctx_ttc == nullptr) { |
| 577 | return ENOENT; |
| 578 | } |
| 579 | |
| 580 | const llama_vocab * vocab = llama_model_get_vocab(model: model_ttc); |
| 581 | |
| 582 | params.model = params.vocoder.model; |
| 583 | params.embedding = true; |
| 584 | params.n_ubatch = params.n_batch; |
| 585 | |
| 586 | common_init_result llama_init_cts = common_init_from_params(params); |
| 587 | |
| 588 | model_cts = llama_init_cts.model.get(); |
| 589 | ctx_cts = llama_init_cts.context.get(); |
| 590 | |
| 591 | if (model_cts == nullptr || ctx_cts == nullptr) { |
| 592 | return ENOENT; |
| 593 | } |
| 594 | |
| 595 | std::vector<common_sampler *> smpl(n_parallel); |
| 596 | for (int i = 0; i < n_parallel; ++i) { |
| 597 | params.sampling.no_perf = (i != 0); |
| 598 | params.sampling.seed = params.sampling.seed + 1; |
| 599 | |
| 600 | smpl[i] = common_sampler_init(model: model_ttc, params: params.sampling); |
| 601 | } |
| 602 | |
| 603 | LOG_INF("sampler seed: %u\n" , common_sampler_get_seed(smpl[0])); |
| 604 | LOG_INF("sampler params: \n%s\n" , params.sampling.print().c_str()); |
| 605 | LOG_INF("sampler chain: %s\n" , common_sampler_print(smpl[0]).c_str()); |
| 606 | |
| 607 | LOG_INF("%s: loading done\n" , __func__); |
| 608 | |
| 609 | const auto t_main_start = ggml_time_us(); |
| 610 | |
| 611 | std::vector<llama_token> codes; |
| 612 | std::vector<llama_token> guide_tokens; |
| 613 | |
| 614 | // the default speaker profile is from: https://github.com/edwko/OuteTTS/blob/main/outetts/version/v1/default_speakers/en_male_1.json |
| 615 | std::string audio_text = "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>" ; |
| 616 | std::string audio_data = R"(<|audio_start|> |
| 617 | the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> |
| 618 | overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> |
| 619 | package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> |
| 620 | from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|> |
| 621 | just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|> |
| 622 | two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|> |
| 623 | people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|> |
| 624 | is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|> |
| 625 | pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|> |
| 626 | remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|> |
| 627 | sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|> |
| 628 | i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|> |
| 629 | have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|> |
| 630 | some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|> |
| 631 | critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|> |
| 632 | about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|> |
| 633 | some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|> |
| 634 | of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|> |
| 635 | the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|> |
| 636 | gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|> |
| 637 | aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|> |
| 638 | but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|> |
| 639 | its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|> |
| 640 | still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|> |
| 641 | really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|> |
| 642 | enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|> |
| 643 | and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|> |
| 644 | it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|> |
| 645 | looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> |
| 646 | lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)" ; |
| 647 | |
| 648 | // audio data for 0.3 version |
| 649 | outetts_version tts_version = get_tts_version(model: model_ttc); |
| 650 | if (tts_version == OUTETTS_V0_3) { |
| 651 | audio_text = std::regex_replace(s: audio_text, e: std::regex(R"(<\|text_sep\|>)" ), fmt: "<|space|>" ); |
| 652 | audio_data = std::regex_replace(s: audio_data, e: std::regex(R"(<\|code_start\|>)" ), fmt: "" ); |
| 653 | audio_data = std::regex_replace(s: audio_data, e: std::regex(R"(<\|code_end\|>)" ), fmt: "<|space|>" ); |
| 654 | } |
| 655 | |
| 656 | // load speaker if given |
| 657 | if (!params.vocoder.speaker_file.empty()) { |
| 658 | LOG_INF("%s: loading speaker ..\n" , __func__); |
| 659 | json speaker = speaker_from_file(speaker_file: params.vocoder.speaker_file); |
| 660 | if (speaker.empty()) { |
| 661 | LOG_ERR("%s: Failed to load speaker file '%s'\n" , __func__, params.vocoder.speaker_file.c_str()); |
| 662 | return 1; |
| 663 | } |
| 664 | audio_text = audio_text_from_speaker(speaker, tts_version); |
| 665 | audio_data = audio_data_from_speaker(speaker, tts_version); |
| 666 | } |
| 667 | |
| 668 | // process prompt and generate voice codes |
| 669 | { |
| 670 | LOG_INF("%s: constructing prompt ..\n" , __func__); |
| 671 | |
| 672 | std::vector<llama_token> prompt_inp; |
| 673 | |
| 674 | prompt_init(prompt&: prompt_inp, vocab); |
| 675 | |
| 676 | prompt_add(prompt&: prompt_inp, vocab, txt: audio_text, add_special: false, parse_special: true); |
| 677 | |
| 678 | // convert the input text into the necessary format expected by OuteTTS |
| 679 | { |
| 680 | std::string prompt_clean = process_text(text: params.prompt, tts_version); |
| 681 | if (params.vocoder.use_guide_tokens) { |
| 682 | guide_tokens = prepare_guide_tokens(vocab, str: prompt_clean, tts_version); |
| 683 | } |
| 684 | |
| 685 | LOG_INF("%s: prompt: '%s'\n" , __func__, prompt_clean.c_str()); |
| 686 | |
| 687 | prompt_add(prompt&: prompt_inp, vocab, txt: prompt_clean, add_special: false, parse_special: true); |
| 688 | } |
| 689 | |
| 690 | prompt_add(prompt&: prompt_inp, vocab, txt: "<|text_end|>\n" , add_special: false, parse_special: true); |
| 691 | |
| 692 | if (!params.vocoder.speaker_file.empty()) { |
| 693 | prompt_add(prompt&: prompt_inp, vocab, txt: audio_data, add_special: false, parse_special: true); |
| 694 | } else { |
| 695 | // disabled to save time on tokenizing each time |
| 696 | #if 1 |
| 697 | const std::string voice_data = audio_data; |
| 698 | |
| 699 | auto tmp = common_tokenize(vocab, text: voice_data, add_special: false, parse_special: true); |
| 700 | |
| 701 | std::ostringstream tokens_oss; |
| 702 | for (size_t i = 0; i < tmp.size(); ++i) { |
| 703 | tokens_oss << tmp[i] << ", " ; |
| 704 | } |
| 705 | LOG_INF("\n\n%s: llama tokens: %s\n\n" , __func__, tokens_oss.str().c_str()); |
| 706 | |
| 707 | prompt_add(prompt&: prompt_inp, tokens: tmp); |
| 708 | #else |
| 709 | prompt_add(prompt_inp, llama_tokens { |
| 710 | 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, |
| 711 | 152460, 153375, 151670, 198, 74455, 155808, 151669, 151799, |
| 712 | 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, |
| 713 | 151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040, |
| 714 | 153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691, |
| 715 | 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, |
| 716 | 152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198, |
| 717 | 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, |
| 718 | 152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267, |
| 719 | 153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179, |
| 720 | 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, |
| 721 | 152311, 151670, 198, 1499, 155791, 151669, 152276, 152454, |
| 722 | 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, |
| 723 | 153043, 152325, 153267, 152622, 151670, 198, 4250, 155797, |
| 724 | 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271, |
| 725 | 152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, |
| 726 | 152112, 153204, 151722, 152542, 151670, 198, 19789, 155796, |
| 727 | 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, |
| 728 | 152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224, |
| 729 | 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, |
| 730 | 152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733, |
| 731 | 152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589, |
| 732 | 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, |
| 733 | 153376, 152272, 152433, 152325, 151941, 151670, 198, 285, |
| 734 | 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, |
| 735 | 152474, 152680, 152157, 153255, 152324, 151682, 151670, 198, |
| 736 | 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682, |
| 737 | 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, |
| 738 | 153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685, |
| 739 | 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, |
| 740 | 151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459, |
| 741 | 153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609, |
| 742 | 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, |
| 743 | 152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018, |
| 744 | 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, |
| 745 | 153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457, |
| 746 | 152393, 153112, 152595, 151670, 198, 19098, 155808, 151669, |
| 747 | 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, |
| 748 | 153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673, |
| 749 | 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, |
| 750 | 152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795, |
| 751 | 152111, 152746, 152377, 153471, 152309, 151670, 198, 19016, |
| 752 | 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, |
| 753 | 152939, 152536, 152091, 151815, 152733, 151672, 151670, 198, |
| 754 | 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, |
| 755 | 153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670, |
| 756 | 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, |
| 757 | 152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600, |
| 758 | 151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847, |
| 759 | 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, |
| 760 | 152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250, |
| 761 | 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418, |
| 762 | 152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, |
| 763 | 153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106, |
| 764 | 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, |
| 765 | 152901, 152885, 152594, 153446, 153080, 151670, 198, 14689, |
| 766 | 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, |
| 767 | 151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384, |
| 768 | 153364, 153188, 153246, 151670, 198, 1055, 155779, 151669, |
| 769 | 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, |
| 770 | 155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046, |
| 771 | 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, |
| 772 | 153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133, |
| 773 | 152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581, |
| 774 | 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, |
| 775 | 152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409, |
| 776 | 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, |
| 777 | 152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230, |
| 778 | 151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435, |
| 779 | 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, |
| 780 | 151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715, |
| 781 | 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450, |
| 782 | 151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, |
| 783 | 152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285, |
| 784 | 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, |
| 785 | 151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271, |
| 786 | 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, |
| 787 | 152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180, |
| 788 | 151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287, |
| 789 | 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, |
| 790 | 152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384, |
| 791 | 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, |
| 792 | 152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676, |
| 793 | 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350, |
| 794 | 152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, |
| 795 | 153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678, |
| 796 | 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, |
| 797 | 152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957, |
| 798 | 151752, 152265, 153381, 152515, 151670, 198, 437, 155787, |
| 799 | 151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, |
| 800 | 151792, 153409, 153327, 152990, 151670, 198, 275, 155781, |
| 801 | 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, |
| 802 | 151670, 198, 94273, 155799, 151669, 152953, 152938, 153427, |
| 803 | 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, |
| 804 | 152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326, |
| 805 | 152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268, |
| 806 | 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, |
| 807 | 152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444, |
| 808 | 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751, |
| 809 | 152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, |
| 810 | 152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349, |
| 811 | 151670,}); |
| 812 | #endif |
| 813 | } |
| 814 | |
| 815 | // print the prompt token-by-token |
| 816 | |
| 817 | LOG("\n" ); |
| 818 | |
| 819 | for (auto id : prompt_inp) { |
| 820 | LOG("%s" , common_token_to_piece(ctx_ttc, id).c_str()); |
| 821 | } |
| 822 | |
| 823 | LOG_INF("%s: prompt size: %d\n" , __func__, (int) prompt_inp.size()); |
| 824 | |
| 825 | LOG("\n" ); |
| 826 | |
| 827 | // create a llama_batch |
| 828 | // we use this object to submit token data for decoding |
| 829 | llama_batch batch = llama_batch_init(n_tokens: std::max(a: prompt_inp.size(), b: (size_t) n_parallel), embd: 0, n_seq_max: n_parallel); |
| 830 | |
| 831 | std::vector<llama_seq_id> seq_ids(n_parallel, 0); |
| 832 | for (int32_t i = 0; i < n_parallel; ++i) { |
| 833 | seq_ids[i] = i; |
| 834 | } |
| 835 | |
| 836 | // evaluate the initial prompt |
| 837 | for (size_t i = 0; i < prompt_inp.size(); ++i) { |
| 838 | common_batch_add(batch, id: prompt_inp[i], pos: i, seq_ids, logits: false); |
| 839 | } |
| 840 | GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); |
| 841 | |
| 842 | // llama_decode will output logits only for the last token of the prompt |
| 843 | batch.logits[batch.n_tokens - 1] = true; |
| 844 | |
| 845 | if (llama_decode(ctx: ctx_ttc, batch) != 0) { |
| 846 | LOG_ERR("%s: llama_decode() failed\n" , __func__); |
| 847 | return 1; |
| 848 | } |
| 849 | |
| 850 | if (n_parallel > 1) { |
| 851 | LOG_INF("\n\n%s: generating %d sequences ...\n" , __func__, n_parallel); |
| 852 | } |
| 853 | |
| 854 | llama_synchronize(ctx: ctx_ttc); |
| 855 | |
| 856 | LOG_INF("%s: time for prompt: %.3f ms\n\n" , __func__, (ggml_time_us() - t_main_start) / 1000.0f); |
| 857 | |
| 858 | const auto t_dec_start = ggml_time_us(); |
| 859 | |
| 860 | // main loop |
| 861 | |
| 862 | // remember the batch index of the last token for each parallel sequence |
| 863 | // we need this to determine which logits to sample from |
| 864 | std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); |
| 865 | |
| 866 | int n_past = batch.n_tokens; |
| 867 | int n_decode = 0; |
| 868 | |
| 869 | bool next_token_uses_guide_token = true; |
| 870 | |
| 871 | while (n_decode <= n_predict) { |
| 872 | // prepare the next batch |
| 873 | common_batch_clear(batch); |
| 874 | |
| 875 | // sample the next token for each parallel sequence / stream |
| 876 | for (int32_t i = 0; i < n_parallel; ++i) { |
| 877 | if (i_batch[i] < 0) { |
| 878 | // the stream has already finished |
| 879 | continue; |
| 880 | } |
| 881 | |
| 882 | llama_token new_token_id = common_sampler_sample(gsmpl: smpl[i], ctx: ctx_ttc, idx: i_batch[i]); |
| 883 | |
| 884 | //guide tokens help prevent hallucinations by forcing the TTS to use the correct word |
| 885 | if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, token: new_token_id) && !llama_vocab_is_eog(vocab, token: new_token_id)) { |
| 886 | llama_token guide_token = guide_tokens[0]; |
| 887 | guide_tokens.erase(position: guide_tokens.begin()); |
| 888 | new_token_id = guide_token; //ensure correct word fragment is used |
| 889 | } |
| 890 | |
| 891 | //this is the token id that always precedes a new word |
| 892 | next_token_uses_guide_token = (new_token_id == 198); |
| 893 | |
| 894 | common_sampler_accept(gsmpl: smpl[i], token: new_token_id, accept_grammar: true); |
| 895 | |
| 896 | codes.push_back(x: new_token_id); |
| 897 | |
| 898 | const auto * cands = common_sampler_get_candidates(gsmpl: smpl[i], do_sort: false); |
| 899 | |
| 900 | // is it an end of generation? -> mark the stream as finished |
| 901 | if (llama_vocab_is_eog(vocab, token: new_token_id) || n_decode == n_predict) { |
| 902 | std::string reason; |
| 903 | if (llama_vocab_is_eog(vocab, token: new_token_id)) { |
| 904 | reason = "eos" ; |
| 905 | } else { |
| 906 | reason = "n_predict" ; |
| 907 | } |
| 908 | |
| 909 | i_batch[i] = -1; |
| 910 | |
| 911 | LOG("\n" ); |
| 912 | if (n_parallel > 1) { |
| 913 | LOG_CNT("\n" ); |
| 914 | LOG_INF("%s: stream %d finished at n_past = %d, reason = '%s'\n" , __func__, i, n_past, reason.c_str()); |
| 915 | } |
| 916 | |
| 917 | continue; |
| 918 | } |
| 919 | |
| 920 | { |
| 921 | const float p = cands->data[cands->selected].p; |
| 922 | |
| 923 | const int col = std::max(a: 0, b: std::min(a: (int) k_colors.size() - 1, b: (int) ((3*p)*float(k_colors.size())))); |
| 924 | |
| 925 | LOG_CNT("%s%d%s" , k_colors[col].c_str(), i, "\033[0m" ); |
| 926 | //LOG_CNT("%d", i); |
| 927 | } |
| 928 | |
| 929 | i_batch[i] = batch.n_tokens; |
| 930 | |
| 931 | // push this new token for next evaluation |
| 932 | common_batch_add(batch, id: new_token_id, pos: n_past, seq_ids: { i }, logits: true); |
| 933 | } |
| 934 | |
| 935 | // all streams are finished |
| 936 | if (batch.n_tokens == 0) { |
| 937 | break; |
| 938 | } |
| 939 | |
| 940 | n_decode += 1; |
| 941 | n_past += 1; |
| 942 | |
| 943 | // evaluate the current batch with the transformer model |
| 944 | if (llama_decode(ctx: ctx_ttc, batch)) { |
| 945 | LOG_ERR("%s : failed to eval, return code %d\n" , __func__, 1); |
| 946 | return 1; |
| 947 | } |
| 948 | } |
| 949 | |
| 950 | llama_batch_free(batch); |
| 951 | |
| 952 | LOG("\n" ); |
| 953 | LOG_INF("%s: time for decoder: %.3f ms\n" , __func__, (ggml_time_us() - t_dec_start) / 1000.0f); |
| 954 | } |
| 955 | |
| 956 | common_perf_print(ctx: ctx_ttc, gsmpl: smpl[0]); |
| 957 | |
| 958 | //std::vector<llama_token> codes = {198, 88225, 155856, 151669, 152205, |
| 959 | // 153064, 152537, 153421, 153209, 152524, 151689, 152993, 152438, 152695, |
| 960 | // 153091, 152945, 152829, 152534, 152934, 153020, 151997, 152263, 153010, |
| 961 | // 153146, 152399, 153208, 152496, 151793, 152848, 152263, 152571, 153286, |
| 962 | // 152227, 153300, 152934, 152263, 153208, 152263, 152965, 152430, 152296, |
| 963 | // 153146, 152920, 152376, 152556, 153363, 151775, 152044, 152972, 152690, |
| 964 | // 153379, 152368, 152233, 153422, 152490, 151996, 152022, 151694, 152061, |
| 965 | // 153238, 152539, 153356, 152640, 153021, 153123, 151962, 153094, 151670, |
| 966 | // 198, 20339, 13189, 155824, 151669, 152070, 152007, 152910, 151683, |
| 967 | // 152000, 152373, 152760, 152046, 151735, 152334, 152394, 153073, 152908, |
| 968 | // 151856, 151953, 153247, 153293, 151903, 153480, 153168, 152478, 153359, |
| 969 | // 153429, 151905, 151678, 152567, 152411, 152165, 152556, 153075, 153424, |
| 970 | // 151993, 152999, 153078, 152151, 152088, 153389, 152484, 151874, 151670, |
| 971 | // 198, 285, 155784, 151669, 152226, 152126, 152638, 153215, 151729, |
| 972 | // 152959, 153479, 153059, 151838, 151670, 198, 1782, 155783, 151669, |
| 973 | // 153288, 153055, 153314, 152497, 152962, 152741, 152076, 153253, 151670, |
| 974 | // 198, 471, 16488, 155825, 151669, 152060, 152916, 151893, 153469, 152501, |
| 975 | // 152080, 152743, 151932, 153161, 152096, 152761, 152698, 153401, 153242, |
| 976 | // 153336, 152441, 152838, 153467, 152706, 153496, 153310, 152422, 153360, |
| 977 | // 153115, 152763, 151998, 152373, 153450, 152554, 151968, 153323, 152055, |
| 978 | // 152468, 153111, 153358, 152813, 152010, 151770, 152823, 152960, 151670, |
| 979 | // 198, 22627, 155823, 151669, 152814, 152366, 153484, 152931, 153441, |
| 980 | // 152164, 152877, 152915, 153463, 151692, 152911, 152747, 152776, 151831, |
| 981 | // 153449, 151882, 152975, 152031, 152513, 153150, 152448, 152667, 153133, |
| 982 | // 153189, 152619, 153466, 152054, 152106, 153119, 152277, 152439, 153109, |
| 983 | // 152997, 152141, 153154, 153256, 153311, 151922, 151670, 198, 1055, |
| 984 | // 155781, 151669, 152633, 151850, 153060, 153270, 152560, 153348, 152729, |
| 985 | // 151670, 198, 25312, 155803, 151669, 152521, 153403, 152561, 153337, |
| 986 | // 153383, 152199, 153493, 153326, 151830, 152254, 152248, 152349, 152153, |
| 987 | // 153007, 151823, 153037, 152575, 152457, 152406, 152592, 153116, 153365, |
| 988 | // 153456, 151670, 198, 88225, 155817, 151669, 153271, 151925, 152218, |
| 989 | // 152418, 152253, 153140, 151903, 153151, 152626, 152338, 152647, 153464, |
| 990 | // 152785, 152768, 151711, 152037, 152033, 151804, 152216, 151701, 151855, |
| 991 | // 152348, 152995, 152955, 152905, 152342, 152340, 153391, 153453, 152418, |
| 992 | // 153415, 151990, 153083, 152884, 151670, 198, 151668, 198, 151645}; |
| 993 | |
| 994 | { |
| 995 | const std::string inp_txt = common_detokenize(ctx: ctx_ttc, tokens: codes, special: true); |
| 996 | |
| 997 | LOG("\n" ); |
| 998 | LOG_INF("codes: '%s'\n" , inp_txt.c_str()); |
| 999 | LOG_INF("%s: codes size: %d\n" , __func__, (int) codes.size()); |
| 1000 | } |
| 1001 | |
| 1002 | // remove all non-audio tokens (i.e. < 151672 || > 155772) |
| 1003 | codes.erase(first: std::remove_if(first: codes.begin(), last: codes.end(), pred: [](llama_token t) { return t < 151672 || t > 155772; }), last: codes.end()); |
| 1004 | |
| 1005 | { |
| 1006 | const std::string inp_txt = common_detokenize(ctx: ctx_ttc, tokens: codes, special: true); |
| 1007 | LOG_INF("codes audio: '%s'\n" , inp_txt.c_str()); |
| 1008 | LOG_INF("%s: codes audio size: %d\n" , __func__, (int) codes.size()); |
| 1009 | } |
| 1010 | |
| 1011 | for (auto & token : codes) { |
| 1012 | token -= 151672; |
| 1013 | } |
| 1014 | |
| 1015 | const auto t_voc_start = ggml_time_us(); |
| 1016 | |
| 1017 | const int n_codes = codes.size(); |
| 1018 | |
| 1019 | llama_batch batch = llama_batch_init(n_tokens: n_codes, embd: 0, n_seq_max: 1); |
| 1020 | |
| 1021 | for (size_t i = 0; i < codes.size(); ++i) { |
| 1022 | common_batch_add(batch, id: codes[i], pos: i, seq_ids: { 0 }, logits: true); // TODO: all logits? |
| 1023 | } |
| 1024 | GGML_ASSERT(batch.n_tokens == n_codes); |
| 1025 | |
| 1026 | if (llama_encode(ctx: ctx_cts, batch) != 0) { |
| 1027 | LOG_ERR("%s: llama_encode() failed\n" , __func__); |
| 1028 | return 1; |
| 1029 | } |
| 1030 | |
| 1031 | llama_synchronize(ctx: ctx_cts); |
| 1032 | |
| 1033 | LOG_INF("%s: time for vocoder: %.3f ms\n" , __func__, (ggml_time_us() - t_voc_start) / 1000.0f); |
| 1034 | |
| 1035 | const auto t_spec_start = ggml_time_us(); |
| 1036 | |
| 1037 | #if 1 |
| 1038 | // spectral operations |
| 1039 | const int n_embd = llama_model_n_embd(model: model_cts); |
| 1040 | const float * embd = llama_get_embeddings(ctx: ctx_cts); |
| 1041 | |
| 1042 | auto audio = embd_to_audio(embd, n_codes, n_embd, n_thread: params.cpuparams.n_threads); |
| 1043 | |
| 1044 | #else |
| 1045 | // read the spectrogram from a file for debugging purposes |
| 1046 | std::vector<float> audio; |
| 1047 | { |
| 1048 | std::ifstream fin("out.bin" , std::ios::binary); |
| 1049 | if (!fin) { |
| 1050 | LOG_ERR("%s: failed to open file '%s'\n" , __func__, "out.bin" ); |
| 1051 | return 1; |
| 1052 | } |
| 1053 | |
| 1054 | std::vector<float> embd; |
| 1055 | |
| 1056 | int n_codes; |
| 1057 | int n_embd; |
| 1058 | |
| 1059 | fin.read(reinterpret_cast<char *>(&n_codes), sizeof(int)); |
| 1060 | fin.read(reinterpret_cast<char *>(&n_embd), sizeof(int)); |
| 1061 | |
| 1062 | embd.resize(n_codes * n_embd); |
| 1063 | fin.read(reinterpret_cast<char *>(embd.data()), n_codes * n_embd * sizeof(float)); |
| 1064 | fin.close(); |
| 1065 | |
| 1066 | LOG_INF("%s: n_codes: %d, n_embd: %d\n" , __func__, n_codes, n_embd); |
| 1067 | |
| 1068 | audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads); |
| 1069 | } |
| 1070 | #endif |
| 1071 | |
| 1072 | const int n_sr = 24000; // sampling rate |
| 1073 | |
| 1074 | // zero out first 0.25 seconds |
| 1075 | for (int i = 0; i < 24000/4; ++i) { |
| 1076 | audio[i] = 0.0f; |
| 1077 | } |
| 1078 | |
| 1079 | LOG_INF("%s: time for spectral ops: %.3f ms\n" , __func__, (ggml_time_us() - t_spec_start) / 1000.0f); |
| 1080 | LOG_INF("%s: total time: %.3f ms\n" , __func__, (ggml_time_us() - t_main_start) / 1000.0f); |
| 1081 | |
| 1082 | int retval = 0; |
| 1083 | |
| 1084 | if (save_wav16(fname: params.out_file, data: audio, sample_rate: n_sr)) { |
| 1085 | LOG_INF("%s: audio written to file '%s'\n" , __func__, params.out_file.c_str()); |
| 1086 | } else { |
| 1087 | retval = ENOENT; |
| 1088 | } |
| 1089 | |
| 1090 | llama_backend_free(); |
| 1091 | |
| 1092 | return retval; |
| 1093 | } |
| 1094 | |